diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 0000000..e1f508b --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,5 @@ +[target.x86_64-unknown-linux-musl] +rustflags = ["-C", "link-arg=-static"] + +[target.aarch64-unknown-linux-musl] +rustflags = ["-C", "link-arg=-static"] diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..7a2c253 --- /dev/null +++ b/.env.example @@ -0,0 +1,70 @@ +# ZeroClaw Environment Variables +# Copy this file to `.env` and fill in your local values. +# Never commit `.env` or any real secrets. + +# ── Core Runtime ────────────────────────────────────────────── +# Provider key resolution at runtime: +# 1) explicit key passed from config/CLI +# 2) provider-specific env var (OPENROUTER_API_KEY, OPENAI_API_KEY, ...) +# 3) generic fallback env vars below + +# Generic fallback API key (used when provider-specific key is absent) +API_KEY=your-api-key-here +# ZEROCLAW_API_KEY=your-api-key-here + +# Default provider/model (can be overridden by CLI flags) +PROVIDER=openrouter +# ZEROCLAW_PROVIDER=openrouter +# ZEROCLAW_MODEL=anthropic/claude-sonnet-4-20250514 +# ZEROCLAW_TEMPERATURE=0.7 + +# Workspace directory override +# ZEROCLAW_WORKSPACE=/path/to/workspace + +# ── Provider-Specific API Keys ──────────────────────────────── +# OpenRouter +# OPENROUTER_API_KEY=sk-or-v1-... + +# Anthropic +# ANTHROPIC_OAUTH_TOKEN=... +# ANTHROPIC_API_KEY=sk-ant-... + +# OpenAI / Gemini +# OPENAI_API_KEY=sk-... +# GEMINI_API_KEY=... +# GOOGLE_API_KEY=... + +# Other supported providers +# VENICE_API_KEY=... +# GROQ_API_KEY=... +# MISTRAL_API_KEY=... +# DEEPSEEK_API_KEY=... +# XAI_API_KEY=... +# TOGETHER_API_KEY=... +# FIREWORKS_API_KEY=... +# PERPLEXITY_API_KEY=... +# COHERE_API_KEY=... +# MOONSHOT_API_KEY=... +# GLM_API_KEY=... +# MINIMAX_API_KEY=... +# QIANFAN_API_KEY=... +# DASHSCOPE_API_KEY=... +# ZAI_API_KEY=... +# SYNTHETIC_API_KEY=... +# OPENCODE_API_KEY=... +# VERCEL_API_KEY=... +# CLOUDFLARE_API_KEY=... + +# ── Gateway ────────────────────────────────────────────────── +# ZEROCLAW_GATEWAY_PORT=3000 +# ZEROCLAW_GATEWAY_HOST=127.0.0.1 +# ZEROCLAW_ALLOW_PUBLIC_BIND=false + +# ── Optional Integrations ──────────────────────────────────── +# Pushover notifications (`pushover` tool) +# PUSHOVER_TOKEN=your-pushover-app-token +# PUSHOVER_USER_KEY=your-pushover-user-key + +# ── Docker Compose ─────────────────────────────────────────── +# Host port mapping (used by docker-compose.yml) +# HOST_PORT=3000 diff --git a/.githooks/pre-commit b/.githooks/pre-commit new file mode 100755 index 0000000..d162ba3 --- /dev/null +++ b/.githooks/pre-commit @@ -0,0 +1,8 @@ +#!/usr/bin/env bash +set -euo pipefail + +if command -v gitleaks >/dev/null 2>&1; then + gitleaks protect --staged --redact +else + echo "warning: gitleaks not found; skipping staged secret scan" >&2 +fi diff --git a/.githooks/pre-push b/.githooks/pre-push index 4d8eea7..f69e1cb 100755 --- a/.githooks/pre-push +++ b/.githooks/pre-push @@ -6,21 +6,46 @@ set -euo pipefail -echo "==> pre-push: checking formatting..." -cargo fmt -- --check || { - echo "FAIL: cargo fmt -- --check found unformatted code." - echo "Run 'cargo fmt' and try again." +echo "==> pre-push: running rust quality gate..." +./scripts/ci/rust_quality_gate.sh || { + echo "FAIL: rust quality gate failed." exit 1 } -echo "==> pre-push: running clippy..." -cargo clippy -- -D warnings || { - echo "FAIL: clippy reported warnings." - exit 1 -} +if [ "${ZEROCLAW_STRICT_LINT:-0}" = "1" ]; then + echo "==> pre-push: running strict clippy warnings gate (ZEROCLAW_STRICT_LINT=1)..." + ./scripts/ci/rust_quality_gate.sh --strict || { + echo "FAIL: strict clippy warnings gate reported issues." + exit 1 + } +fi + +if [ "${ZEROCLAW_STRICT_DELTA_LINT:-0}" = "1" ]; then + echo "==> pre-push: running strict delta lint gate (ZEROCLAW_STRICT_DELTA_LINT=1)..." + ./scripts/ci/rust_strict_delta_gate.sh || { + echo "FAIL: strict delta lint gate reported issues." + exit 1 + } +fi + +if [ "${ZEROCLAW_DOCS_LINT:-0}" = "1" ]; then + echo "==> pre-push: running docs quality gate (ZEROCLAW_DOCS_LINT=1)..." + ./scripts/ci/docs_quality_gate.sh || { + echo "FAIL: docs quality gate reported issues." + exit 1 + } +fi + +if [ "${ZEROCLAW_DOCS_LINKS:-0}" = "1" ]; then + echo "==> pre-push: running docs links gate (ZEROCLAW_DOCS_LINKS=1)..." + ./scripts/ci/docs_links_gate.sh || { + echo "FAIL: docs links gate reported issues." + exit 1 + } +fi echo "==> pre-push: running tests..." -cargo test || { +cargo test --locked || { echo "FAIL: some tests did not pass." exit 1 } diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000..776fb65 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,28 @@ +# Default owner for all files +* @theonlyhennygod + +# High-risk surfaces +/src/security/** @willsarg +/src/runtime/** @theonlyhennygod +/src/memory/** @theonlyhennygod @chumyin +/.github/** @theonlyhennygod +/Cargo.toml @theonlyhennygod +/Cargo.lock @theonlyhennygod + +# CI +/.github/workflows/** @theonlyhennygod @willsarg +/.github/codeql/** @willsarg +/.github/dependabot.yml @willsarg + +# Docs & governance +/docs/** @chumyin +/AGENTS.md @chumyin +/CLAUDE.md @chumyin +/CONTRIBUTING.md @chumyin +/docs/pr-workflow.md @chumyin +/docs/reviewer-playbook.md @chumyin + +# Security / CI-CD governance overrides (last-match wins) +/SECURITY.md @willsarg +/docs/actions-source-policy.md @willsarg +/docs/ci-map.md @willsarg diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml new file mode 100644 index 0000000..8ac7419 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -0,0 +1,148 @@ +name: Bug Report +description: Report a reproducible defect in ZeroClaw +title: "[Bug]: " +labels: + - bug +body: + - type: markdown + attributes: + value: | + Thanks for taking the time to report a bug. + Please provide a minimal reproducible case so maintainers can triage quickly. + Do not include personal/sensitive data; redact and anonymize all logs/payloads. + + - type: input + id: summary + attributes: + label: Summary + description: One-line description of the problem. + placeholder: zeroclaw daemon exits immediately when ... + validations: + required: true + + - type: dropdown + id: component + attributes: + label: Affected component + options: + - runtime/daemon + - provider + - channel + - memory + - security/sandbox + - tooling/ci + - docs + - unknown + validations: + required: true + + - type: dropdown + id: severity + attributes: + label: Severity + options: + - S0 - data loss / security risk + - S1 - workflow blocked + - S2 - degraded behavior + - S3 - minor issue + validations: + required: true + + - type: textarea + id: current + attributes: + label: Current behavior + description: What is happening now? + placeholder: The process exits with ... + validations: + required: true + + - type: textarea + id: expected + attributes: + label: Expected behavior + description: What should happen instead? + placeholder: The daemon should stay alive and ... + validations: + required: true + + - type: textarea + id: reproduce + attributes: + label: Steps to reproduce + description: Please provide exact commands/config. + placeholder: | + 1. zeroclaw onboard --interactive + 2. zeroclaw daemon + 3. Observe crash in logs + render: bash + validations: + required: true + + - type: textarea + id: impact + attributes: + label: Impact + description: Who is affected, how often, and practical consequences. + placeholder: | + Affected users: ... + Frequency: always/intermittent + Consequence: ... + validations: + required: true + + - type: textarea + id: logs + attributes: + label: Logs / stack traces + description: Paste relevant logs (redact secrets, personal identifiers, and sensitive data). + render: text + validations: + required: false + + - type: input + id: version + attributes: + label: ZeroClaw version + placeholder: v0.1.0 / commit SHA + validations: + required: true + + - type: input + id: rust + attributes: + label: Rust version + placeholder: rustc 1.xx.x + validations: + required: true + + - type: input + id: os + attributes: + label: Operating system + placeholder: Ubuntu 24.04 / macOS 15 / Windows 11 + validations: + required: true + + - type: dropdown + id: regression + attributes: + label: Regression? + options: + - Unknown + - Yes, it worked before + - No, first-time setup + validations: + required: true + + - type: checkboxes + id: checks + attributes: + label: Pre-flight checks + options: + - label: I reproduced this on the latest main branch or latest release. + required: true + - label: I redacted secrets/tokens from logs. + required: true + - label: I removed personal identifiers and replaced identity-specific data with neutral placeholders. + required: true diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000..75945ca --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,11 @@ +blank_issues_enabled: false +contact_links: + - name: Security vulnerability report + url: https://github.com/zeroclaw-labs/zeroclaw/security/policy + about: Please report security vulnerabilities privately via SECURITY.md policy. + - name: Contribution guide + url: https://github.com/zeroclaw-labs/zeroclaw/blob/main/CONTRIBUTING.md + about: Please read contribution and PR requirements before opening an issue. + - name: PR workflow & reviewer expectations + url: https://github.com/zeroclaw-labs/zeroclaw/blob/main/docs/pr-workflow.md + about: Read risk-based PR tracks, CI gates, and merge criteria before filing feature requests. diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml new file mode 100644 index 0000000..44553aa --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -0,0 +1,107 @@ +name: Feature Request +description: Propose an improvement or new capability +title: "[Feature]: " +labels: + - enhancement +body: + - type: markdown + attributes: + value: | + Thanks for sharing your idea. + Please focus on user value, constraints, and rollout safety. + Do not include personal/sensitive data; use neutral project-scoped placeholders. + + - type: input + id: summary + attributes: + label: Summary + description: One-line statement of the requested capability. + placeholder: Add a provider-level retry budget override for long-running channels. + validations: + required: true + + - type: textarea + id: problem + attributes: + label: Problem statement + description: What user pain does this solve and why is current behavior insufficient? + placeholder: Teams operating in unstable networks cannot tune retries per provider... + validations: + required: true + + - type: textarea + id: proposal + attributes: + label: Proposed solution + description: Describe preferred behavior and interfaces. + placeholder: Add `[provider.retry]` config and enforce bounds in config validation. + validations: + required: true + + - type: textarea + id: non_goals + attributes: + label: Non-goals / out of scope + description: Clarify what should not be included in the first iteration. + placeholder: No UI changes, no cross-provider dynamic adaptation in v1. + validations: + required: true + + - type: textarea + id: alternatives + attributes: + label: Alternatives considered + description: What alternatives did you evaluate? + placeholder: Keep current behavior, use wrapper scripts, etc. + validations: + required: false + + - type: textarea + id: acceptance + attributes: + label: Acceptance criteria + description: What outcomes would make this request complete? + placeholder: | + - Config key is documented and validated + - Runtime path uses configured retry budget + - Regression tests cover fallback and invalid config + validations: + required: true + + - type: textarea + id: architecture + attributes: + label: Architecture impact + description: Which subsystem(s) are affected? + placeholder: providers/, channels/, memory/, runtime/, security/, docs/ ... + validations: + required: true + + - type: textarea + id: risk + attributes: + label: Risk and rollback + description: Main risk + how to disable/revert quickly. + placeholder: Risk is ... rollback is ... + validations: + required: true + + - type: dropdown + id: breaking + attributes: + label: Breaking change? + options: + - No + - Yes + validations: + required: true + + - type: checkboxes + id: hygiene + attributes: + label: Data hygiene checks + options: + - label: I removed personal/sensitive data from examples, payloads, and logs. + required: true + - label: I used neutral, project-focused wording and placeholders. + required: true diff --git a/.github/actionlint.yaml b/.github/actionlint.yaml new file mode 100644 index 0000000..1c422ab --- /dev/null +++ b/.github/actionlint.yaml @@ -0,0 +1,4 @@ +self-hosted-runner: + labels: + - lxc-ci + - blacksmith-2vcpu-ubuntu-2404 diff --git a/.github/codeql/codeql-config.yml b/.github/codeql/codeql-config.yml new file mode 100644 index 0000000..5c82c1b --- /dev/null +++ b/.github/codeql/codeql-config.yml @@ -0,0 +1,8 @@ +# CodeQL configuration for ZeroClaw +# +# We intentionally ignore integration tests under `tests/` because they often +# contain security-focused fixtures (example secrets, malformed payloads, etc.) +# that can trigger false positives in security queries. + +paths-ignore: + - tests/** diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..1696124 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,35 @@ +version: 2 + +updates: + - package-ecosystem: cargo + directory: "/" + schedule: + interval: weekly + target-branch: main + open-pull-requests-limit: 5 + labels: + - "dependencies" + groups: + rust-minor-patch: + patterns: + - "*" + update-types: + - minor + - patch + + - package-ecosystem: github-actions + directory: "/" + schedule: + interval: weekly + target-branch: main + open-pull-requests-limit: 3 + labels: + - "ci" + - "dependencies" + groups: + actions-minor-patch: + patterns: + - "*" + update-types: + - minor + - patch diff --git a/.github/label-policy.json b/.github/label-policy.json new file mode 100644 index 0000000..e8b254f --- /dev/null +++ b/.github/label-policy.json @@ -0,0 +1,21 @@ +{ + "contributor_tier_color": "2ED9FF", + "contributor_tiers": [ + { + "label": "distinguished contributor", + "min_merged_prs": 50 + }, + { + "label": "principal contributor", + "min_merged_prs": 20 + }, + { + "label": "experienced contributor", + "min_merged_prs": 10 + }, + { + "label": "trusted contributor", + "min_merged_prs": 5 + } + ] +} diff --git a/.github/labeler.yml b/.github/labeler.yml new file mode 100644 index 0000000..21e851f --- /dev/null +++ b/.github/labeler.yml @@ -0,0 +1,147 @@ +"docs": + - changed-files: + - any-glob-to-any-file: + - "docs/**" + - "**/*.md" + - "**/*.mdx" + - "LICENSE" + - ".markdownlint-cli2.yaml" + +"dependencies": + - changed-files: + - any-glob-to-any-file: + - "Cargo.toml" + - "Cargo.lock" + - "deny.toml" + - ".github/dependabot.yml" + +"ci": + - changed-files: + - any-glob-to-any-file: + - ".github/**" + - ".githooks/**" + +"core": + - changed-files: + - any-glob-to-any-file: + - "src/*.rs" + +"agent": + - changed-files: + - any-glob-to-any-file: + - "src/agent/**" + +"channel": + - changed-files: + - any-glob-to-any-file: + - "src/channels/**" + +"gateway": + - changed-files: + - any-glob-to-any-file: + - "src/gateway/**" + +"config": + - changed-files: + - any-glob-to-any-file: + - "src/config/**" + +"cron": + - changed-files: + - any-glob-to-any-file: + - "src/cron/**" + +"daemon": + - changed-files: + - any-glob-to-any-file: + - "src/daemon/**" + +"doctor": + - changed-files: + - any-glob-to-any-file: + - "src/doctor/**" + +"health": + - changed-files: + - any-glob-to-any-file: + - "src/health/**" + +"heartbeat": + - changed-files: + - any-glob-to-any-file: + - "src/heartbeat/**" + +"integration": + - changed-files: + - any-glob-to-any-file: + - "src/integrations/**" + +"memory": + - changed-files: + - any-glob-to-any-file: + - "src/memory/**" + +"security": + - changed-files: + - any-glob-to-any-file: + - "src/security/**" + +"runtime": + - changed-files: + - any-glob-to-any-file: + - "src/runtime/**" + +"onboard": + - changed-files: + - any-glob-to-any-file: + - "src/onboard/**" + +"provider": + - changed-files: + - any-glob-to-any-file: + - "src/providers/**" + +"service": + - changed-files: + - any-glob-to-any-file: + - "src/service/**" + +"skillforge": + - changed-files: + - any-glob-to-any-file: + - "src/skillforge/**" + +"skills": + - changed-files: + - any-glob-to-any-file: + - "src/skills/**" + +"tool": + - changed-files: + - any-glob-to-any-file: + - "src/tools/**" + +"tunnel": + - changed-files: + - any-glob-to-any-file: + - "src/tunnel/**" + +"observability": + - changed-files: + - any-glob-to-any-file: + - "src/observability/**" + +"tests": + - changed-files: + - any-glob-to-any-file: + - "tests/**" + +"scripts": + - changed-files: + - any-glob-to-any-file: + - "scripts/**" + +"dev": + - changed-files: + - any-glob-to-any-file: + - "dev/**" diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 0000000..7c9e601 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,109 @@ +## Summary + +Describe this PR in 2-5 bullets: + +- Problem: +- Why it matters: +- What changed: +- What did **not** change (scope boundary): + +## Label Snapshot (required) + +- Risk label (`risk: low|medium|high`): +- Size label (`size: XS|S|M|L|XL`, auto-managed/read-only): +- Scope labels (`core|agent|channel|config|cron|daemon|doctor|gateway|health|heartbeat|integration|memory|observability|onboard|provider|runtime|security|service|skillforge|skills|tool|tunnel|docs|dependencies|ci|tests|scripts|dev`, comma-separated): +<<<<<<< chore/labeler-spacing-trusted-tier +- Module labels (`: `, for example `channel: telegram`, `provider: kimi`, `tool: shell`): +======= +- Module labels (`:`, for example `channel:telegram`, `provider:kimi`, `tool:shell`): +>>>>>>> main +- Contributor tier label (`trusted contributor|experienced contributor|principal contributor|distinguished contributor`, auto-managed/read-only; author merged PRs >=5/10/20/50): +- If any auto-label is incorrect, note requested correction: + +## Change Metadata + +- Change type (`bug|feature|refactor|docs|security|chore`): +- Primary scope (`runtime|provider|channel|memory|security|ci|docs|multi`): + +## Linked Issue + +- Closes # +- Related # +- Depends on # (if stacked) +- Supersedes # (if replacing older PR) + +## Supersede Attribution (required when `Supersedes #` is used) + +- Superseded PRs + authors (`# by @`, one per line): +- Integrated scope by source PR (what was materially carried forward): +- `Co-authored-by` trailers added for materially incorporated contributors? (`Yes/No`) +- If `No`, explain why (for example: inspiration-only, no direct code/design carry-over): +- Trailer format check (separate lines, no escaped `\n`): (`Pass/Fail`) + +## Validation Evidence (required) + +Commands and result summary: + +```bash +cargo fmt --all -- --check +cargo clippy --all-targets -- -D warnings +cargo test +``` + +- Evidence provided (test/log/trace/screenshot/perf): +- If any command is intentionally skipped, explain why: + +## Security Impact (required) + +- New permissions/capabilities? (`Yes/No`) +- New external network calls? (`Yes/No`) +- Secrets/tokens handling changed? (`Yes/No`) +- File system access scope changed? (`Yes/No`) +- If any `Yes`, describe risk and mitigation: + +## Privacy and Data Hygiene (required) + +- Data-hygiene status (`pass|needs-follow-up`): +- Redaction/anonymization notes: +- Neutral wording confirmation (use ZeroClaw/project-native labels if identity-like wording is needed): + +## Compatibility / Migration + +- Backward compatible? (`Yes/No`) +- Config/env changes? (`Yes/No`) +- Migration needed? (`Yes/No`) +- If yes, exact upgrade steps: + +## Human Verification (required) + +What was personally validated beyond CI: + +- Verified scenarios: +- Edge cases checked: +- What was not verified: + +## Side Effects / Blast Radius (required) + +- Affected subsystems/workflows: +- Potential unintended effects: +- Guardrails/monitoring for early detection: + +## Agent Collaboration Notes (recommended) + +- Agent tools used (if any): +- Workflow/plan summary (if any): +- Verification focus: +- Confirmation: naming + architecture boundaries followed (`AGENTS.md` + `CONTRIBUTING.md`): + +## Rollback Plan (required) + +- Fast rollback command/path: +- Feature flags or config toggles (if any): +- Observable failure symptoms: + +## Risks and Mitigations + +List real risks in this PR (or write `None`). + +- Risk: + - Mitigation: diff --git a/.github/workflows/auto-response.yml b/.github/workflows/auto-response.yml new file mode 100644 index 0000000..3065182 --- /dev/null +++ b/.github/workflows/auto-response.yml @@ -0,0 +1,285 @@ +name: PR Auto Responder + +on: + issues: + types: [opened, reopened, labeled, unlabeled] + pull_request_target: + types: [opened, labeled, unlabeled] + +permissions: {} + +jobs: + contributor-tier-issues: + if: >- + (github.event_name == 'issues' && + (github.event.action == 'opened' || github.event.action == 'reopened' || github.event.action == 'labeled' || github.event.action == 'unlabeled')) || + (github.event_name == 'pull_request_target' && + (github.event.action == 'labeled' || github.event.action == 'unlabeled')) + runs-on: blacksmith-2vcpu-ubuntu-2404 + permissions: + issues: write + pull-requests: write + steps: + - name: Apply contributor tier label for issue author + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8 + with: + script: | + const owner = context.repo.owner; + const repo = context.repo.repo; + const issue = context.payload.issue; + const pullRequest = context.payload.pull_request; + const target = issue ?? pullRequest; + async function loadContributorTierPolicy() { + const fallback = { + contributorTierColor: "2ED9FF", + contributorTierRules: [ + { label: "distinguished contributor", minMergedPRs: 50 }, + { label: "principal contributor", minMergedPRs: 20 }, + { label: "experienced contributor", minMergedPRs: 10 }, + { label: "trusted contributor", minMergedPRs: 5 }, + ], + }; + try { + const { data } = await github.rest.repos.getContent({ + owner, + repo, + path: ".github/label-policy.json", + ref: context.payload.repository?.default_branch || "main", + }); + const json = JSON.parse(Buffer.from(data.content, "base64").toString("utf8")); + const contributorTierRules = (json.contributor_tiers || []).map((entry) => ({ + label: String(entry.label || "").trim(), + minMergedPRs: Number(entry.min_merged_prs || 0), + })); + const contributorTierColor = String(json.contributor_tier_color || "").toUpperCase(); + if (!contributorTierColor || contributorTierRules.length === 0) { + return fallback; + } + return { contributorTierColor, contributorTierRules }; + } catch (error) { + core.warning(`failed to load .github/label-policy.json, using fallback policy: ${error.message}`); + return fallback; + } + } + + const { contributorTierColor, contributorTierRules } = await loadContributorTierPolicy(); + const contributorTierLabels = contributorTierRules.map((rule) => rule.label); + const managedContributorLabels = new Set(contributorTierLabels); + const action = context.payload.action; + const changedLabel = context.payload.label?.name; + + if (!target) return; + if ((action === "labeled" || action === "unlabeled") && !managedContributorLabels.has(changedLabel)) { + return; + } + + const author = target.user; + if (!author || author.type === "Bot") return; + + function contributorTierDescription(rule) { + return `Contributor with ${rule.minMergedPRs}+ merged PRs.`; + } + + async function ensureContributorTierLabels() { + for (const rule of contributorTierRules) { + const label = rule.label; + const expectedDescription = contributorTierDescription(rule); + try { + const { data: existing } = await github.rest.issues.getLabel({ owner, repo, name: label }); + const currentColor = (existing.color || "").toUpperCase(); + const currentDescription = (existing.description || "").trim(); + if (currentColor !== contributorTierColor || currentDescription !== expectedDescription) { + await github.rest.issues.updateLabel({ + owner, + repo, + name: label, + new_name: label, + color: contributorTierColor, + description: expectedDescription, + }); + } + } catch (error) { + if (error.status !== 404) throw error; + await github.rest.issues.createLabel({ + owner, + repo, + name: label, + color: contributorTierColor, + description: expectedDescription, + }); + } + } + } + + function selectContributorTier(mergedCount) { + const matchedTier = contributorTierRules.find((rule) => mergedCount >= rule.minMergedPRs); + return matchedTier ? matchedTier.label : null; + } + + let contributorTierLabel = null; + try { + const { data: mergedSearch } = await github.rest.search.issuesAndPullRequests({ + q: `repo:${owner}/${repo} is:pr is:merged author:${author.login}`, + per_page: 1, + }); + const mergedCount = mergedSearch.total_count || 0; + contributorTierLabel = selectContributorTier(mergedCount); + } catch (error) { + core.warning(`failed to evaluate contributor tier status: ${error.message}`); + return; + } + + await ensureContributorTierLabels(); + + const { data: currentLabels } = await github.rest.issues.listLabelsOnIssue({ + owner, + repo, + issue_number: target.number, + }); + const keepLabels = currentLabels + .map((label) => label.name) + .filter((label) => !contributorTierLabels.includes(label)); + + if (contributorTierLabel) { + keepLabels.push(contributorTierLabel); + } + + await github.rest.issues.setLabels({ + owner, + repo, + issue_number: target.number, + labels: [...new Set(keepLabels)], + }); + + first-interaction: + if: github.event.action == 'opened' + runs-on: blacksmith-2vcpu-ubuntu-2404 + permissions: + issues: write + pull-requests: write + steps: + - name: Greet first-time contributors + uses: actions/first-interaction@2ec0f0fd78838633cd1c1342e4536d49ef72be54 # v1 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + issue-message: | + Thanks for opening this issue. + + Before maintainers triage it, please confirm: + - Repro steps are complete and run on latest `main` + - Environment details are included (OS, Rust version, ZeroClaw version) + - Sensitive values are redacted + + This helps us keep issue throughput high and response latency low. + pr-message: | + Thanks for contributing to ZeroClaw. + + For faster review, please ensure: + - PR template sections are fully completed + - `cargo fmt --all -- --check`, `cargo clippy --all-targets -- -D warnings`, and `cargo test` are included + - If automation/agents were used heavily, add brief workflow notes + - Scope is focused (prefer one concern per PR) + + See `CONTRIBUTING.md` and `docs/pr-workflow.md` for full collaboration rules. + + labeled-routes: + if: github.event.action == 'labeled' + runs-on: blacksmith-2vcpu-ubuntu-2404 + permissions: + issues: write + pull-requests: write + steps: + - name: Handle label-driven responses + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8 + with: + script: | + const label = context.payload.label?.name; + if (!label) return; + + const issue = context.payload.issue; + const pullRequest = context.payload.pull_request; + const target = issue ?? pullRequest; + if (!target) return; + + const isIssue = Boolean(issue); + const issueNumber = target.number; + const owner = context.repo.owner; + const repo = context.repo.repo; + + const rules = [ + { + label: "r:support", + close: true, + closeIssuesOnly: true, + closeReason: "not_planned", + message: + "This looks like a usage/support request. Please use README + docs first, then open a focused bug with repro details if behavior is incorrect.", + }, + { + label: "r:needs-repro", + close: false, + message: + "Thanks for the report. Please add deterministic repro steps, exact environment, and redacted logs so maintainers can triage quickly.", + }, + { + label: "invalid", + close: true, + closeIssuesOnly: true, + closeReason: "not_planned", + message: + "Closing as invalid based on current information. If this is still relevant, open a new issue with updated evidence and reproducible steps.", + }, + { + label: "duplicate", + close: true, + closeIssuesOnly: true, + closeReason: "not_planned", + message: + "Closing as duplicate. Please continue discussion in the canonical linked issue/PR.", + }, + ]; + + const rule = rules.find((entry) => entry.label === label); + if (!rule) return; + + const marker = ``; + const comments = await github.paginate(github.rest.issues.listComments, { + owner, + repo, + issue_number: issueNumber, + per_page: 100, + }); + + const alreadyCommented = comments.some((comment) => + (comment.body || "").includes(marker) + ); + + if (!alreadyCommented) { + await github.rest.issues.createComment({ + owner, + repo, + issue_number: issueNumber, + body: `${rule.message}\n\n${marker}`, + }); + } + + if (!rule.close) return; + if (rule.closeIssuesOnly && !isIssue) return; + if (target.state === "closed") return; + + if (isIssue) { + await github.rest.issues.update({ + owner, + repo, + issue_number: issueNumber, + state: "closed", + state_reason: rule.closeReason || "not_planned", + }); + } else { + await github.rest.issues.update({ + owner, + repo, + issue_number: issueNumber, + state: "closed", + }); + } diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7860946..e377d15 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,45 +1,531 @@ name: CI on: - push: - branches: [main, develop] - pull_request: - branches: [main] + push: + branches: [main] + pull_request: + branches: [main] + +concurrency: + group: ci-${{ github.event.pull_request.number || github.sha }} + cancel-in-progress: true + +permissions: + contents: read env: - CARGO_TERM_COLOR: always + CARGO_TERM_COLOR: always jobs: - test: - name: Test - runs-on: ubuntu-latest - continue-on-error: true # Don't block PRs - steps: - - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable - - uses: Swatinem/rust-cache@v2 - - name: Run tests - run: cargo test --verbose + changes: + name: Detect Change Scope + runs-on: blacksmith-2vcpu-ubuntu-2404 + outputs: + docs_only: ${{ steps.scope.outputs.docs_only }} + docs_changed: ${{ steps.scope.outputs.docs_changed }} + rust_changed: ${{ steps.scope.outputs.rust_changed }} + workflow_changed: ${{ steps.scope.outputs.workflow_changed }} + docs_files: ${{ steps.scope.outputs.docs_files }} + base_sha: ${{ steps.scope.outputs.base_sha }} + steps: + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + with: + fetch-depth: 0 - build: - name: Build - runs-on: ${{ matrix.os }} - continue-on-error: true # Don't block PRs - strategy: - matrix: - include: - - os: ubuntu-latest - target: x86_64-unknown-linux-gnu - - os: macos-latest - target: x86_64-apple-darwin - - os: macos-latest - target: aarch64-apple-darwin - - os: windows-latest - target: x86_64-pc-windows-msvc + - name: Detect docs-only changes + id: scope + shell: bash + run: | + set -euo pipefail - steps: - - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable - - uses: Swatinem/rust-cache@v2 - - name: Build - run: cargo build --release --verbose + write_empty_docs_files() { + { + echo "docs_files<> "$GITHUB_OUTPUT" + } + + if [ "${{ github.event_name }}" = "pull_request" ]; then + BASE="${{ github.event.pull_request.base.sha }}" + else + BASE="${{ github.event.before }}" + fi + + if [ -z "$BASE" ] || ! git cat-file -e "$BASE^{commit}" 2>/dev/null; then + { + echo "docs_only=false" + echo "docs_changed=false" + echo "rust_changed=true" + echo "workflow_changed=false" + echo "base_sha=" + } >> "$GITHUB_OUTPUT" + write_empty_docs_files + exit 0 + fi + + CHANGED="$(git diff --name-only "$BASE" HEAD || true)" + if [ -z "$CHANGED" ]; then + { + echo "docs_only=false" + echo "docs_changed=false" + echo "rust_changed=false" + echo "workflow_changed=false" + echo "base_sha=$BASE" + } >> "$GITHUB_OUTPUT" + write_empty_docs_files + exit 0 + fi + + docs_only=true + docs_changed=false + rust_changed=false + workflow_changed=false + docs_files=() + while IFS= read -r file; do + [ -z "$file" ] && continue + + if [[ "$file" == .github/workflows/* ]]; then + workflow_changed=true + fi + + if [[ "$file" == docs/* ]] \ + || [[ "$file" == *.md ]] \ + || [[ "$file" == *.mdx ]] \ + || [[ "$file" == "LICENSE" ]] \ + || [[ "$file" == ".markdownlint-cli2.yaml" ]] \ + || [[ "$file" == .github/ISSUE_TEMPLATE/* ]] \ + || [[ "$file" == .github/pull_request_template.md ]]; then + if [[ "$file" == *.md ]] \ + || [[ "$file" == *.mdx ]] \ + || [[ "$file" == "LICENSE" ]] \ + || [[ "$file" == .github/pull_request_template.md ]]; then + docs_changed=true + docs_files+=("$file") + fi + continue + fi + + docs_only=false + + if [[ "$file" == src/* ]] \ + || [[ "$file" == tests/* ]] \ + || [[ "$file" == "Cargo.toml" ]] \ + || [[ "$file" == "Cargo.lock" ]] \ + || [[ "$file" == "deny.toml" ]]; then + rust_changed=true + fi + done <<< "$CHANGED" + + { + echo "docs_only=$docs_only" + echo "docs_changed=$docs_changed" + echo "rust_changed=$rust_changed" + echo "workflow_changed=$workflow_changed" + echo "base_sha=$BASE" + echo "docs_files<> "$GITHUB_OUTPUT" + + lint: + name: Lint Gate (Format + Clippy) + needs: [changes] + if: needs.changes.outputs.rust_changed == 'true' + runs-on: blacksmith-2vcpu-ubuntu-2404 + timeout-minutes: 20 + steps: + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + with: + fetch-depth: 0 + - uses: dtolnay/rust-toolchain@631a55b12751854ce901bb631d5902ceb48146f7 # stable + with: + toolchain: 1.92.0 + components: rustfmt, clippy + - uses: Swatinem/rust-cache@779680da715d629ac1d338a641029a2f4372abb5 # v2 + - name: Run rust quality gate + run: ./scripts/ci/rust_quality_gate.sh + + lint-strict-delta: + name: Lint Gate (Strict Delta) + needs: [changes] + if: needs.changes.outputs.rust_changed == 'true' + runs-on: blacksmith-2vcpu-ubuntu-2404 + timeout-minutes: 25 + steps: + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + with: + fetch-depth: 0 + - uses: dtolnay/rust-toolchain@631a55b12751854ce901bb631d5902ceb48146f7 # stable + with: + toolchain: 1.92.0 + components: clippy + - uses: Swatinem/rust-cache@779680da715d629ac1d338a641029a2f4372abb5 # v2 + - name: Run strict lint delta gate + env: + BASE_SHA: ${{ needs.changes.outputs.base_sha }} + run: ./scripts/ci/rust_strict_delta_gate.sh + + test: + name: Test + needs: [changes, lint, lint-strict-delta] + if: needs.changes.outputs.rust_changed == 'true' && needs.lint.result == 'success' && needs.lint-strict-delta.result == 'success' + runs-on: blacksmith-2vcpu-ubuntu-2404 + timeout-minutes: 30 + steps: + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + - uses: dtolnay/rust-toolchain@631a55b12751854ce901bb631d5902ceb48146f7 # stable + with: + toolchain: 1.92.0 + - uses: Swatinem/rust-cache@779680da715d629ac1d338a641029a2f4372abb5 # v2 + - name: Run tests + run: cargo test --locked --verbose + + build: + name: Build (Smoke) + needs: [changes, lint, lint-strict-delta] + if: needs.changes.outputs.rust_changed == 'true' && needs.lint.result == 'success' && needs.lint-strict-delta.result == 'success' + runs-on: blacksmith-2vcpu-ubuntu-2404 + timeout-minutes: 20 + + steps: + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + - uses: dtolnay/rust-toolchain@631a55b12751854ce901bb631d5902ceb48146f7 # stable + with: + toolchain: 1.92.0 + - uses: Swatinem/rust-cache@779680da715d629ac1d338a641029a2f4372abb5 # v2 + - name: Build release binary + run: cargo build --release --locked --verbose + + docs-only: + name: Docs-Only Fast Path + needs: [changes] + if: needs.changes.outputs.docs_only == 'true' + runs-on: blacksmith-2vcpu-ubuntu-2404 + steps: + - name: Skip heavy jobs for docs-only change + run: echo "Docs-only change detected. Rust lint/test/build skipped." + + non-rust: + name: Non-Rust Fast Path + needs: [changes] + if: needs.changes.outputs.docs_only != 'true' && needs.changes.outputs.rust_changed != 'true' + runs-on: blacksmith-2vcpu-ubuntu-2404 + steps: + - name: Skip Rust jobs for non-Rust change scope + run: echo "No Rust-impacting files changed. Rust lint/test/build skipped." + + docs-quality: + name: Docs Quality + needs: [changes] + if: needs.changes.outputs.docs_changed == 'true' + runs-on: blacksmith-2vcpu-ubuntu-2404 + timeout-minutes: 15 + steps: + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + with: + fetch-depth: 0 + + - name: Markdown lint (changed lines only) + env: + BASE_SHA: ${{ needs.changes.outputs.base_sha }} + DOCS_FILES: ${{ needs.changes.outputs.docs_files }} + run: ./scripts/ci/docs_quality_gate.sh + + - name: Collect added links + id: collect_links + shell: bash + env: + BASE_SHA: ${{ needs.changes.outputs.base_sha }} + DOCS_FILES: ${{ needs.changes.outputs.docs_files }} + run: | + set -euo pipefail + python3 ./scripts/ci/collect_changed_links.py \ + --base "$BASE_SHA" \ + --docs-files "$DOCS_FILES" \ + --output .ci-added-links.txt + count=$(wc -l < .ci-added-links.txt | tr -d ' ') + echo "count=$count" >> "$GITHUB_OUTPUT" + if [ "$count" -gt 0 ]; then + echo "Added links queued for check:" + cat .ci-added-links.txt + else + echo "No added links found in changed docs lines." + fi + + - name: Link check (offline, added links only) + if: steps.collect_links.outputs.count != '0' + uses: lycheeverse/lychee-action@a8c4c7cb88f0c7386610c35eb25108e448569cb0 # v2 + with: + fail: true + args: >- + --offline + --no-progress + --format detailed + .ci-added-links.txt + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Skip link check (no added links) + if: steps.collect_links.outputs.count == '0' + run: echo "No added links in changed docs lines. Link check skipped." + + lint-feedback: + name: Lint Feedback + if: github.event_name == 'pull_request' + needs: [changes, lint, lint-strict-delta, docs-quality] + runs-on: blacksmith-2vcpu-ubuntu-2404 + permissions: + contents: read + pull-requests: write + issues: write + steps: + - name: Post actionable lint failure summary + if: always() + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8 + env: + RUST_CHANGED: ${{ needs.changes.outputs.rust_changed }} + DOCS_CHANGED: ${{ needs.changes.outputs.docs_changed }} + LINT_RESULT: ${{ needs.lint.result }} + LINT_DELTA_RESULT: ${{ needs.lint-strict-delta.result }} + DOCS_RESULT: ${{ needs.docs-quality.result }} + with: + script: | + const owner = context.repo.owner; + const repo = context.repo.repo; + const issueNumber = context.payload.pull_request?.number; + if (!issueNumber) return; + + const marker = ""; + const rustChanged = process.env.RUST_CHANGED === "true"; + const docsChanged = process.env.DOCS_CHANGED === "true"; + const lintResult = process.env.LINT_RESULT || "skipped"; + const lintDeltaResult = process.env.LINT_DELTA_RESULT || "skipped"; + const docsResult = process.env.DOCS_RESULT || "skipped"; + + const failures = []; + if (rustChanged && !["success", "skipped"].includes(lintResult)) { + failures.push("`Lint Gate (Format + Clippy)` failed."); + } + if (rustChanged && !["success", "skipped"].includes(lintDeltaResult)) { + failures.push("`Lint Gate (Strict Delta)` failed."); + } + if (docsChanged && !["success", "skipped"].includes(docsResult)) { + failures.push("`Docs Quality` failed."); + } + + const comments = await github.paginate(github.rest.issues.listComments, { + owner, + repo, + issue_number: issueNumber, + per_page: 100, + }); + const existing = comments.find((comment) => (comment.body || "").includes(marker)); + + if (failures.length === 0) { + if (existing) { + await github.rest.issues.deleteComment({ + owner, + repo, + comment_id: existing.id, + }); + } + core.info("No lint/docs gate failures. No feedback comment required."); + return; + } + + const runUrl = `${context.serverUrl}/${owner}/${repo}/actions/runs/${context.runId}`; + const body = [ + marker, + "### CI lint feedback", + "", + "This PR failed one or more fast lint/documentation gates:", + "", + ...failures.map((item) => `- ${item}`), + "", + "Open the failing logs in this run:", + `- ${runUrl}`, + "", + "Local fix commands:", + "- `./scripts/ci/rust_quality_gate.sh`", + "- `./scripts/ci/rust_strict_delta_gate.sh`", + "- `./scripts/ci/docs_quality_gate.sh`", + "", + "After fixes, push a new commit and CI will re-run automatically.", + ].join("\n"); + + if (existing) { + await github.rest.issues.updateComment({ + owner, + repo, + comment_id: existing.id, + body, + }); + } else { + await github.rest.issues.createComment({ + owner, + repo, + issue_number: issueNumber, + body, + }); + } + + workflow-owner-approval: + name: Workflow Owner Approval + needs: [changes] + if: github.event_name == 'pull_request' && needs.changes.outputs.workflow_changed == 'true' + runs-on: blacksmith-2vcpu-ubuntu-2404 + permissions: + contents: read + pull-requests: read + steps: + - name: Require owner approval for workflow file changes + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8 + env: + WORKFLOW_OWNER_LOGINS: ${{ vars.WORKFLOW_OWNER_LOGINS || 'theonlyhennygod,willsarg' }} + with: + script: | + const owner = context.repo.owner; + const repo = context.repo.repo; + const prNumber = context.payload.pull_request?.number; + if (!prNumber) { + core.setFailed("Missing pull_request context."); + return; + } + + const ownerAllowlist = (process.env.WORKFLOW_OWNER_LOGINS || "") + .split(",") + .map((login) => login.trim().toLowerCase()) + .filter(Boolean); + + if (ownerAllowlist.length === 0) { + core.setFailed("WORKFLOW_OWNER_LOGINS is empty. Set a repository variable or use a fallback value."); + return; + } + + const files = await github.paginate(github.rest.pulls.listFiles, { + owner, + repo, + pull_number: prNumber, + per_page: 100, + }); + + const workflowFiles = files + .map((file) => file.filename) + .filter((name) => name.startsWith(".github/workflows/")); + + if (workflowFiles.length === 0) { + core.info("No workflow files changed in this PR."); + return; + } + + core.info(`Workflow files changed:\n- ${workflowFiles.join("\n- ")}`); + + const reviews = await github.paginate(github.rest.pulls.listReviews, { + owner, + repo, + pull_number: prNumber, + per_page: 100, + }); + + const latestReviewByUser = new Map(); + for (const review of reviews) { + const login = review.user?.login; + if (!login) continue; + latestReviewByUser.set(login.toLowerCase(), review.state); + } + + const approvedUsers = [...latestReviewByUser.entries()] + .filter(([, state]) => state === "APPROVED") + .map(([login]) => login); + + if (approvedUsers.length === 0) { + core.setFailed("Workflow files changed but no approving review is present."); + return; + } + + const ownerApprover = approvedUsers.find((login) => ownerAllowlist.includes(login)); + if (!ownerApprover) { + core.setFailed( + `Workflow files changed. Approvals found (${approvedUsers.join(", ")}), but none match WORKFLOW_OWNER_LOGINS.`, + ); + return; + } + + core.info(`Workflow owner approval present: @${ownerApprover}`); + + ci-required: + name: CI Required Gate + if: always() + needs: [changes, lint, lint-strict-delta, test, build, docs-only, non-rust, docs-quality, lint-feedback, workflow-owner-approval] + runs-on: blacksmith-2vcpu-ubuntu-2404 + steps: + - name: Enforce required status + shell: bash + run: | + set -euo pipefail + + docs_changed="${{ needs.changes.outputs.docs_changed }}" + rust_changed="${{ needs.changes.outputs.rust_changed }}" + workflow_changed="${{ needs.changes.outputs.workflow_changed }}" + docs_result="${{ needs.docs-quality.result }}" + workflow_owner_result="${{ needs.workflow-owner-approval.result }}" + + if [ "${{ needs.changes.outputs.docs_only }}" = "true" ]; then + echo "docs=${docs_result}" + echo "workflow_owner_approval=${workflow_owner_result}" + if [ "$workflow_changed" = "true" ] && [ "$workflow_owner_result" != "success" ]; then + echo "Workflow files changed but workflow owner approval gate did not pass." + exit 1 + fi + if [ "$docs_changed" = "true" ] && [ "$docs_result" != "success" ]; then + echo "Docs-only change touched markdown docs, but docs-quality did not pass." + exit 1 + fi + echo "Docs-only fast path passed." + exit 0 + fi + + if [ "$rust_changed" != "true" ]; then + echo "rust_changed=false (non-rust fast path)" + echo "docs=${docs_result}" + echo "workflow_owner_approval=${workflow_owner_result}" + if [ "$workflow_changed" = "true" ] && [ "$workflow_owner_result" != "success" ]; then + echo "Workflow files changed but workflow owner approval gate did not pass." + exit 1 + fi + if [ "$docs_changed" = "true" ] && [ "$docs_result" != "success" ]; then + echo "Docs changed but docs-quality did not pass." + exit 1 + fi + echo "Non-rust fast path passed." + exit 0 + fi + + lint_result="${{ needs.lint.result }}" + lint_strict_delta_result="${{ needs.lint-strict-delta.result }}" + test_result="${{ needs.test.result }}" + build_result="${{ needs.build.result }}" + + echo "lint=${lint_result}" + echo "lint_strict_delta=${lint_strict_delta_result}" + echo "test=${test_result}" + echo "build=${build_result}" + echo "docs=${docs_result}" + echo "workflow_owner_approval=${workflow_owner_result}" + + if [ "$lint_result" != "success" ] || [ "$lint_strict_delta_result" != "success" ] || [ "$test_result" != "success" ] || [ "$build_result" != "success" ]; then + echo "Required CI jobs did not pass." + exit 1 + fi + + if [ "$workflow_changed" = "true" ] && [ "$workflow_owner_result" != "success" ]; then + echo "Workflow files changed but workflow owner approval gate did not pass." + exit 1 + fi + + if [ "$docs_changed" = "true" ] && [ "$docs_result" != "success" ]; then + echo "Docs changed but docs-quality did not pass." + exit 1 + fi + + echo "All required CI jobs passed." diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 0000000..81210b2 --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,39 @@ +name: CodeQL Analysis + +on: + schedule: + - cron: "0 6,18 * * *" # Twice daily at 6am and 6pm UTC + workflow_dispatch: + +concurrency: + group: codeql-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: read + security-events: write + actions: read + +jobs: + codeql: + name: CodeQL Analysis + runs-on: blacksmith-2vcpu-ubuntu-2404 + timeout-minutes: 30 + steps: + - name: Checkout repository + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + + - name: Initialize CodeQL + uses: github/codeql-action/init@v4 + with: + languages: rust + config-file: ./.github/codeql/codeql-config.yml + + - name: Set up Rust + uses: dtolnay/rust-toolchain@stable + + - name: Build + run: cargo build --workspace --all-targets + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v4 diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index c1fe26d..67005c6 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -1,65 +1,110 @@ name: Docker on: - push: - branches: [main] - tags: ["v*"] - pull_request: - branches: [main] + push: + branches: [main] + tags: ["v*"] + pull_request: + branches: [main] + paths: + - "Dockerfile" + - "docker-compose.yml" + - "dev/docker-compose.yml" + - "dev/sandbox/**" + - ".github/workflows/docker.yml" + workflow_dispatch: + +concurrency: + group: docker-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true env: - REGISTRY: ghcr.io - IMAGE_NAME: ${{ github.repository }} + REGISTRY: ghcr.io + IMAGE_NAME: ${{ github.repository }} jobs: - build-and-push: - name: Build and Push Docker Image - runs-on: ubuntu-latest - permissions: - contents: read - packages: write - - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 - - - name: Log in to Container Registry - if: github.event_name != 'pull_request' - uses: docker/login-action@v3 - with: - registry: ${{ env.REGISTRY }} - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - - - name: Extract metadata (tags, labels) - id: meta - uses: docker/metadata-action@v5 - with: - images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} - tags: | - type=ref,event=branch - type=ref,event=pr - type=semver,pattern={{version}} - type=semver,pattern={{major}}.{{minor}} - type=semver,pattern={{major}} - type=raw,value=latest,enable={{is_default_branch}} - - - name: Build and push Docker image - uses: docker/build-push-action@v5 - with: - context: . - push: ${{ github.event_name != 'pull_request' }} - tags: ${{ steps.meta.outputs.tags }} - labels: ${{ steps.meta.outputs.labels }} - cache-from: type=gha - cache-to: type=gha,mode=max - platforms: linux/amd64,linux/arm64 - - - name: Verify image (PR only) + pr-smoke: + name: PR Docker Smoke if: github.event_name == 'pull_request' - run: | - docker build -t zeroclaw-test . - docker run --rm zeroclaw-test --version + runs-on: blacksmith-2vcpu-ubuntu-2404 + timeout-minutes: 25 + permissions: + contents: read + steps: + - name: Checkout repository + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + + - name: Setup Blacksmith Builder + uses: useblacksmith/setup-docker-builder@ef12d5b165b596e3aa44ea8198d8fde563eab402 # v1 + + - name: Extract metadata (tags, labels) + id: meta + uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # v5 + with: + images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + tags: | + type=ref,event=pr + + - name: Build smoke image + uses: useblacksmith/build-push-action@30c71162f16ea2c27c3e21523255d209b8b538c1 # v2 + with: + context: . + push: false + load: true + tags: zeroclaw-pr-smoke:latest + labels: ${{ steps.meta.outputs.labels }} + platforms: linux/amd64 + + - name: Verify image + run: docker run --rm zeroclaw-pr-smoke:latest --version + + publish: + name: Build and Push Docker Image + if: github.event_name == 'push' + runs-on: blacksmith-2vcpu-ubuntu-2404 + timeout-minutes: 25 + permissions: + contents: read + packages: write + steps: + - name: Checkout repository + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + + - name: Setup Blacksmith Builder + uses: useblacksmith/setup-docker-builder@ef12d5b165b596e3aa44ea8198d8fde563eab402 # v1 + + - name: Log in to Container Registry + uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Compute tags + id: meta + shell: bash + run: | + set -euo pipefail + IMAGE="${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}" + SHA_TAG="${IMAGE}:sha-${GITHUB_SHA::12}" + + if [[ "${GITHUB_REF}" == refs/tags/* ]]; then + TAG_NAME="${GITHUB_REF#refs/tags/}" + TAGS="${IMAGE}:${TAG_NAME},${SHA_TAG}" + elif [[ "${GITHUB_REF}" == "refs/heads/main" ]]; then + TAGS="${IMAGE}:latest,${SHA_TAG}" + else + BRANCH_NAME="${GITHUB_REF#refs/heads/}" + BRANCH_NAME="${BRANCH_NAME//\//-}" + TAGS="${IMAGE}:${BRANCH_NAME},${SHA_TAG}" + fi + + echo "tags=${TAGS}" >> "$GITHUB_OUTPUT" + + - name: Build and push Docker image + uses: useblacksmith/build-push-action@30c71162f16ea2c27c3e21523255d209b8b538c1 # v2 + with: + context: . + push: true + tags: ${{ steps.meta.outputs.tags }} + platforms: ${{ startsWith(github.ref, 'refs/tags/') && 'linux/amd64,linux/arm64' || 'linux/amd64' }} diff --git a/.github/workflows/label-policy-sanity.yml b/.github/workflows/label-policy-sanity.yml new file mode 100644 index 0000000..de1bbda --- /dev/null +++ b/.github/workflows/label-policy-sanity.yml @@ -0,0 +1,74 @@ +name: Label Policy Sanity + +on: + pull_request: + paths: + - ".github/label-policy.json" + - ".github/workflows/labeler.yml" + - ".github/workflows/auto-response.yml" + push: + paths: + - ".github/label-policy.json" + - ".github/workflows/labeler.yml" + - ".github/workflows/auto-response.yml" + +concurrency: + group: label-policy-sanity-${{ github.event.pull_request.number || github.sha }} + cancel-in-progress: true + +permissions: + contents: read + +jobs: + contributor-tier-consistency: + runs-on: blacksmith-2vcpu-ubuntu-2404 + timeout-minutes: 10 + steps: + - name: Checkout + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + + - name: Verify shared label policy and workflow wiring + shell: bash + run: | + set -euo pipefail + python3 - <<'PY' + import json + import re + from pathlib import Path + + policy_path = Path('.github/label-policy.json') + policy = json.loads(policy_path.read_text(encoding='utf-8')) + color = str(policy.get('contributor_tier_color', '')).upper() + rules = policy.get('contributor_tiers', []) + if not re.fullmatch(r'[0-9A-F]{6}', color): + raise SystemExit('invalid contributor_tier_color in .github/label-policy.json') + if not rules: + raise SystemExit('contributor_tiers must not be empty in .github/label-policy.json') + + labels = set() + prev_min = None + for entry in rules: + label = str(entry.get('label', '')).strip().lower() + min_merged = int(entry.get('min_merged_prs', 0)) + if not label.endswith('contributor'): + raise SystemExit(f'invalid contributor tier label: {label}') + if label in labels: + raise SystemExit(f'duplicate contributor tier label: {label}') + if prev_min is not None and min_merged > prev_min: + raise SystemExit('contributor_tiers must be sorted descending by min_merged_prs') + labels.add(label) + prev_min = min_merged + + workflow_paths = [ + Path('.github/workflows/labeler.yml'), + Path('.github/workflows/auto-response.yml'), + ] + for workflow in workflow_paths: + text = workflow.read_text(encoding='utf-8') + if '.github/label-policy.json' not in text: + raise SystemExit(f'{workflow} must load .github/label-policy.json') + if re.search(r'contributorTierColor\s*=\s*"[0-9A-Fa-f]{6}"', text): + raise SystemExit(f'{workflow} contains hardcoded contributorTierColor') + + print('label policy file is valid and workflow consumers are wired to shared policy') + PY diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml new file mode 100644 index 0000000..0e38f00 --- /dev/null +++ b/.github/workflows/labeler.yml @@ -0,0 +1,841 @@ +name: PR Labeler + +on: + pull_request_target: + types: [opened, reopened, synchronize, edited, labeled, unlabeled] + workflow_dispatch: + inputs: + mode: + description: "Run mode for managed-label governance" + required: true + default: "audit" + type: choice + options: + - audit + - repair + +concurrency: + group: pr-labeler-${{ github.event.pull_request.number || github.run_id }} + cancel-in-progress: true + +permissions: + contents: read + pull-requests: write + issues: write + +jobs: + label: + runs-on: blacksmith-2vcpu-ubuntu-2404 + timeout-minutes: 10 + steps: + - name: Apply path labels + if: github.event_name == 'pull_request_target' + uses: actions/labeler@8558fd74291d67161a8a78ce36a881fa63b766a9 # v5 + continue-on-error: true + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + sync-labels: true + + - name: Apply size/risk/module labels + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8 + continue-on-error: true + with: + script: | + const pr = context.payload.pull_request; + const owner = context.repo.owner; + const repo = context.repo.repo; + const action = context.payload.action; + const changedLabel = context.payload.label?.name; + + const sizeLabels = ["size: XS", "size: S", "size: M", "size: L", "size: XL"]; + const computedRiskLabels = ["risk: low", "risk: medium", "risk: high"]; + const manualRiskOverrideLabel = "risk: manual"; + const managedEnforcedLabels = new Set([ + ...sizeLabels, + manualRiskOverrideLabel, + ...computedRiskLabels, + ]); + if ((action === "labeled" || action === "unlabeled") && !managedEnforcedLabels.has(changedLabel)) { + core.info(`skip non-size/risk label event: ${changedLabel || "unknown"}`); + return; + } + + async function loadContributorTierPolicy() { + const fallback = { + contributorTierColor: "2ED9FF", + contributorTierRules: [ + { label: "distinguished contributor", minMergedPRs: 50 }, + { label: "principal contributor", minMergedPRs: 20 }, + { label: "experienced contributor", minMergedPRs: 10 }, + { label: "trusted contributor", minMergedPRs: 5 }, + ], + }; + try { + const { data } = await github.rest.repos.getContent({ + owner, + repo, + path: ".github/label-policy.json", + ref: context.payload.repository?.default_branch || "main", + }); + const json = JSON.parse(Buffer.from(data.content, "base64").toString("utf8")); + const contributorTierRules = (json.contributor_tiers || []).map((entry) => ({ + label: String(entry.label || "").trim(), + minMergedPRs: Number(entry.min_merged_prs || 0), + })); + const contributorTierColor = String(json.contributor_tier_color || "").toUpperCase(); + if (!contributorTierColor || contributorTierRules.length === 0) { + return fallback; + } + return { contributorTierColor, contributorTierRules }; + } catch (error) { + core.warning(`failed to load .github/label-policy.json, using fallback policy: ${error.message}`); + return fallback; + } + } + + const { contributorTierColor, contributorTierRules } = await loadContributorTierPolicy(); + const contributorTierLabels = contributorTierRules.map((rule) => rule.label); + + const managedPathLabels = [ + "docs", + "dependencies", + "ci", + "core", + "agent", + "channel", + "config", + "cron", + "daemon", + "doctor", + "gateway", + "health", + "heartbeat", + "integration", + "memory", + "observability", + "onboard", + "provider", + "runtime", + "security", + "service", + "skillforge", + "skills", + "tool", + "tunnel", + "tests", + "scripts", + "dev", + ]; + const managedPathLabelSet = new Set(managedPathLabels); + + const moduleNamespaceRules = [ + { root: "src/agent/", prefix: "agent", coreEntries: new Set(["mod.rs"]) }, + { root: "src/channels/", prefix: "channel", coreEntries: new Set(["mod.rs", "traits.rs"]) }, + { root: "src/config/", prefix: "config", coreEntries: new Set(["mod.rs", "schema.rs"]) }, + { root: "src/cron/", prefix: "cron", coreEntries: new Set(["mod.rs"]) }, + { root: "src/daemon/", prefix: "daemon", coreEntries: new Set(["mod.rs"]) }, + { root: "src/doctor/", prefix: "doctor", coreEntries: new Set(["mod.rs"]) }, + { root: "src/gateway/", prefix: "gateway", coreEntries: new Set(["mod.rs"]) }, + { root: "src/health/", prefix: "health", coreEntries: new Set(["mod.rs"]) }, + { root: "src/heartbeat/", prefix: "heartbeat", coreEntries: new Set(["mod.rs"]) }, + { root: "src/integrations/", prefix: "integration", coreEntries: new Set(["mod.rs", "registry.rs"]) }, + { root: "src/memory/", prefix: "memory", coreEntries: new Set(["mod.rs", "traits.rs"]) }, + { root: "src/observability/", prefix: "observability", coreEntries: new Set(["mod.rs", "traits.rs"]) }, + { root: "src/onboard/", prefix: "onboard", coreEntries: new Set(["mod.rs"]) }, + { root: "src/providers/", prefix: "provider", coreEntries: new Set(["mod.rs", "traits.rs"]) }, + { root: "src/runtime/", prefix: "runtime", coreEntries: new Set(["mod.rs", "traits.rs"]) }, + { root: "src/security/", prefix: "security", coreEntries: new Set(["mod.rs"]) }, + { root: "src/service/", prefix: "service", coreEntries: new Set(["mod.rs"]) }, + { root: "src/skillforge/", prefix: "skillforge", coreEntries: new Set(["mod.rs"]) }, + { root: "src/skills/", prefix: "skills", coreEntries: new Set(["mod.rs"]) }, + { root: "src/tools/", prefix: "tool", coreEntries: new Set(["mod.rs", "traits.rs"]) }, + { root: "src/tunnel/", prefix: "tunnel", coreEntries: new Set(["mod.rs"]) }, + ]; + const managedModulePrefixes = [...new Set(moduleNamespaceRules.map((rule) => `${rule.prefix}:`))]; + const orderedOtherLabelStyles = [ + { label: "health", color: "8EC9B8" }, + { label: "tool", color: "7FC4B6" }, + { label: "agent", color: "86C4A2" }, + { label: "memory", color: "8FCB99" }, + { label: "channel", color: "7EB6F2" }, + { label: "service", color: "95C7B6" }, + { label: "integration", color: "8DC9AE" }, + { label: "tunnel", color: "9FC8B3" }, + { label: "config", color: "AABCD0" }, + { label: "observability", color: "84C9D0" }, + { label: "docs", color: "8FBBE0" }, + { label: "dev", color: "B9C1CC" }, + { label: "tests", color: "9DC8C7" }, + { label: "skills", color: "BFC89B" }, + { label: "skillforge", color: "C9C39B" }, + { label: "provider", color: "958DF0" }, + { label: "runtime", color: "A3ADD8" }, + { label: "heartbeat", color: "C0C88D" }, + { label: "daemon", color: "C8C498" }, + { label: "doctor", color: "C1CF9D" }, + { label: "onboard", color: "D2BF86" }, + { label: "cron", color: "D2B490" }, + { label: "ci", color: "AEB4CE" }, + { label: "dependencies", color: "9FB1DE" }, + { label: "gateway", color: "B5A8E5" }, + { label: "security", color: "E58D85" }, + { label: "core", color: "C8A99B" }, + { label: "scripts", color: "C9B49F" }, + ]; + const otherLabelDisplayOrder = orderedOtherLabelStyles.map((entry) => entry.label); + const modulePrefixSet = new Set(moduleNamespaceRules.map((rule) => rule.prefix)); + const modulePrefixPriority = otherLabelDisplayOrder.filter((label) => modulePrefixSet.has(label)); + const pathLabelPriority = [...otherLabelDisplayOrder]; + const riskDisplayOrder = ["risk: high", "risk: medium", "risk: low", "risk: manual"]; + const sizeDisplayOrder = ["size: XS", "size: S", "size: M", "size: L", "size: XL"]; + const contributorDisplayOrder = [ + "distinguished contributor", + "principal contributor", + "experienced contributor", + "trusted contributor", + ]; + const modulePrefixPriorityIndex = new Map( + modulePrefixPriority.map((prefix, index) => [prefix, index]) + ); + const pathLabelPriorityIndex = new Map( + pathLabelPriority.map((label, index) => [label, index]) + ); + const riskPriorityIndex = new Map( + riskDisplayOrder.map((label, index) => [label, index]) + ); + const sizePriorityIndex = new Map( + sizeDisplayOrder.map((label, index) => [label, index]) + ); + const contributorPriorityIndex = new Map( + contributorDisplayOrder.map((label, index) => [label, index]) + ); + + const otherLabelColors = Object.fromEntries( + orderedOtherLabelStyles.map((entry) => [entry.label, entry.color]) + ); + const staticLabelColors = { + "size: XS": "E7CDD3", + "size: S": "E1BEC7", + "size: M": "DBB0BB", + "size: L": "D4A2AF", + "size: XL": "CE94A4", + "risk: low": "97D3A6", + "risk: medium": "E4C47B", + "risk: high": "E98E88", + "risk: manual": "B7A4E0", + ...otherLabelColors, + }; + const staticLabelDescriptions = { + "size: XS": "Auto size: <=80 non-doc changed lines.", + "size: S": "Auto size: 81-250 non-doc changed lines.", + "size: M": "Auto size: 251-500 non-doc changed lines.", + "size: L": "Auto size: 501-1000 non-doc changed lines.", + "size: XL": "Auto size: >1000 non-doc changed lines.", + "risk: low": "Auto risk: docs/chore-only paths.", + "risk: medium": "Auto risk: src/** or dependency/config changes.", + "risk: high": "Auto risk: security/runtime/gateway/tools/workflows.", + "risk: manual": "Maintainer override: keep selected risk label.", + docs: "Auto scope: docs/markdown/template files changed.", + dependencies: "Auto scope: dependency manifest/lock/policy changed.", + ci: "Auto scope: CI/workflow/hook files changed.", + core: "Auto scope: root src/*.rs files changed.", + agent: "Auto scope: src/agent/** changed.", + channel: "Auto scope: src/channels/** changed.", + config: "Auto scope: src/config/** changed.", + cron: "Auto scope: src/cron/** changed.", + daemon: "Auto scope: src/daemon/** changed.", + doctor: "Auto scope: src/doctor/** changed.", + gateway: "Auto scope: src/gateway/** changed.", + health: "Auto scope: src/health/** changed.", + heartbeat: "Auto scope: src/heartbeat/** changed.", + integration: "Auto scope: src/integrations/** changed.", + memory: "Auto scope: src/memory/** changed.", + observability: "Auto scope: src/observability/** changed.", + onboard: "Auto scope: src/onboard/** changed.", + provider: "Auto scope: src/providers/** changed.", + runtime: "Auto scope: src/runtime/** changed.", + security: "Auto scope: src/security/** changed.", + service: "Auto scope: src/service/** changed.", + skillforge: "Auto scope: src/skillforge/** changed.", + skills: "Auto scope: src/skills/** changed.", + tool: "Auto scope: src/tools/** changed.", + tunnel: "Auto scope: src/tunnel/** changed.", + tests: "Auto scope: tests/** changed.", + scripts: "Auto scope: scripts/** changed.", + dev: "Auto scope: dev/** changed.", + }; + for (const label of contributorTierLabels) { + staticLabelColors[label] = contributorTierColor; + const rule = contributorTierRules.find((entry) => entry.label === label); + if (rule) { + staticLabelDescriptions[label] = `Contributor with ${rule.minMergedPRs}+ merged PRs.`; + } + } + + const modulePrefixColors = Object.fromEntries( + modulePrefixPriority.map((prefix) => [ + `${prefix}:`, + otherLabelColors[prefix] || "BFDADC", + ]) + ); + + const providerKeywordHints = [ + "deepseek", + "moonshot", + "kimi", + "qwen", + "mistral", + "doubao", + "baichuan", + "yi", + "siliconflow", + "vertex", + "azure", + "perplexity", + "venice", + "vercel", + "cloudflare", + "synthetic", + "opencode", + "zai", + "glm", + "minimax", + "bedrock", + "qianfan", + "groq", + "together", + "fireworks", + "cohere", + "openai", + "openrouter", + "anthropic", + "gemini", + "ollama", + ]; + + const channelKeywordHints = [ + "telegram", + "discord", + "slack", + "whatsapp", + "matrix", + "irc", + "imessage", + "email", + "cli", + ]; + + function isDocsLike(path) { + return ( + path.startsWith("docs/") || + path.endsWith(".md") || + path.endsWith(".mdx") || + path === "LICENSE" || + path === ".markdownlint-cli2.yaml" || + path === ".github/pull_request_template.md" || + path.startsWith(".github/ISSUE_TEMPLATE/") + ); + } + + function normalizeLabelSegment(segment) { + return (segment || "") + .toLowerCase() + .replace(/\.rs$/g, "") + .replace(/[^a-z0-9_-]+/g, "-") + .replace(/^[-_]+|[-_]+$/g, "") + .slice(0, 40); + } + + function containsKeyword(text, keyword) { + const escaped = keyword.replace(/[.*+?^${}()|[\]\\]/g, "\\$&"); + const pattern = new RegExp(`(^|[^a-z0-9_])${escaped}([^a-z0-9_]|$)`, "i"); + return pattern.test(text); + } + + function formatModuleLabel(prefix, segment) { + return `${prefix}: ${segment}`; + } + + function parseModuleLabel(label) { + if (typeof label !== "string") return null; + const match = label.match(/^([^:]+):\s*(.+)$/); + if (!match) return null; + const prefix = match[1].trim().toLowerCase(); + const segment = (match[2] || "").trim().toLowerCase(); + if (!prefix || !segment) return null; + return { prefix, segment }; + } + + function sortByPriority(labels, priorityIndex) { + return [...new Set(labels)].sort((left, right) => { + const leftPriority = priorityIndex.has(left) ? priorityIndex.get(left) : Number.MAX_SAFE_INTEGER; + const rightPriority = priorityIndex.has(right) + ? priorityIndex.get(right) + : Number.MAX_SAFE_INTEGER; + if (leftPriority !== rightPriority) return leftPriority - rightPriority; + return left.localeCompare(right); + }); + } + + function sortModuleLabels(labels) { + return [...new Set(labels)].sort((left, right) => { + const leftParsed = parseModuleLabel(left); + const rightParsed = parseModuleLabel(right); + if (!leftParsed || !rightParsed) return left.localeCompare(right); + + const leftPrefixPriority = modulePrefixPriorityIndex.has(leftParsed.prefix) + ? modulePrefixPriorityIndex.get(leftParsed.prefix) + : Number.MAX_SAFE_INTEGER; + const rightPrefixPriority = modulePrefixPriorityIndex.has(rightParsed.prefix) + ? modulePrefixPriorityIndex.get(rightParsed.prefix) + : Number.MAX_SAFE_INTEGER; + + if (leftPrefixPriority !== rightPrefixPriority) { + return leftPrefixPriority - rightPrefixPriority; + } + if (leftParsed.prefix !== rightParsed.prefix) { + return leftParsed.prefix.localeCompare(rightParsed.prefix); + } + + const leftIsCore = leftParsed.segment === "core"; + const rightIsCore = rightParsed.segment === "core"; + if (leftIsCore !== rightIsCore) return leftIsCore ? 1 : -1; + + return leftParsed.segment.localeCompare(rightParsed.segment); + }); + } + + function refineModuleLabels(rawLabels) { + const refined = new Set(rawLabels); + const segmentsByPrefix = new Map(); + + for (const label of rawLabels) { + const parsed = parseModuleLabel(label); + if (!parsed) continue; + if (!segmentsByPrefix.has(parsed.prefix)) { + segmentsByPrefix.set(parsed.prefix, new Set()); + } + segmentsByPrefix.get(parsed.prefix).add(parsed.segment); + } + + for (const [prefix, segments] of segmentsByPrefix) { + const hasSpecificSegment = [...segments].some((segment) => segment !== "core"); + if (hasSpecificSegment) { + refined.delete(formatModuleLabel(prefix, "core")); + } + } + + return refined; + } + + function compactModuleLabels(labels) { + const groupedSegments = new Map(); + const compactedModuleLabels = new Set(); + const forcePathPrefixes = new Set(); + + for (const label of labels) { + const parsed = parseModuleLabel(label); + if (!parsed) { + compactedModuleLabels.add(label); + continue; + } + if (!groupedSegments.has(parsed.prefix)) { + groupedSegments.set(parsed.prefix, new Set()); + } + groupedSegments.get(parsed.prefix).add(parsed.segment); + } + + for (const [prefix, segments] of groupedSegments) { + const uniqueSegments = [...new Set([...segments].filter(Boolean))]; + if (uniqueSegments.length === 0) continue; + + if (uniqueSegments.length === 1) { + compactedModuleLabels.add(formatModuleLabel(prefix, uniqueSegments[0])); + } else { + forcePathPrefixes.add(prefix); + } + } + + return { + moduleLabels: compactedModuleLabels, + forcePathPrefixes, + }; + } + + function colorForLabel(label) { + if (staticLabelColors[label]) return staticLabelColors[label]; + const matchedPrefix = Object.keys(modulePrefixColors).find((prefix) => label.startsWith(prefix)); + if (matchedPrefix) return modulePrefixColors[matchedPrefix]; + return "BFDADC"; + } + + function descriptionForLabel(label) { + if (staticLabelDescriptions[label]) return staticLabelDescriptions[label]; + + const parsed = parseModuleLabel(label); + if (parsed) { + if (parsed.segment === "core") { + return `Auto module: ${parsed.prefix} core files changed.`; + } + return `Auto module: ${parsed.prefix}/${parsed.segment} changed.`; + } + + return "Auto-managed label."; + } + + async function ensureLabel(name, existing = null) { + const expectedColor = colorForLabel(name); + const expectedDescription = descriptionForLabel(name); + try { + const current = existing || (await github.rest.issues.getLabel({ owner, repo, name })).data; + const currentColor = (current.color || "").toUpperCase(); + const currentDescription = (current.description || "").trim(); + if (currentColor !== expectedColor || currentDescription !== expectedDescription) { + await github.rest.issues.updateLabel({ + owner, + repo, + name, + new_name: name, + color: expectedColor, + description: expectedDescription, + }); + } + } catch (error) { + if (error.status !== 404) throw error; + await github.rest.issues.createLabel({ + owner, + repo, + name, + color: expectedColor, + description: expectedDescription, + }); + } + } + + function isManagedLabel(label) { + if (label === manualRiskOverrideLabel) return true; + if (sizeLabels.includes(label) || computedRiskLabels.includes(label)) return true; + if (managedPathLabelSet.has(label)) return true; + if (contributorTierLabels.includes(label)) return true; + if (managedModulePrefixes.some((prefix) => label.startsWith(prefix))) return true; + return false; + } + + async function ensureManagedRepoLabelsMetadata() { + const repoLabels = await github.paginate(github.rest.issues.listLabelsForRepo, { + owner, + repo, + per_page: 100, + }); + + for (const existingLabel of repoLabels) { + const labelName = existingLabel.name || ""; + if (!isManagedLabel(labelName)) continue; + await ensureLabel(labelName, existingLabel); + } + } + + function selectContributorTier(mergedCount) { + const matchedTier = contributorTierRules.find((rule) => mergedCount >= rule.minMergedPRs); + return matchedTier ? matchedTier.label : null; + } + + if (context.eventName === "workflow_dispatch") { + const mode = (context.payload.inputs?.mode || "audit").toLowerCase(); + const shouldRepair = mode === "repair"; + const repoLabels = await github.paginate(github.rest.issues.listLabelsForRepo, { + owner, + repo, + per_page: 100, + }); + + let managedScanned = 0; + const drifts = []; + + for (const existingLabel of repoLabels) { + const labelName = existingLabel.name || ""; + if (!isManagedLabel(labelName)) continue; + managedScanned += 1; + + const expectedColor = colorForLabel(labelName); + const expectedDescription = descriptionForLabel(labelName); + const currentColor = (existingLabel.color || "").toUpperCase(); + const currentDescription = (existingLabel.description || "").trim(); + if (currentColor !== expectedColor || currentDescription !== expectedDescription) { + drifts.push({ + name: labelName, + currentColor, + expectedColor, + currentDescription, + expectedDescription, + }); + if (shouldRepair) { + await ensureLabel(labelName, existingLabel); + } + } + } + + core.summary + .addHeading("Managed Label Governance", 2) + .addRaw(`Mode: ${shouldRepair ? "repair" : "audit"}`) + .addEOL() + .addRaw(`Managed labels scanned: ${managedScanned}`) + .addEOL() + .addRaw(`Drifts found: ${drifts.length}`) + .addEOL(); + + if (drifts.length > 0) { + const sample = drifts.slice(0, 30).map((entry) => [ + entry.name, + `${entry.currentColor} -> ${entry.expectedColor}`, + `${entry.currentDescription || "(blank)"} -> ${entry.expectedDescription}`, + ]); + core.summary.addTable([ + [{ data: "Label", header: true }, { data: "Color", header: true }, { data: "Description", header: true }], + ...sample, + ]); + if (drifts.length > sample.length) { + core.summary + .addRaw(`Additional drifts not shown: ${drifts.length - sample.length}`) + .addEOL(); + } + } + + await core.summary.write(); + + if (!shouldRepair && drifts.length > 0) { + core.info(`Managed-label metadata drifts detected: ${drifts.length}. Re-run with mode=repair to auto-fix.`); + } else if (shouldRepair) { + core.info(`Managed-label metadata repair applied to ${drifts.length} labels.`); + } else { + core.info("No managed-label metadata drift detected."); + } + + return; + } + + const files = await github.paginate(github.rest.pulls.listFiles, { + owner, + repo, + pull_number: pr.number, + per_page: 100, + }); + + const detectedModuleLabels = new Set(); + for (const file of files) { + const path = (file.filename || "").toLowerCase(); + for (const rule of moduleNamespaceRules) { + if (!path.startsWith(rule.root)) continue; + + const relative = path.slice(rule.root.length); + if (!relative) continue; + + const first = relative.split("/")[0]; + const firstStem = first.endsWith(".rs") ? first.slice(0, -3) : first; + let segment = firstStem; + + if (rule.coreEntries.has(first) || rule.coreEntries.has(firstStem)) { + segment = "core"; + } + + segment = normalizeLabelSegment(segment); + if (!segment) continue; + + detectedModuleLabels.add(formatModuleLabel(rule.prefix, segment)); + } + } + + const providerRelevantFiles = files.filter((file) => { + const path = file.filename || ""; + return ( + path.startsWith("src/providers/") || + path.startsWith("src/integrations/") || + path.startsWith("src/onboard/") || + path.startsWith("src/config/") + ); + }); + + if (providerRelevantFiles.length > 0) { + const searchableText = [ + pr.title || "", + pr.body || "", + ...providerRelevantFiles.map((file) => file.filename || ""), + ...providerRelevantFiles.map((file) => file.patch || ""), + ] + .join("\n") + .toLowerCase(); + + for (const keyword of providerKeywordHints) { + if (containsKeyword(searchableText, keyword)) { + detectedModuleLabels.add(formatModuleLabel("provider", keyword)); + } + } + } + + const channelRelevantFiles = files.filter((file) => { + const path = file.filename || ""; + return ( + path.startsWith("src/channels/") || + path.startsWith("src/onboard/") || + path.startsWith("src/config/") + ); + }); + + if (channelRelevantFiles.length > 0) { + const searchableText = [ + pr.title || "", + pr.body || "", + ...channelRelevantFiles.map((file) => file.filename || ""), + ...channelRelevantFiles.map((file) => file.patch || ""), + ] + .join("\n") + .toLowerCase(); + + for (const keyword of channelKeywordHints) { + if (containsKeyword(searchableText, keyword)) { + detectedModuleLabels.add(formatModuleLabel("channel", keyword)); + } + } + } + + const refinedModuleLabels = refineModuleLabels(detectedModuleLabels); + const compactedModuleState = compactModuleLabels(refinedModuleLabels); + const selectedModuleLabels = compactedModuleState.moduleLabels; + const forcePathPrefixes = compactedModuleState.forcePathPrefixes; + const modulePrefixesWithLabels = new Set( + [...selectedModuleLabels] + .map((label) => parseModuleLabel(label)?.prefix) + .filter(Boolean) + ); + + const { data: currentLabels } = await github.rest.issues.listLabelsOnIssue({ + owner, + repo, + issue_number: pr.number, + }); + const currentLabelNames = currentLabels.map((label) => label.name); + const currentPathLabels = currentLabelNames.filter((label) => managedPathLabelSet.has(label)); + const candidatePathLabels = new Set([...currentPathLabels, ...forcePathPrefixes]); + + const dedupedPathLabels = [...candidatePathLabels].filter((label) => { + if (label === "core") return true; + if (forcePathPrefixes.has(label)) return true; + return !modulePrefixesWithLabels.has(label); + }); + + const excludedLockfiles = new Set(["Cargo.lock"]); + const changedLines = files.reduce((total, file) => { + const path = file.filename || ""; + if (isDocsLike(path) || excludedLockfiles.has(path)) { + return total; + } + return total + (file.additions || 0) + (file.deletions || 0); + }, 0); + + let sizeLabel = "size: XL"; + if (changedLines <= 80) sizeLabel = "size: XS"; + else if (changedLines <= 250) sizeLabel = "size: S"; + else if (changedLines <= 500) sizeLabel = "size: M"; + else if (changedLines <= 1000) sizeLabel = "size: L"; + + const hasHighRiskPath = files.some((file) => { + const path = file.filename || ""; + return ( + path.startsWith("src/security/") || + path.startsWith("src/runtime/") || + path.startsWith("src/gateway/") || + path.startsWith("src/tools/") || + path.startsWith(".github/workflows/") + ); + }); + + const hasMediumRiskPath = files.some((file) => { + const path = file.filename || ""; + return ( + path.startsWith("src/") || + path === "Cargo.toml" || + path === "Cargo.lock" || + path === "deny.toml" || + path.startsWith(".githooks/") + ); + }); + + let riskLabel = "risk: low"; + if (hasHighRiskPath) { + riskLabel = "risk: high"; + } else if (hasMediumRiskPath) { + riskLabel = "risk: medium"; + } + + await ensureManagedRepoLabelsMetadata(); + + const labelsToEnsure = new Set([ + ...sizeLabels, + ...computedRiskLabels, + manualRiskOverrideLabel, + ...managedPathLabels, + ...contributorTierLabels, + ...selectedModuleLabels, + ]); + + for (const label of labelsToEnsure) { + await ensureLabel(label); + } + + let contributorTierLabel = null; + const authorLogin = pr.user?.login; + if (authorLogin && pr.user?.type !== "Bot") { + try { + const { data: mergedSearch } = await github.rest.search.issuesAndPullRequests({ + q: `repo:${owner}/${repo} is:pr is:merged author:${authorLogin}`, + per_page: 1, + }); + const mergedCount = mergedSearch.total_count || 0; + contributorTierLabel = selectContributorTier(mergedCount); + } catch (error) { + core.warning(`failed to compute contributor tier label: ${error.message}`); + } + } + + const hasManualRiskOverride = currentLabelNames.includes(manualRiskOverrideLabel); + const keepNonManagedLabels = currentLabelNames.filter((label) => { + if (label === manualRiskOverrideLabel) return true; + if (contributorTierLabels.includes(label)) return false; + if (sizeLabels.includes(label) || computedRiskLabels.includes(label)) return false; + if (managedPathLabelSet.has(label)) return false; + if (managedModulePrefixes.some((prefix) => label.startsWith(prefix))) return false; + return true; + }); + + const manualRiskSelection = + currentLabelNames.find((label) => computedRiskLabels.includes(label)) || riskLabel; + + const moduleLabelList = sortModuleLabels([...selectedModuleLabels]); + const contributorLabelList = contributorTierLabel ? [contributorTierLabel] : []; + const selectedRiskLabels = hasManualRiskOverride + ? sortByPriority([manualRiskSelection, manualRiskOverrideLabel], riskPriorityIndex) + : sortByPriority([riskLabel], riskPriorityIndex); + const selectedSizeLabels = sortByPriority([sizeLabel], sizePriorityIndex); + const sortedContributorLabels = sortByPriority(contributorLabelList, contributorPriorityIndex); + const sortedPathLabels = sortByPriority(dedupedPathLabels, pathLabelPriorityIndex); + const sortedKeepNonManagedLabels = [...new Set(keepNonManagedLabels)].sort((left, right) => + left.localeCompare(right) + ); + + const nextLabels = [ + ...new Set([ + ...selectedRiskLabels, + ...selectedSizeLabels, + ...sortedContributorLabels, + ...moduleLabelList, + ...sortedPathLabels, + ...sortedKeepNonManagedLabels, + ]), + ]; + + await github.rest.issues.setLabels({ + owner, + repo, + issue_number: pr.number, + labels: nextLabels, + }); diff --git a/.github/workflows/pr-hygiene.yml b/.github/workflows/pr-hygiene.yml new file mode 100644 index 0000000..28f536c --- /dev/null +++ b/.github/workflows/pr-hygiene.yml @@ -0,0 +1,184 @@ +name: PR Hygiene + +on: + schedule: + - cron: "15 */12 * * *" + workflow_dispatch: + +permissions: {} + +concurrency: + group: pr-hygiene + cancel-in-progress: true + +jobs: + nudge-stale-prs: + runs-on: blacksmith-2vcpu-ubuntu-2404 + permissions: + contents: read + pull-requests: write + issues: write + env: + STALE_HOURS: "48" + steps: + - name: Nudge PRs that need rebase or CI refresh + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8 + with: + script: | + const staleHours = Number(process.env.STALE_HOURS || "48"); + const ignoreLabels = new Set(["no-stale", "stale", "maintainer", "no-pr-hygiene"]); + const marker = ""; + const owner = context.repo.owner; + const repo = context.repo.repo; + + const openPrs = await github.paginate(github.rest.pulls.list, { + owner, + repo, + state: "open", + per_page: 100, + }); + + const activePrs = openPrs.filter((pr) => { + if (pr.draft) { + return false; + } + + const labels = new Set((pr.labels || []).map((label) => label.name)); + return ![...ignoreLabels].some((label) => labels.has(label)); + }); + + core.info(`Scanning ${activePrs.length} open PR(s) for hygiene nudges.`); + + let nudged = 0; + let skipped = 0; + + for (const pr of activePrs) { + const { data: headCommit } = await github.rest.repos.getCommit({ + owner, + repo, + ref: pr.head.sha, + }); + + const headCommitAt = + headCommit.commit?.committer?.date || headCommit.commit?.author?.date; + if (!headCommitAt) { + skipped += 1; + core.info(`#${pr.number}: missing head commit timestamp, skipping.`); + continue; + } + + const ageHours = (Date.now() - new Date(headCommitAt).getTime()) / 3600000; + if (ageHours < staleHours) { + skipped += 1; + continue; + } + + const { data: prDetail } = await github.rest.pulls.get({ + owner, + repo, + pull_number: pr.number, + }); + + const isBehindBase = prDetail.mergeable_state === "behind"; + + const { data: checkRunsData } = await github.rest.checks.listForRef({ + owner, + repo, + ref: pr.head.sha, + per_page: 100, + }); + + const ciGateRuns = (checkRunsData.check_runs || []) + .filter((run) => run.name === "CI Required Gate") + .sort((a, b) => { + const aTime = new Date(a.started_at || a.completed_at || a.created_at).getTime(); + const bTime = new Date(b.started_at || b.completed_at || b.created_at).getTime(); + return bTime - aTime; + }); + + let ciState = "missing"; + if (ciGateRuns.length > 0) { + const latest = ciGateRuns[0]; + if (latest.status !== "completed") { + ciState = "in_progress"; + } else if (["success", "neutral", "skipped"].includes(latest.conclusion || "")) { + ciState = "success"; + } else { + ciState = String(latest.conclusion || "failure"); + } + } + + const ciMissing = ciState === "missing"; + const ciFailing = !["success", "in_progress", "missing"].includes(ciState); + + if (!isBehindBase && !ciMissing && !ciFailing) { + skipped += 1; + continue; + } + + const reasons = []; + if (isBehindBase) { + reasons.push("- Branch is behind `main` (please rebase or merge the latest base branch)."); + } + if (ciMissing) { + reasons.push("- No `CI Required Gate` run was found for the current head commit."); + } + if (ciFailing) { + reasons.push(`- Latest \`CI Required Gate\` result is \`${ciState}\`.`); + } + + const shortSha = pr.head.sha.slice(0, 12); + const body = [ + marker, + `Hi @${pr.user.login}, friendly automation nudge from PR hygiene.`, + "", + `This PR has had no new commits for **${Math.floor(ageHours)}h** and still needs an update before merge:`, + "", + ...reasons, + "", + "### Recommended next steps", + "1. Rebase your branch on `main`.", + "2. Push the updated branch and re-run checks (or use **Re-run failed jobs**).", + "3. Post fresh validation output in this PR thread.", + "", + "Maintainers: apply `no-stale` to opt out for accepted-but-blocked work.", + `Head SHA: \`${shortSha}\``, + ].join("\n"); + + const { data: comments } = await github.rest.issues.listComments({ + owner, + repo, + issue_number: pr.number, + per_page: 100, + }); + + const existing = comments.find( + (comment) => comment.user?.type === "Bot" && comment.body?.includes(marker), + ); + + if (existing) { + if (existing.body === body) { + skipped += 1; + continue; + } + + await github.rest.issues.updateComment({ + owner, + repo, + comment_id: existing.id, + body, + }); + } else { + await github.rest.issues.createComment({ + owner, + repo, + issue_number: pr.number, + body, + }); + } + + nudged += 1; + core.info(`#${pr.number}: hygiene nudge posted/updated.`); + } + + core.info(`Done. Nudged=${nudged}, skipped=${skipped}`); diff --git a/.github/workflows/pr-intake-sanity.yml b/.github/workflows/pr-intake-sanity.yml new file mode 100644 index 0000000..10a597e --- /dev/null +++ b/.github/workflows/pr-intake-sanity.yml @@ -0,0 +1,179 @@ +name: PR Intake Sanity + +on: + pull_request_target: + types: [opened, reopened, synchronize, edited, ready_for_review] + +concurrency: + group: pr-intake-sanity-${{ github.event.pull_request.number || github.run_id }} + cancel-in-progress: true + +permissions: + contents: read + pull-requests: write + issues: write + +jobs: + intake: + name: Intake Sanity + runs-on: blacksmith-2vcpu-ubuntu-2404 + timeout-minutes: 10 + steps: + - name: Run safe PR intake checks + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8 + with: + script: | + const owner = context.repo.owner; + const repo = context.repo.repo; + const pr = context.payload.pull_request; + if (!pr) return; + + const marker = ""; + const requiredSections = [ + "## Summary", + "## Validation Evidence (required)", + "## Security Impact (required)", + "## Privacy and Data Hygiene (required)", + "## Rollback Plan (required)", + ]; + const body = pr.body || ""; + + const missingSections = requiredSections.filter((section) => !body.includes(section)); + const missingFields = []; + const requiredFieldChecks = [ + ["summary problem", /- Problem:\s*\S+/m], + ["summary why it matters", /- Why it matters:\s*\S+/m], + ["summary what changed", /- What changed:\s*\S+/m], + ["validation commands", /Commands and result summary:\s*[\s\S]*```/m], + ["security risk/mitigation", /- New permissions\/capabilities\?\s*\(`Yes\/No`\):\s*\S+/m], + ["privacy status", /- Data-hygiene status\s*\(`pass\|needs-follow-up`\):\s*\S+/m], + ["rollback plan", /- Fast rollback command\/path:\s*\S+/m], + ]; + for (const [name, pattern] of requiredFieldChecks) { + if (!pattern.test(body)) { + missingFields.push(name); + } + } + + const files = await github.paginate(github.rest.pulls.listFiles, { + owner, + repo, + pull_number: pr.number, + per_page: 100, + }); + + const formatProblems = []; + for (const file of files) { + const patch = file.patch || ""; + if (!patch) continue; + const lines = patch.split("\n"); + for (let idx = 0; idx < lines.length; idx += 1) { + const line = lines[idx]; + if (!line.startsWith("+") || line.startsWith("+++")) continue; + const added = line.slice(1); + const lineNo = idx + 1; + if (/\t/.test(added)) { + formatProblems.push(`${file.filename}:patch#${lineNo} contains tab characters`); + } + if (/[ \t]+$/.test(added)) { + formatProblems.push(`${file.filename}:patch#${lineNo} contains trailing whitespace`); + } + if (/^(<<<<<<<|=======|>>>>>>>)/.test(added)) { + formatProblems.push(`${file.filename}:patch#${lineNo} contains merge conflict markers`); + } + } + } + + const workflowFilesChanged = files + .map((file) => file.filename) + .filter((name) => name.startsWith(".github/workflows/")); + + const failures = []; + if (missingSections.length > 0) { + failures.push(`Missing required PR template sections: ${missingSections.join(", ")}`); + } + if (missingFields.length > 0) { + failures.push(`Incomplete required PR template fields: ${missingFields.join(", ")}`); + } + if (formatProblems.length > 0) { + failures.push(`Formatting/safety issues in added lines (${formatProblems.length})`); + } + + const comments = await github.paginate(github.rest.issues.listComments, { + owner, + repo, + issue_number: pr.number, + per_page: 100, + }); + const existing = comments.find((comment) => (comment.body || "").includes(marker)); + + if (failures.length === 0) { + if (existing) { + await github.rest.issues.deleteComment({ + owner, + repo, + comment_id: existing.id, + }); + } + core.info("PR intake sanity checks passed."); + return; + } + + const runUrl = `${context.serverUrl}/${owner}/${repo}/actions/runs/${context.runId}`; + const details = []; + if (formatProblems.length > 0) { + details.push(...formatProblems.slice(0, 20).map((entry) => `- ${entry}`)); + if (formatProblems.length > 20) { + details.push(`- ...and ${formatProblems.length - 20} more issue(s)`); + } + } + + const ownerApprovalNote = workflowFilesChanged.length > 0 + ? [ + "", + "Workflow files changed in this PR:", + ...workflowFilesChanged.map((name) => `- \`${name}\``), + "", + "Reminder: workflow changes require owner approval via `CI Required Gate`.", + ].join("\n") + : ""; + + const commentBody = [ + marker, + "### PR intake checks failed", + "", + "Fast safe checks ran before full CI and found issues:", + ...failures.map((entry) => `- ${entry}`), + "", + "Action items:", + "1. Complete the required PR template sections/fields.", + "2. Remove tabs, trailing whitespace, and conflict markers from added lines.", + "3. Re-run local checks before pushing:", + " - `./scripts/ci/rust_quality_gate.sh`", + " - `./scripts/ci/rust_strict_delta_gate.sh`", + " - `./scripts/ci/docs_quality_gate.sh`", + "", + `Run logs: ${runUrl}`, + "", + "Detected line issues (sample):", + ...(details.length > 0 ? details : ["- none"]), + ownerApprovalNote, + ].join("\n"); + + if (existing) { + await github.rest.issues.updateComment({ + owner, + repo, + comment_id: existing.id, + body: commentBody, + }); + } else { + await github.rest.issues.createComment({ + owner, + repo, + issue_number: pr.number, + body: commentBody, + }); + } + + core.setFailed("PR intake sanity checks failed. See sticky comment for details."); diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 4a2b071..e8c3cd3 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1,90 +1,117 @@ name: Release on: - push: - tags: ["v*"] + push: + tags: ["v*"] permissions: - contents: write + contents: write + id-token: write # Required for cosign keyless signing via OIDC env: - CARGO_TERM_COLOR: always + CARGO_TERM_COLOR: always jobs: - build-release: - name: Build ${{ matrix.target }} - runs-on: ${{ matrix.os }} - strategy: - matrix: - include: - - os: ubuntu-latest - target: x86_64-unknown-linux-gnu - artifact: zeroclaw - - os: macos-latest - target: x86_64-apple-darwin - artifact: zeroclaw - - os: macos-latest - target: aarch64-apple-darwin - artifact: zeroclaw - - os: windows-latest - target: x86_64-pc-windows-msvc - artifact: zeroclaw.exe + build-release: + name: Build ${{ matrix.target }} + runs-on: ${{ matrix.os }} + timeout-minutes: 40 + strategy: + fail-fast: false + matrix: + include: + - os: ubuntu-latest + target: blacksmith-2vcpu-ubuntu-2404 + artifact: zeroclaw + - os: macos-latest + target: x86_64-apple-darwin + artifact: zeroclaw + - os: macos-latest + target: aarch64-apple-darwin + artifact: zeroclaw + - os: windows-latest + target: x86_64-pc-windows-msvc + artifact: zeroclaw.exe - steps: - - uses: actions/checkout@v4 + steps: + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 - - uses: dtolnay/rust-toolchain@stable - with: - targets: ${{ matrix.target }} + - uses: dtolnay/rust-toolchain@631a55b12751854ce901bb631d5902ceb48146f7 # stable + with: + targets: ${{ matrix.target }} - - uses: Swatinem/rust-cache@v2 + - uses: Swatinem/rust-cache@779680da715d629ac1d338a641029a2f4372abb5 # v2 - - name: Build release - run: cargo build --release --target ${{ matrix.target }} + - name: Build release + run: cargo build --release --locked --target ${{ matrix.target }} - - name: Check binary size (Unix) - if: runner.os != 'Windows' - run: | - SIZE=$(stat -f%z target/${{ matrix.target }}/release/${{ matrix.artifact }} 2>/dev/null || stat -c%s target/${{ matrix.target }}/release/${{ matrix.artifact }}) - echo "Binary size: $((SIZE / 1024 / 1024))MB ($SIZE bytes)" - if [ "$SIZE" -gt 5242880 ]; then - echo "::warning::Binary exceeds 5MB target" - fi + - name: Check binary size (Unix) + if: runner.os != 'Windows' + run: | + SIZE=$(stat -f%z target/${{ matrix.target }}/release/${{ matrix.artifact }} 2>/dev/null || stat -c%s target/${{ matrix.target }}/release/${{ matrix.artifact }}) + echo "Binary size: $((SIZE / 1024 / 1024))MB ($SIZE bytes)" + if [ "$SIZE" -gt 5242880 ]; then + echo "::warning::Binary exceeds 5MB target" + fi - - name: Package (Unix) - if: runner.os != 'Windows' - run: | - cd target/${{ matrix.target }}/release - tar czf ../../../zeroclaw-${{ matrix.target }}.tar.gz ${{ matrix.artifact }} + - name: Package (Unix) + if: runner.os != 'Windows' + run: | + cd target/${{ matrix.target }}/release + tar czf ../../../zeroclaw-${{ matrix.target }}.tar.gz ${{ matrix.artifact }} - - name: Package (Windows) - if: runner.os == 'Windows' - run: | - cd target/${{ matrix.target }}/release - 7z a ../../../zeroclaw-${{ matrix.target }}.zip ${{ matrix.artifact }} + - name: Package (Windows) + if: runner.os == 'Windows' + run: | + cd target/${{ matrix.target }}/release + 7z a ../../../zeroclaw-${{ matrix.target }}.zip ${{ matrix.artifact }} - - name: Upload artifact - uses: actions/upload-artifact@v4 - with: - name: zeroclaw-${{ matrix.target }} - path: zeroclaw-${{ matrix.target }}.* + - name: Upload artifact + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 + with: + name: zeroclaw-${{ matrix.target }} + path: zeroclaw-${{ matrix.target }}.* - publish: - name: Publish Release - needs: build-release - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 + publish: + name: Publish Release + needs: build-release + runs-on: blacksmith-2vcpu-ubuntu-2404 + timeout-minutes: 15 + steps: + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 - - name: Download all artifacts - uses: actions/download-artifact@v4 - with: - path: artifacts + - name: Download all artifacts + uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4 + with: + path: artifacts - - name: Create GitHub Release - uses: softprops/action-gh-release@v2 - with: - generate_release_notes: true - files: artifacts/**/* - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Generate SHA256 checksums + run: | + cd artifacts + find . -type f \( -name '*.tar.gz' -o -name '*.zip' \) -exec sha256sum {} + | sed 's| \./[^/]*/| |' > SHA256SUMS + echo "Generated checksums:" + cat SHA256SUMS + + - name: Install cosign + uses: sigstore/cosign-installer@3454372f43399081ed03b604cb2d021dabca52bb # v3.8.2 + + - name: Sign artifacts with cosign (keyless) + run: | + for file in artifacts/**/*; do + [ -f "$file" ] || continue + cosign sign-blob --yes \ + --oidc-issuer=https://token.actions.githubusercontent.com \ + --output-signature="${file}.sig" \ + --output-certificate="${file}.pem" \ + "$file" + done + + - name: Create GitHub Release + uses: softprops/action-gh-release@a06a81a03ee405af7f2048a818ed3f03bbf83c7b # v2 + with: + generate_release_notes: true + files: | + artifacts/**/* + artifacts/SHA256SUMS + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/rust-reusable.yml b/.github/workflows/rust-reusable.yml new file mode 100644 index 0000000..511ccc4 --- /dev/null +++ b/.github/workflows/rust-reusable.yml @@ -0,0 +1,62 @@ +name: Rust Reusable Job + +on: + workflow_call: + inputs: + run_command: + description: "Shell command(s) to execute." + required: true + type: string + timeout_minutes: + description: "Job timeout in minutes." + required: false + default: 20 + type: number + toolchain: + description: "Rust toolchain channel/version." + required: false + default: "stable" + type: string + components: + description: "Optional rustup components." + required: false + default: "" + type: string + targets: + description: "Optional rustup targets." + required: false + default: "" + type: string + use_cache: + description: "Whether to enable rust-cache." + required: false + default: true + type: boolean + +permissions: + contents: read + +jobs: + run: + runs-on: blacksmith-2vcpu-ubuntu-2404 + timeout-minutes: ${{ inputs.timeout_minutes }} + steps: + - name: Checkout repository + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + + - name: Setup Rust toolchain + uses: dtolnay/rust-toolchain@631a55b12751854ce901bb631d5902ceb48146f7 # stable + with: + toolchain: ${{ inputs.toolchain }} + components: ${{ inputs.components }} + targets: ${{ inputs.targets }} + + - name: Restore Rust cache + if: inputs.use_cache + uses: Swatinem/rust-cache@779680da715d629ac1d338a641029a2f4372abb5 # v2 + + - name: Run command + shell: bash + run: | + set -euo pipefail + ${{ inputs.run_command }} diff --git a/.github/workflows/security.yml b/.github/workflows/security.yml index 822d96a..bf0b99a 100644 --- a/.github/workflows/security.yml +++ b/.github/workflows/security.yml @@ -1,37 +1,43 @@ -name: Security Audit +name: Rust Package Security Audit on: - push: - branches: [main] - pull_request: - branches: [main] - schedule: - - cron: "0 6 * * 1" # Weekly on Monday 6am UTC + push: + branches: [main] + pull_request: + branches: [main] + schedule: + - cron: "0 6 * * 1" # Weekly on Monday 6am UTC + +concurrency: + group: security-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +permissions: + contents: read + security-events: write + actions: read env: - CARGO_TERM_COLOR: always + CARGO_TERM_COLOR: always jobs: - audit: - name: Security Audit - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - uses: dtolnay/rust-toolchain@stable - - - name: Install cargo-audit - run: cargo install cargo-audit - - - name: Run cargo-audit - run: cargo audit - - deny: - name: License & Supply Chain - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - uses: EmbarkStudios/cargo-deny-action@v2 + audit: + name: Security Audit + uses: ./.github/workflows/rust-reusable.yml with: - command: check advisories licenses sources + timeout_minutes: 20 + toolchain: stable + run_command: | + cargo install --locked cargo-audit --version 0.22.1 + cargo audit + + deny: + name: License & Supply Chain + runs-on: blacksmith-2vcpu-ubuntu-2404 + timeout-minutes: 20 + steps: + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + + - uses: EmbarkStudios/cargo-deny-action@3fd3802e88374d3fe9159b834c7714ec57d6c979 # v2 + with: + command: check advisories licenses sources diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml new file mode 100644 index 0000000..f46af3f --- /dev/null +++ b/.github/workflows/stale.yml @@ -0,0 +1,44 @@ +name: Stale + +on: + schedule: + - cron: "20 2 * * *" + workflow_dispatch: + +permissions: {} + +jobs: + stale: + permissions: + issues: write + pull-requests: write + runs-on: blacksmith-2vcpu-ubuntu-2404 + steps: + - name: Mark stale issues and pull requests + uses: actions/stale@5bef64f19d7facfb25b37b414482c7164d639639 # v9 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + days-before-issue-stale: 21 + days-before-issue-close: 7 + days-before-pr-stale: 14 + days-before-pr-close: 7 + stale-issue-label: stale + stale-pr-label: stale + exempt-issue-labels: security,pinned,no-stale,no-pr-hygiene,maintainer + exempt-pr-labels: no-stale,no-pr-hygiene,maintainer + remove-stale-when-updated: true + exempt-all-assignees: true + operations-per-run: 300 + stale-issue-message: | + This issue was automatically marked as stale due to inactivity. + Please provide an update, reproduction details, or current status to keep it open. + close-issue-message: | + Closing this issue due to inactivity. + If the problem still exists on the latest `main`, please open a new issue with fresh repro steps. + close-issue-reason: not_planned + stale-pr-message: | + This PR was automatically marked as stale due to inactivity. + Please rebase/update and post the latest validation results. + close-pr-message: | + Closing this PR due to inactivity. + Maintainers can reopen once the branch is updated and validation is provided. diff --git a/.github/workflows/update-notice.yml b/.github/workflows/update-notice.yml new file mode 100644 index 0000000..8f8a80f --- /dev/null +++ b/.github/workflows/update-notice.yml @@ -0,0 +1,116 @@ +name: Update Contributors NOTICE + +on: + workflow_dispatch: + schedule: + # Run every Sunday at 00:00 UTC + - cron: '0 0 * * 0' + +concurrency: + group: update-notice-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: write + pull-requests: write + +jobs: + update-notice: + name: Update NOTICE with new contributors + runs-on: blacksmith-2vcpu-ubuntu-2404 + steps: + - name: Checkout repository + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + + - name: Fetch contributors + id: contributors + env: + GH_TOKEN: ${{ github.token }} + run: | + # Fetch all contributors (excluding bots) + gh api \ + --paginate \ + "repos/${{ github.repository }}/contributors" \ + --jq '.[] | select(.type != "Bot") | .login' > /tmp/contributors_raw.txt + + # Sort alphabetically and filter + sort -f < /tmp/contributors_raw.txt > contributors.txt + + # Count contributors + count=$(wc -l < contributors.txt | tr -d ' ') + echo "count=$count" >> "$GITHUB_OUTPUT" + + - name: Generate new NOTICE file + run: | + cat > NOTICE << 'EOF' + ZeroClaw + Copyright 2025 ZeroClaw Labs + + This product includes software developed at ZeroClaw Labs (https://github.com/zeroclaw-labs). + + Contributors + ============ + + The following individuals have contributed to ZeroClaw: + + EOF + + # Append contributors in alphabetical order + sed 's/^/- /' contributors.txt >> NOTICE + + # Add third-party dependencies section + cat >> NOTICE << 'EOF' + + + Third-Party Dependencies + ========================= + + This project uses the following third-party libraries and components, + each licensed under their respective terms: + + See Cargo.lock for a complete list of dependencies and their licenses. + EOF + + - name: Check if NOTICE changed + id: check_diff + run: | + if git diff --quiet NOTICE; then + echo "changed=false" >> "$GITHUB_OUTPUT" + else + echo "changed=true" >> "$GITHUB_OUTPUT" + fi + + - name: Create Pull Request + if: steps.check_diff.outputs.changed == 'true' + env: + GH_TOKEN: ${{ github.token }} + COUNT: ${{ steps.contributors.outputs.count }} + run: | + branch_name="auto/update-notice-$(date +%Y%m%d)" + + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + + git checkout -b "$branch_name" + git add NOTICE + git commit -m "chore(notice): update contributor list" + git push origin "$branch_name" + + gh pr create \ + --title "chore(notice): update contributor list" \ + --body "Auto-generated update to NOTICE file with $COUNT contributors." \ + --label "chore" \ + --label "docs" \ + --draft || true + + - name: Summary + run: | + echo "## NOTICE Update Results" >> "$GITHUB_STEP_SUMMARY" + echo "" >> "$GITHUB_STEP_SUMMARY" + if [ "${{ steps.check_diff.outputs.changed }}" = "true" ]; then + echo "✅ PR created to update NOTICE" >> "$GITHUB_STEP_SUMMARY" + else + echo "✓ NOTICE file is up to date" >> "$GITHUB_STEP_SUMMARY" + fi + echo "" >> "$GITHUB_STEP_SUMMARY" + echo "**Contributors:** ${{ steps.contributors.outputs.count }}" >> "$GITHUB_STEP_SUMMARY" diff --git a/.github/workflows/workflow-sanity.yml b/.github/workflows/workflow-sanity.yml new file mode 100644 index 0000000..f353144 --- /dev/null +++ b/.github/workflows/workflow-sanity.yml @@ -0,0 +1,64 @@ +name: Workflow Sanity + +on: + pull_request: + paths: + - ".github/workflows/**" + - ".github/*.yml" + - ".github/*.yaml" + push: + paths: + - ".github/workflows/**" + - ".github/*.yml" + - ".github/*.yaml" + +concurrency: + group: workflow-sanity-${{ github.event.pull_request.number || github.sha }} + cancel-in-progress: true + +permissions: + contents: read + +jobs: + no-tabs: + runs-on: blacksmith-2vcpu-ubuntu-2404 + timeout-minutes: 10 + steps: + - name: Checkout + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + + - name: Fail on tabs in workflow files + shell: bash + run: | + set -euo pipefail + python - <<'PY' + from __future__ import annotations + + import pathlib + import sys + + root = pathlib.Path(".github/workflows") + bad: list[str] = [] + for path in sorted(root.rglob("*.yml")): + if b"\t" in path.read_bytes(): + bad.append(str(path)) + for path in sorted(root.rglob("*.yaml")): + if b"\t" in path.read_bytes(): + bad.append(str(path)) + + if bad: + print("Tabs found in workflow file(s):") + for path in bad: + print(f"- {path}") + sys.exit(1) + PY + + actionlint: + runs-on: blacksmith-2vcpu-ubuntu-2404 + timeout-minutes: 10 + steps: + - name: Checkout + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + + - name: Lint GitHub workflows + uses: rhysd/actionlint@393031adb9afb225ee52ae2ccd7a5af5525e03e8 # v1.7.11 diff --git a/.gitignore b/.gitignore index 1520314..9440b79 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,29 @@ /target +firmware/*/target *.db *.db-journal .DS_Store +.wt-pr37/ +__pycache__/ +*.pyc +docker-compose.override.yml + +# Environment files (may contain secrets) +.env + +# Python virtual environments + +.venv/ +venv/ + +# ESP32 build cache (esp-idf-sys managed) + +.embuild/ +.env.local +.env.*.local + +# Secret keys and credentials +.secret_key +*.key +*.pem +credentials.json diff --git a/.markdownlint-cli2.yaml b/.markdownlint-cli2.yaml new file mode 100644 index 0000000..d6de542 --- /dev/null +++ b/.markdownlint-cli2.yaml @@ -0,0 +1,15 @@ +config: + default: true + MD013: false + MD007: false + MD031: false + MD032: false + MD033: false + MD040: false + MD041: false + MD060: false + MD024: + allow_different_nesting: true + +ignores: + - "target/**" diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..9746fdf --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,413 @@ +# AGENTS.md — ZeroClaw Agent Engineering Protocol + +This file defines the default working protocol for coding agents in this repository. +Scope: entire repository. + +## 1) Project Snapshot (Read First) + +ZeroClaw is a Rust-first autonomous agent runtime optimized for: + +- high performance +- high efficiency +- high stability +- high extensibility +- high sustainability +- high security + +Core architecture is trait-driven and modular. Most extension work should be done by implementing traits and registering in factory modules. + +Key extension points: + +- `src/providers/traits.rs` (`Provider`) +- `src/channels/traits.rs` (`Channel`) +- `src/tools/traits.rs` (`Tool`) +- `src/memory/traits.rs` (`Memory`) +- `src/observability/traits.rs` (`Observer`) +- `src/runtime/traits.rs` (`RuntimeAdapter`) +- `src/peripherals/traits.rs` (`Peripheral`) — hardware boards (STM32, RPi GPIO) + +## 2) Deep Architecture Observations (Why This Protocol Exists) + +These codebase realities should drive every design decision: + +1. **Trait + factory architecture is the stability backbone** + - Extension points are intentionally explicit and swappable. + - Most features should be added via trait implementation + factory registration, not cross-cutting rewrites. +2. **Security-critical surfaces are first-class and internet-adjacent** + - `src/gateway/`, `src/security/`, `src/tools/`, `src/runtime/` carry high blast radius. + - Defaults already lean secure-by-default (pairing, bind safety, limits, secret handling); keep it that way. +3. **Performance and binary size are product goals, not nice-to-have** + - `Cargo.toml` release profile and dependency choices optimize for size and determinism. + - Convenience dependencies and broad abstractions can silently regress these goals. +4. **Config and runtime contracts are user-facing API** + - `src/config/schema.rs` and CLI commands are effectively public interfaces. + - Backward compatibility and explicit migration matter. +5. **The project now runs in high-concurrency collaboration mode** + - CI + docs governance + label routing are part of the product delivery system. + - PR throughput is a design constraint; not just a maintainer inconvenience. + +## 3) Engineering Principles (Normative) + +These principles are mandatory by default. They are not slogans; they are implementation constraints. + +### 3.1 KISS (Keep It Simple, Stupid) + +**Why here:** Runtime + security behavior must stay auditable under pressure. + +Required: + +- Prefer straightforward control flow over clever meta-programming. +- Prefer explicit match branches and typed structs over hidden dynamic behavior. +- Keep error paths obvious and localized. + +### 3.2 YAGNI (You Aren't Gonna Need It) + +**Why here:** Premature features increase attack surface and maintenance burden. + +Required: + +- Do not add new config keys, trait methods, feature flags, or workflow branches without a concrete accepted use case. +- Do not introduce speculative “future-proof” abstractions without at least one current caller. +- Keep unsupported paths explicit (error out) rather than adding partial fake support. + +### 3.3 DRY + Rule of Three + +**Why here:** Naive DRY can create brittle shared abstractions across providers/channels/tools. + +Required: + +- Duplicate small, local logic when it preserves clarity. +- Extract shared utilities only after repeated, stable patterns (rule-of-three). +- When extracting, preserve module boundaries and avoid hidden coupling. + +### 3.4 SRP + ISP (Single Responsibility + Interface Segregation) + +**Why here:** Trait-driven architecture already encodes subsystem boundaries. + +Required: + +- Keep each module focused on one concern. +- Extend behavior by implementing existing narrow traits whenever possible. +- Avoid fat interfaces and “god modules” that mix policy + transport + storage. + +### 3.5 Fail Fast + Explicit Errors + +**Why here:** Silent fallback in agent runtimes can create unsafe or costly behavior. + +Required: + +- Prefer explicit `bail!`/errors for unsupported or unsafe states. +- Never silently broaden permissions/capabilities. +- Document fallback behavior when fallback is intentional and safe. + +### 3.6 Secure by Default + Least Privilege + +**Why here:** Gateway/tools/runtime can execute actions with real-world side effects. + +Required: + +- Deny-by-default for access and exposure boundaries. +- Never log secrets, raw tokens, or sensitive payloads. +- Keep network/filesystem/shell scope as narrow as possible unless explicitly justified. + +### 3.7 Determinism + Reproducibility + +**Why here:** Reliable CI and low-latency triage depend on deterministic behavior. + +Required: + +- Prefer reproducible commands and locked dependency behavior in CI-sensitive paths. +- Keep tests deterministic (no flaky timing/network dependence without guardrails). +- Ensure local validation commands map to CI expectations. + +### 3.8 Reversibility + Rollback-First Thinking + +**Why here:** Fast recovery is mandatory under high PR volume. + +Required: + +- Keep changes easy to revert (small scope, clear blast radius). +- For risky changes, define rollback path before merge. +- Avoid mixed mega-patches that block safe rollback. + +## 4) Repository Map (High-Level) + +- `src/main.rs` — CLI entrypoint and command routing +- `src/lib.rs` — module exports and shared command enums +- `src/config/` — schema + config loading/merging +- `src/agent/` — orchestration loop +- `src/gateway/` — webhook/gateway server +- `src/security/` — policy, pairing, secret store +- `src/memory/` — markdown/sqlite memory backends + embeddings/vector merge +- `src/providers/` — model providers and resilient wrapper +- `src/channels/` — Telegram/Discord/Slack/etc channels +- `src/tools/` — tool execution surface (shell, file, memory, browser) +- `src/peripherals/` — hardware peripherals (STM32, RPi GPIO); see `docs/hardware-peripherals-design.md` +- `src/runtime/` — runtime adapters (currently native) +- `docs/` — architecture + process docs +- `.github/` — CI, templates, automation workflows + +## 5) Risk Tiers by Path (Review Depth Contract) + +Use these tiers when deciding validation depth and review rigor. + +- **Low risk**: docs/chore/tests-only changes +- **Medium risk**: most `src/**` behavior changes without boundary/security impact +- **High risk**: `src/security/**`, `src/runtime/**`, `src/gateway/**`, `src/tools/**`, `.github/workflows/**`, access-control boundaries + +When uncertain, classify as higher risk. + +## 6) Agent Workflow (Required) + +1. **Read before write** + - Inspect existing module, factory wiring, and adjacent tests before editing. +2. **Define scope boundary** + - One concern per PR; avoid mixed feature+refactor+infra patches. +3. **Implement minimal patch** + - Apply KISS/YAGNI/DRY rule-of-three explicitly. +4. **Validate by risk tier** + - Docs-only: lightweight checks. + - Code/risky changes: full relevant checks and focused scenarios. +5. **Document impact** + - Update docs/PR notes for behavior, risk, side effects, and rollback. +6. **Respect queue hygiene** + - If stacked PR: declare `Depends on #...`. + - If replacing old PR: declare `Supersedes #...`. + +### 6.3 Branch / Commit / PR Flow (Required) + +All contributors (human or agent) must follow the same collaboration flow: + +- Create and work from a non-`main` branch. +- Commit changes to that branch with clear, scoped commit messages. +- Open a PR to `main`; do not push directly to `main`. +- Wait for required checks and review outcomes before merging. +- Merge via PR controls (squash/rebase/merge as repository policy allows). +- Branch deletion after merge is optional; long-lived branches are allowed when intentionally maintained. + +### 6.4 Worktree Workflow (Required for Multi-Track Agent Work) + +Use Git worktrees to isolate concurrent agent/human tracks safely and predictably: + +- Use one worktree per active branch/PR stream to avoid cross-task contamination. +- Keep each worktree on a single branch; do not mix unrelated edits in one worktree. +- Run validation commands inside the corresponding worktree before commit/PR. +- Name worktrees clearly by scope (for example: `wt/ci-hardening`, `wt/provider-fix`) and remove stale worktrees when no longer needed. +- PR checkpoint rules from section 6.3 still apply to worktree-based development. + +### 6.1 Code Naming Contract (Required) + +Apply these naming rules for all code changes unless a subsystem has a stronger existing pattern. + +- Use Rust standard casing consistently: modules/files `snake_case`, types/traits/enums `PascalCase`, functions/variables `snake_case`, constants/statics `SCREAMING_SNAKE_CASE`. +- Name types and modules by domain role, not implementation detail (for example `DiscordChannel`, `SecurityPolicy`, `MemoryStore` over vague names like `Manager`/`Helper`). +- Keep trait implementer naming explicit and predictable: `Provider`, `Channel`, `Tool`, `Memory`. +- Keep factory registration keys stable, lowercase, and user-facing (for example `"openai"`, `"discord"`, `"shell"`), and avoid alias sprawl without migration need. +- Name tests by behavior/outcome (`_`) and keep fixture identifiers neutral/project-scoped. +- If identity-like naming is required in tests/examples, use ZeroClaw-native labels only (`ZeroClawAgent`, `zeroclaw_user`, `zeroclaw_node`). + +### 6.2 Architecture Boundary Contract (Required) + +Use these rules to keep the trait/factory architecture stable under growth. + +- Extend capabilities by adding trait implementations + factory wiring first; avoid cross-module rewrites for isolated features. +- Keep dependency direction inward to contracts: concrete integrations depend on trait/config/util layers, not on other concrete integrations. +- Avoid creating cross-subsystem coupling (for example provider code importing channel internals, tool code mutating gateway policy directly). +- Keep module responsibilities single-purpose: orchestration in `agent/`, transport in `channels/`, model I/O in `providers/`, policy in `security/`, execution in `tools/`. +- Introduce new shared abstractions only after repeated use (rule-of-three), with at least one real caller in current scope. +- For config/schema changes, treat keys as public contract: document defaults, compatibility impact, and migration/rollback path. + +## 7) Change Playbooks + +### 7.1 Adding a Provider + +- Implement `Provider` in `src/providers/`. +- Register in `src/providers/mod.rs` factory. +- Add focused tests for factory wiring and error paths. +- Avoid provider-specific behavior leaks into shared orchestration code. + +### 7.2 Adding a Channel + +- Implement `Channel` in `src/channels/`. +- Keep `send`, `listen`, `health_check`, typing semantics consistent. +- Cover auth/allowlist/health behavior with tests. + +### 7.3 Adding a Tool + +- Implement `Tool` in `src/tools/` with strict parameter schema. +- Validate and sanitize all inputs. +- Return structured `ToolResult`; avoid panics in runtime path. + +### 5.4 Adding a Peripheral + +- Implement `Peripheral` in `src/peripherals/`. +- Peripherals expose `tools()` — each tool delegates to the hardware (GPIO, sensors, etc.). +- Register board type in config schema if needed. +- See `docs/hardware-peripherals-design.md` for protocol and firmware notes. + +### 5.5 Security / Runtime / Gateway Changes + +- Include threat/risk notes and rollback strategy. +- Add/update tests or validation evidence for failure modes and boundaries. +- Keep observability useful but non-sensitive. +- For `.github/workflows/**` changes, include Actions allowlist impact in PR notes and update `docs/actions-source-policy.md` when sources change. + +## 8) Validation Matrix + +Default local checks for code changes: + +```bash +cargo fmt --all -- --check +cargo clippy --all-targets -- -D warnings +cargo test +``` + +Preferred local pre-PR validation path (recommended, not required): + +```bash +./dev/ci.sh all +``` + +Notes: + +- Local Docker-based CI is strongly recommended when Docker is available. +- Contributors are not blocked from opening a PR if local Docker CI is unavailable; in that case run the most relevant native checks and document what was run. + +Additional expectations by change type: + +- **Docs/template-only**: run markdown lint and relevant doc checks. +- **Workflow changes**: validate YAML syntax; run workflow lint/sanity checks when available. +- **Security/runtime/gateway/tools**: include at least one boundary/failure-mode validation. + +If full checks are impractical, run the most relevant subset and document what was skipped and why. + +## 9) Collaboration and PR Discipline + +- Follow `.github/pull_request_template.md` fully (including side effects / blast radius). +- Keep PR descriptions concrete: problem, change, non-goals, risk, rollback. +- Use conventional commit titles. +- Prefer small PRs (`size: XS/S/M`) when possible. +- Agent-assisted PRs are welcome, **but contributors remain accountable for understanding what their code will do**. + +### 9.1 Privacy/Sensitive Data and Neutral Wording (Required) + +Treat privacy and neutrality as merge gates, not best-effort guidelines. + +- Never commit personal or sensitive data in code, docs, tests, fixtures, snapshots, logs, examples, or commit messages. +- Prohibited data includes (non-exhaustive): real names, personal emails, phone numbers, addresses, access tokens, API keys, credentials, IDs, and private URLs. +- Use neutral project-scoped placeholders (for example: `user_a`, `test_user`, `project_bot`, `example.com`) instead of real identity data. +- Test names/messages/fixtures must be impersonal and system-focused; avoid first-person or identity-specific language. +- If identity-like context is unavoidable, use ZeroClaw-scoped roles/labels only (for example: `ZeroClawAgent`, `ZeroClawOperator`, `zeroclaw_user`) and avoid real-world personas. +- Recommended identity-safe naming palette (use when identity-like context is required): + - actor labels: `ZeroClawAgent`, `ZeroClawOperator`, `ZeroClawMaintainer`, `zeroclaw_user` + - service/runtime labels: `zeroclaw_bot`, `zeroclaw_service`, `zeroclaw_runtime`, `zeroclaw_node` + - environment labels: `zeroclaw_project`, `zeroclaw_workspace`, `zeroclaw_channel` +- If reproducing external incidents, redact and anonymize all payloads before committing. +- Before push, review `git diff --cached` specifically for accidental sensitive strings and identity leakage. + +### 9.2 Superseded-PR Attribution (Required) + +When a PR supersedes another contributor's PR and carries forward substantive code or design decisions, preserve authorship explicitly. + +- In the integrating commit message, add one `Co-authored-by: Name ` trailer per superseded contributor whose work is materially incorporated. +- Use a GitHub-recognized email (`` or the contributor's verified commit email) so attribution is rendered correctly. +- Keep trailers on their own lines after a blank line at commit-message end; never encode them as escaped `\\n` text. +- In the PR body, list superseded PR links and briefly state what was incorporated from each. +- If no actual code/design was incorporated (only inspiration), do not use `Co-authored-by`; give credit in PR notes instead. + +### 9.3 Superseded-PR PR Template (Recommended) + +When superseding multiple PRs, use a consistent title/body structure to reduce reviewer ambiguity. + +- Recommended title format: `feat(): unify and supersede #, # [and #]` +- If this is docs/chore/meta only, keep the same supersede suffix and use the appropriate conventional-commit type. +- In the PR body, include the following template (fill placeholders, remove non-applicable lines): + +```md +## Supersedes +- # by @ +- # by @ +- # by @ + +## Integrated Scope +- From #: +- From #: +- From #: + +## Attribution +- Co-authored-by trailers added for materially incorporated contributors: Yes/No +- If No, explain why (for example: no direct code/design carry-over) + +## Non-goals +- + +## Risk and Rollback +- Risk: +- Rollback: +``` + +### 9.4 Superseded-PR Commit Template (Recommended) + +When a commit unifies or supersedes prior PR work, use a deterministic commit message layout so attribution is machine-parsed and reviewer-friendly. + +- Keep one blank line between message sections, and exactly one blank line before trailer lines. +- Keep each trailer on its own line; do not wrap, indent, or encode as escaped `\n` text. +- Add one `Co-authored-by` trailer per materially incorporated contributor, using GitHub-recognized email. +- If no direct code/design is carried over, omit `Co-authored-by` and explain attribution in the PR body instead. + +```text +feat(): unify and supersede #, # [and #] + + + +Supersedes: +- # by @ +- # by @ +- # by @ + +Integrated scope: +- : from # +- : from # + +Co-authored-by: +Co-authored-by: +``` + +Reference docs: + +- `CONTRIBUTING.md` +- `docs/pr-workflow.md` +- `docs/reviewer-playbook.md` +- `docs/ci-map.md` +- `docs/actions-source-policy.md` + +## 10) Anti-Patterns (Do Not) + +- Do not add heavy dependencies for minor convenience. +- Do not silently weaken security policy or access constraints. +- Do not add speculative config/feature flags “just in case”. +- Do not mix massive formatting-only changes with functional changes. +- Do not modify unrelated modules “while here”. +- Do not bypass failing checks without explicit explanation. +- Do not hide behavior-changing side effects in refactor commits. +- Do not include personal identity or sensitive information in test data, examples, docs, or commits. + +## 11) Handoff Template (Agent -> Agent / Maintainer) + +When handing off work, include: + +1. What changed +2. What did not change +3. Validation run and results +4. Remaining risks / unknowns +5. Next recommended action + +## 12) Vibe Coding Guardrails + +When working in fast iterative mode: + +- Keep each iteration reversible (small commits, clear rollback). +- Validate assumptions with code search before implementing. +- Prefer deterministic behavior over clever shortcuts. +- Do not “ship and hope” on security-sensitive paths. +- If uncertain, leave a concrete TODO with verification context, not a hidden guess. diff --git a/CHANGELOG.md b/CHANGELOG.md index e1ac7be..79e1712 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,7 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `enc:` prefix for encrypted secrets — Use `enc2:` (ChaCha20-Poly1305) instead. Legacy values are still decrypted for backward compatibility but should be migrated. -## [0.1.0] - 2025-02-13 +## [0.1.0] - 2026-02-13 ### Added - **Core Architecture**: Trait-based pluggable system for Provider, Channel, Observer, RuntimeAdapter, Tool diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..be37697 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,413 @@ +# CLAUDE.md — ZeroClaw Agent Engineering Protocol + +This file defines the default working protocol for claude code in this repository. +Scope: entire repository. + +## 1) Project Snapshot (Read First) + +ZeroClaw is a Rust-first autonomous agent runtime optimized for: + +- high performance +- high efficiency +- high stability +- high extensibility +- high sustainability +- high security + +Core architecture is trait-driven and modular. Most extension work should be done by implementing traits and registering in factory modules. + +Key extension points: + +- `src/providers/traits.rs` (`Provider`) +- `src/channels/traits.rs` (`Channel`) +- `src/tools/traits.rs` (`Tool`) +- `src/memory/traits.rs` (`Memory`) +- `src/observability/traits.rs` (`Observer`) +- `src/runtime/traits.rs` (`RuntimeAdapter`) +- `src/peripherals/traits.rs` (`Peripheral`) — hardware boards (STM32, RPi GPIO) + +## 2) Deep Architecture Observations (Why This Protocol Exists) + +These codebase realities should drive every design decision: + +1. **Trait + factory architecture is the stability backbone** + - Extension points are intentionally explicit and swappable. + - Most features should be added via trait implementation + factory registration, not cross-cutting rewrites. +2. **Security-critical surfaces are first-class and internet-adjacent** + - `src/gateway/`, `src/security/`, `src/tools/`, `src/runtime/` carry high blast radius. + - Defaults already lean secure-by-default (pairing, bind safety, limits, secret handling); keep it that way. +3. **Performance and binary size are product goals, not nice-to-have** + - `Cargo.toml` release profile and dependency choices optimize for size and determinism. + - Convenience dependencies and broad abstractions can silently regress these goals. +4. **Config and runtime contracts are user-facing API** + - `src/config/schema.rs` and CLI commands are effectively public interfaces. + - Backward compatibility and explicit migration matter. +5. **The project now runs in high-concurrency collaboration mode** + - CI + docs governance + label routing are part of the product delivery system. + - PR throughput is a design constraint; not just a maintainer inconvenience. + +## 3) Engineering Principles (Normative) + +These principles are mandatory by default. They are not slogans; they are implementation constraints. + +### 3.1 KISS (Keep It Simple, Stupid) + +**Why here:** Runtime + security behavior must stay auditable under pressure. + +Required: + +- Prefer straightforward control flow over clever meta-programming. +- Prefer explicit match branches and typed structs over hidden dynamic behavior. +- Keep error paths obvious and localized. + +### 3.2 YAGNI (You Aren't Gonna Need It) + +**Why here:** Premature features increase attack surface and maintenance burden. + +Required: + +- Do not add new config keys, trait methods, feature flags, or workflow branches without a concrete accepted use case. +- Do not introduce speculative “future-proof” abstractions without at least one current caller. +- Keep unsupported paths explicit (error out) rather than adding partial fake support. + +### 3.3 DRY + Rule of Three + +**Why here:** Naive DRY can create brittle shared abstractions across providers/channels/tools. + +Required: + +- Duplicate small, local logic when it preserves clarity. +- Extract shared utilities only after repeated, stable patterns (rule-of-three). +- When extracting, preserve module boundaries and avoid hidden coupling. + +### 3.4 SRP + ISP (Single Responsibility + Interface Segregation) + +**Why here:** Trait-driven architecture already encodes subsystem boundaries. + +Required: + +- Keep each module focused on one concern. +- Extend behavior by implementing existing narrow traits whenever possible. +- Avoid fat interfaces and “god modules” that mix policy + transport + storage. + +### 3.5 Fail Fast + Explicit Errors + +**Why here:** Silent fallback in agent runtimes can create unsafe or costly behavior. + +Required: + +- Prefer explicit `bail!`/errors for unsupported or unsafe states. +- Never silently broaden permissions/capabilities. +- Document fallback behavior when fallback is intentional and safe. + +### 3.6 Secure by Default + Least Privilege + +**Why here:** Gateway/tools/runtime can execute actions with real-world side effects. + +Required: + +- Deny-by-default for access and exposure boundaries. +- Never log secrets, raw tokens, or sensitive payloads. +- Keep network/filesystem/shell scope as narrow as possible unless explicitly justified. + +### 3.7 Determinism + Reproducibility + +**Why here:** Reliable CI and low-latency triage depend on deterministic behavior. + +Required: + +- Prefer reproducible commands and locked dependency behavior in CI-sensitive paths. +- Keep tests deterministic (no flaky timing/network dependence without guardrails). +- Ensure local validation commands map to CI expectations. + +### 3.8 Reversibility + Rollback-First Thinking + +**Why here:** Fast recovery is mandatory under high PR volume. + +Required: + +- Keep changes easy to revert (small scope, clear blast radius). +- For risky changes, define rollback path before merge. +- Avoid mixed mega-patches that block safe rollback. + +## 4) Repository Map (High-Level) + +- `src/main.rs` — CLI entrypoint and command routing +- `src/lib.rs` — module exports and shared command enums +- `src/config/` — schema + config loading/merging +- `src/agent/` — orchestration loop +- `src/gateway/` — webhook/gateway server +- `src/security/` — policy, pairing, secret store +- `src/memory/` — markdown/sqlite memory backends + embeddings/vector merge +- `src/providers/` — model providers and resilient wrapper +- `src/channels/` — Telegram/Discord/Slack/etc channels +- `src/tools/` — tool execution surface (shell, file, memory, browser) +- `src/peripherals/` — hardware peripherals (STM32, RPi GPIO); see `docs/hardware-peripherals-design.md` +- `src/runtime/` — runtime adapters (currently native) +- `docs/` — architecture + process docs +- `.github/` — CI, templates, automation workflows + +## 5) Risk Tiers by Path (Review Depth Contract) + +Use these tiers when deciding validation depth and review rigor. + +- **Low risk**: docs/chore/tests-only changes +- **Medium risk**: most `src/**` behavior changes without boundary/security impact +- **High risk**: `src/security/**`, `src/runtime/**`, `src/gateway/**`, `src/tools/**`, `.github/workflows/**`, access-control boundaries + +When uncertain, classify as higher risk. + +## 6) Agent Workflow (Required) + +1. **Read before write** + - Inspect existing module, factory wiring, and adjacent tests before editing. +2. **Define scope boundary** + - One concern per PR; avoid mixed feature+refactor+infra patches. +3. **Implement minimal patch** + - Apply KISS/YAGNI/DRY rule-of-three explicitly. +4. **Validate by risk tier** + - Docs-only: lightweight checks. + - Code/risky changes: full relevant checks and focused scenarios. +5. **Document impact** + - Update docs/PR notes for behavior, risk, side effects, and rollback. +6. **Respect queue hygiene** + - If stacked PR: declare `Depends on #...`. + - If replacing old PR: declare `Supersedes #...`. + +### 6.3 Branch / Commit / PR Flow (Required) + +All contributors (human or agent) must follow the same collaboration flow: + +- Create and work from a non-`main` branch. +- Commit changes to that branch with clear, scoped commit messages. +- Open a PR to `main`; do not push directly to `main`. +- Wait for required checks and review outcomes before merging. +- Merge via PR controls (squash/rebase/merge as repository policy allows). +- Branch deletion after merge is optional; long-lived branches are allowed when intentionally maintained. + +### 6.4 Worktree Workflow (Required for Multi-Track Agent Work) + +Use Git worktrees to isolate concurrent agent/human tracks safely and predictably: + +- Use one worktree per active branch/PR stream to avoid cross-task contamination. +- Keep each worktree on a single branch; do not mix unrelated edits in one worktree. +- Run validation commands inside the corresponding worktree before commit/PR. +- Name worktrees clearly by scope (for example: `wt/ci-hardening`, `wt/provider-fix`) and remove stale worktrees when no longer needed. +- PR checkpoint rules from section 6.3 still apply to worktree-based development. + +### 6.1 Code Naming Contract (Required) + +Apply these naming rules for all code changes unless a subsystem has a stronger existing pattern. + +- Use Rust standard casing consistently: modules/files `snake_case`, types/traits/enums `PascalCase`, functions/variables `snake_case`, constants/statics `SCREAMING_SNAKE_CASE`. +- Name types and modules by domain role, not implementation detail (for example `DiscordChannel`, `SecurityPolicy`, `MemoryStore` over vague names like `Manager`/`Helper`). +- Keep trait implementer naming explicit and predictable: `Provider`, `Channel`, `Tool`, `Memory`. +- Keep factory registration keys stable, lowercase, and user-facing (for example `"openai"`, `"discord"`, `"shell"`), and avoid alias sprawl without migration need. +- Name tests by behavior/outcome (`_`) and keep fixture identifiers neutral/project-scoped. +- If identity-like naming is required in tests/examples, use ZeroClaw-native labels only (`ZeroClawAgent`, `zeroclaw_user`, `zeroclaw_node`). + +### 6.2 Architecture Boundary Contract (Required) + +Use these rules to keep the trait/factory architecture stable under growth. + +- Extend capabilities by adding trait implementations + factory wiring first; avoid cross-module rewrites for isolated features. +- Keep dependency direction inward to contracts: concrete integrations depend on trait/config/util layers, not on other concrete integrations. +- Avoid creating cross-subsystem coupling (for example provider code importing channel internals, tool code mutating gateway policy directly). +- Keep module responsibilities single-purpose: orchestration in `agent/`, transport in `channels/`, model I/O in `providers/`, policy in `security/`, execution in `tools/`. +- Introduce new shared abstractions only after repeated use (rule-of-three), with at least one real caller in current scope. +- For config/schema changes, treat keys as public contract: document defaults, compatibility impact, and migration/rollback path. + +## 7) Change Playbooks + +### 7.1 Adding a Provider + +- Implement `Provider` in `src/providers/`. +- Register in `src/providers/mod.rs` factory. +- Add focused tests for factory wiring and error paths. +- Avoid provider-specific behavior leaks into shared orchestration code. + +### 7.2 Adding a Channel + +- Implement `Channel` in `src/channels/`. +- Keep `send`, `listen`, `health_check`, typing semantics consistent. +- Cover auth/allowlist/health behavior with tests. + +### 7.3 Adding a Tool + +- Implement `Tool` in `src/tools/` with strict parameter schema. +- Validate and sanitize all inputs. +- Return structured `ToolResult`; avoid panics in runtime path. + +### 5.4 Adding a Peripheral + +- Implement `Peripheral` in `src/peripherals/`. +- Peripherals expose `tools()` — each tool delegates to the hardware (GPIO, sensors, etc.). +- Register board type in config schema if needed. +- See `docs/hardware-peripherals-design.md` for protocol and firmware notes. + +### 5.5 Security / Runtime / Gateway Changes + +- Include threat/risk notes and rollback strategy. +- Add/update tests or validation evidence for failure modes and boundaries. +- Keep observability useful but non-sensitive. +- For `.github/workflows/**` changes, include Actions allowlist impact in PR notes and update `docs/actions-source-policy.md` when sources change. + +## 8) Validation Matrix + +Default local checks for code changes: + +```bash +cargo fmt --all -- --check +cargo clippy --all-targets -- -D warnings +cargo test +``` + +Preferred local pre-PR validation path (recommended, not required): + +```bash +./dev/ci.sh all +``` + +Notes: + +- Local Docker-based CI is strongly recommended when Docker is available. +- Contributors are not blocked from opening a PR if local Docker CI is unavailable; in that case run the most relevant native checks and document what was run. + +Additional expectations by change type: + +- **Docs/template-only**: run markdown lint and relevant doc checks. +- **Workflow changes**: validate YAML syntax; run workflow lint/sanity checks when available. +- **Security/runtime/gateway/tools**: include at least one boundary/failure-mode validation. + +If full checks are impractical, run the most relevant subset and document what was skipped and why. + +## 9) Collaboration and PR Discipline + +- Follow `.github/pull_request_template.md` fully (including side effects / blast radius). +- Keep PR descriptions concrete: problem, change, non-goals, risk, rollback. +- Use conventional commit titles. +- Prefer small PRs (`size: XS/S/M`) when possible. +- Agent-assisted PRs are welcome, **but contributors remain accountable for understanding what their code will do**. + +### 9.1 Privacy/Sensitive Data and Neutral Wording (Required) + +Treat privacy and neutrality as merge gates, not best-effort guidelines. + +- Never commit personal or sensitive data in code, docs, tests, fixtures, snapshots, logs, examples, or commit messages. +- Prohibited data includes (non-exhaustive): real names, personal emails, phone numbers, addresses, access tokens, API keys, credentials, IDs, and private URLs. +- Use neutral project-scoped placeholders (for example: `user_a`, `test_user`, `project_bot`, `example.com`) instead of real identity data. +- Test names/messages/fixtures must be impersonal and system-focused; avoid first-person or identity-specific language. +- If identity-like context is unavoidable, use ZeroClaw-scoped roles/labels only (for example: `ZeroClawAgent`, `ZeroClawOperator`, `zeroclaw_user`) and avoid real-world personas. +- Recommended identity-safe naming palette (use when identity-like context is required): + - actor labels: `ZeroClawAgent`, `ZeroClawOperator`, `ZeroClawMaintainer`, `zeroclaw_user` + - service/runtime labels: `zeroclaw_bot`, `zeroclaw_service`, `zeroclaw_runtime`, `zeroclaw_node` + - environment labels: `zeroclaw_project`, `zeroclaw_workspace`, `zeroclaw_channel` +- If reproducing external incidents, redact and anonymize all payloads before committing. +- Before push, review `git diff --cached` specifically for accidental sensitive strings and identity leakage. + +### 9.2 Superseded-PR Attribution (Required) + +When a PR supersedes another contributor's PR and carries forward substantive code or design decisions, preserve authorship explicitly. + +- In the integrating commit message, add one `Co-authored-by: Name ` trailer per superseded contributor whose work is materially incorporated. +- Use a GitHub-recognized email (`` or the contributor's verified commit email) so attribution is rendered correctly. +- Keep trailers on their own lines after a blank line at commit-message end; never encode them as escaped `\\n` text. +- In the PR body, list superseded PR links and briefly state what was incorporated from each. +- If no actual code/design was incorporated (only inspiration), do not use `Co-authored-by`; give credit in PR notes instead. + +### 9.3 Superseded-PR PR Template (Recommended) + +When superseding multiple PRs, use a consistent title/body structure to reduce reviewer ambiguity. + +- Recommended title format: `feat(): unify and supersede #, # [and #]` +- If this is docs/chore/meta only, keep the same supersede suffix and use the appropriate conventional-commit type. +- In the PR body, include the following template (fill placeholders, remove non-applicable lines): + +```md +## Supersedes +- # by @ +- # by @ +- # by @ + +## Integrated Scope +- From #: +- From #: +- From #: + +## Attribution +- Co-authored-by trailers added for materially incorporated contributors: Yes/No +- If No, explain why (for example: no direct code/design carry-over) + +## Non-goals +- + +## Risk and Rollback +- Risk: +- Rollback: +``` + +### 9.4 Superseded-PR Commit Template (Recommended) + +When a commit unifies or supersedes prior PR work, use a deterministic commit message layout so attribution is machine-parsed and reviewer-friendly. + +- Keep one blank line between message sections, and exactly one blank line before trailer lines. +- Keep each trailer on its own line; do not wrap, indent, or encode as escaped `\n` text. +- Add one `Co-authored-by` trailer per materially incorporated contributor, using GitHub-recognized email. +- If no direct code/design is carried over, omit `Co-authored-by` and explain attribution in the PR body instead. + +```text +feat(): unify and supersede #, # [and #] + + + +Supersedes: +- # by @ +- # by @ +- # by @ + +Integrated scope: +- : from # +- : from # + +Co-authored-by: +Co-authored-by: +``` + +Reference docs: + +- `CONTRIBUTING.md` +- `docs/pr-workflow.md` +- `docs/reviewer-playbook.md` +- `docs/ci-map.md` +- `docs/actions-source-policy.md` + +## 10) Anti-Patterns (Do Not) + +- Do not add heavy dependencies for minor convenience. +- Do not silently weaken security policy or access constraints. +- Do not add speculative config/feature flags “just in case”. +- Do not mix massive formatting-only changes with functional changes. +- Do not modify unrelated modules “while here”. +- Do not bypass failing checks without explicit explanation. +- Do not hide behavior-changing side effects in refactor commits. +- Do not include personal identity or sensitive information in test data, examples, docs, or commits. + +## 11) Handoff Template (Agent -> Agent / Maintainer) + +When handing off work, include: + +1. What changed +2. What did not change +3. Validation run and results +4. Remaining risks / unknowns +5. Next recommended action + +## 12) Vibe Coding Guardrails + +When working in fast iterative mode: + +- Keep each iteration reversible (small commits, clear rollback). +- Validate assumptions with code search before implementing. +- Prefer deterministic behavior over clever shortcuts. +- Do not “ship and hope” on security-sensitive paths. +- If uncertain, leave a concrete TODO with verification context, not a hidden guess. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c319cc5..d98a2ce 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -6,7 +6,7 @@ Thanks for your interest in contributing to ZeroClaw! This guide will help you g ```bash # Clone the repo -git clone https://github.com/theonlyhennygod/zeroclaw.git +git clone https://github.com/zeroclaw-labs/zeroclaw.git cd zeroclaw # Enable the pre-push hook (runs fmt, clippy, tests before every push) @@ -16,18 +16,60 @@ git config core.hooksPath .githooks cargo build # Run tests (all must pass) -cargo test +cargo test --locked -# Format & lint (must pass before PR) -cargo fmt && cargo clippy -- -D warnings +# Format & lint (required before PR) +./scripts/ci/rust_quality_gate.sh + +# Optional strict lint audit (full repo, recommended periodically) +./scripts/ci/rust_quality_gate.sh --strict + +# Optional strict lint delta gate (blocks only changed Rust lines) +./scripts/ci/rust_strict_delta_gate.sh + +# Optional docs lint gate (blocks only markdown issues on changed lines) +./scripts/ci/docs_quality_gate.sh + +# Optional docs links gate (checks only links added on changed lines) +./scripts/ci/docs_links_gate.sh # Release build (~3.4MB) -cargo build --release +cargo build --release --locked ``` ### Pre-push hook -The repo includes a pre-push hook in `.githooks/` that enforces `cargo fmt --check`, `cargo clippy -- -D warnings`, and `cargo test` before every push. Enable it with `git config core.hooksPath .githooks`. +The repo includes a pre-push hook in `.githooks/` that enforces `./scripts/ci/rust_quality_gate.sh` and `cargo test --locked` before every push. Enable it with `git config core.hooksPath .githooks`. + +For an opt-in strict lint pass during pre-push, set: + +```bash +ZEROCLAW_STRICT_LINT=1 git push +``` + +For an opt-in strict lint delta pass during pre-push (changed Rust lines only), set: + +```bash +ZEROCLAW_STRICT_DELTA_LINT=1 git push +``` + +For an opt-in docs quality pass during pre-push (changed-line markdown gate), set: + +```bash +ZEROCLAW_DOCS_LINT=1 git push +``` + +For an opt-in docs links pass during pre-push (added-links gate), set: + +```bash +ZEROCLAW_DOCS_LINKS=1 git push +``` + +For full CI parity in Docker, run: + +```bash +./dev/ci.sh all +``` To skip it during rapid iteration: @@ -37,6 +79,182 @@ git push --no-verify > **Note:** CI runs the same checks, so skipped hooks will be caught on the PR. +## Local Secret Management (Required) + +ZeroClaw supports layered secret management for local development and CI hygiene. + +### Secret Storage Options + +1. **Environment variables** (recommended for local development) + - Copy `.env.example` to `.env` and fill in values + - `.env` files are Git-ignored and should stay local + - Best for temporary/local API keys + +2. **Config file** (`~/.zeroclaw/config.toml`) + - Persistent setup for long-term use + - When `secrets.encrypt = true` (default), secret values are encrypted before save + - Secret key is stored at `~/.zeroclaw/.secret_key` with restricted permissions + - Use `zeroclaw onboard` for guided setup + +### Runtime Resolution Rules + +API key resolution follows this order: + +1. Explicit key passed from config/CLI +2. Provider-specific env vars (`OPENROUTER_API_KEY`, `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, ...) +3. Generic env vars (`ZEROCLAW_API_KEY`, `API_KEY`) + +Provider/model config overrides: + +- `ZEROCLAW_PROVIDER` / `PROVIDER` +- `ZEROCLAW_MODEL` + +See `.env.example` for practical examples and currently supported provider key env vars. + +### Pre-Commit Secret Hygiene (Mandatory) + +Before every commit, verify: + +- [ ] No `.env` files are staged (`.env.example` only) +- [ ] No raw API keys/tokens in code, tests, fixtures, examples, logs, or commit messages +- [ ] No credentials in debug output or error payloads +- [ ] `git diff --cached` has no accidental secret-like strings + +Quick local audit: + +```bash +# Search staged diff for common secret markers +git diff --cached | grep -iE '(api[_-]?key|secret|token|password|bearer|sk-)' + +# Confirm no .env file is staged +git status --short | grep -E '\.env$' +``` + +### Optional Local Secret Scanning + +For extra guardrails, install one of: + +- **gitleaks**: [GitHub - gitleaks/gitleaks](https://github.com/gitleaks/gitleaks) +- **truffleHog**: [GitHub - trufflesecurity/trufflehog](https://github.com/trufflesecurity/trufflehog) +- **git-secrets**: [GitHub - awslabs/git-secrets](https://github.com/awslabs/git-secrets) + +This repo includes `.githooks/pre-commit` to run `gitleaks protect --staged --redact` when gitleaks is installed. + +Enable hooks with: + +```bash +git config core.hooksPath .githooks +``` + +If gitleaks is not installed, the pre-commit hook prints a warning and continues. + +### What Must Never Be Committed + +- `.env` files (use `.env.example` only) +- API keys, tokens, passwords, or credentials (plain or encrypted) +- OAuth tokens or session identifiers +- Webhook signing secrets +- `~/.zeroclaw/.secret_key` or similar key files +- Personal identifiers or real user data in tests/fixtures + +### If a Secret Is Committed Accidentally + +1. Revoke/rotate the credential immediately +2. Do not rely only on `git revert` (history still contains the secret) +3. Purge history with `git filter-repo` or BFG +4. Force-push cleaned history (coordinate with maintainers) +5. Ensure the leaked value is removed from PR/issue/discussion/comment history + +Reference: [GitHub guide: removing sensitive data from a repository](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/removing-sensitive-data-from-a-repository) + +## Collaboration Tracks (Risk-Based) + +To keep review throughput high without lowering quality, every PR should map to one track: + +| Track | Typical scope | Required review depth | +|---|---|---| +| **Track A (Low risk)** | docs/tests/chore, isolated refactors, no security/runtime/CI impact | 1 maintainer review + green `CI Required Gate` | +| **Track B (Medium risk)** | providers/channels/memory/tools behavior changes | 1 subsystem-aware review + explicit validation evidence | +| **Track C (High risk)** | `src/security/**`, `src/runtime/**`, `src/gateway/**`, `.github/workflows/**`, access-control boundaries | 2-pass review (fast triage + deep risk review), rollback plan required | + +When in doubt, choose the higher track. + +## Documentation Optimization Principles + +To keep docs useful under high PR volume, we use these rules: + +- **Single source of truth**: policy lives in docs, not scattered across PR comments. +- **Decision-oriented content**: every checklist item should directly help accept/reject a change. +- **Risk-proportionate detail**: high-risk paths need deeper evidence; low-risk paths stay lightweight. +- **Side-effect visibility**: document blast radius, failure modes, and rollback before merge. +- **Automation assists, humans decide**: bots triage and label, but merge accountability stays human. + +### Documentation System Map + +| Doc | Primary purpose | When to update | +|---|---|---| +| `CONTRIBUTING.md` | contributor contract and readiness baseline | contributor expectations or policy changes | +| `docs/pr-workflow.md` | governance logic and merge contract | workflow/risk/merge gate changes | +| `docs/reviewer-playbook.md` | reviewer operating checklist | review depth or triage behavior changes | +| `docs/ci-map.md` | CI ownership and triage entry points | workflow trigger/job ownership changes | + +## PR Definition of Ready (DoR) + +Before requesting review, ensure all of the following are true: + +- Scope is focused to a single concern. +- `.github/pull_request_template.md` is fully completed. +- Relevant local validation has been run (`fmt`, `clippy`, `test`, scenario checks). +- Security impact and rollback path are explicitly described. +- No personal/sensitive data is introduced in code/docs/tests/fixtures/logs/examples/commit messages. +- Tests/fixtures/examples use neutral project-scoped wording (no identity-specific or first-person phrasing). +- If identity-like wording is required, use ZeroClaw-centric labels only (for example: `ZeroClawAgent`, `ZeroClawOperator`, `zeroclaw_user`). +- Linked issue (or rationale for no issue) is included. + +## PR Definition of Done (DoD) + +A PR is merge-ready when: + +- `CI Required Gate` is green. +- Required reviewers approved (including CODEOWNERS paths). +- Risk level matches changed paths (`risk: low/medium/high`). +- User-visible behavior, migration, and rollback notes are complete. +- Follow-up TODOs are explicit and tracked in issues. + +## High-Volume Collaboration Rules + +When PR traffic is high (especially with AI-assisted contributions), these rules keep quality and throughput stable: + +- **One concern per PR**: avoid mixing refactor + feature + infra in one change. +- **Small PRs first**: prefer PR size `XS/S/M`; split large work into stacked PRs. +- **Template is mandatory**: complete every section in `.github/pull_request_template.md`. +- **Explicit rollback**: every PR must include a fast rollback path. +- **Security-first review**: changes in `src/security/`, runtime, gateway, and CI need stricter validation. +- **Risk-first triage**: use labels (`risk: high`, `risk: medium`, `risk: low`) to route review depth. +- **Privacy-first hygiene**: redact/anonymize sensitive payloads and keep tests/examples neutral and project-scoped. +- **Identity normalization**: when identity traits are unavoidable, use ZeroClaw/project-native roles instead of personal or real-world identities. +- **Supersede hygiene**: if your PR replaces an older open PR, add `Supersedes #...` and request maintainers close the outdated one. + +Full maintainer workflow: [`docs/pr-workflow.md`](docs/pr-workflow.md). +CI workflow ownership and triage map: [`docs/ci-map.md`](docs/ci-map.md). +Reviewer operating checklist: [`docs/reviewer-playbook.md`](docs/reviewer-playbook.md). + +## Agent Collaboration Guidance + +Agent-assisted contributions are welcome and treated as first-class contributions. + +For smoother agent-to-agent and human-to-agent review: + +- Keep PR summaries concrete (problem, change, non-goals). +- Include reproducible validation evidence (`fmt`, `clippy`, `test`, scenario checks). +- Add brief workflow notes when automation materially influenced design/code. +- Agent-assisted PRs are welcome, but contributors remain accountable for understanding what the code does and what it could affect. +- Call out uncertainty and risky edges explicitly. + +We do **not** require PRs to declare an AI-vs-human line ratio. + +Agent implementation playbook lives in [`AGENTS.md`](AGENTS.md). + ## Architecture: Trait-Based Pluggability ZeroClaw's architecture is built on **traits** — every subsystem is swappable. This means contributing a new integration is as simple as implementing a trait and registering it in the factory function. @@ -52,6 +270,57 @@ src/ └── security/ # Sandboxing → SecurityPolicy ``` +## Code Naming Conventions (Required) + +Use these defaults unless an existing subsystem pattern clearly overrides them. + +- **Rust casing**: modules/files `snake_case`, types/traits/enums `PascalCase`, functions/variables `snake_case`, constants `SCREAMING_SNAKE_CASE`. +- **Domain-first naming**: prefer explicit role names such as `DiscordChannel`, `SecurityPolicy`, `SqliteMemory` over ambiguous names (`Manager`, `Util`, `Helper`). +- **Trait implementers**: keep predictable suffixes (`*Provider`, `*Channel`, `*Tool`, `*Memory`, `*Observer`, `*RuntimeAdapter`). +- **Factory keys**: keep lowercase and stable (`openai`, `discord`, `shell`); avoid adding aliases without migration need. +- **Tests**: use behavior-oriented names (`subject_expected_behavior`) and neutral project-scoped fixtures. +- **Identity-like labels**: if unavoidable, use ZeroClaw-native identifiers only (`ZeroClawAgent`, `zeroclaw_user`, `zeroclaw_node`). + +## Architecture Boundary Rules (Required) + +Keep architecture extensible and auditable by following these boundaries. + +- Extend features via trait implementations + factory registration before considering broad refactors. +- Keep dependency direction contract-first: concrete integrations depend on shared traits/config/util, not on other concrete integrations. +- Avoid cross-subsystem coupling (provider ↔ channel internals, tools mutating security/gateway internals directly, etc.). +- Keep responsibilities single-purpose by module (`agent` orchestration, `channels` transport, `providers` model I/O, `security` policy, `tools` execution, `memory` persistence). +- Introduce shared abstractions only after repeated stable use (rule-of-three) and at least one current caller. +- Treat `src/config/schema.rs` keys as public contract; document compatibility impact, migration steps, and rollback path for changes. + +## Naming and Architecture Examples (Bad vs Good) + +Use these quick examples to align implementation choices before opening a PR. + +### Naming examples + +- **Bad**: `Manager`, `Helper`, `doStuff`, `tmp_data` +- **Good**: `DiscordChannel`, `SecurityPolicy`, `send_message`, `channel_allowlist` + +- **Bad test name**: `test1` / `works` +- **Good test name**: `allowlist_denies_unknown_user`, `provider_returns_error_on_invalid_model` + +- **Bad identity-like label**: `john_user`, `alice_bot` +- **Good identity-like label**: `ZeroClawAgent`, `zeroclaw_user`, `zeroclaw_node` + +### Architecture boundary examples + +- **Bad**: channel implementation directly imports provider internals to call model APIs. +- **Good**: channel emits normalized `ChannelMessage`; agent/runtime orchestrates provider calls via trait contracts. + +- **Bad**: tool mutates gateway/security policy directly from execution path. +- **Good**: tool returns structured `ToolResult`; policy enforcement remains in security/runtime boundaries. + +- **Bad**: adding broad shared abstraction before any repeated caller. +- **Good**: keep local logic first; extract shared abstraction only after stable rule-of-three evidence. + +- **Bad**: config key changes without migration notes. +- **Good**: config/schema changes include defaults, compatibility impact, migration steps, and rollback guidance. + ## How to Add a New Provider Create `src/providers/your_provider.rs`: @@ -184,13 +453,19 @@ impl Tool for YourTool { ## Pull Request Checklist -- [ ] `cargo fmt` — code is formatted -- [ ] `cargo clippy -- -D warnings` — no warnings -- [ ] `cargo test` — all 129+ tests pass +- [ ] PR template sections are completed (including security + rollback) +- [ ] `./scripts/ci/rust_quality_gate.sh` — merge gate formatter/lint baseline passes +- [ ] `cargo test --locked` — all tests pass locally or skipped tests are explained +- [ ] Optional strict audit: `./scripts/ci/rust_quality_gate.sh --strict` (full repo, run when doing lint cleanup or release-hardening work) +- [ ] Optional strict delta audit: `./scripts/ci/rust_strict_delta_gate.sh` (changed Rust lines only, useful for incremental debt control) - [ ] New code has inline `#[cfg(test)]` tests - [ ] No new dependencies unless absolutely necessary (we optimize for binary size) - [ ] README updated if adding user-facing features - [ ] Follows existing code patterns and conventions +- [ ] Follows code naming conventions and architecture boundary rules in this guide +- [ ] No personal/sensitive data in code/docs/tests/fixtures/logs/examples/commit messages +- [ ] Test names/messages/fixtures/examples are neutral and project-focused +- [ ] Any required identity-like wording uses ZeroClaw/project-native labels only ## Commit Convention @@ -198,6 +473,7 @@ We use [Conventional Commits](https://www.conventionalcommits.org/): ``` feat: add Anthropic provider +feat(provider): add Anthropic provider fix: path traversal edge case with symlinks docs: update contributing guide test: add heartbeat unicode parsing tests @@ -205,6 +481,10 @@ refactor: extract common security checks chore: bump tokio to 1.43 ``` +Recommended scope keys in commit titles: + +- `provider`, `channel`, `memory`, `security`, `runtime`, `ci`, `docs`, `tests` + ## Code Style - **Minimal dependencies** — every crate adds to binary size @@ -218,6 +498,18 @@ chore: bump tokio to 1.43 - **Bugs**: Include OS, Rust version, steps to reproduce, expected vs actual - **Features**: Describe the use case, propose which trait to extend - **Security**: See [SECURITY.md](SECURITY.md) for responsible disclosure +- **Privacy**: Redact/anonymize all personal data and sensitive identifiers before posting logs/payloads + +## Maintainer Merge Policy + +- Require passing `CI Required Gate` before merge. +- Require docs quality checks when docs are touched. +- Require review approval for non-trivial changes. +- Require CODEOWNERS review for protected paths. +- Use risk labels to determine review depth, scope labels (`core`, `provider`, `channel`, `security`, etc.) to route ownership, and module labels (`:`, e.g. `channel:telegram`, `provider:kimi`, `tool:shell`) to route subsystem expertise. +- Contributor tier labels are auto-applied on PRs and issues by merged PR count: `experienced contributor` (>=10), `principal contributor` (>=20), `distinguished contributor` (>=50). Treat them as read-only automation labels; manual edits are auto-corrected. +- Prefer squash merge with conventional commit title. +- Revert fast on regressions; re-land with tests. ## License diff --git a/Cargo.lock b/Cargo.lock index e960ed8..e19c5c9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,21 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "adler2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" + +[[package]] +name = "adobe-cmap-parser" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae8abfa9a4688de8fc9f42b3f013b6fffec18ed8a554f5f113577e0b9b3212a3" +dependencies = [ + "pom", +] + [[package]] name = "aead" version = "0.5.2" @@ -12,6 +27,17 @@ dependencies = [ "generic-array", ] +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + [[package]] name = "ahash" version = "0.8.12" @@ -24,6 +50,21 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + [[package]] name = "android_system_properties" version = "0.1.5" @@ -89,6 +130,33 @@ version = "1.0.101" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f0e0fee31ef5ed1ba1316088939cea399010ed7731dba877ed44aeb407a75ea" +[[package]] +name = "ar_archive_writer" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7eb93bbb63b9c227414f6eb3a0adfddca591a8ce1e9b60661bb08969b87e340b" +dependencies = [ + "object 0.37.3", +] + +[[package]] +name = "async-io" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "456b8a8feb6f42d237746d4b3e9a178494627745c3c56c6ea55d92ba50d026fc" +dependencies = [ + "autocfg", + "cfg-if", + "concurrent-queue", + "futures-io", + "futures-lite", + "parking", + "polling", + "rustix 1.1.3", + "slab", + "windows-sys 0.61.2", +] + [[package]] name = "async-trait" version = "0.1.89" @@ -113,16 +181,39 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] -name = "axum" -version = "0.7.9" +name = "aws-lc-rs" +version = "1.15.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" +checksum = "7b7b6141e96a8c160799cc2d5adecd5cbbe5054cb8c7c4af53da0f83bb7ad256" +dependencies = [ + "aws-lc-sys", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.37.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b092fe214090261288111db7a2b2c2118e5a7f30dc2569f1732c4069a6840549" +dependencies = [ + "cc", + "cmake", + "dunce", + "fs_extra", +] + +[[package]] +name = "axum" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" dependencies = [ - "async-trait", "axum-core", + "base64", "bytes", + "form_urlencoded", "futures-util", - "http", + "http 1.4.0", "http-body", "http-body-util", "hyper", @@ -133,13 +224,14 @@ dependencies = [ "mime", "percent-encoding", "pin-project-lite", - "rustversion", - "serde", + "serde_core", "serde_json", "serde_path_to_error", "serde_urlencoded", + "sha1", "sync_wrapper", "tokio", + "tokio-tungstenite 0.28.0", "tower", "tower-layer", "tower-service", @@ -147,19 +239,17 @@ dependencies = [ [[package]] name = "axum-core" -version = "0.4.5" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" dependencies = [ - "async-trait", "bytes", - "futures-util", - "http", + "futures-core", + "http 1.4.0", "http-body", "http-body-util", "mime", "pin-project-lite", - "rustversion", "sync_wrapper", "tower-layer", "tower-service", @@ -172,10 +262,71 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] -name = "bitflags" -version = "2.10.0" +name = "bincode" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" +checksum = "36eaf5d7b090263e8150820482d5d93cd964a81e4019913c972f4edcc6edb740" +dependencies = [ + "bincode_derive", + "serde", + "unty", +] + +[[package]] +name = "bincode_derive" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf95709a440f45e986983918d0e8a1f30a9b1df04918fc828670606804ac3c09" +dependencies = [ + "virtue", +] + +[[package]] +name = "bitfield" +version = "0.19.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21ba6517c6b0f2bf08be60e187ab64b038438f22dd755614d8fe4d4098c46419" +dependencies = [ + "bitfield-macros", +] + +[[package]] +name = "bitfield-macros" +version = "0.19.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f48d6ace212fdf1b45fd6b566bb40808415344642b76c3224c07c8df9da81e97" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "bitflags" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" +dependencies = [ + "serde_core", +] + +[[package]] +name = "bitvec" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" +dependencies = [ + "funty", + "radium", + "tap", + "wyz", +] [[package]] name = "block-buffer" @@ -186,12 +337,47 @@ dependencies = [ "generic-array", ] +[[package]] +name = "block-padding" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8894febbff9f758034a5b8e12d87918f56dfc64a8e1fe757d65e29041538d93" +dependencies = [ + "generic-array", +] + [[package]] name = "bumpalo" version = "3.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5dd9dc738b7a8311c7ade152424974d8115f2cdad61e8dab8dac9f2362298510" +[[package]] +name = "bytecount" +version = "0.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "175812e0be2bccb6abe50bb8d566126198344f707e304f45c648fd8f2cc0365e" + +[[package]] +name = "bytemuck" +version = "1.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8efb64bd706a16a1bdde310ae86b351e4d21550d98d056f22f8a7f7a2183fec" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9abbd1bc6865053c427f7198e6af43bfdedc55ab791faed4fbd361d789575ff" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "byteorder" version = "1.5.0" @@ -204,6 +390,15 @@ version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" +[[package]] +name = "cbc" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26b52a9543ae338f279b96b0b9fed9c8093744685043739079ce85cd58f289a6" +dependencies = [ + "cipher", +] + [[package]] name = "cc" version = "1.2.56" @@ -211,9 +406,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "aebf35691d1bfb0ac386a69bac2fde4dd276fb618cf8bf4f5318fe285e821bb2" dependencies = [ "find-msvc-tools", + "jobserver", + "libc", "shlex", ] +[[package]] +name = "cff-parser" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31f5b6e9141c036f3ff4ce7b2f7e432b0f00dee416ddcd4f17741d189ddc2e9d" + [[package]] name = "cfg-if" version = "1.0.4" @@ -258,9 +461,30 @@ checksum = "fac4744fb15ae8337dc853fee7fb3f4e48c0fbaa23d0afe49c447b4fab126118" dependencies = [ "iana-time-zone", "num-traits", + "serde", "windows-link", ] +[[package]] +name = "chrono-tz" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6139a8597ed92cf816dfb33f5dd6cf0bb93a6adc938f11039f371bc5bcd26c3" +dependencies = [ + "chrono", + "phf", +] + +[[package]] +name = "chumsky" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eebd66744a15ded14960ab4ccdbfb51ad3b81f51f3f04a80adac98c985396c9" +dependencies = [ + "hashbrown 0.14.5", + "stacker", +] + [[package]] name = "cipher" version = "0.4.4" @@ -312,12 +536,40 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3a822ea5bc7590f9d40f1ba12c0dc3c2760f3482c6984db1573ad11031420831" +[[package]] +name = "cmake" +version = "0.1.57" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75443c44cd6b379beb8c5b45d85d0773baf31cce901fe7bb252f4eff3008ef7d" +dependencies = [ + "cc", +] + +[[package]] +name = "cobs" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ef0193218d365c251b5b9297f9911a908a8ddd2ebd3a36cc5d0ef0f63aee9e" +dependencies = [ + "heapless", + "thiserror 2.0.18", +] + [[package]] name = "colorchoice" version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" +[[package]] +name = "concurrent-queue" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "console" version = "0.15.11" @@ -327,16 +579,79 @@ dependencies = [ "encode_unicode", "libc", "once_cell", - "unicode-width", + "unicode-width 0.2.2", "windows-sys 0.59.0", ] +[[package]] +name = "console" +version = "0.16.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03e45a4a8926227e4197636ba97a9fc9b00477e9f4bd711395687c5f0734bec4" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "unicode-width 0.2.2", + "windows-sys 0.61.2", +] + +[[package]] +name = "cookie" +version = "0.16.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e859cd57d0710d9e06c381b550c06e76992472a8c6d527aecd2fc673dcc231fb" +dependencies = [ + "time", + "version_check", +] + +[[package]] +name = "cookie" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747" +dependencies = [ + "percent-encoding", + "time", + "version_check", +] + +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "core_maths" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77745e017f5edba1a9c1d854f6f3a52dac8a12dd5af5d2f54aecf61e43d80d30" +dependencies = [ + "libm", +] + [[package]] name = "cpufeatures" version = "0.2.17" @@ -346,6 +661,15 @@ dependencies = [ "libc", ] +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + [[package]] name = "cron" version = "0.12.1" @@ -353,10 +677,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6f8c3e73077b4b4a6ab1ea5047c37c57aee77657bc8ecd6f29b0af082d0b0c07" dependencies = [ "chrono", - "nom", + "nom 7.1.3", "once_cell", ] +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + [[package]] name = "crypto-common" version = "0.1.7" @@ -368,6 +698,62 @@ dependencies = [ "typenum", ] +[[package]] +name = "csv" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52cd9d68cf7efc6ddfaaee42e7288d3a99d613d4b50f76ce9827ae0c6e14f938" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde_core", +] + +[[package]] +name = "csv-core" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "704a3c26996a80471189265814dbc2c257598b96b8a7feae2d31ace646bb9782" +dependencies = [ + "memchr", +] + +[[package]] +name = "darling" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn", +] + +[[package]] +name = "darling_macro" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" +dependencies = [ + "darling_core", + "quote", + "syn", +] + [[package]] name = "data-encoding" version = "2.10.0" @@ -375,16 +761,49 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7a1e2f27636f116493b8b860f5546edb47c8d8f8ea73e1d2a20be88e28d1fea" [[package]] -name = "dialoguer" -version = "0.11.0" +name = "deku" +version = "0.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "658bce805d770f407bc62102fca7c2c64ceef2fbcb2b8bd19d2765ce093980de" +checksum = "a9711031e209dc1306d66985363b4397d4c7b911597580340b93c9729b55f6eb" dependencies = [ - "console", + "bitvec", + "deku_derive", + "no_std_io2", + "rustversion", +] + +[[package]] +name = "deku_derive" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "58cb0719583cbe4e81fb40434ace2f0d22ccc3e39a74bb3796c22b451b4f139d" +dependencies = [ + "darling", + "proc-macro-crate", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "deranged" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc3dc5ad92c2e2d1c193bbbbdf2ea477cb81331de4f3103f267ca18368b988c4" +dependencies = [ + "powerfmt", +] + +[[package]] +name = "dialoguer" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25f104b501bf2364e78d0d3974cbc774f738f5865306ed128e1e0d7499c0ad96" +dependencies = [ + "console 0.16.2", "fuzzy-matcher", "shell-words", "tempfile", - "thiserror 1.0.69", "zeroize", ] @@ -396,6 +815,7 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", "crypto-common", + "subtle", ] [[package]] @@ -451,12 +871,98 @@ dependencies = [ "syn", ] +[[package]] +name = "docsplay" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8547ea80db62c5bb9d7796fcce5e6e07d1136bdc1a02269095061e806758fab4" +dependencies = [ + "docsplay-macros", +] + +[[package]] +name = "docsplay-macros" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11772ed3eb3db124d826f3abeadf5a791a557f62c19b123e3f07288158a71fdd" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "dunce" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" + +[[package]] +name = "ecb" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a8bfa975b1aec2145850fcaa1c6fe269a16578c44705a532ae3edc92b8881c7" +dependencies = [ + "cipher", +] + +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + +[[package]] +name = "email-encoding" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9298e6504d9b9e780ed3f7dfd43a61be8cd0e09eb07f7706a945b0072b6670b6" +dependencies = [ + "base64", + "memchr", +] + +[[package]] +name = "email_address" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e079f19b08ca6239f47f8ba8509c11cf3ea30095831f7fed61441475edd8c449" + [[package]] name = "encode_unicode" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" +[[package]] +name = "encoding_rs" +version = "0.8.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "enumflags2" +version = "0.7.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1027f7680c853e056ebcec683615fb6fbbc07dbaa13b4d5d9442b146ded4ecef" +dependencies = [ + "enumflags2_derive", +] + +[[package]] +name = "enumflags2_derive" +version = "0.7.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67c78a4d8fdf9953a5c9d458f9efe940fd97a0cab0941c075a813ac594733827" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "equivalent" version = "1.0.2" @@ -473,6 +979,57 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "esp-idf-part" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5ebc2381d030e4e89183554c3fcd4ad44dc5ab34961ab09e09b4adbe4f94b61" +dependencies = [ + "bitflags 2.11.0", + "csv", + "deku", + "md-5", + "parse_int", + "regex", + "serde", + "serde_plain", + "strum", + "thiserror 2.0.18", +] + +[[package]] +name = "espflash" +version = "4.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46f05d15cb2479a3cbbbe684b9f0831b2ae036d9faefd1eb08f21267275862f9" +dependencies = [ + "base64", + "bitflags 2.11.0", + "bytemuck", + "esp-idf-part", + "flate2", + "gimli", + "libc", + "log", + "md-5", + "miette", + "nix 0.30.1", + "object 0.38.1", + "serde", + "sha2", + "strum", + "thiserror 2.0.18", +] + +[[package]] +name = "euclid" +version = "0.20.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bb7ef65b3777a325d1eeefefab5b6d4959da54747e33bd6258e789640f307ad" +dependencies = [ + "num-traits", +] + [[package]] name = "fallible-iterator" version = "0.3.0" @@ -485,6 +1042,29 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" +[[package]] +name = "fantoccini" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d0086bcd59795408c87a04f94b5a8bd62cba2856cfe656c7e6439061d95b760" +dependencies = [ + "base64", + "cookie 0.18.1", + "futures-util", + "http 1.4.0", + "http-body-util", + "hyper", + "hyper-rustls", + "hyper-util", + "mime", + "serde", + "serde_json", + "time", + "tokio", + "url", + "webdriver", +] + [[package]] name = "fastrand" version = "2.3.0" @@ -497,6 +1077,34 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" +[[package]] +name = "flate2" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "foldhash" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" + [[package]] name = "form_urlencoded" version = "1.2.2" @@ -507,10 +1115,37 @@ dependencies = [ ] [[package]] -name = "futures-channel" -version = "0.3.31" +name = "fs_extra" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + +[[package]] +name = "funty" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" + +[[package]] +name = "futures" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" dependencies = [ "futures-core", "futures-sink", @@ -518,21 +1153,42 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" + +[[package]] +name = "futures-executor" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf29c38818342a3b26b5b923639e7b1f4a61fc5e76102d4b1981c6dc7a7579d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] [[package]] name = "futures-io" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" + +[[package]] +name = "futures-lite" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad" +dependencies = [ + "futures-core", + "pin-project-lite", +] [[package]] name = "futures-macro" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" dependencies = [ "proc-macro2", "quote", @@ -541,22 +1197,23 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" +checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" [[package]] name = "futures-task" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" [[package]] name = "futures-util" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" dependencies = [ + "futures-channel", "futures-core", "futures-io", "futures-macro", @@ -564,7 +1221,6 @@ dependencies = [ "futures-task", "memchr", "pin-project-lite", - "pin-utils", "slab", ] @@ -614,6 +1270,45 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "getrandom" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "139ef39800118c7683f2fd3c98c1b23c09ae076556b435f8e9064ae108aaeeec" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", + "wasip3", +] + +[[package]] +name = "gimli" +version = "0.32.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e629b9b98ef3dd8afe6ca2bd0f89306cec16d43d907889945bc5d6687f2f13c7" +dependencies = [ + "fallible-iterator", + "indexmap", + "stable_deref_trait", +] + +[[package]] +name = "glob" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" + +[[package]] +name = "hash32" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47d60b12902ba28e2730cd37e95b8c9223af2808df9e902d4df49588d1470606" +dependencies = [ + "byteorder", +] + [[package]] name = "hashbrown" version = "0.14.5" @@ -621,6 +1316,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" dependencies = [ "ahash", + "allocator-api2", +] + +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash 0.1.5", ] [[package]] @@ -628,14 +1333,38 @@ name = "hashbrown" version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +dependencies = [ + "foldhash 0.2.0", +] + +[[package]] +name = "hashify" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "149e3ea90eb5a26ad354cfe3cb7f7401b9329032d0235f2687d03a35f30e5d4c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] [[package]] name = "hashlink" -version = "0.9.1" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" +checksum = "ea0b22561a9c04a7cb1a302c013e0259cd3b4bb619f145b32f72b8b4bcbed230" dependencies = [ - "hashbrown 0.14.5", + "hashbrown 0.16.1", +] + +[[package]] +name = "heapless" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2af2455f757db2b292a9b1768c4b70186d443bcb3b316252d6b540aec1cd89ed" +dependencies = [ + "hash32", + "stable_deref_trait", ] [[package]] @@ -644,6 +1373,48 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hermit-abi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" + +[[package]] +name = "hermit-abi" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" + +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + +[[package]] +name = "hidapi" +version = "2.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "565dd4c730b8f8b2c0fb36df6be12e5470ae10895ddcc4e9dcfbfb495de202b0" +dependencies = [ + "cc", + "cfg-if", + "libc", + "nix 0.27.1", + "pkg-config", + "udev", + "windows-sys 0.48.0", +] + +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "hostname" version = "0.4.2" @@ -655,6 +1426,17 @@ dependencies = [ "windows-link", ] +[[package]] +name = "http" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http" version = "1.4.0" @@ -672,7 +1454,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http", + "http 1.4.0", ] [[package]] @@ -683,7 +1465,7 @@ checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" dependencies = [ "bytes", "futures-core", - "http", + "http 1.4.0", "http-body", "pin-project-lite", ] @@ -710,7 +1492,7 @@ dependencies = [ "bytes", "futures-channel", "futures-core", - "http", + "http 1.4.0", "http-body", "httparse", "httpdate", @@ -728,10 +1510,12 @@ version = "0.27.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" dependencies = [ - "http", + "http 1.4.0", "hyper", "hyper-util", + "log", "rustls", + "rustls-native-certs", "rustls-pki-types", "tokio", "tokio-rustls", @@ -749,7 +1533,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http", + "http 1.4.0", "http-body", "hyper", "ipnet", @@ -786,6 +1570,18 @@ dependencies = [ "cc", ] +[[package]] +name = "icu_collections" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" +dependencies = [ + "displaydoc", + "yoke 0.7.5", + "zerofrom", + "zerovec 0.10.4", +] + [[package]] name = "icu_collections" version = "2.1.1" @@ -794,9 +1590,9 @@ checksum = "4c6b649701667bbe825c3b7e6388cb521c23d88644678e83c0c4d0a621a34b43" dependencies = [ "displaydoc", "potential_utf", - "yoke", + "yoke 0.8.1", "zerofrom", - "zerovec", + "zerovec 0.11.5", ] [[package]] @@ -806,10 +1602,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "edba7861004dd3714265b4db54a3c390e880ab658fec5f7db895fae2046b5bb6" dependencies = [ "displaydoc", - "litemap", - "tinystr", - "writeable", - "zerovec", + "litemap 0.8.1", + "tinystr 0.8.2", + "writeable 0.6.2", + "zerovec 0.11.5", +] + +[[package]] +name = "icu_locid" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" +dependencies = [ + "displaydoc", + "litemap 0.7.5", + "tinystr 0.7.6", + "writeable 0.5.5", ] [[package]] @@ -818,12 +1626,12 @@ version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f6c8828b67bf8908d82127b2054ea1b4427ff0230ee9141c54251934ab1b599" dependencies = [ - "icu_collections", + "icu_collections 2.1.1", "icu_normalizer_data", "icu_properties", - "icu_provider", + "icu_provider 2.1.1", "smallvec", - "zerovec", + "zerovec 0.11.5", ] [[package]] @@ -838,12 +1646,12 @@ version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "020bfc02fe870ec3a66d93e677ccca0562506e5872c650f893269e08615d74ec" dependencies = [ - "icu_collections", + "icu_collections 2.1.1", "icu_locale_core", "icu_properties_data", - "icu_provider", + "icu_provider 2.1.1", "zerotrie", - "zerovec", + "zerovec 0.11.5", ] [[package]] @@ -852,6 +1660,23 @@ version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "616c294cf8d725c6afcd8f55abc17c56464ef6211f9ed59cccffe534129c77af" +[[package]] +name = "icu_provider" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_provider_macros", + "stable_deref_trait", + "tinystr 0.7.6", + "writeable 0.5.5", + "yoke 0.7.5", + "zerofrom", + "zerovec 0.10.4", +] + [[package]] name = "icu_provider" version = "2.1.1" @@ -860,13 +1685,58 @@ checksum = "85962cf0ce02e1e0a629cc34e7ca3e373ce20dda4c4d7294bbd0bf1fdb59e614" dependencies = [ "displaydoc", "icu_locale_core", - "writeable", - "yoke", + "writeable 0.6.2", + "yoke 0.8.1", "zerofrom", "zerotrie", - "zerovec", + "zerovec 0.11.5", ] +[[package]] +name = "icu_provider_macros" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "icu_segmenter" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a717725612346ffc2d7b42c94b820db6908048f39434504cb130e8b46256b0de" +dependencies = [ + "core_maths", + "displaydoc", + "icu_collections 1.5.0", + "icu_locid", + "icu_provider 1.5.0", + "icu_segmenter_data", + "utf8_iter", + "zerovec 0.10.4", +] + +[[package]] +name = "icu_segmenter_data" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1e52775179941363cc594e49ce99284d13d6948928d8e72c755f55e98caa1eb" + +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" version = "1.1.0" @@ -888,6 +1758,12 @@ dependencies = [ "icu_properties", ] +[[package]] +name = "ihex" +version = "3.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "365a784774bb381e8c19edb91190a90d7f2625e057b55de2bc0f6b57bc779ff2" + [[package]] name = "indexmap" version = "2.13.0" @@ -896,6 +1772,8 @@ checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" dependencies = [ "equivalent", "hashbrown 0.16.1", + "serde", + "serde_core", ] [[package]] @@ -904,9 +1782,31 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01" dependencies = [ + "block-padding", "generic-array", ] +[[package]] +name = "io-kit-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "617ee6cf8e3f66f3b4ea67a4058564628cde41901316e19f559e14c7c72c5e7b" +dependencies = [ + "core-foundation-sys", + "mach2", +] + +[[package]] +name = "io-lifetimes" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" +dependencies = [ + "hermit-abi 0.3.9", + "libc", + "windows-sys 0.48.0", +] + [[package]] name = "ipnet" version = "2.11.0" @@ -929,12 +1829,40 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" +[[package]] +name = "itertools" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" +[[package]] +name = "jep106" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a1354c92c91fd5595fd4cc46694b6914749cc90ea437246549c26b6ff0ec6d1" +dependencies = [ + "serde", +] + +[[package]] +name = "jobserver" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" +dependencies = [ + "getrandom 0.3.4", + "libc", +] + [[package]] name = "js-sys" version = "0.3.85" @@ -945,57 +1873,169 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "landlock" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49fefd6652c57d68aaa32544a4c0e642929725bdc1fd929367cdeb673ab81088" +dependencies = [ + "enumflags2", + "libc", + "thiserror 2.0.18", +] + [[package]] name = "lazy_static" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + +[[package]] +name = "lettre" +version = "0.11.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e13e10e8818f8b2a60f52cb127041d388b89f3a96a62be9ceaffa22262fef7f" +dependencies = [ + "base64", + "chumsky", + "email-encoding", + "email_address", + "fastrand", + "httpdate", + "idna", + "mime", + "nom 8.0.0", + "percent-encoding", + "quoted_printable", + "rustls", + "socket2", + "tokio", + "url", + "webpki-roots 1.0.6", +] + [[package]] name = "libc" version = "0.2.182" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6800badb6cb2082ffd7b6a67e6125bb39f18782f793520caee8cb8846be06112" +[[package]] +name = "libm" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" + [[package]] name = "libredox" version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d0b95e02c851351f877147b7deea7b1afb1df71b63aa5f8270716e0c5720616" dependencies = [ - "bitflags", + "bitflags 2.11.0", "libc", ] [[package]] name = "libsqlite3-sys" -version = "0.30.1" +version = "0.36.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149" +checksum = "95b4103cffefa72eb8428cb6b47d6627161e51c2739fc5e3b734584157bc642a" dependencies = [ "cc", "pkg-config", "vcpkg", ] +[[package]] +name = "libudev-sys" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c8469b4a23b962c1396b9b451dda50ef5b283e8dd309d69033475fa9b334324" +dependencies = [ + "libc", + "pkg-config", +] + +[[package]] +name = "linux-raw-sys" +version = "0.4.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" + +[[package]] +name = "linux-raw-sys" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" + [[package]] name = "linux-raw-sys" version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" +[[package]] +name = "litemap" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23fb14cb19457329c82206317a5663005a4d404783dc74f4252769b0d5f42856" + [[package]] name = "litemap" version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + [[package]] name = "log" version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +[[package]] +name = "lopdf" +version = "0.38.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7184fdea2bc3cd272a1acec4030c321a8f9875e877b3f92a53f2f6033fdc289" +dependencies = [ + "aes", + "bitflags 2.11.0", + "cbc", + "ecb", + "encoding_rs", + "flate2", + "getrandom 0.3.4", + "indexmap", + "itoa", + "log", + "md-5", + "nom 8.0.0", + "nom_locate", + "rand 0.9.2", + "rangemap", + "sha2", + "stringprep", + "thiserror 2.0.18", + "ttf-parser", + "weezl", +] + [[package]] name = "lru-slab" version = "0.1.2" @@ -1003,10 +2043,47 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" [[package]] -name = "matchit" -version = "0.7.3" +name = "mach2" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +checksum = "d640282b302c0bb0a2a8e0233ead9035e3bed871f0b7e81fe4a1ec829765db44" +dependencies = [ + "libc", +] + +[[package]] +name = "mail-parser" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f82a3d6522697593ba4c683e0a6ee5a40fee93bc1a525e3cc6eeb3da11fd8897" +dependencies = [ + "hashify", +] + +[[package]] +name = "matchers" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" +dependencies = [ + "regex-automata", +] + +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + +[[package]] +name = "md-5" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" +dependencies = [ + "cfg-if", + "digest", +] [[package]] name = "memchr" @@ -1014,6 +2091,28 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" +[[package]] +name = "miette" +version = "7.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f98efec8807c63c752b5bd61f862c165c115b0a35685bdcfd9238c7aeb592b7" +dependencies = [ + "cfg-if", + "miette-derive", + "unicode-width 0.1.14", +] + +[[package]] +name = "miette-derive" +version = "7.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db5b29714e950dbb20d5e6f74f9dcec4edbcc1067bb7f8ed198c097b8c1a818b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "mime" version = "0.3.17" @@ -1036,6 +2135,16 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" +[[package]] +name = "miniz_oxide" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" +dependencies = [ + "adler2", + "simd-adler32", +] + [[package]] name = "mio" version = "1.1.1" @@ -1043,10 +2152,79 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" dependencies = [ "libc", + "log", "wasi", "windows-sys 0.61.2", ] +[[package]] +name = "mio-serial" +version = "5.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "029e1f407e261176a983a6599c084efd322d9301028055c87174beac71397ba3" +dependencies = [ + "log", + "mio", + "nix 0.29.0", + "serialport", + "winapi", +] + +[[package]] +name = "nix" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "598beaf3cc6fdd9a5dfb1630c2800c7acd31df7aaf0f565796fba2b53ca1af1b" +dependencies = [ + "bitflags 1.3.2", + "cfg-if", + "libc", +] + +[[package]] +name = "nix" +version = "0.27.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2eb04e9c688eff1c89d72b407f168cf79bb9e867a9d3323ed6c01519eb9cc053" +dependencies = [ + "bitflags 2.11.0", + "cfg-if", + "libc", +] + +[[package]] +name = "nix" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" +dependencies = [ + "bitflags 2.11.0", + "cfg-if", + "cfg_aliases", + "libc", +] + +[[package]] +name = "nix" +version = "0.30.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74523f3a35e05aba87a1d978330aef40f67b0304ac79c1c00b294c9830543db6" +dependencies = [ + "bitflags 2.11.0", + "cfg-if", + "cfg_aliases", + "libc", +] + +[[package]] +name = "no_std_io2" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a3564ce7035b1e4778d8cb6cacebb5d766b5e8fe5a75b9e441e33fb61a872c6" +dependencies = [ + "memchr", +] + [[package]] name = "nom" version = "7.1.3" @@ -1057,6 +2235,26 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nom" +version = "8.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df9761775871bdef83bee530e60050f7e54b1105350d6884eb0fb4f46c2f9405" +dependencies = [ + "memchr", +] + +[[package]] +name = "nom_locate" +version = "5.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b577e2d69827c4740cba2b52efaad1c4cc7c73042860b199710b3575c68438d" +dependencies = [ + "bytecount", + "memchr", + "nom 8.0.0", +] + [[package]] name = "nu-ansi-term" version = "0.50.3" @@ -1066,6 +2264,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "num-conv" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050" + [[package]] name = "num-traits" version = "0.2.19" @@ -1075,6 +2279,63 @@ dependencies = [ "autocfg", ] +[[package]] +name = "nusb" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f861541f15de120eae5982923d073bfc0c1a65466561988c82d6e197734c19e" +dependencies = [ + "atomic-waker", + "core-foundation 0.9.4", + "core-foundation-sys", + "futures-core", + "io-kit-sys", + "libc", + "log", + "once_cell", + "rustix 0.38.44", + "slab", + "windows-sys 0.48.0", +] + +[[package]] +name = "nusb" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0226f4db3ee78f820747cf713767722877f6449d7a0fcfbf2ec3b840969763f" +dependencies = [ + "core-foundation 0.10.1", + "core-foundation-sys", + "futures-core", + "io-kit-sys", + "linux-raw-sys 0.9.4", + "log", + "once_cell", + "rustix 1.1.3", + "slab", + "windows-sys 0.60.2", +] + +[[package]] +name = "object" +version = "0.37.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff76201f031d8863c38aa7f905eca4f53abbfa15f609db4277d44cd8938f33fe" +dependencies = [ + "memchr", +] + +[[package]] +name = "object" +version = "0.38.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271638cd5fa9cca89c4c304675ca658efc4e64a66c716b7cfe1afb4b9611dbbc" +dependencies = [ + "flate2", + "memchr", + "ruzstd", +] + [[package]] name = "once_cell" version = "1.21.3" @@ -1093,18 +2354,187 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" +[[package]] +name = "openssl-probe" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" + +[[package]] +name = "opentelemetry" +version = "0.31.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b84bcd6ae87133e903af7ef497404dda70c60d0ea14895fc8a5e6722754fc2a0" +dependencies = [ + "futures-core", + "futures-sink", + "js-sys", + "pin-project-lite", + "thiserror 2.0.18", +] + +[[package]] +name = "opentelemetry-http" +version = "0.31.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7a6d09a73194e6b66df7c8f1b680f156d916a1a942abf2de06823dd02b7855d" +dependencies = [ + "async-trait", + "bytes", + "http 1.4.0", + "opentelemetry", + "reqwest", +] + +[[package]] +name = "opentelemetry-otlp" +version = "0.31.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2366db2dca4d2ad033cad11e6ee42844fd727007af5ad04a1730f4cb8163bf" +dependencies = [ + "http 1.4.0", + "opentelemetry", + "opentelemetry-http", + "opentelemetry-proto", + "opentelemetry_sdk", + "prost", + "reqwest", + "thiserror 2.0.18", +] + +[[package]] +name = "opentelemetry-proto" +version = "0.31.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7175df06de5eaee9909d4805a3d07e28bb752c34cab57fa9cff549da596b30f" +dependencies = [ + "opentelemetry", + "opentelemetry_sdk", + "prost", + "tonic", + "tonic-prost", +] + +[[package]] +name = "opentelemetry_sdk" +version = "0.31.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e14ae4f5991976fd48df6d843de219ca6d31b01daaab2dad5af2badeded372bd" +dependencies = [ + "futures-channel", + "futures-executor", + "futures-util", + "opentelemetry", + "percent-encoding", + "rand 0.9.2", + "thiserror 2.0.18", +] + [[package]] name = "option-ext" version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "parking" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" + +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + +[[package]] +name = "parse_int" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c464266693329dd5a8715098c7f86e6c5fd5d985018b8318f53d9c6c2b21a31" +dependencies = [ + "num-traits", +] + +[[package]] +name = "pdf-extract" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28ba1758a3d3f361459645780e09570b573fc3c82637449e9963174c813a98" +dependencies = [ + "adobe-cmap-parser", + "cff-parser", + "encoding_rs", + "euclid", + "log", + "lopdf", + "postscript", + "type1-encoding-parser", + "unicode-normalization", +] + [[package]] name = "percent-encoding" version = "2.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" +[[package]] +name = "phf" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "913273894cec178f401a31ec4b656318d95473527be05c0752cc41cdc32be8b7" +dependencies = [ + "phf_shared", +] + +[[package]] +name = "phf_shared" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06005508882fb681fd97892ecff4b7fd0fee13ef1aa569f8695dae7ab9099981" +dependencies = [ + "siphasher", +] + +[[package]] +name = "pin-project" +version = "1.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "pin-project-lite" version = "0.2.16" @@ -1123,6 +2553,20 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "polling" +version = "3.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d0e4f59085d47d8241c88ead0f274e8a0cb551f3625263c05eb8dd897c34218" +dependencies = [ + "cfg-if", + "concurrent-queue", + "hermit-abi 0.5.2", + "pin-project-lite", + "rustix 1.1.3", + "windows-sys 0.61.2", +] + [[package]] name = "poly1305" version = "0.8.0" @@ -1134,15 +2578,33 @@ dependencies = [ "universal-hash", ] +[[package]] +name = "pom" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60f6ce597ecdcc9a098e7fddacb1065093a3d66446fa16c675e7e71d1b5c28e6" + +[[package]] +name = "postscript" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78451badbdaebaf17f053fd9152b3ffb33b516104eacb45e7864aaa9c712f306" + [[package]] name = "potential_utf" version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b73949432f5e2a09657003c25bca5e19a0e9c84f8058ca374f49e0ebe605af77" dependencies = [ - "zerovec", + "zerovec 0.11.5", ] +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + [[package]] name = "ppv-lite86" version = "0.2.21" @@ -1152,6 +2614,75 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + +[[package]] +name = "probe-rs" +version = "0.30.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ee27329ac37fa02b194c62a4e3c1aa053739884ea7bcf861249866d3bf7de00" +dependencies = [ + "anyhow", + "async-io", + "bincode", + "bitfield", + "bitvec", + "cobs", + "docsplay", + "dunce", + "espflash", + "flate2", + "futures-lite", + "hidapi", + "ihex", + "itertools", + "jep106", + "nusb 0.1.14", + "object 0.37.3", + "parking_lot", + "probe-rs-target", + "rmp-serde", + "scroll", + "serde", + "serde_yaml", + "serialport", + "thiserror 2.0.18", + "tracing", + "uf2-decode", + "zerocopy", +] + +[[package]] +name = "probe-rs-target" +version = "0.30.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2239aca5dc62c68ca6d8ff0051fe617cb8363b803380fbc60567e67c82b474df" +dependencies = [ + "base64", + "indexmap", + "jep106", + "serde", + "serde_with", + "url", +] + +[[package]] +name = "proc-macro-crate" +version = "3.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "219cb19e96be00ab2e37d6e299658a0cfa83e52429179969b0f0121b4ac46983" +dependencies = [ + "toml_edit", +] + [[package]] name = "proc-macro2" version = "1.0.106" @@ -1161,6 +2692,53 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "prometheus" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ca5326d8d0b950a9acd87e6a3f94745394f62e4dae1b1ee22b2bc0c394af43a" +dependencies = [ + "cfg-if", + "fnv", + "lazy_static", + "memchr", + "parking_lot", + "thiserror 2.0.18", +] + +[[package]] +name = "prost" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2ea70524a2f82d518bce41317d0fae74151505651af45faf1ffbd6fd33f0568" +dependencies = [ + "bytes", + "prost-derive", +] + +[[package]] +name = "prost-derive" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b" +dependencies = [ + "anyhow", + "itertools", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "psm" +version = "0.1.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3852766467df634d74f0b2d7819bf8dc483a0eb2e3b0f50f756f9cfe8b0d18d8" +dependencies = [ + "ar_archive_writer", + "cc", +] + [[package]] name = "quinn" version = "0.11.9" @@ -1225,12 +2803,24 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "quoted_printable" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "640c9bd8497b02465aeef5375144c26062e0dcd5939dfcbb0f5db76cb8c17c73" + [[package]] name = "r-efi" version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "radium" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" + [[package]] name = "rand" version = "0.8.5" @@ -1290,6 +2880,21 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "rangemap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "973443cf09a9c8656b574a866ab68dfa19f0867d0340648c7d2f6a71b8a8ea68" + +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags 2.11.0", +] + [[package]] name = "redox_users" version = "0.4.6" @@ -1312,6 +2917,35 @@ dependencies = [ "thiserror 2.0.18", ] +[[package]] +name = "regex" +version = "1.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a96887878f22d7bad8a3b6dc5b7440e0ada9a245242924394987b21cf2210a4c" + [[package]] name = "reqwest" version = "0.12.28" @@ -1323,7 +2957,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "http", + "http 1.4.0", "http-body", "http-body-util", "hyper", @@ -1370,17 +3004,56 @@ dependencies = [ ] [[package]] -name = "rusqlite" -version = "0.32.1" +name = "rmp" +version = "0.8.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7753b721174eb8ff87a9a0e799e2d7bc3749323e773db92e0984debb00019d6e" +checksum = "4ba8be72d372b2c9b35542551678538b562e7cf86c3315773cae48dfbfe7790c" dependencies = [ - "bitflags", + "num-traits", +] + +[[package]] +name = "rmp-serde" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72f81bee8c8ef9b577d1681a70ebbc962c232461e397b22c208c43c04b67a155" +dependencies = [ + "rmp", + "serde", +] + +[[package]] +name = "rppal" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "612e1a22e21f08a246657c6433fe52b773ae43d07c9ef88ccfc433cc8683caba" +dependencies = [ + "libc", +] + +[[package]] +name = "rsqlite-vfs" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8a1f2315036ef6b1fbacd1972e8ee7688030b0a2121edfc2a6550febd41574d" +dependencies = [ + "hashbrown 0.16.1", + "thiserror 2.0.18", +] + +[[package]] +name = "rusqlite" +version = "0.38.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1c93dd1c9683b438c392c492109cb702b8090b2bfc8fed6f6e4eb4523f17af3" +dependencies = [ + "bitflags 2.11.0", "fallible-iterator", "fallible-streaming-iterator", "hashlink", "libsqlite3-sys", "smallvec", + "sqlite-wasm-rs", ] [[package]] @@ -1389,16 +3062,29 @@ version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" +[[package]] +name = "rustix" +version = "0.38.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" +dependencies = [ + "bitflags 2.11.0", + "errno", + "libc", + "linux-raw-sys 0.4.15", + "windows-sys 0.59.0", +] + [[package]] name = "rustix" version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" dependencies = [ - "bitflags", + "bitflags 2.11.0", "errno", "libc", - "linux-raw-sys", + "linux-raw-sys 0.11.0", "windows-sys 0.61.2", ] @@ -1408,6 +3094,8 @@ version = "0.23.36" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c665f33d38cea657d9614f766881e4d510e0eda4239891eea56b4cadcf01801b" dependencies = [ + "aws-lc-rs", + "log", "once_cell", "ring", "rustls-pki-types", @@ -1416,6 +3104,18 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls-native-certs" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63" +dependencies = [ + "openssl-probe", + "rustls-pki-types", + "schannel", + "security-framework", +] + [[package]] name = "rustls-pki-types" version = "1.14.0" @@ -1432,6 +3132,7 @@ version = "0.103.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53" dependencies = [ + "aws-lc-rs", "ring", "rustls-pki-types", "untrusted", @@ -1443,12 +3144,71 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" +[[package]] +name = "ruzstd" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5ff0cc5e135c8870a775d3320910cd9b564ec036b4dc0b8741629020be63f01" +dependencies = [ + "twox-hash", +] + [[package]] name = "ryu" version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" +[[package]] +name = "schannel" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "scroll" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1257cd4248b4132760d6524d6dda4e053bc648c9070b960929bf50cfb1e7add" + +[[package]] +name = "security-framework" +version = "3.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d17b898a6d6948c3a8ee4372c17cb384f90d2e6e912ef00895b14fd7ab54ec38" +dependencies = [ + "bitflags 2.11.0", + "core-foundation 0.10.1", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "321c8673b092a9a42605034a9879d73cb79101ed5fd117bc9a597b89b4e9e61a" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" + [[package]] name = "serde" version = "1.0.228" @@ -1504,14 +3264,23 @@ dependencies = [ ] [[package]] -name = "serde_spanned" -version = "0.6.9" +name = "serde_plain" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf41e0cfaf7226dca15e8197172c295a782857fcb97fad1808a166870dee75a3" +checksum = "9ce1fc6db65a611022b23a0dec6975d63fb80a302cb3388835ff02c097258d50" dependencies = [ "serde", ] +[[package]] +name = "serde_spanned" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8bbf91e5a4d6315eee45e704372590b30e260ee83af6639d64557f51b067776" +dependencies = [ + "serde_core", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -1524,6 +3293,52 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_with" +version = "3.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fa237f2807440d238e0364a218270b98f767a00d3dada77b1c53ae88940e2e7" +dependencies = [ + "base64", + "chrono", + "hex", + "indexmap", + "serde_core", + "serde_json", + "time", +] + +[[package]] +name = "serde_yaml" +version = "0.9.34+deprecated" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" +dependencies = [ + "indexmap", + "itoa", + "ryu", + "serde", + "unsafe-libyaml", +] + +[[package]] +name = "serialport" +version = "4.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acaf3f973e8616d7ceac415f53fc60e190b2a686fbcf8d27d0256c741c5007b" +dependencies = [ + "bitflags 2.11.0", + "cfg-if", + "core-foundation 0.10.1", + "core-foundation-sys", + "io-kit-sys", + "mach2", + "nix 0.26.4", + "scopeguard", + "unescaper", + "winapi", +] + [[package]] name = "sha1" version = "0.10.6" @@ -1535,6 +3350,17 @@ dependencies = [ "digest", ] +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -1575,6 +3401,18 @@ dependencies = [ "libc", ] +[[package]] +name = "simd-adler32" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" + +[[package]] +name = "siphasher" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2aa850e253778c88a04c3d7323b043aeda9d3e30d5971937c1855769763678e" + [[package]] name = "slab" version = "0.4.12" @@ -1597,18 +3435,75 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "sqlite-wasm-rs" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f4206ed3a67690b9c29b77d728f6acc3ce78f16bf846d83c94f76400320181b" +dependencies = [ + "cc", + "js-sys", + "rsqlite-vfs", + "wasm-bindgen", +] + [[package]] name = "stable_deref_trait" version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" +[[package]] +name = "stacker" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d74a23609d509411d10e2176dc2a4346e3b4aea2e7b1869f19fdedbc71c013" +dependencies = [ + "cc", + "cfg-if", + "libc", + "psm", + "windows-sys 0.59.0", +] + +[[package]] +name = "stringprep" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b4df3d392d81bd458a8a621b8bffbd2302a12ffe288a9d931670948749463b1" +dependencies = [ + "unicode-bidi", + "unicode-normalization", + "unicode-properties", +] + [[package]] name = "strsim" version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "strum" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7695ce3845ea4b33927c055a39dc438a45b059f7c1b3d91d38d10355fb8cbca7" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "subtle" version = "2.6.1" @@ -1617,9 +3512,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.115" +version = "2.0.116" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e614ed320ac28113fa64972c4262d5dbc89deacdfd00c34a3e4cea073243c12" +checksum = "3df424c70518695237746f84cede799c9c58fcb37450d7b23716568cc8bc69cb" dependencies = [ "proc-macro2", "quote", @@ -1646,6 +3541,12 @@ dependencies = [ "syn", ] +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + [[package]] name = "tempfile" version = "3.25.0" @@ -1653,9 +3554,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0136791f7c95b1f6dd99f9cc786b91bb81c3800b639b3478e561ddb7be95e5f1" dependencies = [ "fastrand", - "getrandom 0.3.4", + "getrandom 0.4.1", "once_cell", - "rustix", + "rustix 1.1.3", "windows-sys 0.61.2", ] @@ -1708,6 +3609,46 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "time" +version = "0.3.47" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c" +dependencies = [ + "deranged", + "itoa", + "num-conv", + "powerfmt", + "serde_core", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca" + +[[package]] +name = "time-macros" +version = "0.2.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e70e4c5a0e0a8a4823ad65dfe1a6930e4f4d756dcd9dd7939022b5e8c501215" +dependencies = [ + "num-conv", + "time-core", +] + +[[package]] +name = "tinystr" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" +dependencies = [ + "displaydoc", +] + [[package]] name = "tinystr" version = "0.8.2" @@ -1715,7 +3656,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42d3e9c45c09de15d06dd8acf5f4e0e399e85927b7f00711024eb7ae10fa4869" dependencies = [ "displaydoc", - "zerovec", + "zerovec 0.11.5", ] [[package]] @@ -1770,6 +3711,20 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-serial" +version = "5.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa1d5427f11ba7c5e6384521cfd76f2d64572ff29f3f4f7aa0f496282923fdc8" +dependencies = [ + "cfg-if", + "futures", + "log", + "mio-serial", + "serialport", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.18" @@ -1804,10 +3759,22 @@ dependencies = [ "rustls-pki-types", "tokio", "tokio-rustls", - "tungstenite", + "tungstenite 0.24.0", "webpki-roots 0.26.11", ] +[[package]] +name = "tokio-tungstenite" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d25a406cddcc431a75d3d9afc6a7c0f7428d4891dd973e4d54c56b46127bf857" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite 0.28.0", +] + [[package]] name = "tokio-util" version = "0.7.18" @@ -1823,44 +3790,95 @@ dependencies = [ [[package]] name = "toml" -version = "0.8.23" +version = "1.0.1+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc1beb996b9d83529a9e75c17a1686767d148d70663143c7854d8b4a09ced362" -dependencies = [ - "serde", - "serde_spanned", - "toml_datetime", - "toml_edit", -] - -[[package]] -name = "toml_datetime" -version = "0.6.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22cddaf88f4fbc13c51aebbf5f8eceb5c7c5a9da2ac40a13519eb5b0a0e8f11c" -dependencies = [ - "serde", -] - -[[package]] -name = "toml_edit" -version = "0.22.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" +checksum = "bbe30f93627849fa362d4a602212d41bb237dc2bd0f8ba0b2ce785012e124220" dependencies = [ "indexmap", - "serde", + "serde_core", "serde_spanned", - "toml_datetime", - "toml_write", + "toml_datetime 1.0.0+spec-1.1.0", + "toml_parser", + "toml_writer", "winnow", ] [[package]] -name = "toml_write" -version = "0.1.2" +name = "toml_datetime" +version = "0.7.5+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" +checksum = "92e1cfed4a3038bc5a127e35a2d360f145e1f4b971b551a2ba5fd7aedf7e1347" +dependencies = [ + "serde_core", +] + +[[package]] +name = "toml_datetime" +version = "1.0.0+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32c2555c699578a4f59f0cc68e5116c8d7cabbd45e1409b989d4be085b53f13e" +dependencies = [ + "serde_core", +] + +[[package]] +name = "toml_edit" +version = "0.23.10+spec-1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84c8b9f757e028cee9fa244aea147aab2a9ec09d5325a9b01e0a49730c2b5269" +dependencies = [ + "indexmap", + "toml_datetime 0.7.5+spec-1.1.0", + "toml_parser", + "winnow", +] + +[[package]] +name = "toml_parser" +version = "1.0.8+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0742ff5ff03ea7e67c8ae6c93cac239e0d9784833362da3f9a9c1da8dfefcbdc" +dependencies = [ + "winnow", +] + +[[package]] +name = "toml_writer" +version = "1.0.6+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab16f14aed21ee8bfd8ec22513f7287cd4a91aa92e44edfe2c17ddd004e92607" + +[[package]] +name = "tonic" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f32a6f80051a4111560201420c7885d0082ba9efe2ab61875c587bb6b18b9a0" +dependencies = [ + "async-trait", + "base64", + "bytes", + "http 1.4.0", + "http-body", + "http-body-util", + "percent-encoding", + "pin-project", + "sync_wrapper", + "tokio-stream", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tonic-prost" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f86539c0089bfd09b1f8c0ab0239d80392af74c21bc9e0f15e1b4aca4c1647f" +dependencies = [ + "bytes", + "prost", + "tonic", +] [[package]] name = "tower" @@ -1883,10 +3901,10 @@ version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" dependencies = [ - "bitflags", + "bitflags 2.11.0", "bytes", "futures-util", - "http", + "http 1.4.0", "http-body", "http-body-util", "iri-string", @@ -1916,9 +3934,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" dependencies = [ "pin-project-lite", + "tracing-attributes", "tracing-core", ] +[[package]] +name = "tracing-attributes" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tracing-core" version = "0.1.36" @@ -1934,9 +3964,13 @@ version = "0.3.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e" dependencies = [ + "matchers", "nu-ansi-term", + "once_cell", + "regex-automata", "sharded-slab", "thread_local", + "tracing", "tracing-core", ] @@ -1946,6 +3980,12 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "ttf-parser" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2df906b07856748fa3f6e0ad0cbaa047052d4a7dd609e231c4f72cee8c36f31" + [[package]] name = "tungstenite" version = "0.24.0" @@ -1955,7 +3995,7 @@ dependencies = [ "byteorder", "bytes", "data-encoding", - "http", + "http 1.4.0", "httparse", "log", "rand 0.8.5", @@ -1966,6 +4006,38 @@ dependencies = [ "utf-8", ] +[[package]] +name = "tungstenite" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8628dcc84e5a09eb3d8423d6cb682965dea9133204e8fb3efee74c2a0c259442" +dependencies = [ + "bytes", + "data-encoding", + "http 1.4.0", + "httparse", + "log", + "rand 0.9.2", + "sha1", + "thiserror 2.0.18", + "utf-8", +] + +[[package]] +name = "twox-hash" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ea3136b675547379c4bd395ca6b938e5ad3c3d20fad76e7fe85f9e0d011419c" + +[[package]] +name = "type1-encoding-parser" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3d6cc09e1a99c7e01f2afe4953789311a1c50baebbdac5b477ecf78e2e92a5b" +dependencies = [ + "pom", +] + [[package]] name = "typenum" version = "1.19.0" @@ -1973,16 +4045,70 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" [[package]] -name = "unicase" -version = "2.8.1" +name = "udev" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" +checksum = "50051c6e22be28ee6f217d50014f3bc29e81c20dc66ff7ca0d5c5226e1dcc5a1" +dependencies = [ + "io-lifetimes", + "libc", + "libudev-sys", + "pkg-config", +] + +[[package]] +name = "uf2-decode" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca77d41ab27e3fa45df42043f96c79b80c6d8632eed906b54681d8d47ab00623" + +[[package]] +name = "unescaper" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4064ed685c487dbc25bd3f0e9548f2e34bab9d18cefc700f9ec2dba74ba1138e" +dependencies = [ + "thiserror 2.0.18", +] + +[[package]] +name = "unicase" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142" + +[[package]] +name = "unicode-bidi" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5" [[package]] name = "unicode-ident" -version = "1.0.23" +version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "537dd038a89878be9b64dd4bd1b260315c1bb94f4d784956b81e27a088d9a09e" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "unicode-normalization" +version = "0.1.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fd4f6878c9cb28d874b009da9e8d183b5abc80117c40bbd187a1fde336be6e8" +dependencies = [ + "tinyvec", +] + +[[package]] +name = "unicode-properties" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7df058c713841ad818f1dc5d3fd88063241cc61f49f5fbea4b951e8cf5a8d71d" + +[[package]] +name = "unicode-width" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" [[package]] name = "unicode-width" @@ -1990,6 +4116,12 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + [[package]] name = "universal-hash" version = "0.5.1" @@ -2000,12 +4132,24 @@ dependencies = [ "subtle", ] +[[package]] +name = "unsafe-libyaml" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" + [[package]] name = "untrusted" version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "unty" +version = "0.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d49784317cd0d1ee7ec5c716dd598ec5b4483ea832a2dced265471cc0f690ae" + [[package]] name = "url" version = "2.5.8" @@ -2016,6 +4160,7 @@ dependencies = [ "idna", "percent-encoding", "serde", + "serde_derive", ] [[package]] @@ -2038,11 +4183,11 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.20.0" +version = "1.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee48d38b119b0cd71fe4141b30f5ba9c7c5d9f4e7a3a8b4a674e4b6ef789976f" +checksum = "b672338555252d43fd2240c714dc444b8c6fb0a5c5335e65a07bba7742735ddb" dependencies = [ - "getrandom 0.3.4", + "getrandom 0.4.1", "js-sys", "wasm-bindgen", ] @@ -2059,6 +4204,12 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "virtue" +version = "0.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "051eb1abcf10076295e815102942cc58f9d5e3b4560e46e53c21e8ff6f3af7b1" + [[package]] name = "want" version = "0.3.1" @@ -2083,6 +4234,15 @@ dependencies = [ "wit-bindgen", ] +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" +dependencies = [ + "wit-bindgen", +] + [[package]] name = "wasm-bindgen" version = "0.2.108" @@ -2142,6 +4302,28 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap", + "wasm-encoder", + "wasmparser", +] + [[package]] name = "wasm-streams" version = "0.4.2" @@ -2155,6 +4337,18 @@ dependencies = [ "web-sys", ] +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags 2.11.0", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + [[package]] name = "web-sys" version = "0.3.85" @@ -2175,6 +4369,26 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webdriver" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91d53921e1bef27512fa358179c9a22428d55778d2c2ae3c5c37a52b82ce6e92" +dependencies = [ + "base64", + "bytes", + "cookie 0.16.2", + "http 0.2.12", + "icu_segmenter", + "log", + "serde", + "serde_derive", + "serde_json", + "thiserror 1.0.69", + "time", + "url", +] + [[package]] name = "webpki-roots" version = "0.26.11" @@ -2193,6 +4407,34 @@ dependencies = [ "rustls-pki-types", ] +[[package]] +name = "weezl" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a28ac98ddc8b9274cb41bb4d9d4d5c425b6020c50c46f25559911905610b4a88" + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-core" version = "0.62.2" @@ -2497,6 +4739,94 @@ name = "wit-bindgen" version = "0.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen-core" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck", + "indexmap", + "prettyplease", + "syn", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags 2.11.0", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + +[[package]] +name = "writeable" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" [[package]] name = "writeable" @@ -2504,6 +4834,27 @@ version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" +[[package]] +name = "wyz" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" +dependencies = [ + "tap", +] + +[[package]] +name = "yoke" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40" +dependencies = [ + "serde", + "stable_deref_trait", + "yoke-derive 0.7.5", + "zerofrom", +] + [[package]] name = "yoke" version = "0.8.1" @@ -2511,10 +4862,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72d6e5c6afb84d73944e5cedb052c4680d5657337201555f9f2a16b7406d4954" dependencies = [ "stable_deref_trait", - "yoke-derive", + "yoke-derive 0.8.1", "zerofrom", ] +[[package]] +name = "yoke-derive" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + [[package]] name = "yoke-derive" version = "0.8.1" @@ -2534,33 +4897,60 @@ dependencies = [ "anyhow", "async-trait", "axum", + "base64", "chacha20poly1305", "chrono", + "chrono-tz", "clap", - "console", + "console 0.15.11", "cron", "dialoguer", "directories", + "fantoccini", + "futures", "futures-util", + "glob", + "hex", + "hmac", "hostname", "http-body-util", + "landlock", + "lettre", + "mail-parser", + "nusb 0.2.1", + "opentelemetry", + "opentelemetry-otlp", + "opentelemetry_sdk", + "parking_lot", + "pdf-extract", + "probe-rs", + "prometheus", + "prost", + "rand 0.8.5", + "regex", "reqwest", - "ring", + "rppal", "rusqlite", + "rustls", + "rustls-pki-types", "serde", "serde_json", + "sha2", "shellexpand", "tempfile", "thiserror 2.0.18", "tokio", + "tokio-rustls", + "tokio-serial", "tokio-test", - "tokio-tungstenite", + "tokio-tungstenite 0.24.0", "toml", "tower", "tower-http", "tracing", "tracing-subscriber", "uuid", + "webpki-roots 1.0.6", ] [[package]] @@ -2617,19 +5007,41 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a59c17a5562d507e4b54960e8569ebee33bee890c70aa3fe7b97e85a9fd7851" dependencies = [ "displaydoc", - "yoke", + "yoke 0.8.1", "zerofrom", ] +[[package]] +name = "zerovec" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079" +dependencies = [ + "yoke 0.7.5", + "zerofrom", + "zerovec-derive 0.10.3", +] + [[package]] name = "zerovec" version = "0.11.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6c28719294829477f525be0186d13efa9a3c602f7ec202ca9e353d310fb9a002" dependencies = [ - "yoke", + "yoke 0.8.1", "zerofrom", - "zerovec-derive", + "zerovec-derive 0.11.2", +] + +[[package]] +name = "zerovec-derive" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" +dependencies = [ + "proc-macro2", + "quote", + "syn", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 595ab6c..81a22b7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,11 +1,15 @@ +[workspace] +members = ["."] +resolver = "2" + [package] name = "zeroclaw" version = "0.1.0" edition = "2021" authors = ["theonlyhennygod"] -license = "MIT" +license = "Apache-2.0" description = "Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant." -repository = "https://github.com/theonlyhennygod/zeroclaw" +repository = "https://github.com/zeroclaw-labs/zeroclaw" readme = "README.md" keywords = ["ai", "agent", "cli", "assistant", "chatbot"] categories = ["command-line-utilities", "api-bindings"] @@ -26,12 +30,21 @@ serde_json = { version = "1.0", default-features = false, features = ["std"] } # Config directories = "5.0" -toml = "0.8" +toml = "1.0" shellexpand = "3.1" # Logging - minimal tracing = { version = "0.1", default-features = false } -tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "ansi"] } +tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "ansi", "env-filter"] } + +# Observability - Prometheus metrics +prometheus = { version = "0.14", default-features = false } + +# Base64 encoding (screenshots, image data) +base64 = "0.22" + +# Optional Rust-native browser automation backend +fantoccini = { version = "0.22.0", optional = true, default-features = false, features = ["rustls-tls"] } # Error handling anyhow = "1.0" @@ -43,38 +56,109 @@ uuid = { version = "1.11", default-features = false, features = ["v4", "std"] } # Authenticated encryption (AEAD) for secret store chacha20poly1305 = "0.10" +# HMAC for webhook signature verification +hmac = "0.12" +sha2 = "0.10" +hex = "0.4" + +# CSPRNG for secure token generation +rand = "0.8" + +# Fast mutexes that don't poison on panic +parking_lot = "0.12" + # Async traits async-trait = "0.1" # HMAC-SHA256 (Zhipu/GLM JWT auth) ring = "0.17" +# Protobuf encode/decode (Feishu WS long-connection frame codec) +prost = { version = "0.14", default-features = false } + # Memory / persistence -rusqlite = { version = "0.32", features = ["bundled"] } -chrono = { version = "0.4", default-features = false, features = ["clock", "std"] } +rusqlite = { version = "0.38", features = ["bundled"] } +chrono = { version = "0.4", default-features = false, features = ["clock", "std", "serde"] } +chrono-tz = "0.10" cron = "0.12" # Interactive CLI prompts -dialoguer = { version = "0.11", features = ["fuzzy-select"] } +dialoguer = { version = "0.12", features = ["fuzzy-select"] } console = "0.15" +# Hardware discovery (device path globbing) +glob = "0.3" + # Discord WebSocket gateway tokio-tungstenite = { version = "0.24", features = ["rustls-tls-webpki-roots"] } futures-util = { version = "0.3", default-features = false, features = ["sink"] } +futures = "0.3" +regex = "1.10" hostname = "0.4.2" +lettre = { version = "0.11.19", default-features = false, features = ["builder", "smtp-transport", "rustls-tls"] } +mail-parser = "0.11.2" +rustls = "0.23" +rustls-pki-types = "1.14.0" +tokio-rustls = "0.26.4" +webpki-roots = "1.0.6" # HTTP server (gateway) — replaces raw TCP for proper HTTP/1.1 compliance -axum = { version = "0.7", default-features = false, features = ["http1", "json", "tokio", "query"] } +axum = { version = "0.8", default-features = false, features = ["http1", "json", "tokio", "query", "ws"] } tower = { version = "0.5", default-features = false } tower-http = { version = "0.6", default-features = false, features = ["limit", "timeout"] } http-body-util = "0.1" +# OpenTelemetry — OTLP trace + metrics export +opentelemetry = { version = "0.31", default-features = false, features = ["trace", "metrics"] } +opentelemetry_sdk = { version = "0.31", default-features = false, features = ["trace", "metrics"] } +opentelemetry-otlp = { version = "0.31", default-features = false, features = ["trace", "metrics", "http-proto", "reqwest-client", "reqwest-rustls-webpki-roots"] } + +# USB device enumeration (hardware discovery) +nusb = { version = "0.2", default-features = false, optional = true } + +# Serial port for peripheral communication (STM32, etc.) +tokio-serial = { version = "5", default-features = false, optional = true } + +# probe-rs for STM32/Nucleo memory read (Phase B) +probe-rs = { version = "0.30", optional = true } + +# PDF extraction for datasheet RAG (optional, enable with --features rag-pdf) +pdf-extract = { version = "0.10", optional = true } + +# Raspberry Pi GPIO / Landlock (Linux only) — target-specific to avoid compile failure on macOS +[target.'cfg(target_os = "linux")'.dependencies] +rppal = { version = "0.14", optional = true } +landlock = { version = "0.4", optional = true } + +[features] +default = ["hardware"] +hardware = ["nusb", "tokio-serial"] +peripheral-rpi = ["rppal"] +# Browser backend feature alias used by cfg(feature = "browser-native") +browser-native = ["dep:fantoccini"] +# Backward-compatible alias for older invocations +fantoccini = ["browser-native"] +# Sandbox feature aliases used by cfg(feature = "sandbox-*") +sandbox-landlock = ["dep:landlock"] +sandbox-bubblewrap = [] +# Backward-compatible alias for older invocations +landlock = ["sandbox-landlock"] +# probe = probe-rs for Nucleo memory read (adds ~50 deps; optional) +probe = ["dep:probe-rs"] +# rag-pdf = PDF ingestion for datasheet RAG +rag-pdf = ["dep:pdf-extract"] [profile.release] opt-level = "z" # Optimize for size -lto = true # Link-time optimization -codegen-units = 1 # Better optimization -strip = true # Remove debug symbols -panic = "abort" # Reduce binary size +lto = "thin" # Lower memory use during release builds +codegen-units = 1 # Serialized codegen for low-memory devices (e.g., Raspberry Pi 3 with 1GB RAM) + # Higher values (e.g., 8) compile faster but require more RAM during compilation +strip = true # Remove debug symbols +panic = "abort" # Reduce binary size + +[profile.release-fast] +inherits = "release" +codegen-units = 8 # Parallel codegen for faster builds on powerful machines (16GB+ RAM recommended) + # Use: cargo build --profile release-fast [profile.dist] inherits = "release" diff --git a/Dockerfile b/Dockerfile index 0975ee8..693e4de 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,41 +1,115 @@ +# syntax=docker/dockerfile:1.7 + # ── Stage 1: Build ──────────────────────────────────────────── -FROM rust:1.83-slim AS builder +FROM rust:1.92-slim@sha256:bf3368a992915f128293ac76917ab6e561e4dda883273c8f5c9f6f8ea37a378e AS builder WORKDIR /app + +# Install build dependencies +RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ + --mount=type=cache,target=/var/lib/apt,sharing=locked \ + apt-get update && apt-get install -y \ + pkg-config \ + && rm -rf /var/lib/apt/lists/* + +# 1. Copy manifests to cache dependencies COPY Cargo.toml Cargo.lock ./ -COPY src/ src/ +# Create dummy main.rs to build dependencies +RUN mkdir src && echo "fn main() {}" > src/main.rs +RUN --mount=type=cache,id=zeroclaw-cargo-registry,target=/usr/local/cargo/registry,sharing=locked \ + --mount=type=cache,id=zeroclaw-cargo-git,target=/usr/local/cargo/git,sharing=locked \ + --mount=type=cache,id=zeroclaw-target,target=/app/target,sharing=locked \ + cargo build --release --locked +RUN rm -rf src -RUN cargo build --release --locked && \ - strip target/release/zeroclaw +# 2. Copy source code +COPY . . +RUN --mount=type=cache,id=zeroclaw-cargo-registry,target=/usr/local/cargo/registry,sharing=locked \ + --mount=type=cache,id=zeroclaw-cargo-git,target=/usr/local/cargo/git,sharing=locked \ + --mount=type=cache,id=zeroclaw-target,target=/app/target,sharing=locked \ + cargo build --release --locked && \ + cp target/release/zeroclaw /app/zeroclaw && \ + strip /app/zeroclaw -# ── Stage 2: Runtime (distroless nonroot — no shell, no OS, tiny, UID 65534) ── -FROM gcr.io/distroless/cc-debian12:nonroot +# ── Stage 2: Permissions & Config Prep ─────────────────────── +FROM busybox:1.37@sha256:b3255e7dfbcd10cb367af0d409747d511aeb66dfac98cf30e97e87e4207dd76f AS permissions +# Create directory structure (simplified workspace path) +RUN mkdir -p /zeroclaw-data/.zeroclaw /zeroclaw-data/workspace -COPY --from=builder /app/target/release/zeroclaw /usr/local/bin/zeroclaw +# Create minimal config for PRODUCTION (allows binding to public interfaces) +# NOTE: Provider configuration must be done via environment variables at runtime +RUN cat > /zeroclaw-data/.zeroclaw/config.toml <ZeroClaw 🦀

- Zero overhead. Zero compromise. 100% Rust. 100% Agnostic. + Zero overhead. Zero compromise. 100% Rust. 100% Agnostic.
+ ⚡️ Runs on $10 hardware with <5MB RAM: That's 99% less memory than OpenClaw and 98% cheaper than a Mac mini!

License: MIT + Contributors

Fast, small, and fully autonomous AI assistant infrastructure — deploy anywhere, swap anything. ``` -~3.4MB binary · <10ms startup · 1,017 tests · 22+ providers · 8 traits · Pluggable everything +~3.4MB binary · <10ms startup · 1,017 tests · 23+ providers · 8 traits · Pluggable everything ``` +### ✨ Features + +- 🏎️ **Ultra-Lightweight:** <5MB Memory footprint — 99% smaller than OpenClaw core. +- 💰 **Minimal Cost:** Efficient enough to run on $10 Hardware — 98% cheaper than a Mac mini. +- ⚡ **Lightning Fast:** 400X Faster startup time, boot in <10ms (under 1s even on 0.6GHz cores). +- 🌍 **True Portability:** Single self-contained binary across ARM, x86, and RISC-V. + ### Why teams pick ZeroClaw - **Lean by default:** small Rust binary, fast startup, low memory footprint. @@ -27,17 +36,21 @@ Fast, small, and fully autonomous AI assistant infrastructure — deploy anywher ## Benchmark Snapshot (ZeroClaw vs OpenClaw) -Local machine quick benchmark (macOS arm64, Feb 2026), same host, 3 runs each. +Local machine quick benchmark (macOS arm64, Feb 2026) normalized for 0.8GHz edge hardware. -| Metric | ZeroClaw (Rust release binary) | OpenClaw (Node + built `dist`) | -|---|---:|---:| -| Build output size | `target/release/zeroclaw`: **3.4 MB** | `dist/`: **28 MB** | -| `--help` startup (cold/warm) | **0.38s / ~0.00s** | **3.31s / ~1.11s** | -| `status` command runtime (best of 3) | **~0.00s** | **5.98s** | -| `--help` max RSS observed | **~7.3 MB** | **~394 MB** | -| `status` max RSS observed | **~7.8 MB** | **~1.52 GB** | +| | OpenClaw | NanoBot | PicoClaw | ZeroClaw 🦀 | +|---|---|---|---|---| +| **Language** | TypeScript | Python | Go | **Rust** | +| **RAM** | > 1GB | > 100MB | < 10MB | **< 5MB** | +| **Startup (0.8GHz core)** | > 500s | > 30s | < 1s | **< 10ms** | +| **Binary Size** | ~28MB (dist) | N/A (Scripts) | ~8MB | **3.4 MB** | +| **Cost** | Mac Mini $599 | Linux SBC ~$50 | Linux Board $10 | **Any hardware $10** | -> Notes: measured with `/usr/bin/time -l`; first run includes cold-start effects. OpenClaw results were measured after `pnpm install` + `pnpm build`. +> Notes: ZeroClaw results measured with `/usr/bin/time -l` on release builds. OpenClaw requires Node.js runtime (~390MB overhead). PicoClaw and ZeroClaw are static binaries. + +

+ ZeroClaw vs OpenClaw Comparison +

Reproduce ZeroClaw numbers locally: @@ -49,13 +62,78 @@ ls -lh target/release/zeroclaw /usr/bin/time -l target/release/zeroclaw status ``` +## Prerequisites + +
+Windows + +#### Required + +1. **Visual Studio Build Tools** (provides the MSVC linker and Windows SDK): + ```powershell + winget install Microsoft.VisualStudio.2022.BuildTools + ``` + During installation (or via the Visual Studio Installer), select the **"Desktop development with C++"** workload. + +2. **Rust toolchain:** + ```powershell + winget install Rustlang.Rustup + ``` + After installation, open a new terminal and run `rustup default stable` to ensure the stable toolchain is active. + +3. **Verify** both are working: + ```powershell + rustc --version + cargo --version + ``` + +#### Optional + +- **Docker Desktop** — required only if using the [Docker sandboxed runtime](#runtime-support-current) (`runtime.kind = "docker"`). Install via `winget install Docker.DockerDesktop`. + +
+ +
+Linux / macOS + +#### Required + +1. **Build essentials:** + - **Linux (Debian/Ubuntu):** `sudo apt install build-essential pkg-config` + - **Linux (Fedora/RHEL):** `sudo dnf groupinstall "Development Tools" && sudo dnf install pkg-config` + - **macOS:** Install Xcode Command Line Tools: `xcode-select --install` + +2. **Rust toolchain:** + ```bash + curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh + ``` + See [rustup.rs](https://rustup.rs) for details. + +3. **Verify** both are working: + ```bash + rustc --version + cargo --version + ``` + +#### Optional + +- **Docker** — required only if using the [Docker sandboxed runtime](#runtime-support-current) (`runtime.kind = "docker"`). Install via your package manager or [docker.com](https://docs.docker.com/engine/install/). + +> **Note:** The default `cargo build --release` uses `codegen-units=1` for compatibility with low-memory devices (e.g., Raspberry Pi 3 with 1GB RAM). For faster builds on powerful machines, use `cargo build --profile release-fast`. + +
+ + ## Quick Start ```bash -git clone https://github.com/theonlyhennygod/zeroclaw.git +git clone https://github.com/zeroclaw-labs/zeroclaw.git cd zeroclaw -cargo build --release -cargo install --path . --force +cargo build --release --locked +cargo install --path . --force --locked + +# Ensure ~/.cargo/bin is in your PATH +export PATH="$HOME/.cargo/bin:$PATH" # Quick setup (no prompts) zeroclaw onboard --api-key sk-... --provider openrouter @@ -88,6 +166,9 @@ zeroclaw doctor # Check channel health zeroclaw channel doctor +# Bind a Telegram identity into allowlist +zeroclaw channel bind-telegram 123456789 + # Get integration setup details zeroclaw integrations info Telegram @@ -112,12 +193,12 @@ Every subsystem is a **trait** — swap implementations with a config change, ze | Subsystem | Trait | Ships with | Extend | |-----------|-------|------------|--------| -| **AI Models** | `Provider` | 22+ providers (OpenRouter, Anthropic, OpenAI, Ollama, Venice, Groq, Mistral, xAI, DeepSeek, Together, Fireworks, Perplexity, Cohere, Bedrock, etc.) | `custom:https://your-api.com` — any OpenAI-compatible API | -| **Channels** | `Channel` | CLI, Telegram, Discord, Slack, iMessage, Matrix, WhatsApp, Webhook | Any messaging API | -| **Memory** | `Memory` | SQLite with hybrid search (FTS5 + vector cosine similarity), Markdown | Any persistence backend | -| **Tools** | `Tool` | shell, file_read, file_write, memory_store, memory_recall, memory_forget, browser_open (Brave + allowlist), composio (optional) | Any capability | +| **AI Models** | `Provider` | 23+ providers (OpenRouter, Anthropic, OpenAI, Ollama, Venice, Groq, Mistral, xAI, DeepSeek, Together, Fireworks, Perplexity, Cohere, Bedrock, Astrai, etc.) | `custom:https://your-api.com` — any OpenAI-compatible API | +| **Channels** | `Channel` | CLI, Telegram, Discord, Slack, Mattermost, iMessage, Matrix, WhatsApp, Webhook | Any messaging API | +| **Memory** | `Memory` | SQLite with hybrid search (FTS5 + vector cosine similarity), Lucid bridge (CLI sync + SQLite fallback), Markdown | Any persistence backend | +| **Tools** | `Tool` | shell, file_read, file_write, memory_store, memory_recall, memory_forget, browser_open (Brave + allowlist), browser (agent-browser / rust-native), composio (optional) | Any capability | | **Observability** | `Observer` | Noop, Log, Multi | Prometheus, OTel | -| **Runtime** | `RuntimeAdapter` | Native (Mac/Linux/Pi) | Docker, WASM (planned; unsupported kinds fail fast) | +| **Runtime** | `RuntimeAdapter` | Native, Docker (sandboxed) | WASM (planned; unsupported kinds fail fast) | | **Security** | `SecurityPolicy` | Gateway pairing, sandbox, allowlists, rate limits, filesystem scoping, encrypted secrets | — | | **Identity** | `IdentityConfig` | OpenClaw (markdown), AIEOS v1.1 (JSON) | Any identity format | | **Tunnel** | `Tunnel` | None, Cloudflare, Tailscale, ngrok, Custom | Any tunnel binary | @@ -127,8 +208,8 @@ Every subsystem is a **trait** — swap implementations with a config change, ze ### Runtime support (current) -- ✅ Supported today: `runtime.kind = "native"` -- 🚧 Planned, not implemented yet: Docker / WASM / edge runtimes +- ✅ Supported today: `runtime.kind = "native"` or `runtime.kind = "docker"` +- 🚧 Planned, not implemented yet: WASM / edge runtimes When an unsupported `runtime.kind` is configured, ZeroClaw now exits with a clear error instead of silently falling back to native. @@ -150,11 +231,21 @@ The agent automatically recalls, saves, and manages memory via tools. ```toml [memory] -backend = "sqlite" # "sqlite", "markdown", "none" +backend = "sqlite" # "sqlite", "lucid", "markdown", "none" auto_save = true embedding_provider = "openai" vector_weight = 0.7 keyword_weight = 0.3 + +# backend = "none" uses an explicit no-op memory backend (no persistence) + +# Optional for backend = "lucid" +# ZEROCLAW_LUCID_CMD=/usr/local/bin/lucid # default: lucid +# ZEROCLAW_LUCID_BUDGET=200 # default: 200 +# ZEROCLAW_LUCID_LOCAL_HIT_THRESHOLD=3 # local hit count to skip external recall +# ZEROCLAW_LUCID_RECALL_TIMEOUT_MS=120 # low-latency budget for lucid context recall +# ZEROCLAW_LUCID_STORE_TIMEOUT_MS=800 # async sync timeout for lucid store +# ZEROCLAW_LUCID_FAILURE_COOLDOWN_MS=15000 # cooldown after lucid failure to avoid repeated slow attempts ``` ## Security @@ -172,7 +263,7 @@ ZeroClaw enforces security at **every layer** — not just the sandbox. It passe > **Run your own nmap:** `nmap -p 1-65535 ` — ZeroClaw binds to localhost only, so nothing is exposed unless you explicitly configure a tunnel. -### Channel allowlists (Telegram / Discord / Slack) +### Channel allowlists (Telegram / Discord / Slack / Mattermost) Inbound sender policy is now consistent: @@ -187,8 +278,22 @@ Recommended low-friction setup (secure + fast): - **Telegram:** allowlist your own `@username` (without `@`) and/or your numeric Telegram user ID. - **Discord:** allowlist your own Discord user ID. - **Slack:** allowlist your own Slack member ID (usually starts with `U`). +- **Mattermost:** uses standard API v4. Allowlists use Mattermost user IDs. - Use `"*"` only for temporary open testing. +Telegram operator-approval flow: + +1. Keep `[channels_config.telegram].allowed_users = []` for deny-by-default startup. +2. Unauthorized users receive a hint with a copyable operator command: + `zeroclaw channel bind-telegram `. +3. Operator runs that command locally, then user retries sending a message. + +If you need a one-shot manual approval, run: + +```bash +zeroclaw channel bind-telegram 123456789 +``` + If you're not sure which identity to use: 1. Start channels and send one message to your bot. @@ -202,6 +307,21 @@ rerun channel setup only: zeroclaw onboard --channels-only ``` +### Telegram media replies + +Telegram routing now replies to the source **chat ID** from incoming updates (instead of usernames), +which avoids `Bad Request: chat not found` failures. + +For non-text replies, ZeroClaw can send Telegram attachments when the assistant includes markers: + +- `[IMAGE:]` +- `[DOCUMENT:]` +- `[VIDEO:]` +- `[AUDIO:]` +- `[VOICE:]` + +Paths can be local files (for example `/tmp/screenshot.png`) or HTTPS URLs. + ### WhatsApp Business Cloud API Setup WhatsApp uses Meta's Cloud API with webhooks (push-based, not polling): @@ -250,12 +370,14 @@ default_model = "anthropic/claude-sonnet-4-20250514" default_temperature = 0.7 [memory] -backend = "sqlite" # "sqlite", "markdown", "none" +backend = "sqlite" # "sqlite", "lucid", "markdown", "none" auto_save = true embedding_provider = "openai" # "openai", "noop" vector_weight = 0.7 keyword_weight = 0.3 +# backend = "none" disables persistent memory via no-op backend + [gateway] require_pairing = true # require pairing code on first connect allow_public_bind = false # refuse 0.0.0.0 without tunnel @@ -267,7 +389,16 @@ allowed_commands = ["git", "npm", "cargo", "ls", "cat", "grep"] forbidden_paths = ["/etc", "/root", "/proc", "/sys", "~/.ssh", "~/.gnupg", "~/.aws"] [runtime] -kind = "native" # only supported value right now; unsupported kinds fail fast +kind = "native" # "native" or "docker" + +[runtime.docker] +image = "alpine:3.20" # container image for shell execution +network = "none" # docker network mode ("none", "bridge", etc.) +memory_limit_mb = 512 # optional memory limit in MB +cpu_limit = 1.0 # optional CPU limit +read_only_rootfs = true # mount root filesystem as read-only +mount_workspace = true # mount workspace into /workspace +allowed_workspace_roots = [] # optional allowlist for workspace mount validation [heartbeat] enabled = false @@ -280,11 +411,40 @@ provider = "none" # "none", "cloudflare", "tailscale", "ngrok", "c encrypt = true # API keys encrypted with local key file [browser] -enabled = false # opt-in browser_open tool -allowed_domains = ["docs.rs"] # required when browser is enabled +enabled = false # opt-in browser_open + browser tools +allowed_domains = ["docs.rs"] # required when browser is enabled +backend = "agent_browser" # "agent_browser" (default), "rust_native", "computer_use", "auto" +native_headless = true # applies when backend uses rust-native +native_webdriver_url = "http://127.0.0.1:9515" # WebDriver endpoint (chromedriver/selenium) +# native_chrome_path = "/usr/bin/chromium" # optional explicit browser binary for driver + +[browser.computer_use] +endpoint = "http://127.0.0.1:8787/v1/actions" # computer-use sidecar HTTP endpoint +timeout_ms = 15000 # per-action timeout +allow_remote_endpoint = false # secure default: only private/localhost endpoint +window_allowlist = [] # optional window title/process allowlist hints +# api_key = "..." # optional bearer token for sidecar +# max_coordinate_x = 3840 # optional coordinate guardrail +# max_coordinate_y = 2160 # optional coordinate guardrail + +# Rust-native backend build flag: +# cargo build --release --features browser-native +# Ensure a WebDriver server is running, e.g. chromedriver --port=9515 + +# Computer-use sidecar contract (MVP) +# POST browser.computer_use.endpoint +# Request: { +# "action": "mouse_click", +# "params": {"x": 640, "y": 360, "button": "left"}, +# "policy": {"allowed_domains": [...], "window_allowlist": [...], "max_coordinate_x": 3840, "max_coordinate_y": 2160}, +# "metadata": {"session_name": "...", "source": "zeroclaw.browser", "version": "..."} +# } +# Response: {"success": true, "data": {...}} or {"success": false, "error": "..."} [composio] enabled = false # opt-in: 1000+ OAuth apps via composio.dev +# api_key = "cmp_..." # optional: stored encrypted when [secrets].encrypt = true +entity_id = "default" # default user_id for Composio tool calls [identity] format = "openclaw" # "openclaw" (default, markdown files) or "aieos" (JSON) @@ -292,6 +452,57 @@ format = "openclaw" # "openclaw" (default, markdown files) or "aieos # aieos_inline = '{"identity":{"names":{"first":"Nova"}}}' # inline AIEOS JSON ``` +### Ollama Local and Remote Endpoints + +ZeroClaw uses one provider key (`ollama`) for both local and remote Ollama deployments: + +- Local Ollama: keep `api_url` unset, run `ollama serve`, and use models like `llama3.2`. +- Remote Ollama endpoint (including Ollama Cloud): set `api_url` to the remote endpoint and set `api_key` (or `OLLAMA_API_KEY`) when required. +- Optional `:cloud` suffix: model IDs like `qwen3:cloud` are normalized to `qwen3` before the request. + +Example remote configuration: + +```toml +default_provider = "ollama" +default_model = "qwen3:cloud" +api_url = "https://ollama.com" +api_key = "ollama_api_key_here" +``` + +## Python Companion Package (`zeroclaw-tools`) + +For LLM providers with inconsistent native tool calling (e.g., GLM-5/Zhipu), ZeroClaw ships a Python companion package with **LangGraph-based tool calling** for guaranteed consistency: + +```bash +pip install zeroclaw-tools +``` + +```python +from zeroclaw_tools import create_agent, shell, file_read +from langchain_core.messages import HumanMessage + +# Works with any OpenAI-compatible provider +agent = create_agent( + tools=[shell, file_read], + model="glm-5", + api_key="your-key", + base_url="https://api.z.ai/api/coding/paas/v4" +) + +result = await agent.ainvoke({ + "messages": [HumanMessage(content="List files in /tmp")] +}) +print(result["messages"][-1].content) +``` + +**Why use it:** +- **Consistent tool calling** across all providers (even those with poor native support) +- **Automatic tool loop** — keeps calling tools until the task is complete +- **Easy extensibility** — add custom tools with `@tool` decorator +- **Discord bot integration** included (Telegram planned) + +See [`python/README.md`](python/README.md) for full documentation. + ## Identity System (AIEOS Support) ZeroClaw supports **identity-agnostic** AI personas through two formats: @@ -386,13 +597,15 @@ See [aieos.org](https://aieos.org) for the full schema and live examples. | `doctor` | Diagnose daemon/scheduler/channel freshness | | `status` | Show full system status | | `channel doctor` | Run health checks for configured channels | +| `channel bind-telegram ` | Add one Telegram username/user ID to allowlist | | `integrations info ` | Show setup/status details for one integration | ## Development ```bash cargo build # Dev build -cargo build --release # Release build (~3.4MB) +cargo build --release # Release build (codegen-units=1, works on all devices including Raspberry Pi) +cargo build --profile release-fast # Faster build (codegen-units=8, requires 16GB+ RAM) cargo test # 1,017 tests cargo clippy # Lint (0 warnings) cargo fmt # Format @@ -409,19 +622,53 @@ A git hook runs `cargo fmt --check`, `cargo clippy -- -D warnings`, and `cargo t git config core.hooksPath .githooks ``` +### Build troubleshooting (Linux OpenSSL errors) + +If you see an `openssl-sys` build error, sync dependencies and rebuild with the repository lockfile: + +```bash +git pull +cargo build --release --locked +cargo install --path . --force --locked +``` + +ZeroClaw is configured to use `rustls` for HTTP/TLS dependencies; `--locked` keeps the transitive graph deterministic on fresh environments. + To skip the hook when you need a quick push during development: ```bash git push --no-verify ``` +## Collaboration & Docs + +For high-throughput collaboration and consistent reviews: + +- Contribution guide: [CONTRIBUTING.md](CONTRIBUTING.md) +- PR workflow policy: [docs/pr-workflow.md](docs/pr-workflow.md) +- Reviewer playbook (triage + deep review): [docs/reviewer-playbook.md](docs/reviewer-playbook.md) +- CI ownership and triage map: [docs/ci-map.md](docs/ci-map.md) +- Security disclosure policy: [SECURITY.md](SECURITY.md) + +### 🙏 Special Thanks + +A heartfelt thank you to the communities and institutions that inspire and fuel this open-source work: + +- **Harvard University** — for fostering intellectual curiosity and pushing the boundaries of what's possible. +- **MIT** — for championing open knowledge, open source, and the belief that technology should be accessible to everyone. +- **Sundai Club** — for the community, the energy, and the relentless drive to build things that matter. +- **The World & Beyond** 🌍✨ — to every contributor, dreamer, and builder out there making open source a force for good. This is for you. + +We're building in the open because the best ideas come from everywhere. If you're reading this, you're part of it. Welcome. 🦀❤️ + ## License -MIT — see [LICENSE](LICENSE) +MIT — see [LICENSE](LICENSE) and [NOTICE](NOTICE) for contributor attribution ## Contributing See [CONTRIBUTING.md](CONTRIBUTING.md). Implement a trait, submit a PR: +- CI workflow guide: [docs/ci-map.md](docs/ci-map.md) - New `Provider` → `src/providers/` - New `Channel` → `src/channels/` - New `Observer` → `src/observability/` @@ -433,3 +680,11 @@ See [CONTRIBUTING.md](CONTRIBUTING.md). Implement a trait, submit a PR: --- **ZeroClaw** — Zero overhead. Zero compromise. Deploy anywhere. Swap anything. 🦀 + +## Star History + +

+ + Star History Chart + +

diff --git a/RUN_TESTS.md b/RUN_TESTS.md new file mode 100644 index 0000000..eddc578 --- /dev/null +++ b/RUN_TESTS.md @@ -0,0 +1,303 @@ +# 🧪 Test Execution Guide + +## Quick Reference + +```bash +# Full automated test suite (~2 min) +./test_telegram_integration.sh + +# Quick smoke test (~10 sec) +./quick_test.sh + +# Just compile and unit test (~30 sec) +cargo test telegram --lib +``` + +## 📝 What Was Created For You + +### 1. **test_telegram_integration.sh** (Main Test Suite) + - **20+ automated tests** covering all fixes + - **6 test phases**: Code quality, build, config, health, features, manual + - **Colored output** with pass/fail indicators + - **Detailed summary** at the end + + ```bash + ./test_telegram_integration.sh + ``` + +### 2. **quick_test.sh** (Fast Validation) + - **4 essential tests** for quick feedback + - **<10 second** execution time + - Perfect for **pre-commit** checks + + ```bash + ./quick_test.sh + ``` + +### 3. **generate_test_messages.py** (Test Helper) + - Generates test messages of various lengths + - Tests message splitting functionality + - 8 different message types + + ```bash + # Generate a long message (>4096 chars) + python3 test_helpers/generate_test_messages.py long + + # Show all message types + python3 test_helpers/generate_test_messages.py all + ``` + +### 4. **TESTING_TELEGRAM.md** (Complete Guide) + - Comprehensive testing documentation + - Troubleshooting guide + - Performance benchmarks + - CI/CD integration examples + +## 🚀 Step-by-Step: First Run + +### Step 1: Run Automated Tests + +```bash +cd /Users/abdzsam/zeroclaw + +# Make scripts executable (already done) +chmod +x test_telegram_integration.sh quick_test.sh + +# Run the full test suite +./test_telegram_integration.sh +``` + +**Expected output:** +``` +⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡ + +███████╗███████╗██████╗ ██████╗ ██████╗██╗ █████╗ ██╗ ██╗ +... + +🧪 TELEGRAM INTEGRATION TEST SUITE 🧪 + +Phase 1: Code Quality Tests +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +Test 1: Compiling test suite +✓ PASS: Test suite compiles successfully + +Test 2: Running Telegram unit tests +✓ PASS: All Telegram unit tests passed (24 tests) +... + +Test Summary +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +Total Tests: 20 +Passed: 20 +Failed: 0 +Warnings: 0 + +Pass Rate: 100% + +✓ ALL AUTOMATED TESTS PASSED! 🎉 +``` + +### Step 2: Configure Telegram (if not done) + +```bash +# Interactive setup +zeroclaw onboard --interactive + +# Or channels-only setup +zeroclaw onboard --channels-only +``` + +When prompted: +1. Select **Telegram** channel +2. Enter your **bot token** from @BotFather +3. Enter your **Telegram user ID** or username + +### Step 3: Verify Health + +```bash +zeroclaw channel doctor +``` + +**Expected output:** +``` +🩺 ZeroClaw Channel Doctor + + ✅ Telegram healthy + +Summary: 1 healthy, 0 unhealthy, 0 timed out +``` + +### Step 4: Manual Testing + +#### Test 1: Basic Message + +```bash +# Terminal 1: Start the channel +zeroclaw channel start +``` + +**In Telegram:** +- Find your bot +- Send: `Hello bot!` +- **Verify**: Bot responds within 3 seconds + +#### Test 2: Long Message (Split Test) + +```bash +# Generate a long message +python3 test_helpers/generate_test_messages.py long +``` + +- **Copy the output** +- **Paste into Telegram** to your bot +- **Verify**: + - Message is split into 2+ chunks + - First chunk ends with `(continues...)` + - Middle chunks have `(continued)` and `(continues...)` + - Last chunk starts with `(continued)` + - All chunks arrive in order + +#### Test 3: Word Boundary Splitting + +```bash +python3 test_helpers/generate_test_messages.py word +``` + +- Send to bot +- **Verify**: Splits at word boundaries (not mid-word) + +## 🎯 Test Results Checklist + +After running all tests, verify: + +### Automated Tests +- [ ] ✅ All 20 automated tests passed +- [ ] ✅ Build completed successfully +- [ ] ✅ Binary size <10MB +- [ ] ✅ Health check completes in <5s +- [ ] ✅ No clippy warnings + +### Manual Tests +- [ ] ✅ Bot responds to basic messages +- [ ] ✅ Long messages split correctly +- [ ] ✅ Continuation markers appear +- [ ] ✅ Word boundaries respected +- [ ] ✅ Allowlist blocks unauthorized users +- [ ] ✅ No errors in logs + +### Performance +- [ ] ✅ Response time <3 seconds +- [ ] ✅ Memory usage <10MB +- [ ] ✅ No message loss +- [ ] ✅ Rate limiting works (100ms delays) + +## 🐛 Troubleshooting + +### Issue: Tests fail to compile + +```bash +# Clean build +cargo clean +cargo build --release + +# Update dependencies +cargo update +``` + +### Issue: "Bot token not configured" + +```bash +# Check config +cat ~/.zeroclaw/config.toml | grep -A 5 telegram + +# Reconfigure +zeroclaw onboard --channels-only +``` + +### Issue: Health check fails + +```bash +# Test bot token directly +curl "https://api.telegram.org/bot/getMe" + +# Should return: {"ok":true,"result":{...}} +``` + +### Issue: Bot doesn't respond + +```bash +# Enable debug logging +RUST_LOG=debug zeroclaw channel start + +# Look for: +# - "Telegram channel listening for messages..." +# - "ignoring message from unauthorized user" (if allowlist issue) +# - Any error messages +``` + +## 📊 Performance Benchmarks + +After all fixes, you should see: + +| Metric | Target | Command | +|--------|--------|---------| +| Unit test pass | 24/24 | `cargo test telegram --lib` | +| Build time | <30s | `time cargo build --release` | +| Binary size | ~3-4MB | `ls -lh target/release/zeroclaw` | +| Health check | <5s | `time zeroclaw channel doctor` | +| First response | <3s | Manual test in Telegram | +| Message split | <50ms | Check debug logs | +| Memory usage | <10MB | `ps aux \| grep zeroclaw` | + +## 🔄 CI/CD Integration + +Add to your workflow: + +```bash +# Pre-commit hook +#!/bin/bash +./quick_test.sh + +# CI pipeline +./test_telegram_integration.sh +``` + +## 📚 Next Steps + +1. **Run the tests:** + ```bash + ./test_telegram_integration.sh + ``` + +2. **Fix any failures** using the troubleshooting guide + +3. **Complete manual tests** using the checklist + +4. **Deploy to production** when all tests pass + +5. **Monitor logs** for any issues: + ```bash + zeroclaw daemon + # or + RUST_LOG=info zeroclaw channel start + ``` + +## 🎉 Success! + +If all tests pass: +- ✅ Message splitting works (4096 char limit) +- ✅ Health check has 5s timeout +- ✅ Empty chat_id is handled safely +- ✅ All 24 unit tests pass +- ✅ Code is production-ready + +**Your Telegram integration is ready to go!** 🚀 + +--- + +## 📞 Support + +- Issues: https://github.com/theonlyhennygod/zeroclaw/issues +- Docs: `./TESTING_TELEGRAM.md` +- Help: `zeroclaw --help` diff --git a/TESTING_TELEGRAM.md b/TESTING_TELEGRAM.md new file mode 100644 index 0000000..128ff76 --- /dev/null +++ b/TESTING_TELEGRAM.md @@ -0,0 +1,337 @@ +# Telegram Integration Testing Guide + +This guide covers testing the Telegram channel integration for ZeroClaw. + +## 🚀 Quick Start + +### Automated Tests + +```bash +# Full test suite (20+ tests, ~2 minutes) +./test_telegram_integration.sh + +# Quick smoke test (~10 seconds) +./quick_test.sh + +# Just unit tests +cargo test telegram --lib +``` + +## 📋 Test Coverage + +### Automated Tests (20 tests) + +The `test_telegram_integration.sh` script runs: + +**Phase 1: Code Quality (5 tests)** + +- ✅ Test compilation +- ✅ Unit tests (24 tests) +- ✅ Message splitting tests (8 tests) +- ✅ Clippy linting +- ✅ Code formatting + +**Phase 2: Build Tests (3 tests)** + +- ✅ Debug build +- ✅ Release build +- ✅ Binary size verification (<10MB) + +**Phase 3: Configuration Tests (4 tests)** + +- ✅ Config file exists +- ✅ Telegram section configured +- ✅ Bot token set +- ✅ User allowlist configured + +**Phase 4: Health Check Tests (2 tests)** + +- ✅ Health check timeout (<5s) +- ✅ Telegram API connectivity + +**Phase 5: Feature Validation (6 tests)** + +- ✅ Message splitting function +- ✅ Message length constant (4096) +- ✅ Timeout implementation +- ✅ chat_id validation +- ✅ Duration import +- ✅ Continuation markers + +### Manual Tests (6 tests) + +After running automated tests, perform these manual checks: + +1. **Basic messaging** + + ```bash + zeroclaw channel start + ``` + + - Send "Hello bot!" in Telegram + - Verify response within 3 seconds + +2. **Long message splitting** + + ```bash + # Generate 5000+ char message + python3 -c 'print("test " * 1000)' + ``` + + - Paste into Telegram + - Verify: Message split into chunks + - Verify: Markers show `(continues...)` and `(continued)` + - Verify: All chunks arrive in order + +3. **Unauthorized user blocking** + + ```toml + # Edit ~/.zeroclaw/config.toml + allowed_users = ["999999999"] + ``` + + - Send message to bot + - Verify: Warning in logs + - Verify: Message ignored + - Restore correct user ID + +4. **Rate limiting** + - Send 10 messages rapidly + - Verify: All processed + - Verify: No "Too Many Requests" errors + - Verify: Responses have delays + +5. **Error logging** + + ```bash + RUST_LOG=debug zeroclaw channel start + ``` + + - Check for unexpected errors + - Verify proper error handling + +6. **Health check timeout** + + ```bash + time zeroclaw channel doctor + ``` + + - Verify: Completes in <5 seconds + +## 🔍 Test Results Interpretation + +### Success Criteria + +- All 20 automated tests pass ✅ +- Health check completes in <5s ✅ +- Binary size <10MB ✅ +- No clippy warnings ✅ +- All manual tests pass ✅ + +### Common Issues + +**Issue: Health check times out** + +``` +Solution: Check bot token is valid + curl "https://api.telegram.org/bot/getMe" +``` + +**Issue: Bot doesn't respond** + +``` +Solution: Check user allowlist + 1. Send message to bot + 2. Check logs for user_id + 3. Update config: allowed_users = ["YOUR_ID"] + 4. Run: zeroclaw onboard --channels-only +``` + +**Issue: Message splitting not working** + +``` +Solution: Verify code changes + grep -n "split_message_for_telegram" src/channels/telegram.rs + grep -n "TELEGRAM_MAX_MESSAGE_LENGTH" src/channels/telegram.rs +``` + +## 🧪 Test Scenarios + +### Scenario 1: First-Time Setup + +```bash +# 1. Run automated tests +./test_telegram_integration.sh + +# 2. Configure Telegram +zeroclaw onboard --interactive +# Select Telegram channel +# Enter bot token (from @BotFather) +# Enter your user ID + +# 3. Verify health +zeroclaw channel doctor + +# 4. Start channel +zeroclaw channel start + +# 5. Send test message in Telegram +``` + +### Scenario 2: After Code Changes + +```bash +# 1. Quick validation +./quick_test.sh + +# 2. Full test suite +./test_telegram_integration.sh + +# 3. Manual smoke test +zeroclaw channel start +# Send message in Telegram +``` + +### Scenario 3: Production Deployment + +```bash +# 1. Full test suite +./test_telegram_integration.sh + +# 2. Load test (optional) +# Send 100 messages rapidly +for i in {1..100}; do + echo "Test message $i" | \ + curl -X POST "https://api.telegram.org/bot/sendMessage" \ + -d "chat_id=" \ + -d "text=Message $i" +done + +# 3. Monitor logs +RUST_LOG=info zeroclaw daemon + +# 4. Check metrics +zeroclaw status +``` + +## 📊 Performance Benchmarks + +Expected values after all fixes: + +| Metric | Expected | How to Measure | +| ---------------------- | ---------- | -------------------------------- | +| Health check time | <5s | `time zeroclaw channel doctor` | +| First response time | <3s | Time from sending to receiving | +| Message split overhead | <50ms | Check logs for timing | +| Memory usage | <10MB | `ps aux \| grep zeroclaw` | +| Binary size | ~3-4MB | `ls -lh target/release/zeroclaw` | +| Unit test coverage | 24/24 pass | `cargo test telegram --lib` | + +## 🐛 Debugging Failed Tests + +### Debug Unit Tests + +```bash +# Verbose output +cargo test telegram --lib -- --nocapture + +# Specific test +cargo test telegram_split_over_limit -- --nocapture + +# Show ignored tests +cargo test telegram --lib -- --ignored +``` + +### Debug Integration Issues + +```bash +# Maximum logging +RUST_LOG=trace zeroclaw channel start + +# Check Telegram API directly +curl "https://api.telegram.org/bot/getMe" +curl "https://api.telegram.org/bot/getUpdates" + +# Validate config +cat ~/.zeroclaw/config.toml | grep -A 3 "\[channels_config.telegram\]" +``` + +### Debug Build Issues + +```bash +# Clean build +cargo clean +cargo build --release + +# Check dependencies +cargo tree | grep telegram + +# Update dependencies +cargo update +``` + +## 🎯 CI/CD Integration + +Add to your CI pipeline: + +```yaml +# .github/workflows/test.yml +name: Test Telegram Integration + +on: [push, pull_request] + +jobs: + test: + runs-on: blacksmith-2vcpu-ubuntu-2404 + steps: + - uses: actions/checkout@v3 + - uses: actions-rs/toolchain@v1 + with: + toolchain: stable + - name: Run tests + run: | + cargo test telegram --lib + cargo clippy --all-targets -- -D warnings + - name: Check formatting + run: cargo fmt --check +``` + +## 📝 Test Checklist + +Before merging code: + +- [ ] `./quick_test.sh` passes +- [ ] `./test_telegram_integration.sh` passes +- [ ] Manual tests completed +- [ ] No new clippy warnings +- [ ] Code is formatted (`cargo fmt`) +- [ ] Documentation updated +- [ ] CHANGELOG.md updated + +## 🚨 Emergency Rollback + +If tests fail in production: + +```bash +# 1. Check git history +git log --oneline src/channels/telegram.rs + +# 2. Rollback to previous version +git revert + +# 3. Rebuild +cargo build --release + +# 4. Restart service +zeroclaw service restart + +# 5. Verify +zeroclaw channel doctor +``` + +## 📚 Additional Resources + +- [Telegram Bot API Documentation](https://core.telegram.org/bots/api) +- [ZeroClaw Main README](README.md) +- [Contributing Guide](CONTRIBUTING.md) +- [Issue Tracker](https://github.com/theonlyhennygod/zeroclaw/issues) diff --git a/deny.toml b/deny.toml index 93bd114..8f29292 100644 --- a/deny.toml +++ b/deny.toml @@ -2,14 +2,23 @@ # https://embarkstudios.github.io/cargo-deny/ [advisories] -unmaintained = "workspace" -yanked = "warn" +# In v2, vulnerability advisories always emit errors (not configurable). +# unmaintained: scope of unmaintained-crate checks (all | workspace | transitive | none) +unmaintained = "all" +# yanked: deny | warn | allow +yanked = "deny" +# Ignore known unmaintained transitive deps we cannot easily replace +ignore = [ + # bincode v2.0.1 via probe-rs — project ceased but 1.3.3 considered complete + "RUSTSEC-2025-0141", +] [licenses] # All licenses are denied unless explicitly allowed allow = [ "MIT", "Apache-2.0", + "Apache-2.0 WITH LLVM-exception", "BSD-2-Clause", "BSD-3-Clause", "ISC", @@ -19,6 +28,7 @@ allow = [ "Zlib", "MPL-2.0", "CDLA-Permissive-2.0", + "0BSD", ] unused-allowed-license = "allow" diff --git a/dev/README.md b/dev/README.md new file mode 100644 index 0000000..427b566 --- /dev/null +++ b/dev/README.md @@ -0,0 +1,169 @@ +# ZeroClaw Development Environment + +A fully containerized development sandbox for ZeroClaw agents. This environment allows you to develop, test, and debug the agent in isolation without modifying your host system. + +## Directory Structure + +- **`agent/`**: (Merged into root Dockerfile) + - The development image is built from the root `Dockerfile` using the `dev` stage (`target: dev`). + - Based on `debian:bookworm-slim` (unlike production `distroless`). + - Includes `bash`, `curl`, and debug tools. +- **`sandbox/`**: Dockerfile for the simulated user environment. + - Based on `ubuntu:22.04`. + - Pre-loaded with `git`, `python3`, `nodejs`, `npm`, `gcc`, `make`. + - Simulates a real developer machine. +- **`docker-compose.yml`**: Defines the services and `dev-net` network. +- **`cli.sh`**: Helper script to manage the lifecycle. + +## Usage + +Run all commands from the repository root using the helper script: + +### 1. Start Environment + +```bash +./dev/cli.sh up +``` + +Builds the agent from source and starts both containers. + +### 2. Enter Agent Container (`zeroclaw-dev`) + +```bash +./dev/cli.sh agent +``` + +Use this to run `zeroclaw` CLI commands manually, debug the binary, or check logs internally. + +- **Path**: `/zeroclaw-data` +- **User**: `nobody` (65534) + +### 3. Enter Sandbox (`sandbox`) + +```bash +./dev/cli.sh shell +``` + +Use this to act as the "user" or "environment" the agent interacts with. + +- **Path**: `/home/developer/workspace` +- **User**: `developer` (sudo-enabled) + +### 4. Development Cycle + +1. Make changes to Rust code in `src/`. +2. Rebuild the agent: + ```bash + ./dev/cli.sh build + ``` +3. Test changes inside the container: + ```bash + ./dev/cli.sh agent + # inside container: + zeroclaw --version + ``` + +### 5. Persistence & Shared Workspace + +The local `playground/` directory (in repo root) is mounted as the shared workspace: + +- **Agent**: `/zeroclaw-data/workspace` +- **Sandbox**: `/home/developer/workspace` + +Files created by the agent are visible to the sandbox user, and vice versa. + +The agent configuration lives in `target/.zeroclaw` (mounted to `/zeroclaw-data/.zeroclaw`), so settings persist across container rebuilds. + +### 6. Cleanup + +Stop containers and remove volumes and generated config: + +```bash +./dev/cli.sh clean +``` + +**Note:** This removes `target/.zeroclaw` (config/DB) but leaves the `playground/` directory intact. To fully wipe everything, manually delete `playground/`. + +## Local CI/CD (Docker-Only) + +Use this when you want CI-style validation without relying on GitHub Actions and without running Rust toolchain commands on your host. + +### 1. Build the local CI image + +```bash +./dev/ci.sh build-image +``` + +### 2. Run full local CI pipeline + +```bash +./dev/ci.sh all +``` + +This runs inside a container: + +- `./scripts/ci/rust_quality_gate.sh` +- `cargo test --locked --verbose` +- `cargo build --release --locked --verbose` +- `cargo deny check licenses sources` +- `cargo audit` +- Docker smoke build (`docker build --target dev ...` + `--version` check) + +To run an opt-in strict lint audit locally: + +```bash +./dev/ci.sh lint-strict +``` + +To run the incremental strict gate (changed Rust lines only): + +```bash +./dev/ci.sh lint-delta +``` + +### 3. Run targeted stages + +```bash +./dev/ci.sh lint +./dev/ci.sh lint-delta +./dev/ci.sh test +./dev/ci.sh build +./dev/ci.sh deny +./dev/ci.sh audit +./dev/ci.sh security +./dev/ci.sh docker-smoke +# Optional host-side docs gate (changed-line markdown lint) +./scripts/ci/docs_quality_gate.sh +# Optional host-side docs links gate (changed-line added links) +./scripts/ci/docs_links_gate.sh +``` + +Note: local `deny` focuses on license/source policy; advisory scanning is handled by `audit`. + +### 4. Enter CI container shell + +```bash +./dev/ci.sh shell +``` + +### 5. Optional shortcut via existing dev CLI + +```bash +./dev/cli.sh ci +./dev/cli.sh ci lint +``` + +### Isolation model + +- Rust compilation, tests, and audit/deny tools run in `zeroclaw-local-ci` container. +- Your host filesystem is mounted at `/workspace`; no host Rust toolchain is required. +- Cargo build artifacts are written to container volume `/ci-target` (not your host `target/`). +- Docker smoke stage uses your Docker daemon to build image layers, but build steps execute in containers. + +### Build cache notes + +- Both `Dockerfile` and `dev/ci/Dockerfile` use BuildKit cache mounts for Cargo registry/git data. +- The root `Dockerfile` also caches Rust `target/` (`id=zeroclaw-target`) to speed repeat local image builds. +- Local CI reuses named Docker volumes for Cargo registry/git and target outputs. +- `./dev/ci.sh docker-smoke` and `./dev/ci.sh all` now use `docker buildx` local cache at `.cache/buildx-smoke` when available. +- The CI image keeps Rust toolchain defaults from `rust:1.92-slim` and installs pinned toolchain `1.92.0` (no custom `CARGO_HOME`/`RUSTUP_HOME` overrides), preventing repeated toolchain bootstrapping on each run. diff --git a/dev/ci.sh b/dev/ci.sh new file mode 100755 index 0000000..a348a19 --- /dev/null +++ b/dev/ci.sh @@ -0,0 +1,133 @@ +#!/usr/bin/env bash +set -euo pipefail + +if [ -f "dev/docker-compose.ci.yml" ]; then + COMPOSE_FILE="dev/docker-compose.ci.yml" +elif [ -f "docker-compose.ci.yml" ] && [ "$(basename "$(pwd)")" = "dev" ]; then + COMPOSE_FILE="docker-compose.ci.yml" +else + echo "❌ Run this script from repo root or dev/ directory." + exit 1 +fi + +compose_cmd=(docker compose -f "$COMPOSE_FILE") +SMOKE_CACHE_DIR="${SMOKE_CACHE_DIR:-.cache/buildx-smoke}" + +run_in_ci() { + local cmd="$1" + "${compose_cmd[@]}" run --rm local-ci bash -c "$cmd" +} + +build_smoke_image() { + if docker buildx version >/dev/null 2>&1; then + mkdir -p "$SMOKE_CACHE_DIR" + local build_args=( + --load + --target dev + --cache-to "type=local,dest=$SMOKE_CACHE_DIR,mode=max" + -t zeroclaw-local-smoke:latest + . + ) + if [ -f "$SMOKE_CACHE_DIR/index.json" ]; then + build_args=(--cache-from "type=local,src=$SMOKE_CACHE_DIR" "${build_args[@]}") + fi + docker buildx build "${build_args[@]}" + else + DOCKER_BUILDKIT=1 docker build --target dev -t zeroclaw-local-smoke:latest . + fi +} + +print_help() { + cat <<'EOF' +ZeroClaw Local CI in Docker + +Usage: ./dev/ci.sh + +Commands: + build-image Build/update the local CI image + shell Open an interactive shell inside the CI container + lint Run rustfmt + clippy correctness gate (container only) + lint-strict Run rustfmt + full clippy warnings gate (container only) + lint-delta Run strict lint delta gate on changed Rust lines (container only) + test Run cargo test (container only) + build Run release build smoke check (container only) + audit Run cargo audit (container only) + deny Run cargo deny check (container only) + security Run cargo audit + cargo deny (container only) + docker-smoke Build and verify runtime image (host docker daemon) + all Run lint, test, build, security, docker-smoke + clean Remove local CI containers and volumes +EOF +} + +if [ $# -lt 1 ]; then + print_help + exit 1 +fi + +case "$1" in + build-image) + "${compose_cmd[@]}" build local-ci + ;; + + shell) + "${compose_cmd[@]}" run --rm local-ci bash + ;; + + lint) + run_in_ci "./scripts/ci/rust_quality_gate.sh" + ;; + + lint-strict) + run_in_ci "./scripts/ci/rust_quality_gate.sh --strict" + ;; + + lint-delta) + run_in_ci "./scripts/ci/rust_strict_delta_gate.sh" + ;; + + test) + run_in_ci "cargo test --locked --verbose" + ;; + + build) + run_in_ci "cargo build --release --locked --verbose" + ;; + + audit) + run_in_ci "cargo audit" + ;; + + deny) + run_in_ci "cargo deny check licenses sources" + ;; + + security) + run_in_ci "cargo deny check licenses sources" + run_in_ci "cargo audit" + ;; + + docker-smoke) + build_smoke_image + docker run --rm zeroclaw-local-smoke:latest --version + ;; + + all) + run_in_ci "./scripts/ci/rust_quality_gate.sh" + run_in_ci "cargo test --locked --verbose" + run_in_ci "cargo build --release --locked --verbose" + run_in_ci "cargo deny check licenses sources" + run_in_ci "cargo audit" + build_smoke_image + docker run --rm zeroclaw-local-smoke:latest --version + ;; + + clean) + "${compose_cmd[@]}" down -v --remove-orphans + ;; + + *) + print_help + exit 1 + ;; +esac diff --git a/dev/ci/Dockerfile b/dev/ci/Dockerfile new file mode 100644 index 0000000..6220fe9 --- /dev/null +++ b/dev/ci/Dockerfile @@ -0,0 +1,22 @@ +# syntax=docker/dockerfile:1.7 + +FROM rust:1.92-slim@sha256:bf3368a992915f128293ac76917ab6e561e4dda883273c8f5c9f6f8ea37a378e + +RUN apt-get update && apt-get install -y --no-install-recommends \ + ca-certificates \ + git \ + pkg-config \ + libssl-dev \ + curl \ + && rm -rf /var/lib/apt/lists/* + +RUN rustup toolchain install 1.92.0 --profile minimal --component rustfmt --component clippy + +RUN --mount=type=cache,target=/usr/local/cargo/registry \ + --mount=type=cache,target=/usr/local/cargo/git \ + cargo install --locked cargo-audit --version 0.22.1 && \ + cargo install --locked cargo-deny --version 0.18.5 + +WORKDIR /workspace + +CMD ["bash"] diff --git a/dev/cli.sh b/dev/cli.sh new file mode 100755 index 0000000..ec9aad5 --- /dev/null +++ b/dev/cli.sh @@ -0,0 +1,124 @@ +#!/bin/bash +set -e + +# Detect execution context (root or dev/) +if [ -f "dev/docker-compose.yml" ]; then + BASE_DIR="dev" + HOST_TARGET_DIR="target" +elif [ -f "docker-compose.yml" ] && [ "$(basename "$(pwd)")" == "dev" ]; then + BASE_DIR="." + HOST_TARGET_DIR="../target" +else + echo "❌ Error: Run this script from the project root or dev/ directory." + exit 1 +fi + +COMPOSE_FILE="$BASE_DIR/docker-compose.yml" + +# Colors +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +RED='\033[0;31m' +NC='\033[0m' # No Color + +function ensure_config { + CONFIG_DIR="$HOST_TARGET_DIR/.zeroclaw" + CONFIG_FILE="$CONFIG_DIR/config.toml" + WORKSPACE_DIR="$CONFIG_DIR/workspace" + + if [ ! -f "$CONFIG_FILE" ]; then + echo -e "${YELLOW}⚙️ Config file missing in target/.zeroclaw. Creating default dev config from template...${NC}" + mkdir -p "$WORKSPACE_DIR" + + # Copy template + cat "$BASE_DIR/config.template.toml" > "$CONFIG_FILE" + fi +} + +function print_help { + echo -e "${YELLOW}ZeroClaw Development Environment Manager${NC}" + echo "Usage: ./dev/cli.sh [command]" + echo "" + echo "Commands:" + echo -e " ${GREEN}up${NC} Start dev environment (Agent + Sandbox)" + echo -e " ${GREEN}down${NC} Stop containers" + echo -e " ${GREEN}shell${NC} Enter Sandbox (Ubuntu)" + echo -e " ${GREEN}agent${NC} Enter Agent (ZeroClaw CLI)" + echo -e " ${GREEN}logs${NC} View logs" + echo -e " ${GREEN}build${NC} Rebuild images" + echo -e " ${GREEN}ci${NC} Run local CI checks in Docker (see ./dev/ci.sh)" + echo -e " ${GREEN}clean${NC} Stop and wipe workspace data" +} + +if [ -z "$1" ]; then + print_help + exit 1 +fi + +case "$1" in + up) + ensure_config + echo -e "${GREEN}🚀 Starting Dev Environment...${NC}" + # Build context MUST be set correctly for docker compose + docker compose -f "$COMPOSE_FILE" up -d + echo -e "${GREEN}✅ Environment is running!${NC}" + echo -e " - Agent: http://127.0.0.1:3000" + echo -e " - Sandbox: running (background)" + echo -e " - Config: target/.zeroclaw/config.toml (Edit locally to apply changes)" + ;; + + down) + echo -e "${YELLOW}🛑 Stopping services...${NC}" + docker compose -f "$COMPOSE_FILE" down + echo -e "${GREEN}✅ Stopped.${NC}" + ;; + + shell) + echo -e "${GREEN}💻 Entering Sandbox (Ubuntu)... (Type 'exit' to leave)${NC}" + docker exec -it zeroclaw-sandbox /bin/bash + ;; + + agent) + echo -e "${GREEN}🤖 Entering Agent Container (ZeroClaw)... (Type 'exit' to leave)${NC}" + docker exec -it zeroclaw-dev /bin/bash + ;; + + logs) + docker compose -f "$COMPOSE_FILE" logs -f + ;; + + build) + echo -e "${YELLOW}🔨 Rebuilding images...${NC}" + docker compose -f "$COMPOSE_FILE" build + ensure_config + docker compose -f "$COMPOSE_FILE" up -d + echo -e "${GREEN}✅ Rebuild complete.${NC}" + ;; + + ci) + shift + if [ "$BASE_DIR" = "." ]; then + ./ci.sh "${@:-all}" + else + ./dev/ci.sh "${@:-all}" + fi + ;; + + clean) + echo -e "${RED}⚠️ WARNING: This will delete 'target/.zeroclaw' data and Docker volumes.${NC}" + read -p "Are you sure? (y/N) " -n 1 -r + echo + if [[ $REPLY =~ ^[Yy]$ ]]; then + docker compose -f "$COMPOSE_FILE" down -v + rm -rf "$HOST_TARGET_DIR/.zeroclaw" + echo -e "${GREEN}🧹 Cleaned up (playground/ remains intact).${NC}" + else + echo "Cancelled." + fi + ;; + + *) + print_help + exit 1 + ;; +esac diff --git a/dev/config.template.toml b/dev/config.template.toml new file mode 100644 index 0000000..f768587 --- /dev/null +++ b/dev/config.template.toml @@ -0,0 +1,12 @@ +workspace_dir = "/zeroclaw-data/workspace" +config_path = "/zeroclaw-data/.zeroclaw/config.toml" +# This is the Ollama Base URL, not a secret key +api_key = "http://host.docker.internal:11434" +default_provider = "ollama" +default_model = "llama3.2" +default_temperature = 0.7 + +[gateway] +port = 3000 +host = "[::]" +allow_public_bind = true diff --git a/dev/docker-compose.ci.yml b/dev/docker-compose.ci.yml new file mode 100644 index 0000000..2078726 --- /dev/null +++ b/dev/docker-compose.ci.yml @@ -0,0 +1,23 @@ +name: zeroclaw-local-ci + +services: + local-ci: + build: + context: .. + dockerfile: dev/ci/Dockerfile + container_name: zeroclaw-local-ci + working_dir: /workspace + environment: + - CARGO_TERM_COLOR=always + - PATH=/usr/local/cargo/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin + - CARGO_TARGET_DIR=/ci-target + volumes: + - ..:/workspace + - cargo-registry:/usr/local/cargo/registry + - cargo-git:/usr/local/cargo/git + - ci-target:/ci-target + +volumes: + cargo-registry: + cargo-git: + ci-target: diff --git a/dev/docker-compose.yml b/dev/docker-compose.yml new file mode 100644 index 0000000..93de91a --- /dev/null +++ b/dev/docker-compose.yml @@ -0,0 +1,59 @@ +# Development Environment for ZeroClaw Agentic Testing +# +# Use this for: +# - Running the agent in a sandboxed environment +# - Testing dangerous commands safely +# - Developing new skills/integrations +# +# Usage: +# cd dev && ./cli.sh up +# or from root: ./dev/cli.sh up +name: zeroclaw-dev +services: + # ── The Agent (Development Image) ── + # Builds from source using the 'dev' stage of the root Dockerfile + zeroclaw-dev: + build: + context: .. + dockerfile: Dockerfile + target: dev + container_name: zeroclaw-dev + restart: unless-stopped + environment: + - API_KEY + - PROVIDER + - ZEROCLAW_MODEL + - ZEROCLAW_GATEWAY_PORT=3000 + - SANDBOX_HOST=zeroclaw-sandbox + volumes: + # Mount single config file (avoids shadowing other files in .zeroclaw) + - ../target/.zeroclaw/config.toml:/zeroclaw-data/.zeroclaw/config.toml + # Mount shared workspace + - ../playground:/zeroclaw-data/workspace + ports: + - "127.0.0.1:3000:3000" + networks: + - dev-net + + # ── The Sandbox (Ubuntu Environment) ── + # A fully loaded Ubuntu environment for the agent to play in. + sandbox: + build: + context: sandbox # Context relative to dev/ + dockerfile: Dockerfile + container_name: zeroclaw-sandbox + hostname: dev-box + command: ["tail", "-f", "/dev/null"] + working_dir: /home/developer/workspace + user: developer + environment: + - TERM=xterm-256color + - SHELL=/bin/bash + volumes: + - ../playground:/home/developer/workspace # Mount local playground + networks: + - dev-net + +networks: + dev-net: + driver: bridge diff --git a/dev/sandbox/Dockerfile b/dev/sandbox/Dockerfile new file mode 100644 index 0000000..6b81a7a --- /dev/null +++ b/dev/sandbox/Dockerfile @@ -0,0 +1,34 @@ +FROM ubuntu:22.04@sha256:c7eb020043d8fc2ae0793fb35a37bff1cf33f156d4d4b12ccc7f3ef8706c38b1 + +# Prevent interactive prompts during package installation +ENV DEBIAN_FRONTEND=noninteractive + +# Install common development tools and runtimes +# - Node.js: Install v20 (LTS) from NodeSource +# - Core: curl, git, vim, build-essential (gcc, make) +# - Python: python3, pip +# - Network: ping, dnsutils +RUN apt-get update && apt-get install -y curl && \ + curl -fsSL https://deb.nodesource.com/setup_20.x | bash - && \ + apt-get install -y \ + nodejs \ + wget git vim nano unzip zip \ + build-essential \ + python3 python3-pip \ + sudo \ + iputils-ping dnsutils net-tools \ + && rm -rf /var/lib/apt/lists/* \ + && node --version && npm --version + +# Create a non-root user 'developer' with UID 1000 +# Grant passwordless sudo to simulate a local dev environment (using safe sudoers.d) +RUN useradd -m -s /bin/bash -u 1000 developer && \ + echo "developer ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/developer && \ + chmod 0440 /etc/sudoers.d/developer + +# Set up the workspace +USER developer +WORKDIR /home/developer/workspace + +# Default command +CMD ["/bin/bash"] diff --git a/docker-compose.yml b/docker-compose.yml index a923676..3e85171 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -25,17 +25,30 @@ services: # Options: openrouter, openai, anthropic, ollama - PROVIDER=${PROVIDER:-openrouter} + # Allow public bind inside Docker (required for container networking) + - ZEROCLAW_ALLOW_PUBLIC_BIND=true + # Optional: Model override # - ZEROCLAW_MODEL=anthropic/claude-sonnet-4-20250514 volumes: - # Persist workspace and config - - zeroclaw-data:/data + # Persist workspace and config (must match WORKDIR/HOME in Dockerfile) + - zeroclaw-data:/zeroclaw-data ports: - # Gateway API port - - "3000:3000" + # Gateway API port (override HOST_PORT if 3000 is taken) + - "${HOST_PORT:-3000}:3000" + # Resource limits + deploy: + resources: + limits: + cpus: '2' + memory: 2G + reservations: + cpus: '0.5' + memory: 512M + # Health check healthcheck: test: ["CMD", "zeroclaw", "doctor"] diff --git a/docs/Hardware_architecture.jpg b/docs/Hardware_architecture.jpg new file mode 100644 index 0000000..8daf589 Binary files /dev/null and b/docs/Hardware_architecture.jpg differ diff --git a/docs/actions-source-policy.md b/docs/actions-source-policy.md new file mode 100644 index 0000000..21eb6e2 --- /dev/null +++ b/docs/actions-source-policy.md @@ -0,0 +1,89 @@ +# Actions Source Policy (Phase 1) + +This document defines the current GitHub Actions source-control policy for this repository. + +Phase 1 objective: lock down action sources with minimal disruption, before full SHA pinning. + +## Current Policy + +- Repository Actions permissions: enabled +- Allowed actions mode: selected +- SHA pinning required: false (deferred to Phase 2) + +Selected allowlist patterns: + +- `actions/*` (covers `actions/cache`, `actions/checkout`, `actions/upload-artifact`, `actions/download-artifact`, and other first-party actions) +- `docker/*` +- `dtolnay/rust-toolchain@*` +- `Swatinem/rust-cache@*` +- `DavidAnson/markdownlint-cli2-action@*` +- `lycheeverse/lychee-action@*` +- `EmbarkStudios/cargo-deny-action@*` +- `rhysd/actionlint@*` +- `softprops/action-gh-release@*` +- `sigstore/cosign-installer@*` +- `useblacksmith/*` (Blacksmith self-hosted runner infrastructure) + +## Change Control Export + +Use these commands to export the current effective policy for audit/change control: + +```bash +gh api repos/zeroclaw-labs/zeroclaw/actions/permissions +gh api repos/zeroclaw-labs/zeroclaw/actions/permissions/selected-actions +``` + +Record each policy change with: + +- change date/time (UTC) +- actor +- reason +- allowlist delta (added/removed patterns) +- rollback note + +## Why This Phase + +- Reduces supply-chain risk from unreviewed marketplace actions. +- Preserves current CI/CD functionality with low migration overhead. +- Prepares for Phase 2 full SHA pinning without blocking active development. + +## Agentic Workflow Guardrails + +Because this repository has high agent-authored change volume: + +- Any PR that adds or changes `uses:` action sources must include an allowlist impact note. +- New third-party actions require explicit maintainer review before allowlisting. +- Expand allowlist only for verified missing actions; avoid broad wildcard exceptions. +- Keep rollback instructions in the PR description for Actions policy changes. + +## Validation Checklist + +After allowlist changes, validate: + +1. `CI` +2. `Docker` +3. `Security Audit` +4. `Workflow Sanity` +5. `Release` (when safe to run) + +Failure mode to watch for: + +- `action is not allowed by policy` + +If encountered, add only the specific trusted missing action, rerun, and document why. + +Latest sweep notes: + +- 2026-02-16: Hidden dependency discovered in `release.yml`: `sigstore/cosign-installer@...` + - Added allowlist pattern: `sigstore/cosign-installer@*` +- 2026-02-16: Blacksmith migration blocked workflow execution + - Added allowlist pattern: `useblacksmith/*` for self-hosted runner infrastructure + - Actions: `useblacksmith/setup-docker-builder@v1`, `useblacksmith/build-push-action@v2` + +## Rollback + +Emergency unblock path: + +1. Temporarily set Actions policy back to `all`. +2. Restore selected allowlist after identifying missing entries. +3. Record incident and final allowlist delta. diff --git a/docs/adding-boards-and-tools.md b/docs/adding-boards-and-tools.md new file mode 100644 index 0000000..a7e9eaa --- /dev/null +++ b/docs/adding-boards-and-tools.md @@ -0,0 +1,116 @@ +# Adding Boards and Tools — ZeroClaw Hardware Guide + +This guide explains how to add new hardware boards and custom tools to ZeroClaw. + +## Quick Start: Add a Board via CLI + +```bash +# Add a board (updates ~/.zeroclaw/config.toml) +zeroclaw peripheral add nucleo-f401re /dev/ttyACM0 +zeroclaw peripheral add arduino-uno /dev/cu.usbmodem12345 +zeroclaw peripheral add rpi-gpio native # for Raspberry Pi GPIO (Linux) + +# Restart daemon to apply +zeroclaw daemon --host 127.0.0.1 --port 8080 +``` + +## Supported Boards + +| Board | Transport | Path Example | +|-----------------|-----------|---------------------------| +| nucleo-f401re | serial | /dev/ttyACM0, /dev/cu.usbmodem* | +| arduino-uno | serial | /dev/ttyACM0, /dev/cu.usbmodem* | +| arduino-uno-q | bridge | (Uno Q IP) | +| rpi-gpio | native | native | +| esp32 | serial | /dev/ttyUSB0 | + +## Manual Config + +Edit `~/.zeroclaw/config.toml`: + +```toml +[peripherals] +enabled = true +datasheet_dir = "docs/datasheets" # optional: RAG for "turn on red led" → pin 13 + +[[peripherals.boards]] +board = "nucleo-f401re" +transport = "serial" +path = "/dev/ttyACM0" +baud = 115200 + +[[peripherals.boards]] +board = "arduino-uno" +transport = "serial" +path = "/dev/cu.usbmodem12345" +baud = 115200 +``` + +## Adding a Datasheet (RAG) + +Place `.md` or `.txt` files in `docs/datasheets/` (or your `datasheet_dir`). Name files by board: `nucleo-f401re.md`, `arduino-uno.md`. + +### Pin Aliases (Recommended) + +Add a `## Pin Aliases` section so the agent can map "red led" → pin 13: + +```markdown +# My Board + +## Pin Aliases + +| alias | pin | +|-------------|-----| +| red_led | 13 | +| builtin_led | 13 | +| user_led | 5 | +``` + +Or use key-value format: + +```markdown +## Pin Aliases +red_led: 13 +builtin_led: 13 +``` + +### PDF Datasheets + +With the `rag-pdf` feature, ZeroClaw can index PDF files: + +```bash +cargo build --features hardware,rag-pdf +``` + +Place PDFs in the datasheet directory. They are extracted and chunked for RAG. + +## Adding a New Board Type + +1. **Create a datasheet** — `docs/datasheets/my-board.md` with pin aliases and GPIO info. +2. **Add to config** — `zeroclaw peripheral add my-board /dev/ttyUSB0` +3. **Implement a peripheral** (optional) — For custom protocols, implement the `Peripheral` trait in `src/peripherals/` and register in `create_peripheral_tools`. + +See `docs/hardware-peripherals-design.md` for the full design. + +## Adding a Custom Tool + +1. Implement the `Tool` trait in `src/tools/`. +2. Register in `create_peripheral_tools` (for hardware tools) or the agent tool registry. +3. Add a tool description to the agent's `tool_descs` in `src/agent/loop_.rs`. + +## CLI Reference + +| Command | Description | +|---------|-------------| +| `zeroclaw peripheral list` | List configured boards | +| `zeroclaw peripheral add ` | Add board (writes config) | +| `zeroclaw peripheral flash` | Flash Arduino firmware | +| `zeroclaw peripheral flash-nucleo` | Flash Nucleo firmware | +| `zeroclaw hardware discover` | List USB devices | +| `zeroclaw hardware info` | Chip info via probe-rs | + +## Troubleshooting + +- **Serial port not found** — On macOS use `/dev/cu.usbmodem*`; on Linux use `/dev/ttyACM0` or `/dev/ttyUSB0`. +- **Build with hardware** — `cargo build --features hardware` +- **Probe-rs for Nucleo** — `cargo build --features hardware,probe` diff --git a/docs/agnostic-security.md b/docs/agnostic-security.md new file mode 100644 index 0000000..7ed0273 --- /dev/null +++ b/docs/agnostic-security.md @@ -0,0 +1,348 @@ +# Agnostic Security: Zero Impact on Portability + +## Core Question: Will security features break... +1. ❓ Fast cross-compilation builds? +2. ❓ Pluggable architecture (swap anything)? +3. ❓ Hardware agnosticism (ARM, x86, RISC-V)? +4. ❓ Small hardware support (<5MB RAM, $10 boards)? + +**Answer: NO to all** — Security is designed as **optional feature flags** with **platform-specific conditional compilation**. + +--- + +## 1. Build Speed: Feature-Gated Security + +### Cargo.toml: Security Features Behind Features + +```toml +[features] +default = ["basic-security"] + +# Basic security (always on, zero overhead) +basic-security = [] + +# Platform-specific sandboxing (opt-in per platform) +sandbox-landlock = [] # Linux only +sandbox-firejail = [] # Linux only +sandbox-bubblewrap = []# macOS/Linux +sandbox-docker = [] # All platforms (heavy) + +# Full security suite (for production builds) +security-full = [ + "basic-security", + "sandbox-landlock", + "resource-monitoring", + "audit-logging", +] + +# Resource & audit monitoring +resource-monitoring = [] +audit-logging = [] + +# Development builds (fastest, no extra deps) +dev = [] +``` + +### Build Commands (Choose Your Profile) + +```bash +# Ultra-fast dev build (no security extras) +cargo build --profile dev + +# Release build with basic security (default) +cargo build --release +# → Includes: allowlist, path blocking, injection protection +# → Excludes: Landlock, Firejail, audit logging + +# Production build with full security +cargo build --release --features security-full +# → Includes: Everything + +# Platform-specific sandbox only +cargo build --release --features sandbox-landlock # Linux +cargo build --release --features sandbox-docker # All platforms +``` + +### Conditional Compilation: Zero Overhead When Disabled + +```rust +// src/security/mod.rs + +#[cfg(feature = "sandbox-landlock")] +mod landlock; +#[cfg(feature = "sandbox-landlock")] +pub use landlock::LandlockSandbox; + +#[cfg(feature = "sandbox-firejail")] +mod firejail; +#[cfg(feature = "sandbox-firejail")] +pub use firejail::FirejailSandbox; + +// Always-include basic security (no feature flag) +pub mod policy; // allowlist, path blocking, injection protection +``` + +**Result**: When features are disabled, the code isn't even compiled — **zero binary bloat**. + +--- + +## 2. Pluggable Architecture: Security Is a Trait Too + +### Security Backend Trait (Swappable Like Everything Else) + +```rust +// src/security/traits.rs + +#[async_trait] +pub trait Sandbox: Send + Sync { + /// Wrap a command with sandbox protection + fn wrap_command(&self, cmd: &mut std::process::Command) -> std::io::Result<()>; + + /// Check if sandbox is available on this platform + fn is_available(&self) -> bool; + + /// Human-readable name + fn name(&self) -> &str; +} + +// No-op sandbox (always available) +pub struct NoopSandbox; + +impl Sandbox for NoopSandbox { + fn wrap_command(&self, _cmd: &mut std::process::Command) -> std::io::Result<()> { + Ok(()) // Pass through unchanged + } + + fn is_available(&self) -> bool { true } + fn name(&self) -> &str { "none" } +} +``` + +### Factory Pattern: Auto-Select Based on Features + +```rust +// src/security/factory.rs + +pub fn create_sandbox() -> Box { + #[cfg(feature = "sandbox-landlock")] + { + if LandlockSandbox::is_available() { + return Box::new(LandlockSandbox::new()); + } + } + + #[cfg(feature = "sandbox-firejail")] + { + if FirejailSandbox::is_available() { + return Box::new(FirejailSandbox::new()); + } + } + + #[cfg(feature = "sandbox-bubblewrap")] + { + if BubblewrapSandbox::is_available() { + return Box::new(BubblewrapSandbox::new()); + } + } + + #[cfg(feature = "sandbox-docker")] + { + if DockerSandbox::is_available() { + return Box::new(DockerSandbox::new()); + } + } + + // Fallback: always available + Box::new(NoopSandbox) +} +``` + +**Just like providers, channels, and memory — security is pluggable!** + +--- + +## 3. Hardware Agnosticism: Same Binary, Different Platforms + +### Cross-Platform Behavior Matrix + +| Platform | Builds On | Runtime Behavior | +|----------|-----------|------------------| +| **Linux ARM** (Raspberry Pi) | ✅ Yes | Landlock → None (graceful) | +| **Linux x86_64** | ✅ Yes | Landlock → Firejail → None | +| **macOS ARM** (M1/M2) | ✅ Yes | Bubblewrap → None | +| **macOS x86_64** | ✅ Yes | Bubblewrap → None | +| **Windows ARM** | ✅ Yes | None (app-layer) | +| **Windows x86_64** | ✅ Yes | None (app-layer) | +| **RISC-V Linux** | ✅ Yes | Landlock → None | + +### How It Works: Runtime Detection + +```rust +// src/security/detect.rs + +impl SandboxingStrategy { + /// Choose best available sandbox AT RUNTIME + pub fn detect() -> SandboxingStrategy { + #[cfg(target_os = "linux")] + { + // Try Landlock first (kernel feature detection) + if Self::probe_landlock() { + return SandboxingStrategy::Landlock; + } + + // Try Firejail (user-space tool detection) + if Self::probe_firejail() { + return SandboxingStrategy::Firejail; + } + } + + #[cfg(target_os = "macos")] + { + if Self::probe_bubblewrap() { + return SandboxingStrategy::Bubblewrap; + } + } + + // Always available fallback + SandboxingStrategy::ApplicationLayer + } +} +``` + +**Same binary runs everywhere** — it just adapts its protection level based on what's available. + +--- + +## 4. Small Hardware: Memory Impact Analysis + +### Binary Size Impact (Estimated) + +| Feature | Code Size | RAM Overhead | Status | +|---------|-----------|--------------|--------| +| **Base ZeroClaw** | 3.4MB | <5MB | ✅ Current | +| **+ Landlock** | +50KB | +100KB | ✅ Linux 5.13+ | +| **+ Firejail wrapper** | +20KB | +0KB (external) | ✅ Linux + firejail | +| **+ Memory monitoring** | +30KB | +50KB | ✅ All platforms | +| **+ Audit logging** | +40KB | +200KB (buffered) | ✅ All platforms | +| **Full security** | +140KB | +350KB | ✅ Still <6MB total | + +### $10 Hardware Compatibility + +| Hardware | RAM | ZeroClaw (base) | ZeroClaw (full security) | Status | +|----------|-----|-----------------|--------------------------|--------| +| **Raspberry Pi Zero** | 512MB | ✅ 2% | ✅ 2.5% | Works | +| **Orange Pi Zero** | 512MB | ✅ 2% | ✅ 2.5% | Works | +| **NanoPi NEO** | 256MB | ✅ 4% | ✅ 5% | Works | +| **C.H.I.P.** | 512MB | ✅ 2% | ✅ 2.5% | Works | +| **Rock64** | 1GB | ✅ 1% | ✅ 1.2% | Works | + +**Even with full security, ZeroClaw uses <5% of RAM on $10 boards.** + +--- + +## 5. Agnostic Swaps: Everything Remains Pluggable + +### ZeroClaw's Core Promise: Swap Anything + +```rust +// Providers (already pluggable) +Box + +// Channels (already pluggable) +Box + +// Memory (already pluggable) +Box + +// Tunnels (already pluggable) +Box + +// NOW ALSO: Security (newly pluggable) +Box +Box +Box +``` + +### Swap Security Backends via Config + +```toml +# Use no sandbox (fastest, app-layer only) +[security.sandbox] +backend = "none" + +# Use Landlock (Linux kernel LSM, native) +[security.sandbox] +backend = "landlock" + +# Use Firejail (user-space, needs firejail installed) +[security.sandbox] +backend = "firejail" + +# Use Docker (heaviest, most isolated) +[security.sandbox] +backend = "docker" +``` + +**Just like swapping OpenAI for Gemini, or SQLite for PostgreSQL.** + +--- + +## 6. Dependency Impact: Minimal New Deps + +### Current Dependencies (for context) +``` +reqwest, tokio, serde, anyhow, uuid, chrono, rusqlite, +axum, tracing, opentelemetry, ... +``` + +### Security Feature Dependencies + +| Feature | New Dependencies | Platform | +|---------|------------------|----------| +| **Landlock** | `landlock` crate (pure Rust) | Linux only | +| **Firejail** | None (external binary) | Linux only | +| **Bubblewrap** | None (external binary) | macOS/Linux | +| **Docker** | `bollard` crate (Docker API) | All platforms | +| **Memory monitoring** | None (std::alloc) | All platforms | +| **Audit logging** | None (already have hmac/sha2) | All platforms | + +**Result**: Most features add **zero new Rust dependencies** — they either: +1. Use pure-Rust crates (landlock) +2. Wrap external binaries (Firejail, Bubblewrap) +3. Use existing deps (hmac, sha2 already in Cargo.toml) + +--- + +## Summary: Core Value Propositions Preserved + +| Value Prop | Before | After (with security) | Status | +|------------|--------|----------------------|--------| +| **<5MB RAM** | ✅ <5MB | ✅ <6MB (worst case) | ✅ Preserved | +| **<10ms startup** | ✅ <10ms | ✅ <15ms (detection) | ✅ Preserved | +| **3.4MB binary** | ✅ 3.4MB | ✅ 3.5MB (with all features) | ✅ Preserved | +| **ARM + x86 + RISC-V** | ✅ All | ✅ All | ✅ Preserved | +| **$10 hardware** | ✅ Works | ✅ Works | ✅ Preserved | +| **Pluggable everything** | ✅ Yes | ✅ Yes (security too) | ✅ Enhanced | +| **Cross-platform** | ✅ Yes | ✅ Yes | ✅ Preserved | + +--- + +## The Key: Feature Flags + Conditional Compilation + +```bash +# Developer build (fastest, no extra features) +cargo build --profile dev + +# Standard release (your current build) +cargo build --release + +# Production with full security +cargo build --release --features security-full + +# Target specific hardware +cargo build --release --target aarch64-unknown-linux-gnu # Raspberry Pi +cargo build --release --target riscv64gc-unknown-linux-gnu # RISC-V +cargo build --release --target armv7-unknown-linux-gnueabihf # ARMv7 +``` + +**Every target, every platform, every use case — still fast, still small, still agnostic.** diff --git a/docs/arduino-uno-q-setup.md b/docs/arduino-uno-q-setup.md new file mode 100644 index 0000000..8e170e8 --- /dev/null +++ b/docs/arduino-uno-q-setup.md @@ -0,0 +1,217 @@ +# ZeroClaw on Arduino Uno Q — Step-by-Step Guide + +Run ZeroClaw on the Arduino Uno Q's Linux side. Telegram works over WiFi; GPIO control uses the Bridge (requires a minimal App Lab app). + +--- + +## What's Included (No Code Changes Needed) + +ZeroClaw includes everything needed for Arduino Uno Q. **Clone the repo and follow this guide — no patches or custom code required.** + +| Component | Location | Purpose | +|-----------|----------|---------| +| Bridge app | `firmware/zeroclaw-uno-q-bridge/` | MCU sketch + Python socket server (port 9999) for GPIO | +| Bridge tools | `src/peripherals/uno_q_bridge.rs` | `gpio_read` / `gpio_write` tools that talk to the Bridge over TCP | +| Setup command | `src/peripherals/uno_q_setup.rs` | `zeroclaw peripheral setup-uno-q` deploys the Bridge via scp + arduino-app-cli | +| Config schema | `board = "arduino-uno-q"`, `transport = "bridge"` | Supported in `config.toml` | + +Build with `--features hardware` (or the default features) to include Uno Q support. + +--- + +## Prerequisites + +- Arduino Uno Q with WiFi configured +- Arduino App Lab installed on your Mac (for initial setup and deployment) +- API key for LLM (OpenRouter, etc.) + +--- + +## Phase 1: Initial Uno Q Setup (One-Time) + +### 1.1 Configure Uno Q via App Lab + +1. Download [Arduino App Lab](https://docs.arduino.cc/software/app-lab/) (AppImage on Linux). +2. Connect Uno Q via USB, power it on. +3. Open App Lab, connect to the board. +4. Follow the setup wizard: + - Set username and password (for SSH) + - Configure WiFi (SSID, password) + - Apply any firmware updates +5. Note the IP address shown (e.g. `arduino@192.168.1.42`) or find it later via `ip addr show` in App Lab's terminal. + +### 1.2 Verify SSH Access + +```bash +ssh arduino@ +# Enter the password you set +``` + +--- + +## Phase 2: Install ZeroClaw on Uno Q + +### Option A: Build on the Device (Simpler, ~20–40 min) + +```bash +# SSH into Uno Q +ssh arduino@ + +# Install Rust +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y +source ~/.cargo/env + +# Install build deps (Debian) +sudo apt-get update +sudo apt-get install -y pkg-config libssl-dev + +# Clone zeroclaw (or scp your project) +git clone https://github.com/theonlyhennygod/zeroclaw.git +cd zeroclaw + +# Build (takes ~15–30 min on Uno Q) +cargo build --release + +# Install +sudo cp target/release/zeroclaw /usr/local/bin/ +``` + +### Option B: Cross-Compile on Mac (Faster) + +```bash +# On your Mac — add aarch64 target +rustup target add aarch64-unknown-linux-gnu + +# Install cross-compiler (macOS; required for linking) +brew tap messense/macos-cross-toolchains +brew install aarch64-unknown-linux-gnu + +# Build +CC_aarch64_unknown_linux_gnu=aarch64-unknown-linux-gnu-gcc cargo build --release --target aarch64-unknown-linux-gnu + +# Copy to Uno Q +scp target/aarch64-unknown-linux-gnu/release/zeroclaw arduino@:~/ +ssh arduino@ "sudo mv ~/zeroclaw /usr/local/bin/" +``` + +If cross-compile fails, use Option A and build on the device. + +--- + +## Phase 3: Configure ZeroClaw + +### 3.1 Run Onboard (or Create Config Manually) + +```bash +ssh arduino@ + +# Quick config +zeroclaw onboard --api-key YOUR_OPENROUTER_KEY --provider openrouter + +# Or create config manually +mkdir -p ~/.zeroclaw/workspace +nano ~/.zeroclaw/config.toml +``` + +### 3.2 Minimal config.toml + +```toml +api_key = "YOUR_OPENROUTER_API_KEY" +default_provider = "openrouter" +default_model = "anthropic/claude-sonnet-4" + +[peripherals] +enabled = false +# GPIO via Bridge requires Phase 4 + +[channels_config.telegram] +bot_token = "YOUR_TELEGRAM_BOT_TOKEN" +allowed_users = ["*"] + +[gateway] +host = "127.0.0.1" +port = 8080 +allow_public_bind = false + +[agent] +compact_context = true +``` + +--- + +## Phase 4: Run ZeroClaw Daemon + +```bash +ssh arduino@ + +# Run daemon (Telegram polling works over WiFi) +zeroclaw daemon --host 127.0.0.1 --port 8080 +``` + +**At this point:** Telegram chat works. Send messages to your bot — ZeroClaw responds. No GPIO yet. + +--- + +## Phase 5: GPIO via Bridge (ZeroClaw Handles It) + +ZeroClaw includes the Bridge app and setup command. + +### 5.1 Deploy Bridge App + +**From your Mac** (with zeroclaw repo): +```bash +zeroclaw peripheral setup-uno-q --host 192.168.0.48 +``` + +**From the Uno Q** (SSH'd in): +```bash +zeroclaw peripheral setup-uno-q +``` + +This copies the Bridge app to `~/ArduinoApps/zeroclaw-uno-q-bridge` and starts it. + +### 5.2 Add to config.toml + +```toml +[peripherals] +enabled = true + +[[peripherals.boards]] +board = "arduino-uno-q" +transport = "bridge" +``` + +### 5.3 Run ZeroClaw + +```bash +zeroclaw daemon --host 127.0.0.1 --port 8080 +``` + +Now when you message your Telegram bot *"Turn on the LED"* or *"Set pin 13 high"*, ZeroClaw uses `gpio_write` via the Bridge. + +--- + +## Summary: Commands Start to End + +| Step | Command | +|------|---------| +| 1 | Configure Uno Q in App Lab (WiFi, SSH) | +| 2 | `ssh arduino@` | +| 3 | `curl -sSf https://sh.rustup.rs \| sh -s -- -y && source ~/.cargo/env` | +| 4 | `sudo apt-get install -y pkg-config libssl-dev` | +| 5 | `git clone https://github.com/theonlyhennygod/zeroclaw.git && cd zeroclaw` | +| 6 | `cargo build --release --no-default-features` | +| 7 | `zeroclaw onboard --api-key KEY --provider openrouter` | +| 8 | Edit `~/.zeroclaw/config.toml` (add Telegram bot_token) | +| 9 | `zeroclaw daemon --host 127.0.0.1 --port 8080` | +| 10 | Message your Telegram bot — it responds | + +--- + +## Troubleshooting + +- **"command not found: zeroclaw"** — Use full path: `/usr/local/bin/zeroclaw` or ensure `~/.cargo/bin` is in PATH. +- **Telegram not responding** — Check bot_token, allowed_users, and that the Uno Q has internet (WiFi). +- **Out of memory** — Use `--no-default-features` to reduce binary size; consider `compact_context = true`. +- **GPIO commands ignored** — Ensure Bridge app is running (`zeroclaw peripheral setup-uno-q` deploys and starts it). Config must have `board = "arduino-uno-q"` and `transport = "bridge"`. +- **LLM provider (GLM/Zhipu)** — Use `default_provider = "glm"` or `"zhipu"` with `GLM_API_KEY` in env or config. ZeroClaw uses the correct v4 endpoint. diff --git a/docs/audit-logging.md b/docs/audit-logging.md new file mode 100644 index 0000000..8871adb --- /dev/null +++ b/docs/audit-logging.md @@ -0,0 +1,186 @@ +# Audit Logging for ZeroClaw + +## Problem +ZeroClaw logs actions but lacks tamper-evident audit trails for: +- Who executed what command +- When and from which channel +- What resources were accessed +- Whether security policies were triggered + +--- + +## Proposed Audit Log Format + +```json +{ + "timestamp": "2026-02-16T12:34:56Z", + "event_id": "evt_1a2b3c4d", + "event_type": "command_execution", + "actor": { + "channel": "telegram", + "user_id": "123456789", + "username": "@alice" + }, + "action": { + "command": "ls -la", + "risk_level": "low", + "approved": false, + "allowed": true + }, + "result": { + "success": true, + "exit_code": 0, + "duration_ms": 15 + }, + "security": { + "policy_violation": false, + "rate_limit_remaining": 19 + }, + "signature": "SHA256:abc123..." // HMAC for tamper evidence +} +``` + +--- + +## Implementation + +```rust +// src/security/audit.rs +use serde::{Deserialize, Serialize}; +use std::io::Write; +use std::path::PathBuf; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AuditEvent { + pub timestamp: String, + pub event_id: String, + pub event_type: AuditEventType, + pub actor: Actor, + pub action: Action, + pub result: ExecutionResult, + pub security: SecurityContext, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum AuditEventType { + CommandExecution, + FileAccess, + ConfigurationChange, + AuthSuccess, + AuthFailure, + PolicyViolation, +} + +pub struct AuditLogger { + log_path: PathBuf, + signing_key: Option>, +} + +impl AuditLogger { + pub fn log(&self, event: &AuditEvent) -> anyhow::Result<()> { + let mut line = serde_json::to_string(event)?; + + // Add HMAC signature if key configured + if let Some(ref key) = self.signing_key { + let signature = compute_hmac(key, line.as_bytes()); + line.push_str(&format!("\n\"signature\": \"{}\"", signature)); + } + + let mut file = std::fs::OpenOptions::new() + .create(true) + .append(true) + .open(&self.log_path)?; + + writeln!(file, "{}", line)?; + file.sync_all()?; // Force flush for durability + Ok(()) + } + + pub fn search(&self, filter: AuditFilter) -> Vec { + // Search log file by filter criteria + todo!() + } +} +``` + +--- + +## Config Schema + +```toml +[security.audit] +enabled = true +log_path = "~/.config/zeroclaw/audit.log" +max_size_mb = 100 +rotate = "daily" # daily | weekly | size + +# Tamper evidence +sign_events = true +signing_key_path = "~/.config/zeroclaw/audit.key" + +# What to log +log_commands = true +log_file_access = true +log_auth_events = true +log_policy_violations = true +``` + +--- + +## Audit Query CLI + +```bash +# Show all commands executed by @alice +zeroclaw audit --user @alice + +# Show all high-risk commands +zeroclaw audit --risk high + +# Show violations from last 24 hours +zeroclaw audit --since 24h --violations-only + +# Export to JSON for analysis +zeroclaw audit --format json --output audit.json + +# Verify log integrity +zeroclaw audit --verify-signatures +``` + +--- + +## Log Rotation + +```rust +pub fn rotate_audit_log(log_path: &PathBuf, max_size: u64) -> anyhow::Result<()> { + let metadata = std::fs::metadata(log_path)?; + if metadata.len() < max_size { + return Ok(()); + } + + // Rotate: audit.log -> audit.log.1 -> audit.log.2 -> ... + let stem = log_path.file_stem().unwrap_or_default(); + let extension = log_path.extension().and_then(|s| s.to_str()).unwrap_or("log"); + + for i in (1..10).rev() { + let old_name = format!("{}.{}.{}", stem, i, extension); + let new_name = format!("{}.{}.{}", stem, i + 1, extension); + let _ = std::fs::rename(old_name, new_name); + } + + let rotated = format!("{}.1.{}", stem, extension); + std::fs::rename(log_path, &rotated)?; + + Ok(()) +} +``` + +--- + +## Implementation Priority + +| Phase | Feature | Effort | Security Value | +|-------|---------|--------|----------------| +| **P0** | Basic event logging | Low | Medium | +| **P1** | Query CLI | Medium | Medium | +| **P2** | HMAC signing | Medium | High | +| **P3** | Log rotation + archival | Low | Medium | diff --git a/docs/ci-map.md b/docs/ci-map.md new file mode 100644 index 0000000..344ed6f --- /dev/null +++ b/docs/ci-map.md @@ -0,0 +1,110 @@ +# CI Workflow Map + +This document explains what each GitHub workflow does, when it runs, and whether it should block merges. + +## Merge-Blocking vs Optional + +Merge-blocking checks should stay small and deterministic. Optional checks are useful for automation and maintenance, but should not block normal development. + +### Merge-Blocking + +- `.github/workflows/ci.yml` (`CI`) + - Purpose: Rust validation (`cargo fmt --all -- --check`, `cargo clippy --locked --all-targets -- -D clippy::correctness`, strict delta lint gate on changed Rust lines, `test`, release build smoke) + docs quality checks when docs change (`markdownlint` blocks only issues on changed lines; link check scans only links added on changed lines) + - Additional behavior: PRs that change `.github/workflows/**` require at least one approving review from a login in `WORKFLOW_OWNER_LOGINS` (repository variable fallback: `theonlyhennygod,willsarg`) + - Additional behavior: lint gates run before `test`/`build`; when lint/docs gates fail on PRs, CI posts an actionable feedback comment with failing gate names and local fix commands + - Merge gate: `CI Required Gate` +- `.github/workflows/workflow-sanity.yml` (`Workflow Sanity`) + - Purpose: lint GitHub workflow files (`actionlint`, tab checks) + - Recommended for workflow-changing PRs +- `.github/workflows/pr-intake-sanity.yml` (`PR Intake Sanity`) + - Purpose: safe pre-CI PR checks (template completeness, added-line tabs/trailing-whitespace/conflict markers) with immediate sticky feedback comment + +### Non-Blocking but Important + +- `.github/workflows/docker.yml` (`Docker`) + - Purpose: PR docker smoke check and publish images on `main`/tag pushes +- `.github/workflows/security.yml` (`Security Audit`) + - Purpose: dependency advisories (`cargo audit`) and policy/license checks (`cargo deny`) +- `.github/workflows/release.yml` (`Release`) + - Purpose: build tagged release artifacts and publish GitHub releases +- `.github/workflows/label-policy-sanity.yml` (`Label Policy Sanity`) + - Purpose: validate shared contributor-tier policy in `.github/label-policy.json` and ensure label workflows consume that policy +- `.github/workflows/rust-reusable.yml` (`Rust Reusable Job`) + - Purpose: reusable Rust setup/cache + command runner for workflow-call consumers + +### Optional Repository Automation + +- `.github/workflows/labeler.yml` (`PR Labeler`) + - Purpose: scope/path labels + size/risk labels + fine-grained module labels (`: `) + - Additional behavior: label descriptions are auto-managed as hover tooltips to explain each auto-judgment rule + - Additional behavior: provider-related keywords in provider/config/onboard/integration changes are promoted to `provider:*` labels (for example `provider:kimi`, `provider:deepseek`) + - Additional behavior: hierarchical de-duplication keeps only the most specific scope labels (for example `tool:composio` suppresses `tool:core` and `tool`) + - Additional behavior: module namespaces are compacted — one specific module keeps `prefix:component`; multiple specifics collapse to just `prefix` + - Additional behavior: applies contributor tiers on PRs by merged PR count (`trusted` >=5, `experienced` >=10, `principal` >=20, `distinguished` >=50) + - Additional behavior: final label set is priority-sorted (`risk:*` first, then `size:*`, then contributor tier, then module/path labels) + - Additional behavior: managed label colors follow display order to produce a smooth left-to-right gradient when many labels are present + - Manual governance: supports `workflow_dispatch` with `mode=audit|repair` to inspect/fix managed label metadata drift across the whole repository + - Additional behavior: risk + size labels are auto-corrected on manual PR label edits (`labeled`/`unlabeled` events); apply `risk: manual` when maintainers intentionally override automated risk selection + - High-risk heuristic paths: `src/security/**`, `src/runtime/**`, `src/gateway/**`, `src/tools/**`, `.github/workflows/**` + - Guardrail: maintainers can apply `risk: manual` to freeze automated risk recalculation +- `.github/workflows/auto-response.yml` (`PR Auto Responder`) + - Purpose: first-time contributor onboarding + label-driven response routing (`r:support`, `r:needs-repro`, etc.) + - Additional behavior: applies contributor tiers on issues by merged PR count (`trusted` >=5, `experienced` >=10, `principal` >=20, `distinguished` >=50), matching PR tier thresholds exactly + - Additional behavior: contributor-tier labels are treated as automation-managed (manual add/remove on PR/issue is auto-corrected) + - Guardrail: label-based close routes are issue-only; PRs are never auto-closed by route labels +- `.github/workflows/stale.yml` (`Stale`) + - Purpose: stale issue/PR lifecycle automation +- `.github/dependabot.yml` (`Dependabot`) + - Purpose: grouped, rate-limited dependency update PRs (Cargo + GitHub Actions) +- `.github/workflows/pr-hygiene.yml` (`PR Hygiene`) + - Purpose: nudge stale-but-active PRs to rebase/re-run required checks before queue starvation + +## Trigger Map + +- `CI`: push to `main`, PRs to `main` +- `Docker`: push to `main`, tag push (`v*`), PRs touching docker/workflow files, manual dispatch +- `Release`: tag push (`v*`) +- `Security Audit`: push to `main`, PRs to `main`, weekly schedule +- `Workflow Sanity`: PR/push when `.github/workflows/**`, `.github/*.yml`, or `.github/*.yaml` change +- `PR Intake Sanity`: `pull_request_target` on opened/reopened/synchronize/edited/ready_for_review +- `Label Policy Sanity`: PR/push when `.github/label-policy.json`, `.github/workflows/labeler.yml`, or `.github/workflows/auto-response.yml` changes +- `PR Labeler`: `pull_request_target` lifecycle events +- `PR Auto Responder`: issue opened/labeled, `pull_request_target` opened/labeled +- `Stale`: daily schedule, manual dispatch +- `Dependabot`: weekly dependency maintenance windows +- `PR Hygiene`: every 12 hours schedule, manual dispatch + +## Fast Triage Guide + +1. `CI Required Gate` failing: start with `.github/workflows/ci.yml`. +2. Docker failures on PRs: inspect `.github/workflows/docker.yml` `pr-smoke` job. +3. Release failures on tags: inspect `.github/workflows/release.yml`. +4. Security failures: inspect `.github/workflows/security.yml` and `deny.toml`. +5. Workflow syntax/lint failures: inspect `.github/workflows/workflow-sanity.yml`. +6. PR intake failures: inspect `.github/workflows/pr-intake-sanity.yml` sticky comment and run logs. +7. Label policy parity failures: inspect `.github/workflows/label-policy-sanity.yml`. +8. Docs failures in CI: inspect `docs-quality` job logs in `.github/workflows/ci.yml`. +9. Strict delta lint failures in CI: inspect `lint-strict-delta` job logs and compare with `BASE_SHA` diff scope. + +## Maintenance Rules + +- Keep merge-blocking checks deterministic and reproducible (`--locked` where applicable). +- Keep merge-blocking rust quality policy aligned across `.github/workflows/ci.yml`, `dev/ci.sh`, and `.githooks/pre-push` (`./scripts/ci/rust_quality_gate.sh` + `./scripts/ci/rust_strict_delta_gate.sh`). +- Use `./scripts/ci/rust_strict_delta_gate.sh` (or `./dev/ci.sh lint-delta`) as the incremental strict merge gate for changed Rust lines. +- Run full strict lint audits regularly via `./scripts/ci/rust_quality_gate.sh --strict` (for example through `./dev/ci.sh lint-strict`) and track cleanup in focused PRs. +- Keep docs markdown gating incremental via `./scripts/ci/docs_quality_gate.sh` (block changed-line issues, report baseline issues separately). +- Keep docs link gating incremental via `./scripts/ci/collect_changed_links.py` + lychee (check only links added on changed lines). +- Prefer explicit workflow permissions (least privilege). +- Keep Actions source policy restricted to approved allowlist patterns (see `docs/actions-source-policy.md`). +- Use path filters for expensive workflows when practical. +- Keep docs quality checks low-noise (incremental markdown + incremental added-link checks). +- Keep dependency update volume controlled (grouping + PR limits). +- Avoid mixing onboarding/community automation with merge-gating logic. + +## Automation Side-Effect Controls + +- Prefer deterministic automation that can be manually overridden (`risk: manual`) when context is nuanced. +- Keep auto-response comments deduplicated to prevent triage noise. +- Keep auto-close behavior scoped to issues; maintainers own PR close/merge decisions. +- If automation is wrong, correct labels first, then continue review with explicit rationale. +- Use `superseded` / `stale-candidate` labels to prune duplicate or dormant PRs before deep review. diff --git a/docs/datasheets/arduino-uno.md b/docs/datasheets/arduino-uno.md new file mode 100644 index 0000000..be4d4fc --- /dev/null +++ b/docs/datasheets/arduino-uno.md @@ -0,0 +1,37 @@ +# Arduino Uno + +## Pin Aliases + +| alias | pin | +|-------------|-----| +| red_led | 13 | +| builtin_led | 13 | +| user_led | 13 | + +## Overview + +Arduino Uno is a microcontroller board based on the ATmega328P. It has 14 digital I/O pins (0–13) and 6 analog inputs (A0–A5). + +## Digital Pins + +- **Pins 0–13:** Digital I/O. Can be INPUT or OUTPUT. +- **Pin 13:** Built-in LED (onboard). Connect LED to GND or use for output. +- **Pins 0–1:** Also used for Serial (RX/TX). Avoid if using Serial. + +## GPIO + +- `digitalWrite(pin, HIGH)` or `digitalWrite(pin, LOW)` for output. +- `digitalRead(pin)` for input (returns 0 or 1). +- Pin numbers in ZeroClaw protocol: 0–13. + +## Serial + +- UART on pins 0 (RX) and 1 (TX). +- USB via ATmega16U2 or CH340 (clones). +- Baud rate: 115200 for ZeroClaw firmware. + +## ZeroClaw Tools + +- `gpio_read`: Read pin value (0 or 1). +- `gpio_write`: Set pin high (1) or low (0). +- `arduino_upload`: Agent generates full Arduino sketch code; ZeroClaw compiles and uploads it via arduino-cli. Use for "make a heart", custom patterns — agent writes the code, no manual editing. Pin 13 = built-in LED. diff --git a/docs/datasheets/esp32.md b/docs/datasheets/esp32.md new file mode 100644 index 0000000..8cb453d --- /dev/null +++ b/docs/datasheets/esp32.md @@ -0,0 +1,22 @@ +# ESP32 GPIO Reference + +## Pin Aliases + +| alias | pin | +|-------------|-----| +| builtin_led | 2 | +| red_led | 2 | + +## Common pins (ESP32 / ESP32-C3) + +- **GPIO 2**: Built-in LED on many dev boards (output) +- **GPIO 13**: General-purpose output +- **GPIO 21/20**: Often used for UART0 TX/RX (avoid if using serial) + +## Protocol + +ZeroClaw host sends JSON over serial (115200 baud): +- `gpio_read`: `{"id":"1","cmd":"gpio_read","args":{"pin":13}}` +- `gpio_write`: `{"id":"1","cmd":"gpio_write","args":{"pin":13,"value":1}}` + +Response: `{"id":"1","ok":true,"result":"0"}` or `{"id":"1","ok":true,"result":"done"}` diff --git a/docs/datasheets/nucleo-f401re.md b/docs/datasheets/nucleo-f401re.md new file mode 100644 index 0000000..22b1e93 --- /dev/null +++ b/docs/datasheets/nucleo-f401re.md @@ -0,0 +1,16 @@ +# Nucleo-F401RE GPIO + +## Pin Aliases + +| alias | pin | +|-------------|-----| +| red_led | 13 | +| user_led | 13 | +| ld2 | 13 | +| builtin_led | 13 | + +## GPIO + +Pin 13: User LED (LD2) +- Output, active high +- PA5 on STM32F401 diff --git a/docs/frictionless-security.md b/docs/frictionless-security.md new file mode 100644 index 0000000..d23dbfc --- /dev/null +++ b/docs/frictionless-security.md @@ -0,0 +1,312 @@ +# Frictionless Security: Zero Impact on Wizard + +## Core Principle +> **"Security features should be like airbags — present, protective, and invisible until needed."** + +## Design: Silent Auto-Detection + +### 1. No New Wizard Steps (Stays 9 Steps, < 60 Seconds) + +```rust +// Wizard remains UNCHANGED +// Security features auto-detect in background + +pub fn run_wizard() -> Result { + // ... existing 9 steps, no changes ... + + let config = Config { + // ... existing fields ... + + // NEW: Auto-detected security (not shown in wizard) + security: SecurityConfig::autodetect(), // Silent! + }; + + config.save()?; + Ok(config) +} +``` + +### 2. Auto-Detection Logic (Runs Once at First Start) + +```rust +// src/security/detect.rs + +impl SecurityConfig { + /// Detect available sandboxing and enable automatically + /// Returns smart defaults based on platform + available tools + pub fn autodetect() -> Self { + Self { + // Sandbox: prefer Landlock (native), then Firejail, then none + sandbox: SandboxConfig::autodetect(), + + // Resource limits: always enable monitoring + resources: ResourceLimits::default(), + + // Audit: enable by default, log to config dir + audit: AuditConfig::default(), + + // Everything else: safe defaults + ..SecurityConfig::default() + } + } +} + +impl SandboxConfig { + pub fn autodetect() -> Self { + #[cfg(target_os = "linux")] + { + // Prefer Landlock (native, no dependency) + if Self::probe_landlock() { + return Self { + enabled: true, + backend: SandboxBackend::Landlock, + ..Self::default() + }; + } + + // Fallback: Firejail if installed + if Self::probe_firejail() { + return Self { + enabled: true, + backend: SandboxBackend::Firejail, + ..Self::default() + }; + } + } + + #[cfg(target_os = "macos")] + { + // Try Bubblewrap on macOS + if Self::probe_bubblewrap() { + return Self { + enabled: true, + backend: SandboxBackend::Bubblewrap, + ..Self::default() + }; + } + } + + // Fallback: disabled (but still has application-layer security) + Self { + enabled: false, + backend: SandboxBackend::None, + ..Self::default() + } + } + + #[cfg(target_os = "linux")] + fn probe_landlock() -> bool { + // Try creating a minimal Landlock ruleset + // If it works, kernel supports Landlock + landlock::Ruleset::new() + .set_access_fs(landlock::AccessFS::read_file) + .add_path(Path::new("/tmp"), landlock::AccessFS::read_file) + .map(|ruleset| ruleset.restrict_self().is_ok()) + .unwrap_or(false) + } + + fn probe_firejail() -> bool { + // Check if firejail command exists + std::process::Command::new("firejail") + .arg("--version") + .output() + .map(|o| o.status.success()) + .unwrap_or(false) + } +} +``` + +### 3. First Run: Silent Logging + +```bash +$ zeroclaw agent -m "hello" + +# First time: silent detection +[INFO] Detecting security features... +[INFO] ✓ Landlock sandbox enabled (kernel 6.2+) +[INFO] ✓ Memory monitoring active (512MB limit) +[INFO] ✓ Audit logging enabled (~/.config/zeroclaw/audit.log) + +# Subsequent runs: quiet +$ zeroclaw agent -m "hello" +[agent] Thinking... +``` + +### 4. Config File: All Defaults Hidden + +```toml +# ~/.config/zeroclaw/config.toml + +# These sections are NOT written unless user customizes +# [security.sandbox] +# enabled = true # (default, auto-detected) +# backend = "landlock" # (default, auto-detected) + +# [security.resources] +# max_memory_mb = 512 # (default) + +# [security.audit] +# enabled = true # (default) +``` + +Only when user changes something: +```toml +[security.sandbox] +enabled = false # User explicitly disabled + +[security.resources] +max_memory_mb = 1024 # User increased limit +``` + +### 5. Advanced Users: Explicit Control + +```bash +# Check what's active +$ zeroclaw security --status +Security Status: + ✓ Sandbox: Landlock (Linux kernel 6.2) + ✓ Memory monitoring: 512MB limit + ✓ Audit logging: ~/.config/zeroclaw/audit.log + → 47 events logged today + +# Disable sandbox explicitly (writes to config) +$ zeroclaw config set security.sandbox.enabled false + +# Enable specific backend +$ zeroclaw config set security.sandbox.backend firejail + +# Adjust limits +$ zeroclaw config set security.resources.max_memory_mb 2048 +``` + +### 6. Graceful Degradation + +| Platform | Best Available | Fallback | Worst Case | +|----------|---------------|----------|------------| +| **Linux 5.13+** | Landlock | None | App-layer only | +| **Linux (any)** | Firejail | Landlock | App-layer only | +| **macOS** | Bubblewrap | None | App-layer only | +| **Windows** | None | - | App-layer only | + +**App-layer security is always present** — this is the existing allowlist/path blocking/injection protection that's already comprehensive. + +--- + +## Config Schema Extension + +```rust +// src/config/schema.rs + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SecurityConfig { + /// Sandbox configuration (auto-detected if not set) + #[serde(default)] + pub sandbox: SandboxConfig, + + /// Resource limits (defaults applied if not set) + #[serde(default)] + pub resources: ResourceLimits, + + /// Audit logging (enabled by default) + #[serde(default)] + pub audit: AuditConfig, +} + +impl Default for SecurityConfig { + fn default() -> Self { + Self { + sandbox: SandboxConfig::autodetect(), // Silent detection! + resources: ResourceLimits::default(), + audit: AuditConfig::default(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SandboxConfig { + /// Enable sandboxing (default: auto-detected) + #[serde(default)] + pub enabled: Option, // None = auto-detect + + /// Sandbox backend (default: auto-detect) + #[serde(default)] + pub backend: SandboxBackend, + + /// Custom Firejail args (optional) + #[serde(default)] + pub firejail_args: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum SandboxBackend { + Auto, // Auto-detect (default) + Landlock, // Linux kernel LSM + Firejail, // User-space sandbox + Bubblewrap, // User namespaces + Docker, // Container (heavy) + None, // Disabled +} + +impl Default for SandboxBackend { + fn default() -> Self { + Self::Auto // Always auto-detect by default + } +} +``` + +--- + +## User Experience Comparison + +### Before (Current) +```bash +$ zeroclaw onboard +[1/9] Workspace Setup... +[2/9] AI Provider... +... +[9/9] Workspace Files... +✓ Security: Supervised | workspace-scoped +``` + +### After (With Frictionless Security) +```bash +$ zeroclaw onboard +[1/9] Workspace Setup... +[2/9] AI Provider... +... +[9/9] Workspace Files... +✓ Security: Supervised | workspace-scoped | Landlock sandbox ✓ +# ↑ Just one extra word, silent auto-detection! +``` + +### Advanced User (Explicit Control) +```bash +$ zeroclaw onboard --security-level paranoid +[1/9] Workspace Setup... +... +✓ Security: Paranoid | Landlock + Firejail | Audit signed +``` + +--- + +## Backward Compatibility + +| Scenario | Behavior | +|----------|----------| +| **Existing config** | Works unchanged, new features opt-in | +| **New install** | Auto-detects and enables available security | +| **No sandbox available** | Falls back to app-layer (still secure) | +| **User disables** | One config flag: `sandbox.enabled = false` | + +--- + +## Summary + +✅ **Zero impact on wizard** — stays 9 steps, < 60 seconds +✅ **Zero new prompts** — silent auto-detection +✅ **Zero breaking changes** — backward compatible +✅ **Opt-out available** — explicit config flags +✅ **Status visibility** — `zeroclaw security --status` + +The wizard remains "quick setup universal applications" — security is just **quietly better**. diff --git a/docs/hardware-peripherals-design.md b/docs/hardware-peripherals-design.md new file mode 100644 index 0000000..87f61bf --- /dev/null +++ b/docs/hardware-peripherals-design.md @@ -0,0 +1,324 @@ +# Hardware Peripherals Design — ZeroClaw + +ZeroClaw enables microcontrollers (MCUs) and Single Board Computers (SBCs) to **dynamically interpret natural language commands**, generate hardware-specific code, and execute peripheral interactions in real-time. + +## 1. Vision + +**Goal:** ZeroClaw acts as a hardware-aware AI agent that: +- Receives natural language triggers (e.g. "Move X arm", "Turn on LED") via channels (WhatsApp, Telegram) +- Fetches accurate hardware documentation (datasheets, register maps) +- Synthesizes Rust code/logic using an LLM (Gemini, local open-source models) +- Executes the logic to manipulate peripherals (GPIO, I2C, SPI) +- Persists optimized code for future reuse + +**Mental model:** ZeroClaw = brain that understands hardware. Peripherals = arms and legs it controls. + +## 2. Two Modes of Operation + +### Mode 1: Edge-Native (Standalone) + +**Target:** Wi-Fi-enabled boards (ESP32, Raspberry Pi). + +ZeroClaw runs **directly on the device**. The board spins up a gRPC/nanoRPC server and communicates with peripherals locally. + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ ZeroClaw on ESP32 / Raspberry Pi (Edge-Native) │ +│ │ +│ ┌─────────────┐ ┌──────────────┐ ┌─────────────────────────────────┐ │ +│ │ Channels │───►│ Agent Loop │───►│ RAG: datasheets, register maps │ │ +│ │ WhatsApp │ │ (LLM calls) │ │ → LLM context │ │ +│ │ Telegram │ └──────┬───────┘ └─────────────────────────────────┘ │ +│ └─────────────┘ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────────┐│ +│ │ Code synthesis → Wasm / dynamic exec → GPIO / I2C / SPI → persist ││ +│ └─────────────────────────────────────────────────────────────────────────┘│ +│ │ +│ gRPC/nanoRPC server ◄──► Peripherals (GPIO, I2C, SPI, sensors, actuators) │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +**Workflow:** +1. User sends WhatsApp: *"Turn on LED on pin 13"* +2. ZeroClaw fetches board-specific docs (e.g. ESP32 GPIO mapping) +3. LLM synthesizes Rust code +4. Code runs in a sandbox (Wasm or dynamic linking) +5. GPIO is toggled; result returned to user +6. Optimized code is persisted for future "Turn on LED" requests + +**All happens on-device.** No host required. + +### Mode 2: Host-Mediated (Development / Debugging) + +**Target:** Hardware connected via USB / J-Link / Aardvark to a host (macOS, Linux). + +ZeroClaw runs on the **host** and maintains a hardware-aware link to the target. Used for development, introspection, and flashing. + +``` +┌─────────────────────┐ ┌──────────────────────────────────┐ +│ ZeroClaw on Mac │ USB / J-Link / │ STM32 Nucleo-F401RE │ +│ │ Aardvark │ (or other MCU) │ +│ - Channels │ ◄────────────────► │ - Memory map │ +│ - LLM │ │ - Peripherals (GPIO, ADC, I2C) │ +│ - Hardware probe │ VID/PID │ - Flash / RAM │ +│ - Flash / debug │ discovery │ │ +└─────────────────────┘ └──────────────────────────────────┘ +``` + +**Workflow:** +1. User sends Telegram: *"What are the readable memory addresses on this USB device?"* +2. ZeroClaw identifies connected hardware (VID/PID, architecture) +3. Performs memory mapping; suggests available address spaces +4. Returns result to user + +**Or:** +1. User: *"Flash this firmware to the Nucleo"* +2. ZeroClaw writes/flashes via OpenOCD or probe-rs +3. Confirms success + +**Or:** +1. ZeroClaw auto-discovers: *"STM32 Nucleo on /dev/ttyACM0, ARM Cortex-M4"* +2. Suggests: *"I can read/write GPIO, ADC, flash. What would you like to do?"* + +--- + +### Mode Comparison + +| Aspect | Edge-Native | Host-Mediated | +|------------------|--------------------------------|----------------------------------| +| ZeroClaw runs on | Device (ESP32, RPi) | Host (Mac, Linux) | +| Hardware link | Local (GPIO, I2C, SPI) | USB, J-Link, Aardvark | +| LLM | On-device or cloud (Gemini) | Host (cloud or local) | +| Use case | Production, standalone | Dev, debug, introspection | +| Channels | WhatsApp, etc. (via WiFi) | Telegram, CLI, etc. | + +## 3. Legacy / Simpler Modes (Pre-LLM-on-Edge) + +For boards without WiFi or before full Edge-Native is ready: + +### Mode A: Host + Remote Peripheral (STM32 via serial) + +Host runs ZeroClaw; peripheral runs minimal firmware. Simple JSON over serial. + +### Mode B: RPi as Host (Native GPIO) + +ZeroClaw on Pi; GPIO via rppal or sysfs. No separate firmware. + +## 4. Technical Requirements + +| Requirement | Description | +|-------------|-------------| +| **Language** | Pure Rust. `no_std` where applicable for embedded targets (STM32, ESP32). | +| **Communication** | Lightweight gRPC or nanoRPC stack for low-latency command processing. | +| **Dynamic execution** | Safely run LLM-generated logic on-the-fly: Wasm runtime for isolation, or dynamic linking where supported. | +| **Documentation retrieval** | RAG (Retrieval-Augmented Generation) pipeline to feed datasheet snippets, register maps, and pinouts into LLM context. | +| **Hardware discovery** | VID/PID-based identification for USB devices; architecture detection (ARM Cortex-M, RISC-V, etc.). | + +### RAG Pipeline (Datasheet Retrieval) + +- **Index:** Datasheets, reference manuals, register maps (PDF → chunks, embeddings). +- **Retrieve:** On user query ("turn on LED"), fetch relevant snippets (e.g. GPIO section for target board). +- **Inject:** Add to LLM system prompt or context. +- **Result:** LLM generates accurate, board-specific code. + +### Dynamic Execution Options + +| Option | Pros | Cons | +|-------|------|------| +| **Wasm** | Sandboxed, portable, no FFI | Overhead; limited HW access from Wasm | +| **Dynamic linking** | Native speed, full HW access | Platform-specific; security concerns | +| **Interpreted DSL** | Safe, auditable | Slower; limited expressiveness | +| **Pre-compiled templates** | Fast, secure | Less flexible; requires template library | + +**Recommendation:** Start with pre-compiled templates + parameterization; evolve to Wasm for user-defined logic once stable. + +## 5. CLI and Config + +### CLI Flags + +```bash +# Edge-Native: run on device (ESP32, RPi) +zeroclaw agent --mode edge + +# Host-Mediated: connect to USB/J-Link target +zeroclaw agent --peripheral nucleo-f401re:/dev/ttyACM0 +zeroclaw agent --probe jlink + +# Hardware introspection +zeroclaw hardware discover +zeroclaw hardware introspect /dev/ttyACM0 +``` + +### Config (config.toml) + +```toml +[peripherals] +enabled = true +mode = "host" # "edge" | "host" +datasheet_dir = "docs/datasheets" # RAG: board-specific docs for LLM context + +[[peripherals.boards]] +board = "nucleo-f401re" +transport = "serial" +path = "/dev/ttyACM0" +baud = 115200 + +[[peripherals.boards]] +board = "rpi-gpio" +transport = "native" + +[[peripherals.boards]] +board = "esp32" +transport = "wifi" +# Edge-Native: ZeroClaw runs on ESP32 +``` + +## 6. Architecture: Peripheral as Extension Point + +### New Trait: `Peripheral` + +```rust +/// A hardware peripheral that exposes capabilities as tools. +#[async_trait] +pub trait Peripheral: Send + Sync { + fn name(&self) -> &str; + fn board_type(&self) -> &str; // e.g. "nucleo-f401re", "rpi-gpio" + async fn connect(&mut self) -> anyhow::Result<()>; + async fn disconnect(&mut self) -> anyhow::Result<()>; + async fn health_check(&self) -> bool; + /// Tools this peripheral provides (gpio_read, gpio_write, sensor_read, etc.) + fn tools(&self) -> Vec>; +} +``` + +### Flow + +1. **Startup:** ZeroClaw loads config, sees `peripherals.boards`. +2. **Connect:** For each board, create a `Peripheral` impl, call `connect()`. +3. **Tools:** Collect tools from all connected peripherals; merge with default tools. +4. **Agent loop:** Agent can call `gpio_write`, `sensor_read`, etc. — these delegate to the peripheral. +5. **Shutdown:** Call `disconnect()` on each peripheral. + +### Board Support + +| Board | Transport | Firmware / Driver | Tools | +|--------------------|-----------|------------------------|--------------------------| +| nucleo-f401re | serial | Zephyr / Embassy | gpio_read, gpio_write, adc_read | +| rpi-gpio | native | rppal or sysfs | gpio_read, gpio_write | +| esp32 | serial/ws | ESP-IDF / Embassy | gpio, wifi, mqtt | + +## 7. Communication Protocols + +### gRPC / nanoRPC (Edge-Native, Host-Mediated) + +For low-latency, typed RPC between ZeroClaw and peripherals: + +- **nanoRPC** or **tonic** (gRPC): Protobuf-defined services. +- Methods: `GpioWrite`, `GpioRead`, `I2cTransfer`, `SpiTransfer`, `MemoryRead`, `FlashWrite`, etc. +- Enables streaming, bidirectional calls, and code generation from `.proto` files. + +### Serial Fallback (Host-Mediated, legacy) + +Simple JSON over serial for boards without gRPC support: + +**Request (host → peripheral):** +```json +{"id":"1","cmd":"gpio_write","args":{"pin":13,"value":1}} +``` + +**Response (peripheral → host):** +```json +{"id":"1","ok":true,"result":"done"} +``` + +## 8. Firmware (Separate Repo or Crate) + +- **zeroclaw-firmware** or **zeroclaw-peripheral** — a separate crate/workspace. +- Targets: `thumbv7em-none-eabihf` (STM32), `armv7-unknown-linux-gnueabihf` (RPi), etc. +- Uses `embassy` or Zephyr for STM32. +- Implements the protocol above. +- User flashes this to the board; ZeroClaw connects and discovers capabilities. + +## 9. Implementation Phases + +### Phase 1: Skeleton ✅ (Done) + +- [x] Add `Peripheral` trait, config schema, CLI (`zeroclaw peripheral list/add`) +- [x] Add `--peripheral` flag to agent +- [x] Document in AGENTS.md + +### Phase 2: Host-Mediated — Hardware Discovery ✅ (Done) + +- [x] `zeroclaw hardware discover`: enumerate USB devices (VID/PID) +- [x] Board registry: map VID/PID → architecture, name (e.g. Nucleo-F401RE) +- [x] `zeroclaw hardware introspect `: memory map, peripheral list + +### Phase 3: Host-Mediated — Serial / J-Link + +- [x] `SerialPeripheral` for STM32 over USB CDC +- [ ] probe-rs or OpenOCD integration for flash/debug +- [x] Tools: `gpio_read`, `gpio_write` (memory_read, flash_write in future) + +### Phase 4: RAG Pipeline ✅ (Done) + +- [x] Datasheet index (markdown/text → chunks) +- [x] Retrieve-and-inject into LLM context on hardware-related queries +- [x] Board-specific prompt augmentation + +**Usage:** Add `datasheet_dir = "docs/datasheets"` to `[peripherals]` in config.toml. Place `.md` or `.txt` files named by board (e.g. `nucleo-f401re.md`, `rpi-gpio.md`). Files in `_generic/` or named `generic.md` apply to all boards. Chunks are retrieved by keyword match and injected into the user message context. + +### Phase 5: Edge-Native — RPi ✅ (Done) + +- [x] ZeroClaw on Raspberry Pi (native GPIO via rppal) +- [ ] gRPC/nanoRPC server for local peripheral access +- [ ] Code persistence (store synthesized snippets) + +### Phase 6: Edge-Native — ESP32 + +- [x] Host-mediated ESP32 (serial transport) — same JSON protocol as STM32 +- [x] `zeroclaw-esp32` firmware crate (`firmware/zeroclaw-esp32`) — GPIO over UART +- [x] ESP32 in hardware registry (CH340 VID/PID) +- [ ] ZeroClaw *on* ESP32 (WiFi + LLM, edge-native) — future +- [ ] Wasm or template-based execution for LLM-generated logic + +**Usage:** Flash `firmware/zeroclaw-esp32` to ESP32, add `board = "esp32"`, `transport = "serial"`, `path = "/dev/ttyUSB0"` to config. + +### Phase 7: Dynamic Execution (LLM-Generated Code) + +- [ ] Template library: parameterized GPIO/I2C/SPI snippets +- [ ] Optional: Wasm runtime for user-defined logic (sandboxed) +- [ ] Persist and reuse optimized code paths + +## 10. Security Considerations + +- **Serial path:** Validate `path` is in allowlist (e.g. `/dev/ttyACM*`, `/dev/ttyUSB*`); never arbitrary paths. +- **GPIO:** Restrict which pins are exposed; avoid power/reset pins. +- **No secrets on peripheral:** Firmware should not store API keys; host handles auth. + +## 11. Non-Goals (For Now) + +- Running full ZeroClaw *on* bare STM32 (no WiFi, limited RAM) — use Host-Mediated instead +- Real-time guarantees — peripherals are best-effort +- Arbitrary native code execution from LLM — prefer Wasm or templates + +## 12. Related Documents + +- [adding-boards-and-tools.md](./adding-boards-and-tools.md) — How to add boards and datasheets +- [network-deployment.md](./network-deployment.md) — RPi and network deployment + +## 13. References + +- [Zephyr RTOS Rust support](https://docs.zephyrproject.org/latest/develop/languages/rust/index.html) +- [Embassy](https://embassy.dev/) — async embedded framework +- [rppal](https://github.com/golemparts/rppal) — Raspberry Pi GPIO in Rust +- [STM32 Nucleo-F401RE](https://www.st.com/en/evaluation-tools/nucleo-f401re.html) +- [tonic](https://github.com/hyperium/tonic) — gRPC for Rust +- [probe-rs](https://probe.rs/) — ARM debug probe, flash, memory access +- [nusb](https://github.com/nic-hartley/nusb) — USB device enumeration (VID/PID) + +## 14. Raw Prompt Summary + +> *"Boards like ESP, Raspberry Pi, or boards with WiFi can connect to an LLM (Gemini or open-source). ZeroClaw runs on the device, creates its own gRPC, spins it up, and communicates with peripherals. User asks via WhatsApp: 'move X arm' or 'turn on LED'. ZeroClaw gets accurate documentation, writes code, executes it, stores it optimally, runs it, and turns on the LED — all on the development board.* +> +> *For STM Nucleo connected via USB/J-Link/Aardvark to my Mac: ZeroClaw from my Mac accesses the hardware, installs or writes what it wants on the device, and returns the result. Example: 'Hey ZeroClaw, what are the available/readable addresses on this USB device?' It can figure out what's connected where and suggest."* diff --git a/docs/langgraph-integration.md b/docs/langgraph-integration.md new file mode 100644 index 0000000..a7e64f9 --- /dev/null +++ b/docs/langgraph-integration.md @@ -0,0 +1,239 @@ +# LangGraph Integration Guide + +This guide explains how to use the `zeroclaw-tools` Python package for consistent tool calling with any OpenAI-compatible LLM provider. + +## Background + +Some LLM providers, particularly Chinese models like GLM-5 (Zhipu AI), have inconsistent tool calling behavior when using text-based tool invocation. ZeroClaw's Rust core uses structured tool calling via the OpenAI API format, but some models respond better to a different approach. + +LangGraph provides a stateful graph execution engine that guarantees consistent tool calling behavior regardless of the underlying model's native capabilities. + +## Architecture + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Your Application │ +├─────────────────────────────────────────────────────────────┤ +│ zeroclaw-tools Agent │ +│ │ +│ ┌─────────────────────────────────────────────────────┐ │ +│ │ LangGraph StateGraph │ │ +│ │ │ │ +│ │ ┌────────────┐ ┌────────────┐ │ │ +│ │ │ Agent │ ──────▶ │ Tools │ │ │ +│ │ │ Node │ ◀────── │ Node │ │ │ +│ │ └────────────┘ └────────────┘ │ │ +│ │ │ │ │ │ +│ │ ▼ ▼ │ │ +│ │ [Continue?] [Execute Tool] │ │ +│ │ │ │ │ │ +│ │ Yes │ No Result│ │ │ +│ │ ▼ ▼ │ │ +│ │ [END] [Back to Agent] │ │ +│ │ │ │ +│ └─────────────────────────────────────────────────────┘ │ +│ │ +├─────────────────────────────────────────────────────────────┤ +│ OpenAI-Compatible LLM Provider │ +│ (Z.AI, OpenRouter, Groq, DeepSeek, Ollama, etc.) │ +└─────────────────────────────────────────────────────────────┘ +``` + +## Quick Start + +### Installation + +```bash +pip install zeroclaw-tools +``` + +### Basic Usage + +```python +import asyncio +from zeroclaw_tools import create_agent, shell, file_read, file_write +from langchain_core.messages import HumanMessage + +async def main(): + agent = create_agent( + tools=[shell, file_read, file_write], + model="glm-5", + api_key="your-api-key", + base_url="https://api.z.ai/api/coding/paas/v4" + ) + + result = await agent.ainvoke({ + "messages": [HumanMessage(content="Read /etc/hostname and tell me the machine name")] + }) + + print(result["messages"][-1].content) + +asyncio.run(main()) +``` + +## Available Tools + +### Core Tools + +| Tool | Description | +|------|-------------| +| `shell` | Execute shell commands | +| `file_read` | Read file contents | +| `file_write` | Write content to files | + +### Extended Tools + +| Tool | Description | +|------|-------------| +| `web_search` | Search the web (requires `BRAVE_API_KEY`) | +| `http_request` | Make HTTP requests | +| `memory_store` | Store data in persistent memory | +| `memory_recall` | Recall stored data | + +## Custom Tools + +Create your own tools with the `@tool` decorator: + +```python +from zeroclaw_tools import tool, create_agent + +@tool +def get_weather(city: str) -> str: + """Get the current weather for a city.""" + # Your implementation + return f"Weather in {city}: Sunny, 25°C" + +@tool +def query_database(sql: str) -> str: + """Execute a SQL query and return results.""" + # Your implementation + return "Query returned 5 rows" + +agent = create_agent( + tools=[get_weather, query_database], + model="glm-5", + api_key="your-key" +) +``` + +## Provider Configuration + +### Z.AI / GLM-5 + +```python +agent = create_agent( + model="glm-5", + api_key="your-zhipu-key", + base_url="https://api.z.ai/api/coding/paas/v4" +) +``` + +### OpenRouter + +```python +agent = create_agent( + model="anthropic/claude-3.5-sonnet", + api_key="your-openrouter-key", + base_url="https://openrouter.ai/api/v1" +) +``` + +### Groq + +```python +agent = create_agent( + model="llama-3.3-70b-versatile", + api_key="your-groq-key", + base_url="https://api.groq.com/openai/v1" +) +``` + +### Ollama (Local) + +```python +agent = create_agent( + model="llama3.2", + base_url="http://localhost:11434/v1" +) +``` + +## Discord Bot Integration + +```python +import os +from zeroclaw_tools.integrations import DiscordBot + +bot = DiscordBot( + token=os.environ["DISCORD_TOKEN"], + guild_id=123456789, # Your Discord server ID + allowed_users=["123456789"], # User IDs that can use the bot + api_key=os.environ["API_KEY"], + model="glm-5" +) + +bot.run() +``` + +## CLI Usage + +```bash +# Set environment variables +export API_KEY="your-key" +export BRAVE_API_KEY="your-brave-key" # Optional, for web search + +# Single message +zeroclaw-tools "What is the current date?" + +# Interactive mode +zeroclaw-tools -i +``` + +## Comparison with Rust ZeroClaw + +| Aspect | Rust ZeroClaw | zeroclaw-tools | +|--------|---------------|-----------------| +| **Performance** | Ultra-fast (~10ms startup) | Python startup (~500ms) | +| **Memory** | <5 MB | ~50 MB | +| **Binary size** | ~3.4 MB | pip package | +| **Tool consistency** | Model-dependent | LangGraph guarantees | +| **Extensibility** | Rust traits | Python decorators | +| **Ecosystem** | Rust crates | PyPI packages | + +**When to use Rust ZeroClaw:** +- Production edge deployments +- Resource-constrained environments (Raspberry Pi, etc.) +- Maximum performance requirements + +**When to use zeroclaw-tools:** +- Models with inconsistent native tool calling +- Python-centric development +- Rapid prototyping +- Integration with Python ML ecosystem + +## Troubleshooting + +### "API key required" error + +Set the `API_KEY` environment variable or pass `api_key` to `create_agent()`. + +### Tool calls not executing + +Ensure your model supports function calling. Some older models may not support tools. + +### Rate limiting + +Add delays between calls or implement your own rate limiting: + +```python +import asyncio + +for message in messages: + result = await agent.ainvoke({"messages": [message]}) + await asyncio.sleep(1) # Rate limit +``` + +## Related Projects + +- [rs-graph-llm](https://github.com/a-agmon/rs-graph-llm) - Rust LangGraph alternative +- [langchain-rust](https://github.com/Abraxas-365/langchain-rust) - LangChain for Rust +- [llm-chain](https://github.com/sobelio/llm-chain) - LLM chains in Rust diff --git a/docs/mattermost-setup.md b/docs/mattermost-setup.md new file mode 100644 index 0000000..6549880 --- /dev/null +++ b/docs/mattermost-setup.md @@ -0,0 +1,48 @@ +# Mattermost Integration Guide + +ZeroClaw supports native integration with Mattermost via its REST API v4. This integration is ideal for self-hosted, private, or air-gapped environments where sovereign communication is a requirement. + +## Prerequisites + +1. **Mattermost Server**: A running Mattermost instance (self-hosted or cloud). +2. **Bot Account**: + - Go to **Main Menu > Integrations > Bot Accounts**. + - Click **Add Bot Account**. + - Set a username (e.g., `zeroclaw-bot`). + - Enable **post:all** and **channel:read** permissions (or appropriate scopes). + - Save the **Access Token**. +3. **Channel ID**: + - Open the Mattermost channel you want the bot to monitor. + - Click the channel header and select **View Info**. + - Copy the **ID** (e.g., `7j8k9l...`). + +## Configuration + +Add the following to your `config.toml` under the `[channels]` section: + +```toml +[channels.mattermost] +url = "https://mm.your-domain.com" +bot_token = "your-bot-access-token" +channel_id = "your-channel-id" +allowed_users = ["user-id-1", "user-id-2"] +``` + +### Configuration Fields + +| Field | Description | +|---|---| +| `url` | The base URL of your Mattermost server. | +| `bot_token` | The Personal Access Token for the bot account. | +| `channel_id` | (Optional) The ID of the channel to listen to. Required for `listen` mode. | +| `allowed_users` | (Optional) A list of Mattermost User IDs permitted to interact with the bot. Use `["*"]` to allow everyone. | + +## Threaded Conversations + +ZeroClaw automatically supports Mattermost threads. +- If a user sends a message in a thread, ZeroClaw will reply within that same thread. +- If a user sends a top-level message, ZeroClaw will start a thread by replying to that post. + +## Security Note + +Mattermost integration is designed for **sovereign communication**. By hosting your own Mattermost server, your agent's communication history remains entirely within your own infrastructure, avoiding third-party cloud logging. diff --git a/docs/network-deployment.md b/docs/network-deployment.md new file mode 100644 index 0000000..54a7694 --- /dev/null +++ b/docs/network-deployment.md @@ -0,0 +1,203 @@ +# Network Deployment — ZeroClaw on Raspberry Pi and Local Network + +This document covers deploying ZeroClaw on a Raspberry Pi or other host on your local network, with Telegram and optional webhook channels. + +--- + +## 1. Overview + +| Mode | Inbound port needed? | Use case | +|------|----------------------|----------| +| **Telegram polling** | No | ZeroClaw polls Telegram API; works from anywhere | +| **Discord/Slack** | No | Same — outbound only | +| **Gateway webhook** | Yes | POST /webhook, WhatsApp, etc. need a public URL | +| **Gateway pairing** | Yes | If you pair clients via the gateway | + +**Key:** Telegram, Discord, and Slack use **long-polling** — ZeroClaw makes outbound requests. No port forwarding or public IP required. + +--- + +## 2. ZeroClaw on Raspberry Pi + +### 2.1 Prerequisites + +- Raspberry Pi (3/4/5) with Raspberry Pi OS +- USB peripherals (Arduino, Nucleo) if using serial transport +- Optional: `rppal` for native GPIO (`peripheral-rpi` feature) + +### 2.2 Install + +```bash +# Build for RPi (or cross-compile from host) +cargo build --release --features hardware + +# Or install via your preferred method +``` + +### 2.3 Config + +Edit `~/.zeroclaw/config.toml`: + +```toml +[peripherals] +enabled = true + +[[peripherals.boards]] +board = "rpi-gpio" +transport = "native" + +# Or Arduino over USB +[[peripherals.boards]] +board = "arduino-uno" +transport = "serial" +path = "/dev/ttyACM0" +baud = 115200 + +[channels_config.telegram] +bot_token = "YOUR_BOT_TOKEN" +allowed_users = [] + +[gateway] +host = "127.0.0.1" +port = 8080 +allow_public_bind = false +``` + +### 2.4 Run Daemon (Local Only) + +```bash +zeroclaw daemon --host 127.0.0.1 --port 8080 +``` + +- Gateway binds to `127.0.0.1` — not reachable from other machines +- Telegram channel works: ZeroClaw polls Telegram API (outbound) +- No firewall or port forwarding needed + +--- + +## 3. Binding to 0.0.0.0 (Local Network) + +To allow other devices on your LAN to hit the gateway (e.g. for pairing or webhooks): + +### 3.1 Option A: Explicit Opt-In + +```toml +[gateway] +host = "0.0.0.0" +port = 8080 +allow_public_bind = true +``` + +```bash +zeroclaw daemon --host 0.0.0.0 --port 8080 +``` + +**Security:** `allow_public_bind = true` exposes the gateway to your local network. Only use on trusted LANs. + +### 3.2 Option B: Tunnel (Recommended for Webhooks) + +If you need a **public URL** (e.g. WhatsApp webhook, external clients): + +1. Run gateway on localhost: + ```bash + zeroclaw daemon --host 127.0.0.1 --port 8080 + ``` + +2. Start a tunnel: + ```toml + [tunnel] + provider = "tailscale" # or "ngrok", "cloudflare" + ``` + Or use `zeroclaw tunnel` (see tunnel docs). + +3. ZeroClaw will refuse `0.0.0.0` unless `allow_public_bind = true` or a tunnel is active. + +--- + +## 4. Telegram Polling (No Inbound Port) + +Telegram uses **long-polling** by default: + +- ZeroClaw calls `https://api.telegram.org/bot{token}/getUpdates` +- No inbound port or public IP needed +- Works behind NAT, on RPi, in a home lab + +**Config:** + +```toml +[channels_config.telegram] +bot_token = "YOUR_BOT_TOKEN" +allowed_users = [] # deny-by-default, bind identities explicitly +``` + +Run `zeroclaw daemon` — Telegram channel starts automatically. + +To approve one Telegram account at runtime: + +```bash +zeroclaw channel bind-telegram +``` + +`` can be a numeric Telegram user ID or a username (without `@`). + +### 4.1 Single Poller Rule (Important) + +Telegram Bot API `getUpdates` supports only one active poller per bot token. + +- Keep one runtime instance for the same token (recommended: `zeroclaw daemon` service). +- Do not run `cargo run -- channel start` or another bot process at the same time. + +If you hit this error: + +`Conflict: terminated by other getUpdates request` + +you have a polling conflict. Stop extra instances and restart only one daemon. + +--- + +## 5. Webhook Channels (WhatsApp, Custom) + +Webhook-based channels need a **public URL** so Meta (WhatsApp) or your client can POST events. + +### 5.1 Tailscale Funnel + +```toml +[tunnel] +provider = "tailscale" +``` + +Tailscale Funnel exposes your gateway via a `*.ts.net` URL. No port forwarding. + +### 5.2 ngrok + +```toml +[tunnel] +provider = "ngrok" +``` + +Or run ngrok manually: +```bash +ngrok http 8080 +# Use the HTTPS URL for your webhook +``` + +### 5.3 Cloudflare Tunnel + +Configure Cloudflare Tunnel to forward to `127.0.0.1:8080`, then set your webhook URL to the tunnel's public hostname. + +--- + +## 6. Checklist: RPi Deployment + +- [ ] Build with `--features hardware` (and `peripheral-rpi` if using native GPIO) +- [ ] Configure `[peripherals]` and `[channels_config.telegram]` +- [ ] Run `zeroclaw daemon --host 127.0.0.1 --port 8080` (Telegram works without 0.0.0.0) +- [ ] For LAN access: `--host 0.0.0.0` + `allow_public_bind = true` in config +- [ ] For webhooks: use Tailscale, ngrok, or Cloudflare tunnel + +--- + +## 7. References + +- [hardware-peripherals-design.md](./hardware-peripherals-design.md) — Peripherals design +- [adding-boards-and-tools.md](./adding-boards-and-tools.md) — Hardware setup and adding boards diff --git a/docs/nucleo-setup.md b/docs/nucleo-setup.md new file mode 100644 index 0000000..76e942e --- /dev/null +++ b/docs/nucleo-setup.md @@ -0,0 +1,147 @@ +# ZeroClaw on Nucleo-F401RE — Step-by-Step Guide + +Run ZeroClaw on your Mac or Linux host. Connect a Nucleo-F401RE via USB. Control GPIO (LED, pins) via Telegram or CLI. + +--- + +## Get Board Info via Telegram (No Firmware Needed) + +ZeroClaw can read chip info from the Nucleo over USB **without flashing any firmware**. Message your Telegram bot: + +- *"What board info do I have?"* +- *"Board info"* +- *"What hardware is connected?"* +- *"Chip info"* + +The agent uses the `hardware_board_info` tool to return chip name, architecture, and memory map. With the `probe` feature, it reads live data via USB/SWD; otherwise it returns static datasheet info. + +**Config:** Add Nucleo to `config.toml` first (so the agent knows which board to query): + +```toml +[[peripherals.boards]] +board = "nucleo-f401re" +transport = "serial" +path = "/dev/ttyACM0" +baud = 115200 +``` + +**CLI alternative:** + +```bash +cargo build --features hardware,probe +zeroclaw hardware info +zeroclaw hardware discover +``` + +--- + +## What's Included (No Code Changes Needed) + +ZeroClaw includes everything for Nucleo-F401RE: + +| Component | Location | Purpose | +|-----------|----------|---------| +| Firmware | `firmware/zeroclaw-nucleo/` | Embassy Rust — USART2 (115200), gpio_read, gpio_write | +| Serial peripheral | `src/peripherals/serial.rs` | JSON-over-serial protocol (same as Arduino/ESP32) | +| Flash command | `zeroclaw peripheral flash-nucleo` | Builds firmware, flashes via probe-rs | + +Protocol: newline-delimited JSON. Request: `{"id":"1","cmd":"gpio_write","args":{"pin":13,"value":1}}`. Response: `{"id":"1","ok":true,"result":"done"}`. + +--- + +## Prerequisites + +- Nucleo-F401RE board +- USB cable (USB-A to Mini-USB; Nucleo has built-in ST-Link) +- For flashing: `cargo install probe-rs-tools --locked` (or use the [install script](https://probe.rs/docs/getting-started/installation/)) + +--- + +## Phase 1: Flash Firmware + +### 1.1 Connect Nucleo + +1. Connect Nucleo to your Mac/Linux via USB. +2. The board appears as a USB device (ST-Link). No separate driver needed on modern systems. + +### 1.2 Flash via ZeroClaw + +From the zeroclaw repo root: + +```bash +zeroclaw peripheral flash-nucleo +``` + +This builds `firmware/zeroclaw-nucleo` and runs `probe-rs run --chip STM32F401RETx`. The firmware runs immediately after flashing. + +### 1.3 Manual Flash (Alternative) + +```bash +cd firmware/zeroclaw-nucleo +cargo build --release --target thumbv7em-none-eabihf +probe-rs run --chip STM32F401RETx target/thumbv7em-none-eabihf/release/zeroclaw-nucleo +``` + +--- + +## Phase 2: Find Serial Port + +- **macOS:** `/dev/cu.usbmodem*` or `/dev/tty.usbmodem*` (e.g. `/dev/cu.usbmodem101`) +- **Linux:** `/dev/ttyACM0` (or check `dmesg` after plugging in) + +USART2 (PA2/PA3) is bridged to the ST-Link's virtual COM port, so the host sees one serial device. + +--- + +## Phase 3: Configure ZeroClaw + +Add to `~/.zeroclaw/config.toml`: + +```toml +[peripherals] +enabled = true + +[[peripherals.boards]] +board = "nucleo-f401re" +transport = "serial" +path = "/dev/cu.usbmodem101" # adjust to your port +baud = 115200 +``` + +--- + +## Phase 4: Run and Test + +```bash +zeroclaw daemon --host 127.0.0.1 --port 8080 +``` + +Or use the agent directly: + +```bash +zeroclaw agent --message "Turn on the LED on pin 13" +``` + +Pin 13 = PA5 = User LED (LD2) on Nucleo-F401RE. + +--- + +## Summary: Commands + +| Step | Command | +|------|---------| +| 1 | Connect Nucleo via USB | +| 2 | `cargo install probe-rs --locked` | +| 3 | `zeroclaw peripheral flash-nucleo` | +| 4 | Add Nucleo to config.toml (path = your serial port) | +| 5 | `zeroclaw daemon` or `zeroclaw agent -m "Turn on LED"` | + +--- + +## Troubleshooting + +- **flash-nucleo unrecognized** — Build from repo: `cargo run --features hardware -- peripheral flash-nucleo`. The subcommand is only in the repo build, not in crates.io installs. +- **probe-rs not found** — `cargo install probe-rs-tools --locked` (the `probe-rs` crate is a library; the CLI is in `probe-rs-tools`) +- **No probe detected** — Ensure Nucleo is connected. Try another USB cable/port. +- **Serial port not found** — On Linux, add user to `dialout`: `sudo usermod -a -G dialout $USER`, then log out/in. +- **GPIO commands ignored** — Check `path` in config matches your serial port. Run `zeroclaw peripheral list` to verify. diff --git a/docs/pr-workflow.md b/docs/pr-workflow.md new file mode 100644 index 0000000..0afb9cd --- /dev/null +++ b/docs/pr-workflow.md @@ -0,0 +1,261 @@ +# ZeroClaw PR Workflow (High-Volume Collaboration) + +This document defines how ZeroClaw handles high PR volume while maintaining: + +- High performance +- High efficiency +- High stability +- High extensibility +- High sustainability +- High security + +Related references: + +- [`docs/ci-map.md`](ci-map.md) for per-workflow ownership, triggers, and triage flow. +- [`docs/reviewer-playbook.md`](reviewer-playbook.md) for day-to-day reviewer execution. + +## 1) Governance Goals + +1. Keep merge throughput predictable under heavy PR load. +2. Keep CI signal quality high (fast feedback, low false positives). +3. Keep security review explicit for risky surfaces. +4. Keep changes easy to reason about and easy to revert. +5. Keep repository artifacts free of personal/sensitive data leakage. + +### Governance Design Logic (Control Loop) + +This workflow is intentionally layered to reduce reviewer load while keeping accountability clear: + +1. **Intake classification**: path/size/risk/module labels route the PR to the right review depth. +2. **Deterministic validation**: merge gate depends on reproducible checks, not subjective comments. +3. **Risk-based review depth**: high-risk paths trigger deep review; low-risk paths stay fast. +4. **Rollback-first merge contract**: every merge path includes concrete recovery steps. + +Automation assists with triage and guardrails, but final merge accountability remains with human maintainers and PR authors. + +## 2) Required Repository Settings + +Maintain these branch protection rules on `main`: + +- Require status checks before merge. +- Require check `CI Required Gate`. +- Require pull request reviews before merge. +- Require CODEOWNERS review for protected paths. +- For `.github/workflows/**`, require owner approval via `CI Required Gate` (`WORKFLOW_OWNER_LOGINS`) and keep branch/ruleset bypass limited to org owners. +- Dismiss stale approvals when new commits are pushed. +- Restrict force-push on protected branches. + +## 3) PR Lifecycle + +### Step A: Intake + +- Contributor opens PR with full `.github/pull_request_template.md`. +- `PR Labeler` applies scope/path labels + size labels + risk labels + module labels (for example `channel:telegram`, `provider:kimi`, `tool:shell`) and contributor tiers by merged PR count (`trusted` >=5, `experienced` >=10, `principal` >=20, `distinguished` >=50), while de-duplicating less-specific scope labels when a more specific module label is present. +- For all module prefixes, module labels are compacted to reduce noise: one specific module keeps `prefix:component`, but multiple specifics collapse to the base scope label `prefix`. +- Label ordering is priority-first: `risk:*` -> `size:*` -> contributor tier -> module/path labels. +- Maintainers can run `PR Labeler` manually (`workflow_dispatch`) in `audit` mode for drift visibility or `repair` mode to normalize managed label metadata repository-wide. +- Hovering a label in GitHub shows its auto-managed description (rule/threshold summary). +- Managed label colors are arranged by display order to create a smooth gradient across long label rows. +- `PR Auto Responder` posts first-time guidance, handles label-driven routing for low-signal items, and auto-applies issue contributor tiers using the same thresholds as `PR Labeler` (`trusted` >=5, `experienced` >=10, `principal` >=20, `distinguished` >=50). + +### Step B: Validation + +- `CI Required Gate` is the merge gate. +- Docs-only PRs use fast-path and skip heavy Rust jobs. +- Non-doc PRs must pass lint, tests, and release build smoke check. + +### Step C: Review + +- Reviewers prioritize by risk and size labels. +- Security-sensitive paths (`src/security`, `src/runtime`, `src/gateway`, and CI workflows) require maintainer attention. +- Large PRs (`size: L`/`size: XL`) should be split unless strongly justified. + +### Step D: Merge + +- Prefer **squash merge** to keep history compact. +- PR title should follow Conventional Commit style. +- Merge only when rollback path is documented. + +## 4) PR Readiness Contracts (DoR / DoD) + +### Definition of Ready (before requesting review) + +- PR template fully completed. +- Scope boundary is explicit (what changed / what did not). +- Validation evidence attached (not just "CI will check"). +- Security and rollback fields completed for risky paths. +- Privacy/data-hygiene checks are completed and test language is neutral/project-scoped. +- If identity-like wording appears in tests/examples, it is normalized to ZeroClaw/project-native labels. + +### Definition of Done (merge-ready) + +- `CI Required Gate` is green. +- Required reviewers approved (including CODEOWNERS paths). +- Risk class labels match touched paths. +- Migration/compatibility impact is documented. +- Rollback path is concrete and fast. + +## 5) PR Size Policy + +- `size: XS` <= 80 changed lines +- `size: S` <= 250 changed lines +- `size: M` <= 500 changed lines +- `size: L` <= 1000 changed lines +- `size: XL` > 1000 changed lines + +Policy: + +- Target `XS/S/M` by default. +- `L/XL` PRs need explicit justification and tighter test evidence. +- If a large feature is unavoidable, split into stacked PRs. + +Automation behavior: + +- `PR Labeler` applies `size:*` labels from effective changed lines. +- Docs-only/lockfile-heavy PRs are normalized to avoid size inflation. + +## 6) AI/Agent Contribution Policy + +AI-assisted PRs are welcome, and review can also be agent-assisted. + +Required: + +1. Clear PR summary with scope boundary. +2. Explicit test/validation evidence. +3. Security impact and rollback notes for risky changes. + +Recommended: + +1. Brief tool/workflow notes when automation materially influenced the change. +2. Optional prompt/plan snippets for reproducibility. + +We do **not** require contributors to quantify AI-vs-human line ownership. + +Review emphasis for AI-heavy PRs: + +- Contract compatibility +- Security boundaries +- Error handling and fallback behavior +- Performance and memory regressions + +## 7) Review SLA and Queue Discipline + +- First maintainer triage target: within 48 hours. +- If PR is blocked, maintainer leaves one actionable checklist. +- `stale` automation is used to keep queue healthy; maintainers can apply `no-stale` when needed. +- `pr-hygiene` automation checks open PRs every 12 hours and posts a nudge when a PR has no new commits for 48+ hours and is either behind `main` or missing/failing `CI Required Gate` on the head commit. + +Backlog pressure controls: + +- Use a review queue budget: limit concurrent deep-review PRs per maintainer and keep the rest in triage state. +- For stacked work, require explicit `Depends on #...` so review order is deterministic. +- If a new PR replaces an older open PR, require `Supersedes #...` and close the older one after maintainer confirmation. +- Mark dormant/redundant PRs with `stale-candidate` or `superseded` to reduce duplicate review effort. + +Issue triage discipline: + +- `r:needs-repro` for incomplete bug reports (request deterministic repro before deep triage). +- `r:support` for usage/help items better handled outside bug backlog. +- `invalid` / `duplicate` labels trigger **issue-only** closing automation with guidance. + +Automation side-effect guards: + +- `PR Auto Responder` deduplicates label-based comments to avoid spam. +- Automated close routes are limited to issues, not PRs. +- Maintainers can freeze automated risk recalculation with `risk: manual` when context demands human override. + +## 8) Security and Stability Rules + +Changes in these areas require stricter review and stronger test evidence: + +- `src/security/**` +- runtime process management +- gateway ingress/authentication behavior (`src/gateway/**`) +- filesystem access boundaries +- network/authentication behavior +- GitHub workflows and release pipeline +- tools with execution capability (`src/tools/**`) + +Minimum for risky PRs: + +- threat/risk statement +- mitigation notes +- rollback steps + +Recommended for high-risk PRs: + +- include a focused test proving boundary behavior +- include one explicit failure-mode scenario and expected degradation + +For agent-assisted contributions, reviewers should also verify the author demonstrates understanding of runtime behavior and blast radius. + +## 9) Failure Recovery + +If a merged PR causes regressions: + +1. Revert PR immediately on `main`. +2. Open a follow-up issue with root-cause analysis. +3. Re-introduce fix only with regression tests. + +Prefer fast restore of service quality over delayed perfect fixes. + +## 10) Maintainer Checklist (Merge-Ready) + +- Scope is focused and understandable. +- CI gate is green. +- Docs-quality checks are green when docs changed. +- Security impact fields are complete. +- Privacy/data-hygiene fields are complete and evidence is redacted/anonymized. +- Agent workflow notes are sufficient for reproducibility (if automation was used). +- Rollback plan is explicit. +- Commit title follows Conventional Commits. + +## 11) Agent Review Operating Model + +To keep review quality stable under high PR volume, we use a two-lane review model: + +### Lane A: Fast triage (agent-friendly) + +- Confirm PR template completeness. +- Confirm CI gate signal (`CI Required Gate`). +- Confirm risk class via labels and touched paths. +- Confirm rollback statement exists. +- Confirm privacy/data-hygiene section and neutral wording requirements are satisfied. +- Confirm any required identity-like wording uses ZeroClaw/project-native terminology. + +### Lane B: Deep review (risk-based) + +Required for high-risk changes (security/runtime/gateway/CI): + +- Validate threat model assumptions. +- Validate failure mode and degradation behavior. +- Validate backward compatibility and migration impact. +- Validate observability/logging impact. + +## 12) Queue Priority and Label Discipline + +Triage order recommendation: + +1. `size: XS`/`size: S` + bug/security fixes +2. `size: M` focused changes +3. `size: L`/`size: XL` split requests or staged review + +Label discipline: + +- Path labels identify subsystem ownership quickly. +- Size labels drive batching strategy. +- Risk labels drive review depth (`risk: low/medium/high`). +- Module labels (`: `) improve reviewer routing for integration-specific changes and future newly-added modules. +- `risk: manual` allows maintainers to preserve a human risk judgment when automation lacks context. +- `no-stale` is reserved for accepted-but-blocked work. + +## 13) Agent Handoff Contract + +When one agent hands off to another (or to a maintainer), include: + +1. Scope boundary (what changed / what did not). +2. Validation evidence. +3. Open risks and unknowns. +4. Suggested next action. + +This keeps context loss low and avoids repeated deep dives. diff --git a/docs/resource-limits.md b/docs/resource-limits.md new file mode 100644 index 0000000..e3834fc --- /dev/null +++ b/docs/resource-limits.md @@ -0,0 +1,100 @@ +# Resource Limits for ZeroClaw + +## Problem +ZeroClaw has rate limiting (20 actions/hour) but no resource caps. A runaway agent could: +- Exhaust available memory +- Spin CPU at 100% +- Fill disk with logs/output + +--- + +## Proposed Solutions + +### Option 1: cgroups v2 (Linux, Recommended) +Automatically create a cgroup for zeroclaw with limits. + +```bash +# Create systemd service with limits +[Service] +MemoryMax=512M +CPUQuota=100% +IOReadBandwidthMax=/dev/sda 10M +IOWriteBandwidthMax=/dev/sda 10M +TasksMax=100 +``` + +### Option 2: tokio::task::deadlock detection +Prevent task starvation. + +```rust +use tokio::time::{timeout, Duration}; + +pub async fn execute_with_timeout( + fut: F, + cpu_time_limit: Duration, + memory_limit: usize, +) -> Result +where + F: Future>, +{ + // CPU timeout + timeout(cpu_time_limit, fut).await? +} +``` + +### Option 3: Memory monitoring +Track heap usage and kill if over limit. + +```rust +use std::alloc::{GlobalAlloc, Layout, System}; + +struct LimitedAllocator { + inner: A, + max_bytes: usize, + used: std::sync::atomic::AtomicUsize, +} + +unsafe impl GlobalAlloc for LimitedAllocator { + unsafe fn alloc(&self, layout: Layout) -> *mut u8 { + let current = self.used.fetch_add(layout.size(), std::sync::atomic::Ordering::Relaxed); + if current + layout.size() > self.max_bytes { + std::process::abort(); + } + self.inner.alloc(layout) + } +} +``` + +--- + +## Config Schema + +```toml +[resources] +# Memory limits (in MB) +max_memory_mb = 512 +max_memory_per_command_mb = 128 + +# CPU limits +max_cpu_percent = 50 +max_cpu_time_seconds = 60 + +# Disk I/O limits +max_log_size_mb = 100 +max_temp_storage_mb = 500 + +# Process limits +max_subprocesses = 10 +max_open_files = 100 +``` + +--- + +## Implementation Priority + +| Phase | Feature | Effort | Impact | +|-------|---------|--------|--------| +| **P0** | Memory monitoring + kill | Low | High | +| **P1** | CPU timeout per command | Low | High | +| **P2** | cgroups integration (Linux) | Medium | Very High | +| **P3** | Disk I/O limits | Medium | Medium | diff --git a/docs/reviewer-playbook.md b/docs/reviewer-playbook.md new file mode 100644 index 0000000..6f72fea --- /dev/null +++ b/docs/reviewer-playbook.md @@ -0,0 +1,110 @@ +# Reviewer Playbook + +This playbook is the operational companion to [`docs/pr-workflow.md`](pr-workflow.md). +Use it to reduce review latency without reducing quality. + +## 1) Review Objectives + +- Keep queue throughput predictable. +- Keep risk review proportionate to change risk. +- Keep merge decisions reproducible and auditable. + +## 2) 5-Minute Intake Triage + +For every new PR, do a fast intake pass: + +1. Confirm template completeness (`summary`, `validation`, `security`, `rollback`). +2. Confirm labels (`size:*`, `risk:*`, scope labels such as `provider`/`channel`/`security`, module-scoped labels such as `channel: *`/`provider: *`/`tool: *`, and contributor tier labels when applicable) are present and plausible. +3. Confirm CI signal status (`CI Required Gate`). +4. Confirm scope is one concern (reject mixed mega-PRs unless justified). +5. Confirm privacy/data-hygiene and neutral test wording requirements are satisfied. + +If any intake requirement fails, leave one actionable checklist comment instead of deep review. + +## 3) Risk-to-Depth Matrix + +| Risk label | Typical touched paths | Minimum review depth | +|---|---|---| +| `risk: low` | docs/tests/chore, isolated non-runtime changes | 1 reviewer + CI gate | +| `risk: medium` | `src/providers/**`, `src/channels/**`, `src/memory/**`, `src/config/**` | 1 subsystem-aware reviewer + behavior verification | +| `risk: high` | `src/security/**`, `src/runtime/**`, `src/gateway/**`, `src/tools/**`, `.github/workflows/**` | fast triage + deep review, strong rollback and failure-mode checks | + +When uncertain, treat as `risk: high`. + +If automated risk labeling is contextually wrong, maintainers can apply `risk: manual` and set the final risk label explicitly. + +## 4) Fast-Lane Checklist (All PRs) + +- Scope boundary is explicit and believable. +- Validation commands are present and results are coherent. +- User-facing behavior changes are documented. +- Author demonstrates understanding of behavior and blast radius (especially for agent-assisted PRs). +- Rollback path is concrete (not just “revert”). +- Compatibility/migration impacts are clear. +- No personal/sensitive data leakage in diff artifacts; examples/tests remain neutral and project-scoped. +- If identity-like wording exists, it uses ZeroClaw/project-native roles (not personal or real-world identities). +- Naming and architecture boundaries follow project contracts (`AGENTS.md`, `CONTRIBUTING.md`). + +## 5) Deep Review Checklist (High Risk) + +For high-risk PRs, verify at least one example in each category: + +- **Security boundaries**: deny-by-default behavior preserved, no accidental scope broadening. +- **Failure modes**: error handling is explicit and degrades safely. +- **Contract stability**: CLI/config/API compatibility preserved or migration documented. +- **Observability**: failures are diagnosable without leaking secrets. +- **Rollback safety**: revert path and blast radius are clear. + +## 6) Issue Triage Playbook + +Use labels to keep backlog actionable: + +- `r:needs-repro` for incomplete bug reports. +- `r:support` for usage/support questions better routed outside bug backlog. +- `duplicate` / `invalid` for non-actionable duplicates/noise. +- `no-stale` for accepted work waiting on external blockers. +- Request redaction if logs/payloads include personal identifiers or sensitive data. + +## 7) Review Comment Style + +Prefer checklist-style comments with one of these outcomes: + +- **Ready to merge** (explicitly say why). +- **Needs author action** (ordered list of blockers). +- **Needs deeper security/runtime review** (state exact risk and requested evidence). + +Avoid vague comments that create back-and-forth latency. + +## 8) Automation Override Protocol + +Use this when automation output creates review side effects: + +1. **Incorrect risk label**: add `risk: manual`, then set the intended `risk:*` label. +2. **Incorrect auto-close on issue triage**: reopen issue, remove route label, and leave one clarifying comment. +3. **Label spam/noise**: keep one canonical maintainer comment and remove redundant route labels. +4. **Ambiguous PR scope**: request split before deep review. + +### PR Backlog Pruning Protocol + +When review demand exceeds capacity, apply this order: + +1. Keep active bug/security PRs (`size: XS/S`) at the top of queue. +2. Ask overlapping PRs to consolidate; close older ones as `superseded` after acknowledgement. +3. Mark dormant PRs as `stale-candidate` before stale closure window starts. +4. Require rebase + fresh validation before reopening stale/superseded technical work. + +## 9) Handoff Protocol + +If handing off review to another maintainer/agent, include: + +1. Scope summary +2. Current risk class and why +3. What has been validated already +4. Open blockers +5. Suggested next action + +## 10) Weekly Queue Hygiene + +- Review stale queue and apply `no-stale` only to accepted-but-blocked work. +- Prioritize `size: XS/S` bug/security PRs first. +- Convert recurring support issues into docs updates and auto-response guidance. diff --git a/docs/sandboxing.md b/docs/sandboxing.md new file mode 100644 index 0000000..06abf59 --- /dev/null +++ b/docs/sandboxing.md @@ -0,0 +1,190 @@ +# ZeroClaw Sandboxing Strategies + +## Problem +ZeroClaw currently has application-layer security (allowlists, path blocking, command injection protection) but lacks OS-level containment. If an attacker is on the allowlist, they can run any allowed command with zeroclaw's user permissions. + +## Proposed Solutions + +### Option 1: Firejail Integration (Recommended for Linux) +Firejail provides user-space sandboxing with minimal overhead. + +```rust +// src/security/firejail.rs +use std::process::Command; + +pub struct FirejailSandbox { + enabled: bool, +} + +impl FirejailSandbox { + pub fn new() -> Self { + let enabled = which::which("firejail").is_ok(); + Self { enabled } + } + + pub fn wrap_command(&self, cmd: &mut Command) -> &mut Command { + if !self.enabled { + return cmd; + } + + // Firejail wraps any command with sandboxing + let mut jail = Command::new("firejail"); + jail.args([ + "--private=home", // New home directory + "--private-dev", // Minimal /dev + "--nosound", // No audio + "--no3d", // No 3D acceleration + "--novideo", // No video devices + "--nowheel", // No input devices + "--notv", // No TV devices + "--noprofile", // Skip profile loading + "--quiet", // Suppress warnings + ]); + + // Append original command + if let Some(program) = cmd.get_program().to_str() { + jail.arg(program); + } + for arg in cmd.get_args() { + if let Some(s) = arg.to_str() { + jail.arg(s); + } + } + + // Replace original command with firejail wrapper + *cmd = jail; + cmd + } +} +``` + +**Config option:** +```toml +[security] +enable_sandbox = true +sandbox_backend = "firejail" # or "none", "bubblewrap", "docker" +``` + +--- + +### Option 2: Bubblewrap (Portable, no root required) +Bubblewrap uses user namespaces to create containers. + +```bash +# Install bubblewrap +sudo apt install bubblewrap + +# Wrap command: +bwrap --ro-bind /usr /usr \ + --dev /dev \ + --proc /proc \ + --bind /workspace /workspace \ + --unshare-all \ + --share-net \ + --die-with-parent \ + -- /bin/sh -c "command" +``` + +--- + +### Option 3: Docker-in-Docker (Heavyweight but complete isolation) +Run agent tools inside ephemeral containers. + +```rust +pub struct DockerSandbox { + image: String, +} + +impl DockerSandbox { + pub async fn execute(&self, command: &str, workspace: &Path) -> Result { + let output = Command::new("docker") + .args([ + "run", "--rm", + "--memory", "512m", + "--cpus", "1.0", + "--network", "none", + "--volume", &format!("{}:/workspace", workspace.display()), + &self.image, + "sh", "-c", command + ]) + .output() + .await?; + + Ok(String::from_utf8_lossy(&output.stdout).to_string()) + } +} +``` + +--- + +### Option 4: Landlock (Linux Kernel LSM, Rust native) +Landlock provides file system access control without containers. + +```rust +use landlock::{Ruleset, AccessFS}; + +pub fn apply_landlock() -> Result<()> { + let ruleset = Ruleset::new() + .set_access_fs(AccessFS::read_file | AccessFS::write_file) + .add_path(Path::new("/workspace"), AccessFS::read_file | AccessFS::write_file)? + .add_path(Path::new("/tmp"), AccessFS::read_file | AccessFS::write_file)? + .restrict_self()?; + + Ok(()) +} +``` + +--- + +## Priority Implementation Order + +| Phase | Solution | Effort | Security Gain | +|-------|----------|--------|---------------| +| **P0** | Landlock (Linux only, native) | Low | High (filesystem) | +| **P1** | Firejail integration | Low | Very High | +| **P2** | Bubblewrap wrapper | Medium | Very High | +| **P3** | Docker sandbox mode | High | Complete | + +## Config Schema Extension + +```toml +[security.sandbox] +enabled = true +backend = "auto" # auto | firejail | bubblewrap | landlock | docker | none + +# Firejail-specific +[security.sandbox.firejail] +extra_args = ["--seccomp", "--caps.drop=all"] + +# Landlock-specific +[security.sandbox.landlock] +readonly_paths = ["/usr", "/bin", "/lib"] +readwrite_paths = ["$HOME/workspace", "/tmp/zeroclaw"] +``` + +## Testing Strategy + +```rust +#[cfg(test)] +mod tests { + #[test] + fn sandbox_blocks_path_traversal() { + // Try to read /etc/passwd through sandbox + let result = sandboxed_execute("cat /etc/passwd"); + assert!(result.is_err()); + } + + #[test] + fn sandbox_allows_workspace_access() { + let result = sandboxed_execute("ls /workspace"); + assert!(result.is_ok()); + } + + #[test] + fn sandbox_no_network_isolation() { + // Ensure network is blocked when configured + let result = sandboxed_execute("curl http://example.com"); + assert!(result.is_err()); + } +} +``` diff --git a/docs/security-roadmap.md b/docs/security-roadmap.md new file mode 100644 index 0000000..6578d1f --- /dev/null +++ b/docs/security-roadmap.md @@ -0,0 +1,180 @@ +# ZeroClaw Security Improvement Roadmap + +## Current State: Strong Foundation + +ZeroClaw already has **excellent application-layer security**: + +✅ Command allowlist (not blocklist) +✅ Path traversal protection +✅ Command injection blocking (`$(...)`, backticks, `&&`, `>`) +✅ Secret isolation (API keys not leaked to shell) +✅ Rate limiting (20 actions/hour) +✅ Channel authorization (empty = deny all, `*` = allow all) +✅ Risk classification (Low/Medium/High) +✅ Environment variable sanitization +✅ Forbidden paths blocking +✅ Comprehensive test coverage (1,017 tests) + +## What's Missing: OS-Level Containment + +🔴 No OS-level sandboxing (chroot, containers, namespaces) +🔴 No resource limits (CPU, memory, disk I/O caps) +🔴 No tamper-evident audit logging +🔴 No syscall filtering (seccomp) + +--- + +## Comparison: ZeroClaw vs PicoClaw vs Production Grade + +| Feature | PicoClaw | ZeroClaw Now | ZeroClaw + Roadmap | Production Target | +|---------|----------|--------------|-------------------|-------------------| +| **Binary Size** | ~8MB | **3.4MB** ✅ | 3.5-4MB | < 5MB | +| **RAM Usage** | < 10MB | **< 5MB** ✅ | < 10MB | < 20MB | +| **Startup Time** | < 1s | **< 10ms** ✅ | < 50ms | < 100ms | +| **Command Allowlist** | Unknown | ✅ Yes | ✅ Yes | ✅ Yes | +| **Path Blocking** | Unknown | ✅ Yes | ✅ Yes | ✅ Yes | +| **Injection Protection** | Unknown | ✅ Yes | ✅ Yes | ✅ Yes | +| **OS Sandbox** | No | ❌ No | ✅ Firejail/Landlock | ✅ Container/namespaces | +| **Resource Limits** | No | ❌ No | ✅ cgroups/Monitor | ✅ Full cgroups | +| **Audit Logging** | No | ❌ No | ✅ HMAC-signed | ✅ SIEM integration | +| **Security Score** | C | **B+** | **A-** | **A+** | + +--- + +## Implementation Roadmap + +### Phase 1: Quick Wins (1-2 weeks) +**Goal**: Address critical gaps with minimal complexity + +| Task | File | Effort | Impact | +|------|------|--------|-------| +| Landlock filesystem sandbox | `src/security/landlock.rs` | 2 days | High | +| Memory monitoring + OOM kill | `src/resources/memory.rs` | 1 day | High | +| CPU timeout per command | `src/tools/shell.rs` | 1 day | High | +| Basic audit logging | `src/security/audit.rs` | 2 days | Medium | +| Config schema updates | `src/config/schema.rs` | 1 day | - | + +**Deliverables**: +- Linux: Filesystem access restricted to workspace +- All platforms: Memory/CPU guards against runaway commands +- All platforms: Tamper-evident audit trail + +--- + +### Phase 2: Platform Integration (2-3 weeks) +**Goal**: Deep OS integration for production-grade isolation + +| Task | Effort | Impact | +|------|--------|-------| +| Firejail auto-detection + wrapping | 3 days | Very High | +| Bubblewrap wrapper for macOS/*nix | 4 days | Very High | +| cgroups v2 systemd integration | 3 days | High | +| seccomp syscall filtering | 5 days | High | +| Audit log query CLI | 2 days | Medium | + +**Deliverables**: +- Linux: Full container-like isolation via Firejail +- macOS: Bubblewrap filesystem isolation +- Linux: cgroups resource enforcement +- Linux: Syscall allowlisting + +--- + +### Phase 3: Production Hardening (1-2 weeks) +**Goal**: Enterprise security features + +| Task | Effort | Impact | +|------|--------|-------| +| Docker sandbox mode option | 3 days | High | +| Certificate pinning for channels | 2 days | Medium | +| Signed config verification | 2 days | Medium | +| SIEM-compatible audit export | 2 days | Medium | +| Security self-test (`zeroclaw audit --check`) | 1 day | Low | + +**Deliverables**: +- Optional Docker-based execution isolation +- HTTPS certificate pinning for channel webhooks +- Config file signature verification +- JSON/CSV audit export for external analysis + +--- + +## New Config Schema Preview + +```toml +[security] +level = "strict" # relaxed | default | strict | paranoid + +# Sandbox configuration +[security.sandbox] +enabled = true +backend = "auto" # auto | firejail | bubblewrap | landlock | docker | none + +# Resource limits +[resources] +max_memory_mb = 512 +max_memory_per_command_mb = 128 +max_cpu_percent = 50 +max_cpu_time_seconds = 60 +max_subprocesses = 10 + +# Audit logging +[security.audit] +enabled = true +log_path = "~/.config/zeroclaw/audit.log" +sign_events = true +max_size_mb = 100 + +# Autonomy (existing, enhanced) +[autonomy] +level = "supervised" # readonly | supervised | full +allowed_commands = ["git", "ls", "cat", "grep", "find"] +forbidden_paths = ["/etc", "/root", "~/.ssh"] +require_approval_for_medium_risk = true +block_high_risk_commands = true +max_actions_per_hour = 20 +``` + +--- + +## CLI Commands Preview + +```bash +# Security status check +zeroclaw security --check +# → ✓ Sandbox: Firejail active +# → ✓ Audit logging enabled (42 events today) +# → → Resource limits: 512MB mem, 50% CPU + +# Audit log queries +zeroclaw audit --user @alice --since 24h +zeroclaw audit --risk high --violations-only +zeroclaw audit --verify-signatures + +# Sandbox test +zeroclaw sandbox --test +# → Testing isolation... +# ✓ Cannot read /etc/passwd +# ✓ Cannot access ~/.ssh +# ✓ Can read /workspace +``` + +--- + +## Summary + +**ZeroClaw is already more secure than PicoClaw** with: +- 50% smaller binary (3.4MB vs 8MB) +- 50% less RAM (< 5MB vs < 10MB) +- 100x faster startup (< 10ms vs < 1s) +- Comprehensive security policy engine +- Extensive test coverage + +**By implementing this roadmap**, ZeroClaw becomes: +- Production-grade with OS-level sandboxing +- Resource-aware with memory/CPU guards +- Audit-ready with tamper-evident logging +- Enterprise-ready with configurable security levels + +**Estimated effort**: 4-7 weeks for full implementation +**Value**: Transforms ZeroClaw from "safe for testing" to "safe for production" diff --git a/examples/custom_channel.rs b/examples/custom_channel.rs index dd3fdf8..790762d 100644 --- a/examples/custom_channel.rs +++ b/examples/custom_channel.rs @@ -12,6 +12,8 @@ use tokio::sync::mpsc; pub struct ChannelMessage { pub id: String, pub sender: String, + /// Channel-specific reply address (e.g. Telegram chat_id, Discord channel_id). + pub reply_to: String, pub content: String, pub channel: String, pub timestamp: u64, @@ -90,9 +92,12 @@ impl Channel for TelegramChannel { continue; } + let chat_id = msg["chat"]["id"].to_string(); + let channel_msg = ChannelMessage { id: msg["message_id"].to_string(), sender, + reply_to: chat_id, content: msg["text"].as_str().unwrap_or("").to_string(), channel: "telegram".into(), timestamp: msg["date"].as_u64().unwrap_or(0), diff --git a/firmware/zeroclaw-arduino/zeroclaw-arduino.ino b/firmware/zeroclaw-arduino/zeroclaw-arduino.ino new file mode 100644 index 0000000..5e9c4ee --- /dev/null +++ b/firmware/zeroclaw-arduino/zeroclaw-arduino.ino @@ -0,0 +1,143 @@ +/* + * ZeroClaw Arduino Uno Firmware + * + * Listens for JSON commands on Serial (115200 baud), executes gpio_read/gpio_write, + * responds with JSON. Compatible with ZeroClaw SerialPeripheral protocol. + * + * Protocol (newline-delimited JSON): + * Request: {"id":"1","cmd":"gpio_write","args":{"pin":13,"value":1}} + * Response: {"id":"1","ok":true,"result":"done"} + * + * Arduino Uno: Pin 13 has built-in LED. Digital pins 0-13 supported. + * + * 1. Open in Arduino IDE + * 2. Select Board: Arduino Uno + * 3. Select correct Port (Tools -> Port) + * 4. Upload + */ + +#define BAUDRATE 115200 +#define MAX_LINE 256 + +char lineBuf[MAX_LINE]; +int lineLen = 0; + +// Parse integer from JSON: "pin":13 or "value":1 +int parseArg(const char* key, const char* json) { + char search[32]; + snprintf(search, sizeof(search), "\"%s\":", key); + const char* p = strstr(json, search); + if (!p) return -1; + p += strlen(search); + return atoi(p); +} + +// Extract "id" for response +void copyId(char* out, int outLen, const char* json) { + const char* p = strstr(json, "\"id\":\""); + if (!p) { + out[0] = '0'; + out[1] = '\0'; + return; + } + p += 6; + int i = 0; + while (i < outLen - 1 && *p && *p != '"') { + out[i++] = *p++; + } + out[i] = '\0'; +} + +// Check if cmd is present +bool hasCmd(const char* json, const char* cmd) { + char search[64]; + snprintf(search, sizeof(search), "\"cmd\":\"%s\"", cmd); + return strstr(json, search) != NULL; +} + +void handleLine(const char* line) { + char idBuf[16]; + copyId(idBuf, sizeof(idBuf), line); + + if (hasCmd(line, "ping")) { + Serial.print("{\"id\":\""); + Serial.print(idBuf); + Serial.println("\",\"ok\":true,\"result\":\"pong\"}"); + return; + } + + // Phase C: Dynamic discovery — report GPIO pins and LED pin + if (hasCmd(line, "capabilities")) { + Serial.print("{\"id\":\""); + Serial.print(idBuf); + Serial.print("\",\"ok\":true,\"result\":\"{\\\"gpio\\\":[0,1,2,3,4,5,6,7,8,9,10,11,12,13],\\\"led_pin\\\":13}\"}"); + Serial.println(); + return; + } + + if (hasCmd(line, "gpio_read")) { + int pin = parseArg("pin", line); + if (pin < 0 || pin > 13) { + Serial.print("{\"id\":\""); + Serial.print(idBuf); + Serial.print("\",\"ok\":false,\"result\":\"\",\"error\":\"Invalid pin "); + Serial.print(pin); + Serial.println("\"}"); + return; + } + pinMode(pin, INPUT); + int val = digitalRead(pin); + Serial.print("{\"id\":\""); + Serial.print(idBuf); + Serial.print("\",\"ok\":true,\"result\":\""); + Serial.print(val); + Serial.println("\"}"); + return; + } + + if (hasCmd(line, "gpio_write")) { + int pin = parseArg("pin", line); + int value = parseArg("value", line); + if (pin < 0 || pin > 13) { + Serial.print("{\"id\":\""); + Serial.print(idBuf); + Serial.print("\",\"ok\":false,\"result\":\"\",\"error\":\"Invalid pin "); + Serial.print(pin); + Serial.println("\"}"); + return; + } + pinMode(pin, OUTPUT); + digitalWrite(pin, value ? HIGH : LOW); + Serial.print("{\"id\":\""); + Serial.print(idBuf); + Serial.println("\",\"ok\":true,\"result\":\"done\"}"); + return; + } + + // Unknown command + Serial.print("{\"id\":\""); + Serial.print(idBuf); + Serial.println("\",\"ok\":false,\"result\":\"\",\"error\":\"Unknown command\"}"); +} + +void setup() { + Serial.begin(BAUDRATE); + lineLen = 0; +} + +void loop() { + while (Serial.available()) { + char c = Serial.read(); + if (c == '\n' || c == '\r') { + if (lineLen > 0) { + lineBuf[lineLen] = '\0'; + handleLine(lineBuf); + lineLen = 0; + } + } else if (lineLen < MAX_LINE - 1) { + lineBuf[lineLen++] = c; + } else { + lineLen = 0; // Overflow, discard + } + } +} diff --git a/firmware/zeroclaw-esp32-ui/.cargo/config.toml b/firmware/zeroclaw-esp32-ui/.cargo/config.toml new file mode 100644 index 0000000..83dced8 --- /dev/null +++ b/firmware/zeroclaw-esp32-ui/.cargo/config.toml @@ -0,0 +1,13 @@ +[build] +target = "riscv32imc-esp-espidf" + +[target.riscv32imc-esp-espidf] +linker = "ldproxy" +rustflags = [ + "--cfg", 'espidf_time64', + "-C", "default-linker-libraries", +] + +[unstable] +build-std = ["std", "panic_abort"] +build-std-features = ["panic_immediate_abort"] diff --git a/firmware/zeroclaw-esp32-ui/Cargo.toml b/firmware/zeroclaw-esp32-ui/Cargo.toml new file mode 100644 index 0000000..5c7ddcc --- /dev/null +++ b/firmware/zeroclaw-esp32-ui/Cargo.toml @@ -0,0 +1,46 @@ +[package] +name = "zeroclaw-esp32-ui" +version = "0.1.0" +edition = "2021" +license = "MIT" +description = "ZeroClaw ESP32 UI firmware with Slint - Graphical interface for AI assistant" +authors = ["ZeroClaw Team"] + +[dependencies] +anyhow = "1.0" +esp-idf-svc = "0.48" +log = { version = "0.4", default-features = false } + +# Slint UI - MCU optimized +slint = { version = "1.10", default-features = false, features = [ + "compat-1-2", + "libm", + "renderer-software", +] } + +[build-dependencies] +embuild = { version = "0.31", features = ["elf"] } +slint-build = "1.10" + +[features] +default = ["std", "display-st7789"] +std = ["esp-idf-svc/std"] + +# Display selection (choose one) +display-st7789 = [] # 320x240 or 135x240 +display-ili9341 = [] # 320x240 +display-ssd1306 = [] # 128x64 OLED + +# Input +touch-xpt2046 = [] # Resistive touch +touch-ft6x36 = [] # Capacitive touch + +[profile.release] +opt-level = "s" +lto = true +codegen-units = 1 +strip = true +panic = "abort" + +[profile.dev] +opt-level = "s" diff --git a/firmware/zeroclaw-esp32-ui/README.md b/firmware/zeroclaw-esp32-ui/README.md new file mode 100644 index 0000000..ffba119 --- /dev/null +++ b/firmware/zeroclaw-esp32-ui/README.md @@ -0,0 +1,106 @@ +# ZeroClaw ESP32 UI Firmware + +Slint-based graphical UI firmware scaffold for ZeroClaw edge scenarios on ESP32. + +## Scope of This Crate + +This crate intentionally provides a **minimal, bootable UI scaffold**: + +- Initializes ESP-IDF logging/runtime patches +- Compiles and runs a small Slint UI (`MainWindow`) +- Keeps display and touch feature flags available for incremental driver integration + +What this crate **does not** do yet: + +- No full chat runtime integration +- No production display/touch driver wiring in `src/main.rs` +- No Wi-Fi/BLE transport logic + +## Features + +- **Slint UI scaffold** suitable for MCU-oriented iteration +- **Display feature flags** for ST7789, ILI9341, SSD1306 +- **Touch feature flags** for XPT2046 and FT6X36 integration planning +- **ESP-IDF baseline** for embedded target builds + +## Project Structure + +```text +firmware/zeroclaw-esp32-ui/ +├── Cargo.toml # Rust package and feature flags +├── build.rs # Slint compilation hook +├── .cargo/ +│ └── config.toml # Cross-compilation defaults +├── ui/ +│ └── main.slint # Slint UI definition +└── src/ + └── main.rs # Firmware entry point +``` + +## Prerequisites + +1. **ESP Rust toolchain** + ```bash + cargo install espup + espup install + source ~/export-esp.sh + ``` + +2. **Flashing tools** + ```bash + cargo install espflash cargo-espflash + ``` + +## Build and Flash + +### Default target (ESP32-C3, from `.cargo/config.toml`) + +```bash +cd firmware/zeroclaw-esp32-ui +cargo build --release +cargo espflash flash --release --monitor +``` + +### Build for ESP32-S3 (override target) + +```bash +cargo build --release --target xtensa-esp32s3-espidf +``` + +## Feature Flags + +```bash +# Switch display profile +cargo build --release --features display-ili9341 + +# Enable planned touch profile +cargo build --release --features touch-ft6x36 +``` + +## UI Layout + +The current `ui/main.slint` defines: + +- `StatusBar` +- `MessageList` +- `InputBar` +- `MainWindow` + +These components are placeholders to keep future hardware integration incremental and low-risk. + +## Next Integration Steps + +1. Wire real display driver initialization in `src/main.rs` +2. Attach touch input events to Slint callbacks +3. Connect UI state with ZeroClaw edge/runtime messaging +4. Add board-specific pin maps with explicit target profiles + +## License + +MIT - See root `LICENSE` + +## References + +- [Slint ESP32 Documentation](https://slint.dev/esp32) +- [ESP-IDF Rust Book](https://esp-rs.github.io/book/) +- [ZeroClaw Hardware Design](../../docs/hardware-peripherals-design.md) diff --git a/firmware/zeroclaw-esp32-ui/build.rs b/firmware/zeroclaw-esp32-ui/build.rs new file mode 100644 index 0000000..0d99898 --- /dev/null +++ b/firmware/zeroclaw-esp32-ui/build.rs @@ -0,0 +1,14 @@ +use embuild::espidf::sysenv::output; + +fn main() { + output(); + slint_build::compile_with_config( + "ui/main.slint", + slint_build::CompilerConfiguration::new() + .embed_resources(slint_build::EmbedResourcesKind::EmbedForSoftwareRenderer) + .with_style("material".into()), + ) + .expect("Slint UI compilation failed"); + + println!("cargo:rerun-if-changed=ui/"); +} diff --git a/firmware/zeroclaw-esp32-ui/src/main.rs b/firmware/zeroclaw-esp32-ui/src/main.rs new file mode 100644 index 0000000..6db084e --- /dev/null +++ b/firmware/zeroclaw-esp32-ui/src/main.rs @@ -0,0 +1,22 @@ +//! ZeroClaw ESP32 UI firmware scaffold. +//! +//! This binary initializes ESP-IDF, boots a minimal Slint UI, and keeps +//! architecture boundaries explicit so hardware integrations can be added +//! incrementally. + +use anyhow::Context; +use log::info; + +slint::include_modules!(); + +fn main() -> anyhow::Result<()> { + esp_idf_svc::sys::link_patches(); + esp_idf_svc::log::EspLogger::initialize_default(); + + info!("Starting ZeroClaw ESP32 UI scaffold"); + + let window = MainWindow::new().context("failed to create MainWindow")?; + window.run().context("MainWindow event loop failed")?; + + Ok(()) +} diff --git a/firmware/zeroclaw-esp32-ui/ui/main.slint b/firmware/zeroclaw-esp32-ui/ui/main.slint new file mode 100644 index 0000000..f2815b3 --- /dev/null +++ b/firmware/zeroclaw-esp32-ui/ui/main.slint @@ -0,0 +1,83 @@ +component StatusBar inherits Rectangle { + in property title_text: "ZeroClaw ESP32 UI"; + in property status_text: "disconnected"; + + height: 32px; + background: #1f2937; + border-radius: 6px; + + HorizontalLayout { + padding: 8px; + + Text { + text: root.title_text; + color: #e5e7eb; + font-size: 14px; + vertical-alignment: center; + } + + Text { + text: root.status_text; + color: #93c5fd; + font-size: 12px; + horizontal-alignment: right; + vertical-alignment: center; + } + } +} + +component MessageList inherits Rectangle { + in property message_text: "UI scaffold is running"; + + background: #0f172a; + border-radius: 6px; + border-color: #334155; + border-width: 1px; + + Text { + text: root.message_text; + color: #cbd5e1; + horizontal-alignment: center; + vertical-alignment: center; + } +} + +component InputBar inherits Rectangle { + in property hint_text: "Touch input integration pending"; + + height: 36px; + background: #1e293b; + border-radius: 6px; + + Text { + text: root.hint_text; + color: #e2e8f0; + horizontal-alignment: center; + vertical-alignment: center; + font-size: 12px; + } +} + +export component MainWindow inherits Window { + width: 320px; + height: 240px; + background: #020617; + + VerticalLayout { + padding: 10px; + spacing: 10px; + + StatusBar { + title_text: "ZeroClaw Edge UI"; + status_text: "booting"; + } + + MessageList { + message_text: "Display/touch drivers can be wired here"; + } + + InputBar { + hint_text: "Use touch-xpt2046 or touch-ft6x36 feature later"; + } + } +} diff --git a/firmware/zeroclaw-esp32/.cargo/config.toml b/firmware/zeroclaw-esp32/.cargo/config.toml new file mode 100644 index 0000000..56dd71b --- /dev/null +++ b/firmware/zeroclaw-esp32/.cargo/config.toml @@ -0,0 +1,11 @@ +[build] +target = "riscv32imc-esp-espidf" + +[target.riscv32imc-esp-espidf] +linker = "ldproxy" +runner = "espflash flash --monitor" +# ESP-IDF 5.x uses 64-bit time_t +rustflags = ["-C", "default-linker-libraries", "--cfg", "espidf_time64"] + +[unstable] +build-std = ["std", "panic_abort"] diff --git a/firmware/zeroclaw-esp32/Cargo.lock b/firmware/zeroclaw-esp32/Cargo.lock new file mode 100644 index 0000000..69e989b --- /dev/null +++ b/firmware/zeroclaw-esp32/Cargo.lock @@ -0,0 +1,1794 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + +[[package]] +name = "aligned" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee4508988c62edf04abd8d92897fca0c2995d907ce1dfeaf369dac3716a40685" +dependencies = [ + "as-slice", +] + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + +[[package]] +name = "anyhow" +version = "1.0.101" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f0e0fee31ef5ed1ba1316088939cea399010ed7731dba877ed44aeb407a75ea" + +[[package]] +name = "as-slice" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "516b6b4f0e40d50dcda9365d53964ec74560ad4284da2e7fc97122cd83174516" +dependencies = [ + "stable_deref_trait", +] + +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "bindgen" +version = "0.71.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3" +dependencies = [ + "bitflags 2.11.0", + "cexpr", + "clang-sys", + "itertools", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "syn 2.0.116", +] + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "bitflags" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" + +[[package]] +name = "bstr" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63044e1ae8e69f3b5a92c736ca6269b8d12fa7efe39bf34ddb06d102cf0e2cab" +dependencies = [ + "memchr", + "serde", +] + +[[package]] +name = "build-time" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1219c19fc29b7bfd74b7968b420aff5bc951cf517800176e795d6b2300dd382" +dependencies = [ + "chrono", + "once_cell", + "proc-macro2", + "quote", + "syn 2.0.116", +] + +[[package]] +name = "bumpalo" +version = "3.19.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5dd9dc738b7a8311c7ade152424974d8115f2cdad61e8dab8dac9f2362298510" + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "camino" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e629a66d692cb9ff1a1c664e41771b3dcaf961985a9774c0eb0bd1b51cf60a48" +dependencies = [ + "serde_core", +] + +[[package]] +name = "cargo-platform" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e35af189006b9c0f00a064685c727031e3ed2d8020f7ba284d78cc2671bd36ea" +dependencies = [ + "serde", +] + +[[package]] +name = "cargo_metadata" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d886547e41f740c616ae73108f6eb70afe6d940c7bc697cb30f13daec073037" +dependencies = [ + "camino", + "cargo-platform", + "semver", + "serde", + "serde_json", + "thiserror 1.0.69", +] + +[[package]] +name = "cc" +version = "1.2.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aebf35691d1bfb0ac386a69bac2fde4dd276fb618cf8bf4f5318fe285e821bb2" +dependencies = [ + "find-msvc-tools", + "shlex", +] + +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + +[[package]] +name = "chrono" +version = "0.4.43" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fac4744fb15ae8337dc853fee7fb3f4e48c0fbaa23d0afe49c447b4fab126118" +dependencies = [ + "iana-time-zone", + "num-traits", + "windows-link", +] + +[[package]] +name = "clang-sys" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" +dependencies = [ + "glob", + "libc", + "libloading", +] + +[[package]] +name = "cmake" +version = "0.1.57" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75443c44cd6b379beb8c5b45d85d0773baf31cce901fe7bb252f4eff3008ef7d" +dependencies = [ + "cc", +] + +[[package]] +name = "const_format" +version = "0.2.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7faa7469a93a566e9ccc1c73fe783b4a65c274c5ace346038dca9c39fe0030ad" +dependencies = [ + "const_format_proc_macros", +] + +[[package]] +name = "const_format_proc_macros" +version = "0.2.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d57c2eccfb16dbac1f4e61e206105db5820c9d26c3c472bc17c774259ef7744" +dependencies = [ + "proc-macro2", + "quote", + "unicode-xid", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "critical-section" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b" + +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + +[[package]] +name = "cvt" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2ae9bf77fbf2d39ef573205d554d87e86c12f1994e9ea335b0651b9b278bcf1" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "darling" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1247195ecd7e3c85f83c8d2a366e4210d588e802133e1e355180a9870b517ea4" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "syn 2.0.116", +] + +[[package]] +name = "darling_macro" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81" +dependencies = [ + "darling_core", + "quote", + "syn 2.0.116", +] + +[[package]] +name = "defmt" +version = "0.3.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0963443817029b2024136fc4dd07a5107eb8f977eaf18fcd1fdeb11306b64ad" +dependencies = [ + "defmt 1.0.1", +] + +[[package]] +name = "defmt" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "548d977b6da32fa1d1fda2876453da1e7df63ad0304c8b3dae4dbe7b96f39b78" +dependencies = [ + "bitflags 1.3.2", + "defmt-macros", +] + +[[package]] +name = "defmt-macros" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d4fc12a85bcf441cfe44344c4b72d58493178ce635338a3f3b78943aceb258e" +dependencies = [ + "defmt-parser", + "proc-macro-error2", + "proc-macro2", + "quote", + "syn 2.0.116", +] + +[[package]] +name = "defmt-parser" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10d60334b3b2e7c9d91ef8150abfb6fa4c1c39ebbcf4a81c2e346aad939fee3e" +dependencies = [ + "thiserror 2.0.18", +] + +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + +[[package]] +name = "embassy-futures" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc2d050bdc5c21e0862a89256ed8029ae6c290a93aecefc73084b3002cdebb01" + +[[package]] +name = "embassy-sync" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73974a3edbd0bd286759b3d483540f0ebef705919a5f56f4fc7709066f71689b" +dependencies = [ + "cfg-if", + "critical-section", + "embedded-io-async", + "futures-core", + "futures-sink", + "heapless", +] + +[[package]] +name = "embedded-can" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9d2e857f87ac832df68fa498d18ddc679175cf3d2e4aa893988e5601baf9438" +dependencies = [ + "nb 1.1.0", +] + +[[package]] +name = "embedded-hal" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35949884794ad573cf46071e41c9b60efb0cb311e3ca01f7af807af1debc66ff" +dependencies = [ + "nb 0.1.3", + "void", +] + +[[package]] +name = "embedded-hal" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "361a90feb7004eca4019fb28352a9465666b24f840f5c3cddf0ff13920590b89" + +[[package]] +name = "embedded-hal-async" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c4c685bbef7fe13c3c6dd4da26841ed3980ef33e841cddfa15ce8a8fb3f1884" +dependencies = [ + "embedded-hal 1.0.0", +] + +[[package]] +name = "embedded-hal-nb" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fba4268c14288c828995299e59b12babdbe170f6c6d73731af1b4648142e8605" +dependencies = [ + "embedded-hal 1.0.0", + "nb 1.1.0", +] + +[[package]] +name = "embedded-io" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edd0f118536f44f5ccd48bcb8b111bdc3de888b58c74639dfb034a357d0f206d" + +[[package]] +name = "embedded-io-async" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ff09972d4073aa8c299395be75161d582e7629cd663171d62af73c8d50dba3f" +dependencies = [ + "embedded-io", +] + +[[package]] +name = "embedded-svc" +version = "0.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7770e30ab55cfbf954c00019522490d6ce26a3334bede05a732ba61010e98e0" +dependencies = [ + "defmt 0.3.100", + "embedded-io", + "embedded-io-async", + "enumset", + "heapless", + "num_enum", + "serde", + "strum 0.25.0", +] + +[[package]] +name = "embuild" +version = "0.33.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e188ad2bbe82afa841ea4a29880651e53ab86815db036b2cb9f8de3ac32dad75" +dependencies = [ + "anyhow", + "bindgen", + "bitflags 1.3.2", + "cmake", + "filetime", + "globwalk", + "home", + "log", + "regex", + "remove_dir_all", + "serde", + "serde_json", + "shlex", + "strum 0.24.1", + "tempfile", + "thiserror 1.0.69", + "which", +] + +[[package]] +name = "enumset" +version = "1.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25b07a8dfbbbfc0064c0a6bdf9edcf966de6b1c33ce344bdeca3b41615452634" +dependencies = [ + "enumset_derive", +] + +[[package]] +name = "enumset_derive" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f43e744e4ea338060faee68ed933e46e722fb7f3617e722a5772d7e856d8b3ce" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.116", +] + +[[package]] +name = "envy" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f47e0157f2cb54f5ae1bd371b30a2ae4311e1c028f575cd4e81de7353215965" +dependencies = [ + "serde", +] + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "esp-idf-hal" +version = "0.45.2" +source = "git+https://github.com/esp-rs/esp-idf-hal#bc48639bd626c72afc1e25e5d497b5c639161d30" +dependencies = [ + "atomic-waker", + "embassy-sync", + "embedded-can", + "embedded-hal 0.2.7", + "embedded-hal 1.0.0", + "embedded-hal-async", + "embedded-hal-nb", + "embedded-io", + "embedded-io-async", + "embuild", + "enumset", + "esp-idf-sys", + "heapless", + "log", + "nb 1.1.0", +] + +[[package]] +name = "esp-idf-svc" +version = "0.51.0" +source = "git+https://github.com/esp-rs/esp-idf-svc#dee202f146c7681e54eabbf118a216fc0195d203" +dependencies = [ + "embassy-futures", + "embedded-hal-async", + "embedded-svc", + "embuild", + "enumset", + "esp-idf-hal", + "futures-io", + "heapless", + "log", + "num_enum", + "uncased", +] + +[[package]] +name = "esp-idf-sys" +version = "0.36.1" +source = "git+https://github.com/esp-rs/esp-idf-sys#64667a38fb8004e1fc3b032488af6857ca3cd849" +dependencies = [ + "anyhow", + "build-time", + "cargo_metadata", + "cmake", + "const_format", + "embuild", + "envy", + "libc", + "regex", + "serde", + "strum 0.24.1", + "which", +] + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + +[[package]] +name = "filetime" +version = "0.2.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f98844151eee8917efc50bd9e8318cb963ae8b297431495d3f758616ea5c57db" +dependencies = [ + "cfg-if", + "libc", + "libredox", +] + +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "fs_at" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14af6c9694ea25db25baa2a1788703b9e7c6648dcaeeebeb98f7561b5384c036" +dependencies = [ + "aligned", + "cfg-if", + "cvt", + "libc", + "nix", + "windows-sys 0.52.0", +] + +[[package]] +name = "futures-core" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" + +[[package]] +name = "futures-io" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" + +[[package]] +name = "futures-sink" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" + +[[package]] +name = "getrandom" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "139ef39800118c7683f2fd3c98c1b23c09ae076556b435f8e9064ae108aaeeec" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", + "wasip3", +] + +[[package]] +name = "glob" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" + +[[package]] +name = "globset" +version = "0.4.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52dfc19153a48bde0cbd630453615c8151bce3a5adfac7a0aebfbf0a1e1f57e3" +dependencies = [ + "aho-corasick", + "bstr", + "log", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "globwalk" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93e3af942408868f6934a7b85134a3230832b9977cf66125df2f9edcfce4ddcc" +dependencies = [ + "bitflags 1.3.2", + "ignore", + "walkdir", +] + +[[package]] +name = "hash32" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47d60b12902ba28e2730cd37e95b8c9223af2808df9e902d4df49588d1470606" +dependencies = [ + "byteorder", +] + +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" + +[[package]] +name = "heapless" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bfb9eb618601c89945a70e254898da93b13be0388091d42117462b265bb3fad" +dependencies = [ + "hash32", + "stable_deref_trait", +] + +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "home" +version = "0.5.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc627f471c528ff0c4a49e1d5e60450c8f6461dd6d10ba9dcd3a61d3dff7728d" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "iana-time-zone" +version = "0.1.65" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e31bc9ad994ba00e440a8aa5c9ef0ec67d5cb5e5cb0cc7f8b744a35b389cc470" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "log", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + +[[package]] +name = "ignore" +version = "0.4.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3d782a365a015e0f5c04902246139249abf769125006fbe7649e2ee88169b4a" +dependencies = [ + "crossbeam-deque", + "globset", + "log", + "memchr", + "regex-automata", + "same-file", + "walkdir", + "winapi-util", +] + +[[package]] +name = "indexmap" +version = "2.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" +dependencies = [ + "equivalent", + "hashbrown 0.16.1", + "serde", + "serde_core", +] + +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" + +[[package]] +name = "js-sys" +version = "0.3.85" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c942ebf8e95485ca0d52d97da7c5a2c387d0e7f0ba4c35e93bfcaee045955b3" +dependencies = [ + "once_cell", + "wasm-bindgen", +] + +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + +[[package]] +name = "libc" +version = "0.2.182" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6800badb6cb2082ffd7b6a67e6125bb39f18782f793520caee8cb8846be06112" + +[[package]] +name = "libloading" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55" +dependencies = [ + "cfg-if", + "windows-link", +] + +[[package]] +name = "libredox" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d0b95e02c851351f877147b7deea7b1afb1df71b63aa5f8270716e0c5720616" +dependencies = [ + "bitflags 2.11.0", + "libc", + "redox_syscall", +] + +[[package]] +name = "linux-raw-sys" +version = "0.4.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" + +[[package]] +name = "linux-raw-sys" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + +[[package]] +name = "nb" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "801d31da0513b6ec5214e9bf433a77966320625a37860f910be265be6e18d06f" +dependencies = [ + "nb 1.1.0", +] + +[[package]] +name = "nb" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d5439c4ad607c3c23abf66de8c8bf57ba8adcd1f129e699851a6e43935d339d" + +[[package]] +name = "nix" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" +dependencies = [ + "bitflags 2.11.0", + "cfg-if", + "cfg_aliases", + "libc", +] + +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + +[[package]] +name = "normpath" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf23ab2b905654b4cb177e30b629937b3868311d4e1cba859f899c041046e69b" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "num_enum" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1207a7e20ad57b847bbddc6776b968420d38292bbfe2089accff5e19e82454c" +dependencies = [ + "num_enum_derive", + "rustversion", +] + +[[package]] +name = "num_enum_derive" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff32365de1b6743cb203b710788263c44a03de03802daf96092f2da4fe6ba4d7" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 2.0.116", +] + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn 2.0.116", +] + +[[package]] +name = "proc-macro-crate" +version = "3.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "219cb19e96be00ab2e37d6e299658a0cfa83e52429179969b0f0121b4ac46983" +dependencies = [ + "toml_edit", +] + +[[package]] +name = "proc-macro-error-attr2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96de42df36bb9bba5542fe9f1a054b8cc87e172759a1868aa05c1f3acc89dfc5" +dependencies = [ + "proc-macro2", + "quote", +] + +[[package]] +name = "proc-macro-error2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11ec05c52be0a07b08061f7dd003e7d7092e0472bc731b4af7bb1ef876109802" +dependencies = [ + "proc-macro-error-attr2", + "proc-macro2", + "quote", + "syn 2.0.116", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "redox_syscall" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35985aa610addc02e24fc232012c86fd11f14111180f902b67e2d5331f8ebf2b" +dependencies = [ + "bitflags 2.11.0", +] + +[[package]] +name = "regex" +version = "1.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a96887878f22d7bad8a3b6dc5b7440e0ada9a245242924394987b21cf2210a4c" + +[[package]] +name = "remove_dir_all" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a694f9e0eb3104451127f6cc1e5de55f59d3b1fc8c5ddfaeb6f1e716479ceb4a" +dependencies = [ + "cfg-if", + "cvt", + "fs_at", + "libc", + "normpath", + "windows-sys 0.59.0", +] + +[[package]] +name = "rustc-hash" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" + +[[package]] +name = "rustix" +version = "0.38.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" +dependencies = [ + "bitflags 2.11.0", + "errno", + "libc", + "linux-raw-sys 0.4.15", + "windows-sys 0.59.0", +] + +[[package]] +name = "rustix" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" +dependencies = [ + "bitflags 2.11.0", + "errno", + "libc", + "linux-raw-sys 0.11.0", + "windows-sys 0.61.2", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" +dependencies = [ + "serde", + "serde_core", +] + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.116", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "stable_deref_trait" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" + +[[package]] +name = "strum" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "063e6045c0e62079840579a7e47a355ae92f60eb74daaf156fb1e84ba164e63f" +dependencies = [ + "strum_macros 0.24.3", +] + +[[package]] +name = "strum" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" +dependencies = [ + "strum_macros 0.25.3", +] + +[[package]] +name = "strum_macros" +version = "0.24.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e385be0d24f186b4ce2f9982191e7101bb737312ad61c1f2f984f34bcf85d59" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "quote", + "rustversion", + "syn 1.0.109", +] + +[[package]] +name = "strum_macros" +version = "0.25.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.116", +] + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.116" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3df424c70518695237746f84cede799c9c58fcb37450d7b23716568cc8bc69cb" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tempfile" +version = "3.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0136791f7c95b1f6dd99f9cc786b91bb81c3800b639b3478e561ddb7be95e5f1" +dependencies = [ + "fastrand", + "getrandom", + "once_cell", + "rustix 1.1.3", + "windows-sys 0.61.2", +] + +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +dependencies = [ + "thiserror-impl 2.0.18", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.116", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.116", +] + +[[package]] +name = "toml_datetime" +version = "0.7.5+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92e1cfed4a3038bc5a127e35a2d360f145e1f4b971b551a2ba5fd7aedf7e1347" +dependencies = [ + "serde_core", +] + +[[package]] +name = "toml_edit" +version = "0.23.10+spec-1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84c8b9f757e028cee9fa244aea147aab2a9ec09d5325a9b01e0a49730c2b5269" +dependencies = [ + "indexmap", + "toml_datetime", + "toml_parser", + "winnow", +] + +[[package]] +name = "toml_parser" +version = "1.0.8+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0742ff5ff03ea7e67c8ae6c93cac239e0d9784833362da3f9a9c1da8dfefcbdc" +dependencies = [ + "winnow", +] + +[[package]] +name = "uncased" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1b88fcfe09e89d3866a5c11019378088af2d24c3fbd4f0543f96b479ec90697" +dependencies = [ + "version_check", +] + +[[package]] +name = "unicode-ident" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "537dd038a89878be9b64dd4bd1b260315c1bb94f4d784956b81e27a088d9a09e" + +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "void" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d" + +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + +[[package]] +name = "wasip2" +version = "1.0.2+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasm-bindgen" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64024a30ec1e37399cf85a7ffefebdb72205ca1c972291c51512360d90bd8566" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "008b239d9c740232e71bd39e8ef6429d27097518b6b30bdf9086833bd5b6d608" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5256bae2d58f54820e6490f9839c49780dff84c65aeab9e772f15d5f0e913a55" +dependencies = [ + "bumpalo", + "proc-macro2", + "quote", + "syn 2.0.116", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f01b580c9ac74c8d8f0c0e4afb04eeef2acf145458e52c03845ee9cd23e3d12" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags 2.11.0", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + +[[package]] +name = "which" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" +dependencies = [ + "either", + "home", + "once_cell", + "rustix 0.38.44", +] + +[[package]] +name = "winapi-util" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "windows-core" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-implement" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.116", +] + +[[package]] +name = "windows-interface" +version = "0.59.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.116", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-result" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "winnow" +version = "0.7.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a5364e9d77fcdeeaa6062ced926ee3381faa2ee02d3eb83a5c27a8825540829" +dependencies = [ + "memchr", +] + +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen-core" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck 0.5.0", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck 0.5.0", + "indexmap", + "prettyplease", + "syn 2.0.116", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn 2.0.116", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags 2.11.0", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + +[[package]] +name = "zeroclaw-esp32" +version = "0.1.0" +dependencies = [ + "anyhow", + "embuild", + "esp-idf-svc", + "log", + "serde", + "serde_json", +] + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" diff --git a/firmware/zeroclaw-esp32/Cargo.toml b/firmware/zeroclaw-esp32/Cargo.toml new file mode 100644 index 0000000..2ec056f --- /dev/null +++ b/firmware/zeroclaw-esp32/Cargo.toml @@ -0,0 +1,41 @@ +# ZeroClaw ESP32 firmware — JSON-over-serial peripheral for host-mediated control. +# +# Flash to ESP32 and connect via serial. The host ZeroClaw sends gpio_read/gpio_write +# commands; this firmware executes them and responds. +# +# Prerequisites: espup (cargo install espup; espup install; source ~/export-esp.sh) +# Build: cargo build --release +# Flash: cargo espflash flash --monitor + +[package] +name = "zeroclaw-esp32" +version = "0.1.0" +edition = "2021" +license = "MIT" +description = "ZeroClaw ESP32 peripheral firmware — GPIO over JSON serial" + +[patch.crates-io] +# Use latest esp-rs crates to fix u8/i8 char pointer compatibility with ESP-IDF 5.x +esp-idf-sys = { git = "https://github.com/esp-rs/esp-idf-sys" } +esp-idf-hal = { git = "https://github.com/esp-rs/esp-idf-hal" } +esp-idf-svc = { git = "https://github.com/esp-rs/esp-idf-svc" } + +[dependencies] +esp-idf-svc = { git = "https://github.com/esp-rs/esp-idf-svc" } +log = "0.4" +anyhow = "1.0" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" + +[build-dependencies] +embuild = { version = "0.33", features = ["espidf"] } + +[profile.release] +opt-level = "s" +lto = true +codegen-units = 1 +strip = true +panic = "abort" + +[profile.dev] +opt-level = "s" diff --git a/firmware/zeroclaw-esp32/README.md b/firmware/zeroclaw-esp32/README.md new file mode 100644 index 0000000..f4b2c08 --- /dev/null +++ b/firmware/zeroclaw-esp32/README.md @@ -0,0 +1,80 @@ +# ZeroClaw ESP32 Firmware + +Peripheral firmware for ESP32 — speaks the same JSON-over-serial protocol as the STM32 firmware. Flash this to your ESP32, then configure ZeroClaw on the host to connect via serial. + +**New to this?** See [SETUP.md](SETUP.md) for step-by-step commands and troubleshooting. + +## Protocol + + +- **Request** (host → ESP32): `{"id":"1","cmd":"gpio_write","args":{"pin":13,"value":1}}\n` +- **Response** (ESP32 → host): `{"id":"1","ok":true,"result":"done"}\n` + +Commands: `gpio_read`, `gpio_write`. + +## Prerequisites + +1. **RISC-V ESP-IDF** (ESP32-C2/C3): Uses nightly Rust with `build-std`. + + **Python**: ESP-IDF requires Python 3.10–3.13 (not 3.14). If you have Python 3.14: + ```sh + brew install python@3.12 + ``` + + **virtualenv** (needed by ESP-IDF tools; PEP 668 workaround on macOS): + ```sh + /opt/homebrew/opt/python@3.12/bin/python3.12 -m pip install virtualenv --break-system-packages + ``` + + **Rust tools**: + ```sh + cargo install espflash ldproxy + ``` + + The project's `rust-toolchain.toml` pins nightly + rust-src. `esp-idf-sys` downloads ESP-IDF automatically on first build. Use Python 3.12 for the build: + ```sh + export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH" + ``` + +2. **Xtensa targets** (ESP32, ESP32-S2, ESP32-S3): Use espup instead: + ```sh + cargo install espup espflash + espup install + source ~/export-esp.sh + ``` + Then edit `.cargo/config.toml` to change the target (e.g. `xtensa-esp32-espidf`). + +## Build & Flash + +```sh +cd firmware/zeroclaw-esp32 +# Use Python 3.12 (required if you have 3.14) +export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH" +# Optional: pin MCU (esp32c3 or esp32c2) +export MCU=esp32c3 +cargo build --release +espflash flash target/riscv32imc-esp-espidf/release/zeroclaw-esp32 --monitor +``` + +## Host Config + +Add to `config.toml`: + +```toml +[peripherals] +enabled = true + +[[peripherals.boards]] +board = "esp32" +transport = "serial" +path = "/dev/ttyUSB0" # or /dev/ttyACM0, COM3, etc. +baud = 115200 +``` + +## Pin Mapping + +Default GPIO 2 and 13 are configured for output. Edit `src/main.rs` to add more pins or change for your board. ESP32-C3 has different pin layout — adjust UART pins (gpio21/gpio20) if needed. + +## Edge-Native (Future) + +Phase 6 also envisions ZeroClaw running *on* the ESP32 (WiFi + LLM). This firmware is the host-mediated serial peripheral; edge-native will be a separate crate. diff --git a/firmware/zeroclaw-esp32/SETUP.md b/firmware/zeroclaw-esp32/SETUP.md new file mode 100644 index 0000000..0624f4d --- /dev/null +++ b/firmware/zeroclaw-esp32/SETUP.md @@ -0,0 +1,156 @@ +# ESP32 Firmware Setup Guide + +Step-by-step setup for building the ZeroClaw ESP32 firmware. Follow this if you run into issues. + +## Quick Start (copy-paste) + +```sh +# 1. Install Python 3.12 (ESP-IDF needs 3.10–3.13, not 3.14) +brew install python@3.12 + +# 2. Install virtualenv (PEP 668 workaround on macOS) +/opt/homebrew/opt/python@3.12/bin/python3.12 -m pip install virtualenv --break-system-packages + +# 3. Install Rust tools +cargo install espflash ldproxy + +# 4. Build +cd firmware/zeroclaw-esp32 +export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH" +cargo build --release + +# 5. Flash (connect ESP32 via USB) +espflash flash target/riscv32imc-esp-espidf/release/zeroclaw-esp32 --monitor +``` + +--- + +## Detailed Steps + +### 1. Python + +ESP-IDF requires Python 3.10–3.13. **Python 3.14 is not supported.** + +```sh +brew install python@3.12 +``` + +### 2. virtualenv + +ESP-IDF tools need `virtualenv`. On macOS with Homebrew Python, PEP 668 blocks `pip install`; use: + +```sh +/opt/homebrew/opt/python@3.12/bin/python3.12 -m pip install virtualenv --break-system-packages +``` + +### 3. Rust Tools + +```sh +cargo install espflash ldproxy +``` + +- **espflash**: flash and monitor +- **ldproxy**: linker for ESP-IDF builds + +### 4. Use Python 3.12 for Builds + +Before every build (or add to `~/.zshrc`): + +```sh +export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH" +``` + +### 5. Build + +```sh +cd firmware/zeroclaw-esp32 +cargo build --release +``` + +First build downloads and compiles ESP-IDF (~5–15 min). + +### 6. Flash + +```sh +espflash flash target/riscv32imc-esp-espidf/release/zeroclaw-esp32 --monitor +``` + +--- + +## Troubleshooting + +### "No space left on device" + +Free disk space. Common targets: + +```sh +# Cargo cache (often 5–20 GB) +rm -rf ~/.cargo/registry/cache ~/.cargo/registry/src + +# Unused Rust toolchains +rustup toolchain list +rustup toolchain uninstall + +# iOS Simulator runtimes (~35 GB) +xcrun simctl delete unavailable + +# Temp files +rm -rf /var/folders/*/T/cargo-install* +``` + +### "can't find crate for `core`" / "riscv32imc-esp-espidf target may not be installed" + +This project uses **nightly Rust with build-std**, not espup. Ensure: + +- `rust-toolchain.toml` exists (pins nightly + rust-src) +- You are **not** sourcing `~/export-esp.sh` (that's for Xtensa targets) +- Run `cargo build` from `firmware/zeroclaw-esp32` + +### "externally-managed-environment" / "No module named 'virtualenv'" + +Install virtualenv with the PEP 668 workaround: + +```sh +/opt/homebrew/opt/python@3.12/bin/python3.12 -m pip install virtualenv --break-system-packages +``` + +### "expected `i64`, found `i32`" (time_t mismatch) + +Already fixed in `.cargo/config.toml` with `espidf_time64` for ESP-IDF 5.x. If you use ESP-IDF 4.4, switch to `espidf_time32`. + +### "expected `*const u8`, found `*const i8`" (esp-idf-svc) + +Already fixed via `[patch.crates-io]` in `Cargo.toml` using esp-rs crates from git. Do not remove the patch. + +### 10,000+ files in `git status` + +The `.embuild/` directory (ESP-IDF cache) has ~100k+ files. It is in `.gitignore`. If you see them, ensure `.gitignore` contains: + +``` +.embuild/ +``` + +--- + +## Optional: Auto-load Python 3.12 + +Add to `~/.zshrc`: + +```sh +# ESP32 firmware build +export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH" +``` + +--- + +## Xtensa Targets (ESP32, ESP32-S2, ESP32-S3) + +For non–RISC-V chips, use espup instead: + +```sh +cargo install espup espflash +espup install +source ~/export-esp.sh +``` + +Then edit `.cargo/config.toml` to use `xtensa-esp32-espidf` (or the correct target). diff --git a/firmware/zeroclaw-esp32/build.rs b/firmware/zeroclaw-esp32/build.rs new file mode 100644 index 0000000..112ec3f --- /dev/null +++ b/firmware/zeroclaw-esp32/build.rs @@ -0,0 +1,3 @@ +fn main() { + embuild::espidf::sysenv::output(); +} diff --git a/firmware/zeroclaw-esp32/rust-toolchain.toml b/firmware/zeroclaw-esp32/rust-toolchain.toml new file mode 100644 index 0000000..f70d225 --- /dev/null +++ b/firmware/zeroclaw-esp32/rust-toolchain.toml @@ -0,0 +1,3 @@ +[toolchain] +channel = "nightly" +components = ["rust-src"] diff --git a/firmware/zeroclaw-esp32/src/main.rs b/firmware/zeroclaw-esp32/src/main.rs new file mode 100644 index 0000000..a85b67d --- /dev/null +++ b/firmware/zeroclaw-esp32/src/main.rs @@ -0,0 +1,163 @@ +//! ZeroClaw ESP32 firmware — JSON-over-serial peripheral. +//! +//! Listens for newline-delimited JSON commands on UART0, executes gpio_read/gpio_write, +//! responds with JSON. Compatible with host ZeroClaw SerialPeripheral protocol. +//! +//! Protocol: same as STM32 — see docs/hardware-peripherals-design.md + +use esp_idf_svc::hal::gpio::PinDriver; +use esp_idf_svc::hal::peripherals::Peripherals; +use esp_idf_svc::hal::uart::{UartConfig, UartDriver}; +use esp_idf_svc::hal::units::Hertz; +use log::info; +use serde::{Deserialize, Serialize}; + +/// Incoming command from host. +#[derive(Debug, Deserialize)] +struct Request { + id: String, + cmd: String, + args: serde_json::Value, +} + +/// Outgoing response to host. +#[derive(Debug, Serialize)] +struct Response { + id: String, + ok: bool, + result: String, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, +} + +fn main() -> anyhow::Result<()> { + esp_idf_svc::sys::link_patches(); + esp_idf_svc::log::EspLogger::initialize_default(); + + let peripherals = Peripherals::take()?; + let pins = peripherals.pins; + + // Create GPIO output drivers first (they take ownership of pins) + let mut gpio2 = PinDriver::output(pins.gpio2)?; + let mut gpio13 = PinDriver::output(pins.gpio13)?; + + // UART0: TX=21, RX=20 (ESP32) — ESP32-C3 may use different pins; adjust for your board + let config = UartConfig::new().baudrate(Hertz(115_200)); + let uart = UartDriver::new( + peripherals.uart0, + pins.gpio21, + pins.gpio20, + Option::::None, + Option::::None, + &config, + )?; + + info!("ZeroClaw ESP32 firmware ready on UART0 (115200)"); + + let mut buf = [0u8; 512]; + let mut line = Vec::new(); + + loop { + match uart.read(&mut buf, 100) { + Ok(0) => continue, + Ok(n) => { + for &b in &buf[..n] { + if b == b'\n' { + if !line.is_empty() { + if let Ok(line_str) = std::str::from_utf8(&line) { + if let Ok(resp) = handle_request(line_str, &mut gpio2, &mut gpio13) + { + let out = serde_json::to_string(&resp).unwrap_or_default(); + let _ = uart.write(format!("{}\n", out).as_bytes()); + } + } + line.clear(); + } + } else { + line.push(b); + if line.len() > 400 { + line.clear(); + } + } + } + } + Err(_) => {} + } + } +} + +fn handle_request( + line: &str, + gpio2: &mut PinDriver<'_, G2>, + gpio13: &mut PinDriver<'_, G13>, +) -> anyhow::Result +where + G2: esp_idf_svc::hal::gpio::OutputMode, + G13: esp_idf_svc::hal::gpio::OutputMode, +{ + let req: Request = serde_json::from_str(line.trim())?; + let id = req.id.clone(); + + let result = match req.cmd.as_str() { + "capabilities" => { + // Phase C: report GPIO pins and LED pin (matches Arduino protocol) + let caps = serde_json::json!({ + "gpio": [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 16, 17, 18, 19], + "led_pin": 2 + }); + Ok(caps.to_string()) + } + "gpio_read" => { + let pin_num = req.args.get("pin").and_then(|v| v.as_u64()).unwrap_or(0) as i32; + let value = gpio_read(pin_num)?; + Ok(value.to_string()) + } + "gpio_write" => { + let pin_num = req.args.get("pin").and_then(|v| v.as_u64()).unwrap_or(0) as i32; + let value = req.args.get("value").and_then(|v| v.as_u64()).unwrap_or(0); + gpio_write(gpio2, gpio13, pin_num, value)?; + Ok("done".into()) + } + _ => Err(anyhow::anyhow!("Unknown command: {}", req.cmd)), + }; + + match result { + Ok(r) => Ok(Response { + id, + ok: true, + result: r, + error: None, + }), + Err(e) => Ok(Response { + id, + ok: false, + result: String::new(), + error: Some(e.to_string()), + }), + } +} + +fn gpio_read(_pin: i32) -> anyhow::Result { + // TODO: implement input pin read — requires storing InputPin drivers per pin + Ok(0) +} + +fn gpio_write( + gpio2: &mut PinDriver<'_, G2>, + gpio13: &mut PinDriver<'_, G13>, + pin: i32, + value: u64, +) -> anyhow::Result<()> +where + G2: esp_idf_svc::hal::gpio::OutputMode, + G13: esp_idf_svc::hal::gpio::OutputMode, +{ + let level = esp_idf_svc::hal::gpio::Level::from(value != 0); + + match pin { + 2 => gpio2.set_level(level)?, + 13 => gpio13.set_level(level)?, + _ => anyhow::bail!("Pin {} not configured (add to gpio_write)", pin), + } + Ok(()) +} diff --git a/firmware/zeroclaw-nucleo/Cargo.lock b/firmware/zeroclaw-nucleo/Cargo.lock new file mode 100644 index 0000000..41b57b5 --- /dev/null +++ b/firmware/zeroclaw-nucleo/Cargo.lock @@ -0,0 +1,849 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "aligned" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee4508988c62edf04abd8d92897fca0c2995d907ce1dfeaf369dac3716a40685" +dependencies = [ + "as-slice", +] + +[[package]] +name = "as-slice" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "516b6b4f0e40d50dcda9365d53964ec74560ad4284da2e7fc97122cd83174516" +dependencies = [ + "stable_deref_trait", +] + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "bare-metal" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5deb64efa5bd81e31fcd1938615a6d98c82eafcbcd787162b6f63b91d6bac5b3" +dependencies = [ + "rustc_version", +] + +[[package]] +name = "bit_field" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e4b40c7323adcfc0a41c4b88143ed58346ff65a288fc144329c5c45e05d70c6" + +[[package]] +name = "bitfield" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46afbd2983a5d5a7bd740ccb198caf5b82f45c40c09c0eed36052d91cb92e719" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "bitflags" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" + +[[package]] +name = "block-device-driver" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44c051592f59fe68053524b4c4935249b806f72c1f544cfb7abe4f57c3be258e" +dependencies = [ + "aligned", +] + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "cortex-m" +version = "0.7.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ec610d8f49840a5b376c69663b6369e71f4b34484b9b2eb29fb918d92516cb9" +dependencies = [ + "bare-metal", + "bitfield", + "embedded-hal 0.2.7", + "volatile-register", +] + +[[package]] +name = "cortex-m-rt" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "801d4dec46b34c299ccf6b036717ae0fce602faa4f4fe816d9013b9a7c9f5ba6" +dependencies = [ + "cortex-m-rt-macros", +] + +[[package]] +name = "cortex-m-rt-macros" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e37549a379a9e0e6e576fd208ee60394ccb8be963889eebba3ffe0980364f472" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.116", +] + +[[package]] +name = "critical-section" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b" + +[[package]] +name = "darling" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.116", +] + +[[package]] +name = "darling_macro" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" +dependencies = [ + "darling_core", + "quote", + "syn 2.0.116", +] + +[[package]] +name = "defmt" +version = "0.3.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0963443817029b2024136fc4dd07a5107eb8f977eaf18fcd1fdeb11306b64ad" +dependencies = [ + "defmt 1.0.1", +] + +[[package]] +name = "defmt" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "548d977b6da32fa1d1fda2876453da1e7df63ad0304c8b3dae4dbe7b96f39b78" +dependencies = [ + "bitflags 1.3.2", + "defmt-macros", +] + +[[package]] +name = "defmt-macros" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d4fc12a85bcf441cfe44344c4b72d58493178ce635338a3f3b78943aceb258e" +dependencies = [ + "defmt-parser", + "proc-macro-error2", + "proc-macro2", + "quote", + "syn 2.0.116", +] + +[[package]] +name = "defmt-parser" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10d60334b3b2e7c9d91ef8150abfb6fa4c1c39ebbcf4a81c2e346aad939fee3e" +dependencies = [ + "thiserror", +] + +[[package]] +name = "defmt-rtt" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93d5a25c99d89c40f5676bec8cefe0614f17f0f40e916f98e345dae941807f9e" +dependencies = [ + "critical-section", + "defmt 1.0.1", +] + +[[package]] +name = "document-features" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4b8a88685455ed29a21542a33abd9cb6510b6b129abadabdcef0f4c55bc8f61" +dependencies = [ + "litrs", +] + +[[package]] +name = "embassy-embedded-hal" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "554e3e840696f54b4c9afcf28a0f24da431c927f4151040020416e7393d6d0d8" +dependencies = [ + "defmt 1.0.1", + "embassy-futures", + "embassy-hal-internal 0.3.0", + "embassy-sync", + "embassy-time", + "embedded-hal 0.2.7", + "embedded-hal 1.0.0", + "embedded-hal-async", + "embedded-storage", + "embedded-storage-async", + "nb 1.1.0", +] + +[[package]] +name = "embassy-executor" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06070468370195e0e86f241c8e5004356d696590a678d47d6676795b2e439c6b" +dependencies = [ + "cortex-m", + "critical-section", + "defmt 1.0.1", + "document-features", + "embassy-executor-macros", + "embassy-executor-timer-queue", +] + +[[package]] +name = "embassy-executor-macros" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfdddc3a04226828316bf31393b6903ee162238576b1584ee2669af215d55472" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.116", +] + +[[package]] +name = "embassy-executor-timer-queue" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2fc328bf943af66b80b98755db9106bf7e7471b0cf47dc8559cd9a6be504cc9c" + +[[package]] +name = "embassy-futures" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc2d050bdc5c21e0862a89256ed8029ae6c290a93aecefc73084b3002cdebb01" + +[[package]] +name = "embassy-hal-internal" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95285007a91b619dc9f26ea8f55452aa6c60f7115a4edc05085cd2bd3127cd7a" +dependencies = [ + "num-traits", +] + +[[package]] +name = "embassy-hal-internal" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f10ce10a4dfdf6402d8e9bd63128986b96a736b1a0a6680547ed2ac55d55dba" +dependencies = [ + "cortex-m", + "critical-section", + "defmt 1.0.1", + "num-traits", +] + +[[package]] +name = "embassy-net-driver" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "524eb3c489760508f71360112bca70f6e53173e6fe48fc5f0efd0f5ab217751d" +dependencies = [ + "defmt 0.3.100", +] + +[[package]] +name = "embassy-stm32" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "088d65743a48f2cc9b3ae274ed85d6e8b68bd3ee92eb6b87b15dca2f81f7a101" +dependencies = [ + "aligned", + "bit_field", + "bitflags 2.11.0", + "block-device-driver", + "cfg-if", + "cortex-m", + "cortex-m-rt", + "critical-section", + "defmt 1.0.1", + "document-features", + "embassy-embedded-hal", + "embassy-futures", + "embassy-hal-internal 0.4.0", + "embassy-net-driver", + "embassy-sync", + "embassy-time", + "embassy-time-driver", + "embassy-time-queue-utils", + "embassy-usb-driver", + "embassy-usb-synopsys-otg", + "embedded-can", + "embedded-hal 0.2.7", + "embedded-hal 1.0.0", + "embedded-hal-async", + "embedded-hal-nb", + "embedded-io 0.7.1", + "embedded-io-async 0.7.0", + "embedded-storage", + "embedded-storage-async", + "futures-util", + "heapless 0.9.2", + "nb 1.1.0", + "proc-macro2", + "quote", + "rand_core 0.6.4", + "rand_core 0.9.5", + "sdio-host", + "static_assertions", + "stm32-fmc", + "stm32-metapac", + "trait-set", + "vcell", + "volatile-register", +] + +[[package]] +name = "embassy-sync" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73974a3edbd0bd286759b3d483540f0ebef705919a5f56f4fc7709066f71689b" +dependencies = [ + "cfg-if", + "critical-section", + "defmt 1.0.1", + "embedded-io-async 0.6.1", + "futures-core", + "futures-sink", + "heapless 0.8.0", +] + +[[package]] +name = "embassy-time" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4fa65b9284d974dad7a23bb72835c4ec85c0b540d86af7fc4098c88cff51d65" +dependencies = [ + "cfg-if", + "critical-section", + "defmt 1.0.1", + "document-features", + "embassy-time-driver", + "embedded-hal 0.2.7", + "embedded-hal 1.0.0", + "embedded-hal-async", + "futures-core", +] + +[[package]] +name = "embassy-time-driver" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0a244c7dc22c8d0289379c8d8830cae06bb93d8f990194d0de5efb3b5ae7ba6" +dependencies = [ + "document-features", +] + +[[package]] +name = "embassy-time-queue-utils" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80e2ee86063bd028a420a5fb5898c18c87a8898026da1d4c852af2c443d0a454" +dependencies = [ + "embassy-executor-timer-queue", + "heapless 0.8.0", +] + +[[package]] +name = "embassy-usb-driver" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17119855ccc2d1f7470a39756b12068454ae27a3eabb037d940b5c03d9c77b7a" +dependencies = [ + "defmt 1.0.1", + "embedded-io-async 0.6.1", +] + +[[package]] +name = "embassy-usb-synopsys-otg" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "288751f8eaa44a5cf2613f13cee0ca8e06e6638cb96e897e6834702c79084b23" +dependencies = [ + "critical-section", + "defmt 1.0.1", + "embassy-sync", + "embassy-usb-driver", +] + +[[package]] +name = "embedded-can" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9d2e857f87ac832df68fa498d18ddc679175cf3d2e4aa893988e5601baf9438" +dependencies = [ + "nb 1.1.0", +] + +[[package]] +name = "embedded-hal" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35949884794ad573cf46071e41c9b60efb0cb311e3ca01f7af807af1debc66ff" +dependencies = [ + "nb 0.1.3", + "void", +] + +[[package]] +name = "embedded-hal" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "361a90feb7004eca4019fb28352a9465666b24f840f5c3cddf0ff13920590b89" + +[[package]] +name = "embedded-hal-async" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c4c685bbef7fe13c3c6dd4da26841ed3980ef33e841cddfa15ce8a8fb3f1884" +dependencies = [ + "embedded-hal 1.0.0", +] + +[[package]] +name = "embedded-hal-nb" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fba4268c14288c828995299e59b12babdbe170f6c6d73731af1b4648142e8605" +dependencies = [ + "embedded-hal 1.0.0", + "nb 1.1.0", +] + +[[package]] +name = "embedded-io" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edd0f118536f44f5ccd48bcb8b111bdc3de888b58c74639dfb034a357d0f206d" + +[[package]] +name = "embedded-io" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9eb1aa714776b75c7e67e1da744b81a129b3ff919c8712b5e1b32252c1f07cc7" +dependencies = [ + "defmt 1.0.1", +] + +[[package]] +name = "embedded-io-async" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ff09972d4073aa8c299395be75161d582e7629cd663171d62af73c8d50dba3f" +dependencies = [ + "embedded-io 0.6.1", +] + +[[package]] +name = "embedded-io-async" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2564b9f813c544241430e147d8bc454815ef9ac998878d30cc3055449f7fd4c0" +dependencies = [ + "defmt 1.0.1", + "embedded-io 0.7.1", +] + +[[package]] +name = "embedded-storage" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a21dea9854beb860f3062d10228ce9b976da520a73474aed3171ec276bc0c032" + +[[package]] +name = "embedded-storage-async" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1763775e2323b7d5f0aa6090657f5e21cfa02ede71f5dc40eead06d64dcd15cc" +dependencies = [ + "embedded-storage", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "futures-core" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" + +[[package]] +name = "futures-sink" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" + +[[package]] +name = "futures-task" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" + +[[package]] +name = "futures-util" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" +dependencies = [ + "futures-core", + "futures-task", + "pin-project-lite", +] + +[[package]] +name = "hash32" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47d60b12902ba28e2730cd37e95b8c9223af2808df9e902d4df49588d1470606" +dependencies = [ + "byteorder", +] + +[[package]] +name = "heapless" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bfb9eb618601c89945a70e254898da93b13be0388091d42117462b265bb3fad" +dependencies = [ + "hash32", + "stable_deref_trait", +] + +[[package]] +name = "heapless" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2af2455f757db2b292a9b1768c4b70186d443bcb3b316252d6b540aec1cd89ed" +dependencies = [ + "hash32", + "stable_deref_trait", +] + +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + +[[package]] +name = "litrs" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11d3d7f243d5c5a8b9bb5d6dd2b1602c0cb0b9db1621bafc7ed66e35ff9fe092" + +[[package]] +name = "nb" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "801d31da0513b6ec5214e9bf433a77966320625a37860f910be265be6e18d06f" +dependencies = [ + "nb 1.1.0", +] + +[[package]] +name = "nb" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d5439c4ad607c3c23abf66de8c8bf57ba8adcd1f129e699851a6e43935d339d" + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "panic-probe" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd402d00b0fb94c5aee000029204a46884b1262e0c443f166d86d2c0747e1a1a" +dependencies = [ + "cortex-m", + "defmt 1.0.1", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + +[[package]] +name = "proc-macro-error-attr2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96de42df36bb9bba5542fe9f1a054b8cc87e172759a1868aa05c1f3acc89dfc5" +dependencies = [ + "proc-macro2", + "quote", +] + +[[package]] +name = "proc-macro-error2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11ec05c52be0a07b08061f7dd003e7d7092e0472bc731b4af7bb1ef876109802" +dependencies = [ + "proc-macro-error-attr2", + "proc-macro2", + "quote", + "syn 2.0.116", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" + +[[package]] +name = "rand_core" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" + +[[package]] +name = "rustc_version" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a" +dependencies = [ + "semver", +] + +[[package]] +name = "sdio-host" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b328e2cb950eeccd55b7f55c3a963691455dcd044cfb5354f0c5e68d2c2d6ee2" + +[[package]] +name = "semver" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d7eb9ef2c18661902cc47e535f9bc51b78acd254da71d375c2f6720d9a40403" +dependencies = [ + "semver-parser", +] + +[[package]] +name = "semver-parser" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3" + +[[package]] +name = "stable_deref_trait" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" + +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + +[[package]] +name = "stm32-fmc" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72692594faa67f052e5e06dd34460951c21e83bc55de4feb8d2666e2f15480a2" +dependencies = [ + "embedded-hal 1.0.0", +] + +[[package]] +name = "stm32-metapac" +version = "19.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a411079520dbccc613af73172f944b7cf97ba84e3bd7381a0352b6ec7bfef03b" +dependencies = [ + "cortex-m", + "cortex-m-rt", + "defmt 0.3.100", +] + +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.116" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3df424c70518695237746f84cede799c9c58fcb37450d7b23716568cc8bc69cb" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "thiserror" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.116", +] + +[[package]] +name = "trait-set" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b79e2e9c9ab44c6d7c20d5976961b47e8f49ac199154daa514b77cd1ab536625" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "vcell" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77439c1b53d2303b20d9459b1ade71a83c716e3f9c34f3228c00e6f185d6c002" + +[[package]] +name = "void" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d" + +[[package]] +name = "volatile-register" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de437e2a6208b014ab52972a27e59b33fa2920d3e00fe05026167a1c509d19cc" +dependencies = [ + "vcell", +] + +[[package]] +name = "zeroclaw-nucleo" +version = "0.1.0" +dependencies = [ + "cortex-m-rt", + "critical-section", + "defmt 1.0.1", + "defmt-rtt", + "embassy-executor", + "embassy-stm32", + "embassy-time", + "heapless 0.9.2", + "panic-probe", +] diff --git a/firmware/zeroclaw-nucleo/Cargo.toml b/firmware/zeroclaw-nucleo/Cargo.toml new file mode 100644 index 0000000..a5d97f8 --- /dev/null +++ b/firmware/zeroclaw-nucleo/Cargo.toml @@ -0,0 +1,39 @@ +# ZeroClaw Nucleo-F401RE firmware — JSON-over-serial peripheral. +# +# Listens for newline-delimited JSON on USART2 (PA2/PA3, ST-Link VCP). +# Protocol: same as Arduino/ESP32 — ping, capabilities, gpio_read, gpio_write. +# +# Build: cargo build --release +# Flash: probe-rs run --chip STM32F401RETx target/thumbv7em-none-eabihf/release/zeroclaw-nucleo +# Or: zeroclaw peripheral flash-nucleo + +[package] +name = "zeroclaw-nucleo" +version = "0.1.0" +edition = "2021" +license = "MIT" +description = "ZeroClaw Nucleo-F401RE peripheral firmware — GPIO over JSON serial" + +[dependencies] +embassy-executor = { version = "0.9", features = ["arch-cortex-m", "executor-thread", "defmt"] } +embassy-stm32 = { version = "0.5", features = ["defmt", "stm32f401re", "unstable-pac", "memory-x", "time-driver-tim4", "exti"] } +embassy-time = { version = "0.5", features = ["defmt", "defmt-timestamp-uptime", "tick-hz-32_768"] } +defmt = "1.0" +defmt-rtt = "1.0" +panic-probe = { version = "1.0", features = ["print-defmt"] } +heapless = { version = "0.9", default-features = false } +critical-section = "1.1" +cortex-m-rt = "0.7" + +[package.metadata.embassy] +build = [ + { target = "thumbv7em-none-eabihf", artifact-dir = "target" } +] + +[profile.release] +opt-level = "s" +lto = true +codegen-units = 1 +strip = true +panic = "abort" +debug = 1 diff --git a/firmware/zeroclaw-nucleo/src/main.rs b/firmware/zeroclaw-nucleo/src/main.rs new file mode 100644 index 0000000..909645e --- /dev/null +++ b/firmware/zeroclaw-nucleo/src/main.rs @@ -0,0 +1,187 @@ +//! ZeroClaw Nucleo-F401RE firmware — JSON-over-serial peripheral. +//! +//! Listens for newline-delimited JSON on USART2 (PA2=TX, PA3=RX). +//! USART2 is connected to ST-Link VCP — host sees /dev/ttyACM0 (Linux) or /dev/cu.usbmodem* (macOS). +//! +//! Protocol: same as Arduino/ESP32 — see docs/hardware-peripherals-design.md + +#![no_std] +#![no_main] + +use core::fmt::Write; +use core::str; +use defmt::info; +use embassy_executor::Spawner; +use embassy_stm32::gpio::{Level, Output, Speed}; +use embassy_stm32::usart::{Config, Uart}; +use heapless::String; +use {defmt_rtt as _, panic_probe as _}; + +/// Arduino-style pin 13 = PA5 (User LED LD2 on Nucleo-F401RE) +const LED_PIN: u8 = 13; + +/// Parse integer from JSON: "pin":13 or "value":1 +fn parse_arg(line: &[u8], key: &[u8]) -> Option { + // key like b"pin" -> search for b"\"pin\":" + let mut suffix: [u8; 32] = [0; 32]; + suffix[0] = b'"'; + let mut len = 1; + for (i, &k) in key.iter().enumerate() { + if i >= 30 { + break; + } + suffix[len] = k; + len += 1; + } + suffix[len] = b'"'; + suffix[len + 1] = b':'; + len += 2; + let suffix = &suffix[..len]; + + let line_len = line.len(); + if line_len < len { + return None; + } + for i in 0..=line_len - len { + if line[i..].starts_with(suffix) { + let rest = &line[i + len..]; + let mut num: i32 = 0; + let mut neg = false; + let mut j = 0; + if j < rest.len() && rest[j] == b'-' { + neg = true; + j += 1; + } + while j < rest.len() && rest[j].is_ascii_digit() { + num = num * 10 + (rest[j] - b'0') as i32; + j += 1; + } + return Some(if neg { -num } else { num }); + } + } + None +} + +fn has_cmd(line: &[u8], cmd: &[u8]) -> bool { + let mut pat: [u8; 64] = [0; 64]; + pat[0..7].copy_from_slice(b"\"cmd\":\""); + let clen = cmd.len().min(50); + pat[7..7 + clen].copy_from_slice(&cmd[..clen]); + pat[7 + clen] = b'"'; + let pat = &pat[..8 + clen]; + + let line_len = line.len(); + if line_len < pat.len() { + return false; + } + for i in 0..=line_len - pat.len() { + if line[i..].starts_with(pat) { + return true; + } + } + false +} + +/// Extract "id" for response +fn copy_id(line: &[u8], out: &mut [u8]) -> usize { + let prefix = b"\"id\":\""; + if line.len() < prefix.len() + 1 { + out[0] = b'0'; + return 1; + } + for i in 0..=line.len() - prefix.len() { + if line[i..].starts_with(prefix) { + let start = i + prefix.len(); + let mut j = 0; + while start + j < line.len() && j < out.len() - 1 && line[start + j] != b'"' { + out[j] = line[start + j]; + j += 1; + } + return j; + } + } + out[0] = b'0'; + 1 +} + +#[embassy_executor::main] +async fn main(_spawner: Spawner) { + let p = embassy_stm32::init(Default::default()); + + let mut config = Config::default(); + config.baudrate = 115_200; + + let mut usart = Uart::new_blocking(p.USART2, p.PA3, p.PA2, config).unwrap(); + let mut led = Output::new(p.PA5, Level::Low, Speed::Low); + + info!("ZeroClaw Nucleo firmware ready on USART2 (115200)"); + + let mut line_buf: heapless::Vec = heapless::Vec::new(); + let mut id_buf = [0u8; 16]; + let mut resp_buf: String<128> = String::new(); + + loop { + let mut byte = [0u8; 1]; + if usart.blocking_read(&mut byte).is_ok() { + let b = byte[0]; + if b == b'\n' || b == b'\r' { + if !line_buf.is_empty() { + let id_len = copy_id(&line_buf, &mut id_buf); + let id_str = str::from_utf8(&id_buf[..id_len]).unwrap_or("0"); + + resp_buf.clear(); + if has_cmd(&line_buf, b"ping") { + let _ = write!(resp_buf, "{{\"id\":\"{}\",\"ok\":true,\"result\":\"pong\"}}", id_str); + } else if has_cmd(&line_buf, b"capabilities") { + let _ = write!( + resp_buf, + "{{\"id\":\"{}\",\"ok\":true,\"result\":\"{{\\\"gpio\\\":[0,1,2,3,4,5,6,7,8,9,10,11,12,13],\\\"led_pin\\\":13}}\"}}", + id_str + ); + } else if has_cmd(&line_buf, b"gpio_read") { + let pin = parse_arg(&line_buf, b"pin").unwrap_or(-1); + if pin == LED_PIN as i32 { + // Output doesn't support read; return 0 (LED state not readable) + let _ = write!(resp_buf, "{{\"id\":\"{}\",\"ok\":true,\"result\":\"0\"}}", id_str); + } else if pin >= 0 && pin <= 13 { + let _ = write!(resp_buf, "{{\"id\":\"{}\",\"ok\":true,\"result\":\"0\"}}", id_str); + } else { + let _ = write!( + resp_buf, + "{{\"id\":\"{}\",\"ok\":false,\"result\":\"\",\"error\":\"Invalid pin {}\"}}", + id_str, pin + ); + } + } else if has_cmd(&line_buf, b"gpio_write") { + let pin = parse_arg(&line_buf, b"pin").unwrap_or(-1); + let value = parse_arg(&line_buf, b"value").unwrap_or(0); + if pin == LED_PIN as i32 { + led.set_level(if value != 0 { Level::High } else { Level::Low }); + let _ = write!(resp_buf, "{{\"id\":\"{}\",\"ok\":true,\"result\":\"done\"}}", id_str); + } else if pin >= 0 && pin <= 13 { + let _ = write!(resp_buf, "{{\"id\":\"{}\",\"ok\":true,\"result\":\"done\"}}", id_str); + } else { + let _ = write!( + resp_buf, + "{{\"id\":\"{}\",\"ok\":false,\"result\":\"\",\"error\":\"Invalid pin {}\"}}", + id_str, pin + ); + } + } else { + let _ = write!( + resp_buf, + "{{\"id\":\"{}\",\"ok\":false,\"result\":\"\",\"error\":\"Unknown command\"}}", + id_str + ); + } + + let _ = usart.blocking_write(resp_buf.as_bytes()); + let _ = usart.blocking_write(b"\n"); + line_buf.clear(); + } + } else if line_buf.push(b).is_err() { + line_buf.clear(); + } + } + } +} diff --git a/firmware/zeroclaw-uno-q-bridge/app.yaml b/firmware/zeroclaw-uno-q-bridge/app.yaml new file mode 100644 index 0000000..32c5eb6 --- /dev/null +++ b/firmware/zeroclaw-uno-q-bridge/app.yaml @@ -0,0 +1,9 @@ +name: ZeroClaw Bridge +description: "GPIO bridge for ZeroClaw — exposes digitalWrite/digitalRead via socket for agent control" +icon: 🦀 +version: "1.0.0" + +ports: + - 9999 + +bricks: [] diff --git a/firmware/zeroclaw-uno-q-bridge/python/main.py b/firmware/zeroclaw-uno-q-bridge/python/main.py new file mode 100644 index 0000000..d4b286b --- /dev/null +++ b/firmware/zeroclaw-uno-q-bridge/python/main.py @@ -0,0 +1,66 @@ +# ZeroClaw Bridge — socket server for GPIO control from ZeroClaw agent +# SPDX-License-Identifier: MPL-2.0 + +import socket +import threading +from arduino.app_utils import App, Bridge + +ZEROCLAW_PORT = 9999 + +def handle_client(conn): + try: + data = conn.recv(256).decode().strip() + if not data: + conn.close() + return + parts = data.split() + if len(parts) < 2: + conn.sendall(b"error: invalid command\n") + conn.close() + return + cmd = parts[0].lower() + if cmd == "gpio_write" and len(parts) >= 3: + pin = int(parts[1]) + value = int(parts[2]) + Bridge.call("digitalWrite", [pin, value]) + conn.sendall(b"ok\n") + elif cmd == "gpio_read" and len(parts) >= 2: + pin = int(parts[1]) + val = Bridge.call("digitalRead", [pin]) + conn.sendall(f"{val}\n".encode()) + else: + conn.sendall(b"error: unknown command\n") + except Exception as e: + try: + conn.sendall(f"error: {e}\n".encode()) + except Exception: + pass + finally: + conn.close() + +def accept_loop(server): + while True: + try: + conn, _ = server.accept() + t = threading.Thread(target=handle_client, args=(conn,)) + t.daemon = True + t.start() + except Exception: + break + +def loop(): + App.sleep(1) + +def main(): + server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + server.bind(("127.0.0.1", ZEROCLAW_PORT)) + server.listen(5) + server.settimeout(1.0) + t = threading.Thread(target=accept_loop, args=(server,)) + t.daemon = True + t.start() + App.run(user_loop=loop) + +if __name__ == "__main__": + main() diff --git a/firmware/zeroclaw-uno-q-bridge/python/requirements.txt b/firmware/zeroclaw-uno-q-bridge/python/requirements.txt new file mode 100644 index 0000000..a7fe2e0 --- /dev/null +++ b/firmware/zeroclaw-uno-q-bridge/python/requirements.txt @@ -0,0 +1 @@ +# ZeroClaw Bridge — no extra deps (arduino.app_utils is preinstalled on Uno Q) diff --git a/firmware/zeroclaw-uno-q-bridge/sketch/sketch.ino b/firmware/zeroclaw-uno-q-bridge/sketch/sketch.ino new file mode 100644 index 0000000..0e7b11b --- /dev/null +++ b/firmware/zeroclaw-uno-q-bridge/sketch/sketch.ino @@ -0,0 +1,24 @@ +// ZeroClaw Bridge — expose digitalWrite/digitalRead for agent GPIO control +// SPDX-License-Identifier: MPL-2.0 + +#include "Arduino_RouterBridge.h" + +void gpio_write(int pin, int value) { + pinMode(pin, OUTPUT); + digitalWrite(pin, value ? HIGH : LOW); +} + +int gpio_read(int pin) { + pinMode(pin, INPUT); + return digitalRead(pin); +} + +void setup() { + Bridge.begin(); + Bridge.provide("digitalWrite", gpio_write); + Bridge.provide("digitalRead", gpio_read); +} + +void loop() { + Bridge.update(); +} diff --git a/firmware/zeroclaw-uno-q-bridge/sketch/sketch.yaml b/firmware/zeroclaw-uno-q-bridge/sketch/sketch.yaml new file mode 100644 index 0000000..d9fe917 --- /dev/null +++ b/firmware/zeroclaw-uno-q-bridge/sketch/sketch.yaml @@ -0,0 +1,11 @@ +profiles: + default: + fqbn: arduino:zephyr:unoq + platforms: + - platform: arduino:zephyr + libraries: + - MsgPack (0.4.2) + - DebugLog (0.8.4) + - ArxContainer (0.7.0) + - ArxTypeTraits (0.3.1) +default_profile: default diff --git a/python/README.md b/python/README.md new file mode 100644 index 0000000..0f04f3e --- /dev/null +++ b/python/README.md @@ -0,0 +1,154 @@ +# zeroclaw-tools + +Python companion package for [ZeroClaw](https://github.com/zeroclaw-labs/zeroclaw) — LangGraph-based tool calling for consistent LLM agent execution. + +## Why This Package? + +Some LLM providers (particularly GLM-5/Zhipu and similar models) have inconsistent tool calling behavior when using text-based tool invocation. This package provides a LangGraph-based approach that delivers: + +- **Consistent tool calling** across all OpenAI-compatible providers +- **Automatic tool loop** — keeps calling tools until the task is complete +- **Easy extensibility** — add new tools with a simple `@tool` decorator +- **Framework agnostic** — works with any OpenAI-compatible API + +## Installation + +```bash +pip install zeroclaw-tools +``` + +With Discord integration: + +```bash +pip install zeroclaw-tools[discord] +``` + +## Quick Start + +### Basic Agent + +```python +import asyncio +from zeroclaw_tools import create_agent, shell, file_read, file_write +from langchain_core.messages import HumanMessage + +async def main(): + # Create agent with tools + agent = create_agent( + tools=[shell, file_read, file_write], + model="glm-5", + api_key="your-api-key", + base_url="https://api.z.ai/api/coding/paas/v4" + ) + + # Execute a task + result = await agent.ainvoke({ + "messages": [HumanMessage(content="List files in /tmp directory")] + }) + + print(result["messages"][-1].content) + +asyncio.run(main()) +``` + +### CLI Usage + +```bash +# Set environment variables +export API_KEY="your-api-key" +export API_BASE="https://api.z.ai/api/coding/paas/v4" + +# Run the CLI +zeroclaw-tools "List files in the current directory" + +# Interactive mode (no message required) +zeroclaw-tools -i +``` + +### Discord Bot + +```python +import os +from zeroclaw_tools.integrations import DiscordBot + +bot = DiscordBot( + token=os.environ["DISCORD_TOKEN"], + guild_id=123456789, + allowed_users=["123456789"] +) + +bot.run() +``` + +## Available Tools + +| Tool | Description | +|------|-------------| +| `shell` | Execute shell commands | +| `file_read` | Read file contents | +| `file_write` | Write content to files | +| `web_search` | Search the web (requires Brave API key) | +| `http_request` | Make HTTP requests | +| `memory_store` | Store data in memory | +| `memory_recall` | Recall stored data | + +## Creating Custom Tools + +```python +from zeroclaw_tools import tool + +@tool +def my_custom_tool(query: str) -> str: + """Description of what this tool does.""" + # Your implementation here + return f"Result for: {query}" + +# Use with agent +agent = create_agent(tools=[my_custom_tool]) +``` + +## Provider Compatibility + +Works with any OpenAI-compatible provider: + +- **Z.AI / GLM-5** — `https://api.z.ai/api/coding/paas/v4` +- **OpenRouter** — `https://openrouter.ai/api/v1` +- **Groq** — `https://api.groq.com/openai/v1` +- **DeepSeek** — `https://api.deepseek.com` +- **Ollama** — `http://localhost:11434/v1` +- **And many more...** + +## Architecture + +``` +┌─────────────────────────────────────────────┐ +│ Your Application │ +├─────────────────────────────────────────────┤ +│ zeroclaw-tools Agent │ +│ ┌─────────────────────────────────────┐ │ +│ │ LangGraph StateGraph │ │ +│ │ ┌───────────┐ ┌──────────┐ │ │ +│ │ │ Agent │───▶│ Tools │ │ │ +│ │ │ Node │◀───│ Node │ │ │ +│ │ └───────────┘ └──────────┘ │ │ +│ └─────────────────────────────────────┘ │ +├─────────────────────────────────────────────┤ +│ OpenAI-Compatible LLM Provider │ +└─────────────────────────────────────────────┘ +``` + +## Comparison with Rust ZeroClaw + +| Feature | Rust ZeroClaw | zeroclaw-tools | +|---------|---------------|----------------| +| **Binary size** | ~3.4 MB | Python package | +| **Memory** | <5 MB | ~50 MB | +| **Startup** | <10ms | ~500ms | +| **Tool consistency** | Model-dependent | LangGraph guarantees | +| **Extensibility** | Rust traits | Python decorators | + +Use **Rust ZeroClaw** for production edge deployments. Use **zeroclaw-tools** when you need guaranteed tool calling consistency or Python ecosystem integration. + +## License + +MIT License — see [LICENSE](../LICENSE) diff --git a/python/pyproject.toml b/python/pyproject.toml new file mode 100644 index 0000000..dea680b --- /dev/null +++ b/python/pyproject.toml @@ -0,0 +1,67 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "zeroclaw-tools" +version = "0.1.0" +description = "Python companion package for ZeroClaw - LangGraph-based tool calling for consistent LLM agent execution" +readme = "README.md" +license = "MIT" +requires-python = ">=3.10" +authors = [ + { name = "ZeroClaw Community" } +] +keywords = [ + "ai", + "llm", + "agent", + "langgraph", + "zeroclaw", + "tool-calling", +] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] +dependencies = [ + "langgraph>=0.2.0", + "langchain-core>=0.3.0", + "langchain-openai>=0.2.0", + "httpx>=0.25.0", +] + +[project.scripts] +zeroclaw-tools = "zeroclaw_tools.__main__:main" + +[project.optional-dependencies] +discord = ["discord.py>=2.3.0"] +telegram = ["python-telegram-bot>=20.0"] +dev = [ + "pytest>=7.0.0", + "pytest-asyncio>=0.21.0", + "ruff>=0.1.0", +] + +[project.urls] +Homepage = "https://github.com/zeroclaw-labs/zeroclaw" +Documentation = "https://github.com/zeroclaw-labs/zeroclaw/tree/main/python" +Repository = "https://github.com/zeroclaw-labs/zeroclaw" +Issues = "https://github.com/zeroclaw-labs/zeroclaw/issues" + +[tool.hatch.build.targets.wheel] +packages = ["zeroclaw_tools"] + +[tool.ruff] +line-length = 100 +target-version = "py310" + +[tool.pytest.ini_options] +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" diff --git a/python/tests/__init__.py b/python/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/tests/test_tools.py b/python/tests/test_tools.py new file mode 100644 index 0000000..c5242c7 --- /dev/null +++ b/python/tests/test_tools.py @@ -0,0 +1,103 @@ +""" +Tests for zeroclaw-tools package. +""" + +import pytest + + +def test_import_main(): + """Test that main package imports work.""" + from zeroclaw_tools import create_agent, shell, file_read, file_write + + assert callable(create_agent) + assert hasattr(shell, "invoke") + assert hasattr(file_read, "invoke") + assert hasattr(file_write, "invoke") + + +def test_import_tool_decorator(): + """Test that tool decorator works.""" + from zeroclaw_tools import tool + + @tool + def test_func(x: str) -> str: + """Test tool.""" + return x + + assert hasattr(test_func, "invoke") + + +def test_tool_decorator_custom_metadata(): + """Test that custom tool metadata is preserved.""" + from zeroclaw_tools import tool + + @tool(name="echo_tool", description="Echo input back") + def echo(value: str) -> str: + return value + + assert echo.name == "echo_tool" + assert "Echo input back" in echo.description + + +def test_agent_creation(): + """Test that agent can be created with default tools.""" + from zeroclaw_tools import create_agent, shell, file_read, file_write + + agent = create_agent( + tools=[shell, file_read, file_write], model="test-model", api_key="test-key" + ) + + assert agent is not None + assert agent.model == "test-model" + + +def test_cli_allows_interactive_without_message(): + """Interactive mode should not require positional message.""" + from zeroclaw_tools.__main__ import parse_args + + args = parse_args(["-i"]) + + assert args.interactive is True + assert args.message == [] + + +def test_cli_requires_message_when_not_interactive(): + """Non-interactive mode requires at least one message token.""" + from zeroclaw_tools.__main__ import parse_args + + with pytest.raises(SystemExit): + parse_args([]) + + +@pytest.mark.asyncio +async def test_invoke_in_event_loop_raises(): + """invoke() should fail fast when called from an active event loop.""" + from zeroclaw_tools import create_agent, shell + + agent = create_agent(tools=[shell], model="test-model", api_key="test-key") + + with pytest.raises(RuntimeError, match="ainvoke"): + agent.invoke({"messages": []}) + + +@pytest.mark.asyncio +async def test_shell_tool(): + """Test shell tool execution.""" + from zeroclaw_tools import shell + + result = await shell.ainvoke({"command": "echo hello"}) + assert "hello" in result + + +@pytest.mark.asyncio +async def test_file_tools(tmp_path): + """Test file read/write tools.""" + from zeroclaw_tools import file_read, file_write + + test_file = tmp_path / "test.txt" + + write_result = await file_write.ainvoke({"path": str(test_file), "content": "Hello, World!"}) + assert "Successfully" in write_result + + read_result = await file_read.ainvoke({"path": str(test_file)}) + assert "Hello, World!" in read_result diff --git a/python/zeroclaw_tools/__init__.py b/python/zeroclaw_tools/__init__.py new file mode 100644 index 0000000..be72de5 --- /dev/null +++ b/python/zeroclaw_tools/__init__.py @@ -0,0 +1,32 @@ +""" +ZeroClaw Tools - LangGraph-based tool calling for consistent LLM agent execution. + +This package provides a reliable tool-calling layer for LLM providers that may have +inconsistent native tool calling behavior. Built on LangGraph for guaranteed execution. +""" + +from .agent import create_agent, ZeroclawAgent +from .tools import ( + shell, + file_read, + file_write, + web_search, + http_request, + memory_store, + memory_recall, +) +from .tools.base import tool + +__version__ = "0.1.0" +__all__ = [ + "create_agent", + "ZeroclawAgent", + "tool", + "shell", + "file_read", + "file_write", + "web_search", + "http_request", + "memory_store", + "memory_recall", +] diff --git a/python/zeroclaw_tools/__main__.py b/python/zeroclaw_tools/__main__.py new file mode 100644 index 0000000..1d284a5 --- /dev/null +++ b/python/zeroclaw_tools/__main__.py @@ -0,0 +1,133 @@ +""" +CLI entry point for zeroclaw-tools. +""" + +import argparse +import asyncio +import os +import sys +from typing import Optional + +from langchain_core.messages import HumanMessage + +from .agent import create_agent +from .tools import ( + shell, + file_read, + file_write, + web_search, + http_request, + memory_store, + memory_recall, +) + + +DEFAULT_SYSTEM_PROMPT = """You are ZeroClaw, an AI assistant with full system access. Use tools to accomplish tasks. +Be concise and helpful. Execute tools directly without excessive explanation.""" + + +async def chat(message: str, api_key: str, base_url: Optional[str], model: str) -> str: + """Run a single chat message through the agent.""" + agent = create_agent( + tools=[shell, file_read, file_write, web_search, http_request, memory_store, memory_recall], + model=model, + api_key=api_key, + base_url=base_url, + system_prompt=DEFAULT_SYSTEM_PROMPT, + ) + + result = await agent.ainvoke({"messages": [HumanMessage(content=message)]}) + return result["messages"][-1].content or "Done." + + +def _build_parser() -> argparse.ArgumentParser: + """Build CLI argument parser.""" + parser = argparse.ArgumentParser( + description="ZeroClaw Tools - LangGraph-based tool calling for LLMs" + ) + parser.add_argument( + "message", + nargs="*", + help="Message to send to the agent (optional in interactive mode)", + ) + parser.add_argument("--model", "-m", default="glm-5", help="Model to use") + parser.add_argument("--api-key", "-k", default=None, help="API key") + parser.add_argument("--base-url", "-u", default=None, help="API base URL") + parser.add_argument("--interactive", "-i", action="store_true", help="Interactive mode") + return parser + + +def parse_args(argv: list[str] | None = None) -> argparse.Namespace: + """Parse CLI arguments and enforce mode-specific requirements.""" + parser = _build_parser() + args = parser.parse_args(argv) + + if not args.interactive and not args.message: + parser.error("message is required unless --interactive is set") + + return args + + +def main(argv: list[str] | None = None): + """CLI main entry point.""" + args = parse_args(argv) + + api_key = args.api_key or os.environ.get("API_KEY") or os.environ.get("GLM_API_KEY") + base_url = args.base_url or os.environ.get("API_BASE") + + if not api_key: + print("Error: API key required. Set API_KEY env var or use --api-key", file=sys.stderr) + sys.exit(1) + + if args.interactive: + print("ZeroClaw Tools CLI (Interactive Mode)") + print("Type 'exit' to quit\n") + + agent = create_agent( + tools=[ + shell, + file_read, + file_write, + web_search, + http_request, + memory_store, + memory_recall, + ], + model=args.model, + api_key=api_key, + base_url=base_url, + system_prompt=DEFAULT_SYSTEM_PROMPT, + ) + + history = [] + + while True: + try: + user_input = input("You: ").strip() + if not user_input: + continue + if user_input.lower() in ["exit", "quit", "q"]: + print("Goodbye!") + break + + history.append(HumanMessage(content=user_input)) + + result = asyncio.run(agent.ainvoke({"messages": history})) + + for msg in result["messages"][len(history) :]: + history.append(msg) + + response = result["messages"][-1].content or "Done." + print(f"\nZeroClaw: {response}\n") + + except KeyboardInterrupt: + print("\nGoodbye!") + break + else: + message = " ".join(args.message) + result = asyncio.run(chat(message, api_key, base_url, args.model)) + print(result) + + +if __name__ == "__main__": + main() diff --git a/python/zeroclaw_tools/agent.py b/python/zeroclaw_tools/agent.py new file mode 100644 index 0000000..35e9ab2 --- /dev/null +++ b/python/zeroclaw_tools/agent.py @@ -0,0 +1,173 @@ +""" +LangGraph-based agent factory for consistent tool calling. +""" + +import os +from typing import Any, Optional + +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.tools import BaseTool +from langchain_openai import ChatOpenAI +from langgraph.graph import StateGraph, MessagesState, END +from langgraph.prebuilt import ToolNode + + +SYSTEM_PROMPT = """You are ZeroClaw, an AI assistant with tool access. Use tools to accomplish tasks. +Be concise and helpful. Execute tools directly when needed without excessive explanation.""" +GLM_DEFAULT_BASE_URL = "https://api.z.ai/api/coding/paas/v4" + + +class ZeroclawAgent: + """ + LangGraph-based agent with consistent tool calling behavior. + + This agent wraps an LLM with LangGraph's tool execution loop, ensuring + reliable tool calling even with providers that have inconsistent native + tool calling support. + """ + + def __init__( + self, + tools: list[BaseTool], + model: str = "glm-5", + api_key: Optional[str] = None, + base_url: Optional[str] = None, + temperature: float = 0.7, + system_prompt: Optional[str] = None, + ): + self.tools = tools + self.model = model + self.temperature = temperature + self.system_prompt = system_prompt or SYSTEM_PROMPT + + api_key = api_key or os.environ.get("API_KEY") or os.environ.get("GLM_API_KEY") + base_url = base_url or os.environ.get("API_BASE") + + if base_url is None and model.lower().startswith(("glm", "zhipu")): + base_url = GLM_DEFAULT_BASE_URL + + if not api_key: + raise ValueError( + "API key required. Set API_KEY environment variable or pass api_key parameter." + ) + + self.llm = ChatOpenAI( + model=model, + api_key=api_key, + base_url=base_url, + temperature=temperature, + ).bind_tools(tools) + + self._graph = self._build_graph() + + def _build_graph(self) -> StateGraph: + """Build the LangGraph execution graph.""" + tool_node = ToolNode(self.tools) + + def should_continue(state: MessagesState) -> str: + messages = state["messages"] + last_message = messages[-1] + if hasattr(last_message, "tool_calls") and last_message.tool_calls: + return "tools" + return END + + async def call_model(state: MessagesState) -> dict: + response = await self.llm.ainvoke(state["messages"]) + return {"messages": [response]} + + workflow = StateGraph(MessagesState) + workflow.add_node("agent", call_model) + workflow.add_node("tools", tool_node) + workflow.set_entry_point("agent") + workflow.add_conditional_edges("agent", should_continue, {"tools": "tools", END: END}) + workflow.add_edge("tools", "agent") + + return workflow.compile() + + async def ainvoke(self, input: dict[str, Any], config: Optional[dict] = None) -> dict: + """ + Asynchronously invoke the agent. + + Args: + input: Dict with "messages" key containing list of messages + config: Optional LangGraph config + + Returns: + Dict with "messages" key containing the conversation + """ + messages = input.get("messages", []) + + if messages and isinstance(messages[0], HumanMessage): + if not any(isinstance(m, SystemMessage) for m in messages): + messages = [SystemMessage(content=self.system_prompt)] + messages + + return await self._graph.ainvoke({"messages": messages}, config) + + def invoke(self, input: dict[str, Any], config: Optional[dict] = None) -> dict: + """ + Synchronously invoke the agent. + """ + import asyncio + + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(self.ainvoke(input, config)) + + raise RuntimeError( + "ZeroclawAgent.invoke() cannot be called inside an active event loop. " + "Use 'await ZeroclawAgent.ainvoke(...)' instead." + ) + + +def create_agent( + tools: Optional[list[BaseTool]] = None, + model: str = "glm-5", + api_key: Optional[str] = None, + base_url: Optional[str] = None, + temperature: float = 0.7, + system_prompt: Optional[str] = None, +) -> ZeroclawAgent: + """ + Create a ZeroClaw agent with LangGraph-based tool calling. + + Args: + tools: List of tools. Defaults to shell, file_read, file_write. + model: Model name to use + api_key: API key for the provider + base_url: Base URL for the provider API + temperature: Sampling temperature + system_prompt: Custom system prompt + + Returns: + Configured ZeroclawAgent instance + + Example: + ```python + from zeroclaw_tools import create_agent, shell, file_read + from langchain_core.messages import HumanMessage + + agent = create_agent( + tools=[shell, file_read], + model="glm-5", + api_key="your-key" + ) + + result = await agent.ainvoke({ + "messages": [HumanMessage(content="List files in /tmp")] + }) + ``` + """ + if tools is None: + from .tools import shell, file_read, file_write + + tools = [shell, file_read, file_write] + + return ZeroclawAgent( + tools=tools, + model=model, + api_key=api_key, + base_url=base_url, + temperature=temperature, + system_prompt=system_prompt, + ) diff --git a/python/zeroclaw_tools/integrations/__init__.py b/python/zeroclaw_tools/integrations/__init__.py new file mode 100644 index 0000000..ef58dbb --- /dev/null +++ b/python/zeroclaw_tools/integrations/__init__.py @@ -0,0 +1,7 @@ +""" +Integrations for supported external platforms. +""" + +from .discord_bot import DiscordBot + +__all__ = ["DiscordBot"] diff --git a/python/zeroclaw_tools/integrations/discord_bot.py b/python/zeroclaw_tools/integrations/discord_bot.py new file mode 100644 index 0000000..298f9f6 --- /dev/null +++ b/python/zeroclaw_tools/integrations/discord_bot.py @@ -0,0 +1,177 @@ +""" +Discord bot integration for ZeroClaw. +""" + +import os +from typing import Optional, Set + +try: + import discord + + DISCORD_AVAILABLE = True +except ImportError: + DISCORD_AVAILABLE = False + discord = None + +from langchain_core.messages import HumanMessage + +from ..agent import create_agent +from ..tools import shell, file_read, file_write, web_search + + +class DiscordBot: + """ + Discord bot powered by ZeroClaw agent with LangGraph tool calling. + + Example: + ```python + import os + from zeroclaw_tools.integrations import DiscordBot + + bot = DiscordBot( + token=os.environ["DISCORD_TOKEN"], + guild_id=123456789, + allowed_users=["123456789"], + api_key=os.environ["API_KEY"] + ) + + bot.run() + ``` + """ + + def __init__( + self, + token: str, + guild_id: int, + allowed_users: list[str], + api_key: Optional[str] = None, + base_url: Optional[str] = None, + model: str = "glm-5", + prefix: str = "", + ): + if not DISCORD_AVAILABLE: + raise ImportError( + "discord.py is required for Discord integration. " + "Install with: pip install zeroclaw-tools[discord]" + ) + + self.token = token + self.guild_id = guild_id + self.allowed_users: Set[str] = set(allowed_users) + self.api_key = api_key or os.environ.get("API_KEY") + self.base_url = base_url or os.environ.get("API_BASE") + self.model = model + self.prefix = prefix + + if not self.api_key: + raise ValueError( + "API key required. Set API_KEY environment variable or pass api_key parameter." + ) + + self.agent = create_agent( + tools=[shell, file_read, file_write, web_search], + model=self.model, + api_key=self.api_key, + base_url=self.base_url, + ) + + self._histories: dict[str, list] = {} + self._max_history = 20 + + intents = discord.Intents.default() + intents.message_content = True + intents.guilds = True + + self.client = discord.Client(intents=intents) + self._setup_events() + + def _setup_events(self): + @self.client.event + async def on_ready(): + print(f"ZeroClaw Discord Bot ready: {self.client.user}") + print(f"Guild: {self.guild_id}") + print(f"Allowed users: {self.allowed_users}") + + @self.client.event + async def on_message(message): + if message.author == self.client.user: + return + + if message.guild and message.guild.id != self.guild_id: + return + + user_id = str(message.author.id) + if user_id not in self.allowed_users: + return + + content = message.content.strip() + if not content: + return + + if self.prefix and not content.startswith(self.prefix): + return + + if self.prefix: + content = content[len(self.prefix) :].strip() + + print(f"[{message.author}] {content[:50]}...") + + async with message.channel.typing(): + try: + response = await self._process_message(content, user_id) + for chunk in self._split_message(response): + await message.reply(chunk) + except Exception as e: + print(f"Error: {e}") + await message.reply(f"Error: {e}") + + async def _process_message(self, content: str, user_id: str) -> str: + """Process a message and return the response.""" + messages = [] + + if user_id in self._histories: + for msg in self._histories[user_id][-10:]: + messages.append(msg) + + messages.append(HumanMessage(content=content)) + + result = await self.agent.ainvoke({"messages": messages}) + + if user_id not in self._histories: + self._histories[user_id] = [] + self._histories[user_id].append(HumanMessage(content=content)) + + for msg in result["messages"][len(messages) :]: + self._histories[user_id].append(msg) + + self._histories[user_id] = self._histories[user_id][-self._max_history * 2 :] + + final = result["messages"][-1] + return final.content or "Done." + + @staticmethod + def _split_message(text: str, max_len: int = 1900) -> list[str]: + """Split long messages for Discord's character limit.""" + if len(text) <= max_len: + return [text] + + chunks = [] + while text: + if len(text) <= max_len: + chunks.append(text) + break + + pos = text.rfind("\n", 0, max_len) + if pos == -1: + pos = text.rfind(" ", 0, max_len) + if pos == -1: + pos = max_len + + chunks.append(text[:pos].strip()) + text = text[pos:].strip() + + return chunks + + def run(self): + """Start the Discord bot.""" + self.client.run(self.token) diff --git a/python/zeroclaw_tools/tools/__init__.py b/python/zeroclaw_tools/tools/__init__.py new file mode 100644 index 0000000..230becf --- /dev/null +++ b/python/zeroclaw_tools/tools/__init__.py @@ -0,0 +1,20 @@ +""" +Built-in tools for ZeroClaw agents. +""" + +from .base import tool +from .shell import shell +from .file import file_read, file_write +from .web import web_search, http_request +from .memory import memory_store, memory_recall + +__all__ = [ + "tool", + "shell", + "file_read", + "file_write", + "web_search", + "http_request", + "memory_store", + "memory_recall", +] diff --git a/python/zeroclaw_tools/tools/base.py b/python/zeroclaw_tools/tools/base.py new file mode 100644 index 0000000..12fe337 --- /dev/null +++ b/python/zeroclaw_tools/tools/base.py @@ -0,0 +1,50 @@ +""" +Base utilities for creating tools. +""" + +from typing import Any, Callable, Optional + +from langchain_core.tools import tool as langchain_tool + + +def tool( + func: Optional[Callable] = None, + *, + name: Optional[str] = None, + description: Optional[str] = None, +) -> Any: + """ + Decorator to create a LangChain tool from a function. + + This is a convenience wrapper around langchain_core.tools.tool that + provides a simpler interface for ZeroClaw users. + + Args: + func: The function to wrap (when used without parentheses) + name: Optional custom name for the tool + description: Optional custom description + + Returns: + A BaseTool instance + + Example: + ```python + from zeroclaw_tools import tool + + @tool + def my_tool(query: str) -> str: + \"\"\"Description of what this tool does.\"\"\" + return f"Result: {query}" + ``` + """ + if func is not None: + if name is not None: + return langchain_tool(name, func, description=description) + return langchain_tool(func, description=description) + + def decorator(f: Callable) -> Any: + if name is not None: + return langchain_tool(name, f, description=description) + return langchain_tool(f, description=description) + + return decorator diff --git a/python/zeroclaw_tools/tools/file.py b/python/zeroclaw_tools/tools/file.py new file mode 100644 index 0000000..92265e7 --- /dev/null +++ b/python/zeroclaw_tools/tools/file.py @@ -0,0 +1,60 @@ +""" +File read/write tools. +""" + +import os + +from langchain_core.tools import tool + + +MAX_FILE_SIZE = 100_000 + + +@tool +def file_read(path: str) -> str: + """ + Read the contents of a file at the given path. + + Args: + path: The file path to read (absolute or relative) + + Returns: + The file contents, or an error message + """ + try: + with open(path, "r", encoding="utf-8", errors="replace") as f: + content = f.read() + if len(content) > MAX_FILE_SIZE: + return content[:MAX_FILE_SIZE] + f"\n... (truncated, {len(content)} bytes total)" + return content + except FileNotFoundError: + return f"Error: File not found: {path}" + except PermissionError: + return f"Error: Permission denied: {path}" + except Exception as e: + return f"Error: {e}" + + +@tool +def file_write(path: str, content: str) -> str: + """ + Write content to a file, creating directories if needed. + + Args: + path: The file path to write to + content: The content to write + + Returns: + Success message or error + """ + try: + parent = os.path.dirname(path) + if parent: + os.makedirs(parent, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + f.write(content) + return f"Successfully wrote {len(content)} bytes to {path}" + except PermissionError: + return f"Error: Permission denied: {path}" + except Exception as e: + return f"Error: {e}" diff --git a/python/zeroclaw_tools/tools/memory.py b/python/zeroclaw_tools/tools/memory.py new file mode 100644 index 0000000..f9586ce --- /dev/null +++ b/python/zeroclaw_tools/tools/memory.py @@ -0,0 +1,85 @@ +""" +Memory storage tools for persisting data between conversations. +""" + +import json +from pathlib import Path + +from langchain_core.tools import tool + + +def _get_memory_path() -> Path: + """Get the path to the memory storage file.""" + return Path.home() / ".zeroclaw" / "memory_store.json" + + +def _load_memory() -> dict: + """Load memory from disk.""" + path = _get_memory_path() + if not path.exists(): + return {} + try: + with open(path, "r", encoding="utf-8") as f: + return json.load(f) + except Exception: + return {} + + +def _save_memory(data: dict) -> None: + """Save memory to disk.""" + path = _get_memory_path() + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + + +@tool +def memory_store(key: str, value: str) -> str: + """ + Store a key-value pair in persistent memory. + + Args: + key: The key to store under + value: The value to store + + Returns: + Confirmation message + """ + try: + data = _load_memory() + data[key] = value + _save_memory(data) + return f"Stored: {key}" + except Exception as e: + return f"Error: {e}" + + +@tool +def memory_recall(query: str) -> str: + """ + Search memory for entries matching the query. + + Args: + query: The search query + + Returns: + Matching entries or "no matches" message + """ + try: + data = _load_memory() + if not data: + return "No memories stored yet" + + query_lower = query.lower() + matches = { + k: v + for k, v in data.items() + if query_lower in k.lower() or query_lower in str(v).lower() + } + + if not matches: + return f"No matches for: {query}" + + return json.dumps(matches, indent=2) + except Exception as e: + return f"Error: {e}" diff --git a/python/zeroclaw_tools/tools/shell.py b/python/zeroclaw_tools/tools/shell.py new file mode 100644 index 0000000..81e896f --- /dev/null +++ b/python/zeroclaw_tools/tools/shell.py @@ -0,0 +1,32 @@ +""" +Shell execution tool. +""" + +import subprocess + +from langchain_core.tools import tool + + +@tool +def shell(command: str) -> str: + """ + Execute a shell command and return the output. + + Args: + command: The shell command to execute + + Returns: + The command output (stdout and stderr combined) + """ + try: + result = subprocess.run(command, shell=True, capture_output=True, text=True, timeout=60) + output = result.stdout + if result.stderr: + output += f"\nSTDERR: {result.stderr}" + if result.returncode != 0: + output += f"\nExit code: {result.returncode}" + return output or "(no output)" + except subprocess.TimeoutExpired: + return "Error: Command timed out after 60 seconds" + except Exception as e: + return f"Error: {e}" diff --git a/python/zeroclaw_tools/tools/web.py b/python/zeroclaw_tools/tools/web.py new file mode 100644 index 0000000..110770b --- /dev/null +++ b/python/zeroclaw_tools/tools/web.py @@ -0,0 +1,88 @@ +""" +Web-related tools: HTTP requests and web search. +""" + +import json +import os +import urllib.error +import urllib.parse +import urllib.request + +from langchain_core.tools import tool + + +@tool +def http_request(url: str, method: str = "GET", headers: str = "", body: str = "") -> str: + """ + Make an HTTP request to a URL. + + Args: + url: The URL to request + method: HTTP method (GET, POST, PUT, DELETE, etc.) + headers: Comma-separated headers in format "Name: Value, Name2: Value2" + body: Request body for POST/PUT requests + + Returns: + The response status and body + """ + try: + req_headers = {"User-Agent": "ZeroClaw/1.0"} + if headers: + for h in headers.split(","): + if ":" in h: + k, v = h.split(":", 1) + req_headers[k.strip()] = v.strip() + + data = body.encode() if body else None + req = urllib.request.Request(url, data=data, headers=req_headers, method=method.upper()) + + with urllib.request.urlopen(req, timeout=30) as resp: + body_text = resp.read().decode("utf-8", errors="replace") + return f"Status: {resp.status}\n{body_text[:5000]}" + except urllib.error.HTTPError as e: + error_body = e.read().decode("utf-8", errors="replace")[:1000] + return f"HTTP Error {e.code}: {error_body}" + except Exception as e: + return f"Error: {e}" + + +@tool +def web_search(query: str) -> str: + """ + Search the web using Brave Search API. + + Requires BRAVE_API_KEY environment variable to be set. + + Args: + query: The search query + + Returns: + Search results as formatted text + """ + api_key = os.environ.get("BRAVE_API_KEY", "") + if not api_key: + return "Error: BRAVE_API_KEY environment variable not set. Get one at https://brave.com/search/api/" + + try: + encoded_query = urllib.parse.quote(query) + url = f"https://api.search.brave.com/res/v1/web/search?q={encoded_query}" + + req = urllib.request.Request( + url, headers={"Accept": "application/json", "X-Subscription-Token": api_key} + ) + + with urllib.request.urlopen(req, timeout=10) as resp: + data = json.loads(resp.read().decode()) + results = [] + + for item in data.get("web", {}).get("results", [])[:5]: + title = item.get("title", "No title") + url_link = item.get("url", "") + desc = item.get("description", "")[:200] + results.append(f"- {title}\n {url_link}\n {desc}") + + if not results: + return "No results found" + return "\n\n".join(results) + except Exception as e: + return f"Error: {e}" diff --git a/quick_test.sh b/quick_test.sh new file mode 100755 index 0000000..07f0eac --- /dev/null +++ b/quick_test.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Quick smoke test for Telegram integration +# Run this before committing code changes + +set -e + +echo "🔥 Quick Telegram Smoke Test" +echo "" + +# Test 1: Compile check +echo -n "1. Compiling... " +cargo build --release --quiet 2>&1 && echo "✓" || { echo "✗ FAILED"; exit 1; } + +# Test 2: Unit tests +echo -n "2. Running tests... " +cargo test telegram_split --lib --quiet 2>&1 && echo "✓" || { echo "✗ FAILED"; exit 1; } + +# Test 3: Health check +echo -n "3. Health check... " +timeout 7 target/release/zeroclaw channel doctor &>/dev/null && echo "✓" || echo "⚠ (configure bot first)" + +# Test 4: File checks +echo -n "4. Code structure... " +grep -q "TELEGRAM_MAX_MESSAGE_LENGTH" src/channels/telegram.rs && \ +grep -q "split_message_for_telegram" src/channels/telegram.rs && \ +grep -q "tokio::time::timeout" src/channels/telegram.rs && \ +echo "✓" || { echo "✗ FAILED"; exit 1; } + +echo "" +echo "✅ Quick tests passed! Run ./test_telegram_integration.sh for full suite." diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 0000000..f19782d --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,2 @@ +[toolchain] +channel = "1.92.0" diff --git a/scripts/ci/collect_changed_links.py b/scripts/ci/collect_changed_links.py new file mode 100755 index 0000000..01b45fe --- /dev/null +++ b/scripts/ci/collect_changed_links.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +import argparse +import os +import re +import subprocess +import sys +from pathlib import Path + + +DOC_PATH_RE = re.compile(r"\.mdx?$") +URL_RE = re.compile(r"https?://[^\s<>'\"]+") +INLINE_LINK_RE = re.compile(r"!?\[[^\]]*\]\(([^)]+)\)") +REF_LINK_RE = re.compile(r"^\s*\[[^\]]+\]:\s*(\S+)") +TRAILING_PUNCTUATION = ").,;:!?]}'\"" + + +def run_git(args: list[str]) -> subprocess.CompletedProcess[str]: + return subprocess.run(["git", *args], check=False, capture_output=True, text=True) + + +def commit_exists(rev: str) -> bool: + if not rev: + return False + return run_git(["cat-file", "-e", f"{rev}^{{commit}}"]).returncode == 0 + + +def normalize_docs_files(raw: str) -> list[str]: + if not raw: + return [] + files: list[str] = [] + for line in raw.splitlines(): + path = line.strip() + if path: + files.append(path) + return files + + +def infer_base_sha(provided: str) -> str: + if commit_exists(provided): + return provided + if run_git(["rev-parse", "--verify", "origin/main"]).returncode != 0: + return "" + proc = run_git(["merge-base", "origin/main", "HEAD"]) + candidate = proc.stdout.strip() + return candidate if commit_exists(candidate) else "" + + +def infer_docs_files(base_sha: str, provided: list[str]) -> list[str]: + if provided: + return provided + if not base_sha: + return [] + diff = run_git(["diff", "--name-only", base_sha, "HEAD"]) + files: list[str] = [] + for line in diff.stdout.splitlines(): + path = line.strip() + if not path: + continue + if DOC_PATH_RE.search(path) or path in {"LICENSE", ".github/pull_request_template.md"}: + files.append(path) + return files + + +def normalize_link_target(raw_target: str, source_path: str) -> str | None: + target = raw_target.strip() + if target.startswith("<") and target.endswith(">"): + target = target[1:-1].strip() + + if not target: + return None + + if " " in target: + target = target.split()[0].strip() + + if not target or target.startswith("#"): + return None + + lower = target.lower() + if lower.startswith(("mailto:", "tel:", "javascript:")): + return None + + if target.startswith(("http://", "https://")): + return target.rstrip(TRAILING_PUNCTUATION) + + path_without_fragment = target.split("#", 1)[0].split("?", 1)[0] + if not path_without_fragment: + return None + + if path_without_fragment.startswith("/"): + resolved = path_without_fragment.lstrip("/") + else: + resolved = os.path.normpath( + os.path.join(os.path.dirname(source_path) or ".", path_without_fragment) + ) + + if not resolved or resolved == ".": + return None + + return resolved + + +def extract_links(text: str, source_path: str) -> list[str]: + links: list[str] = [] + for match in URL_RE.findall(text): + url = match.rstrip(TRAILING_PUNCTUATION) + if url: + links.append(url) + + for match in INLINE_LINK_RE.findall(text): + normalized = normalize_link_target(match, source_path) + if normalized: + links.append(normalized) + + ref_match = REF_LINK_RE.match(text) + if ref_match: + normalized = normalize_link_target(ref_match.group(1), source_path) + if normalized: + links.append(normalized) + + return links + + +def added_lines_for_file(base_sha: str, path: str) -> list[str]: + if base_sha: + diff = run_git(["diff", "--unified=0", base_sha, "HEAD", "--", path]) + lines: list[str] = [] + for raw_line in diff.stdout.splitlines(): + if raw_line.startswith("+++"): + continue + if raw_line.startswith("+"): + lines.append(raw_line[1:]) + return lines + + file_path = Path(path) + if not file_path.is_file(): + return [] + return file_path.read_text(encoding="utf-8", errors="ignore").splitlines() + + +def main() -> int: + parser = argparse.ArgumentParser(description="Collect HTTP(S) links added in changed docs lines") + parser.add_argument("--base", default="", help="Base commit SHA") + parser.add_argument( + "--docs-files", + default="", + help="Newline-separated docs files list", + ) + parser.add_argument("--output", required=True, help="Output file for unique URLs") + args = parser.parse_args() + + base_sha = infer_base_sha(args.base) + docs_files = infer_docs_files(base_sha, normalize_docs_files(args.docs_files)) + + existing_files = [path for path in docs_files if Path(path).is_file()] + if not existing_files: + Path(args.output).write_text("", encoding="utf-8") + print("No docs files available for link collection.") + return 0 + + unique_urls: list[str] = [] + seen: set[str] = set() + for path in existing_files: + for line in added_lines_for_file(base_sha, path): + for link in extract_links(line, path): + if link not in seen: + seen.add(link) + unique_urls.append(link) + + Path(args.output).write_text("\n".join(unique_urls) + ("\n" if unique_urls else ""), encoding="utf-8") + print(f"Collected {len(unique_urls)} added link(s) from {len(existing_files)} docs file(s).") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/ci/docs_links_gate.sh b/scripts/ci/docs_links_gate.sh new file mode 100755 index 0000000..95e6a3d --- /dev/null +++ b/scripts/ci/docs_links_gate.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash + +set -euo pipefail + +BASE_SHA="${BASE_SHA:-}" +DOCS_FILES_RAW="${DOCS_FILES:-}" + +LINKS_FILE="$(mktemp)" +trap 'rm -f "$LINKS_FILE"' EXIT + +python3 ./scripts/ci/collect_changed_links.py \ + --base "$BASE_SHA" \ + --docs-files "$DOCS_FILES_RAW" \ + --output "$LINKS_FILE" + +if [ ! -s "$LINKS_FILE" ]; then + echo "No added links detected in changed docs lines." + exit 0 +fi + +if ! command -v lychee >/dev/null 2>&1; then + echo "lychee is required to run docs link gate locally." + echo "Install via: cargo install lychee" + exit 1 +fi + +echo "Checking added links with lychee (offline mode)..." +lychee --offline --no-progress --format detailed "$LINKS_FILE" diff --git a/scripts/ci/docs_quality_gate.sh b/scripts/ci/docs_quality_gate.sh new file mode 100755 index 0000000..989d81a --- /dev/null +++ b/scripts/ci/docs_quality_gate.sh @@ -0,0 +1,186 @@ +#!/usr/bin/env bash + +set -euo pipefail + +BASE_SHA="${BASE_SHA:-}" +DOCS_FILES_RAW="${DOCS_FILES:-}" + +if [ -z "$BASE_SHA" ] && git rev-parse --verify origin/main >/dev/null 2>&1; then + BASE_SHA="$(git merge-base origin/main HEAD)" +fi + +if [ -z "$DOCS_FILES_RAW" ] && [ -n "$BASE_SHA" ] && git cat-file -e "$BASE_SHA^{commit}" 2>/dev/null; then + DOCS_FILES_RAW="$(git diff --name-only "$BASE_SHA" HEAD | awk ' + /\.md$/ || /\.mdx$/ || $0 == "LICENSE" || $0 == ".github/pull_request_template.md" { + print + } + ')" +fi + +if [ -z "$DOCS_FILES_RAW" ]; then + echo "No docs files detected; skipping docs quality gate." + exit 0 +fi + +if [ -z "$BASE_SHA" ] || ! git cat-file -e "$BASE_SHA^{commit}" 2>/dev/null; then + echo "BASE_SHA is missing or invalid; falling back to full-file markdown lint." + BASE_SHA="" +fi + +ALL_FILES=() +while IFS= read -r file; do + if [ -n "$file" ]; then + ALL_FILES+=("$file") + fi +done < <(printf '%s\n' "$DOCS_FILES_RAW") + +if [ "${#ALL_FILES[@]}" -eq 0 ]; then + echo "No docs files detected after normalization; skipping docs quality gate." + exit 0 +fi + +EXISTING_FILES=() +for file in "${ALL_FILES[@]}"; do + if [ -f "$file" ]; then + EXISTING_FILES+=("$file") + fi +done + +if [ "${#EXISTING_FILES[@]}" -eq 0 ]; then + echo "No existing docs files to lint; skipping docs quality gate." + exit 0 +fi + +if command -v npx >/dev/null 2>&1; then + MD_CMD=(npx --yes markdownlint-cli2@0.20.0) +elif command -v markdownlint-cli2 >/dev/null 2>&1; then + MD_CMD=(markdownlint-cli2) +else + echo "markdownlint-cli2 is required (via npx or local binary)." + exit 1 +fi + +echo "Linting docs files: ${EXISTING_FILES[*]}" + +LINT_OUTPUT_FILE="$(mktemp)" +set +e +"${MD_CMD[@]}" "${EXISTING_FILES[@]}" >"$LINT_OUTPUT_FILE" 2>&1 +LINT_EXIT=$? +set -e + +if [ "$LINT_EXIT" -eq 0 ]; then + cat "$LINT_OUTPUT_FILE" + rm -f "$LINT_OUTPUT_FILE" + exit 0 +fi + +if [ -z "$BASE_SHA" ]; then + cat "$LINT_OUTPUT_FILE" + rm -f "$LINT_OUTPUT_FILE" + exit "$LINT_EXIT" +fi + +CHANGED_LINES_JSON_FILE="$(mktemp)" +python3 - "$BASE_SHA" "${EXISTING_FILES[@]}" >"$CHANGED_LINES_JSON_FILE" <<'PY' +import json +import re +import subprocess +import sys + +base = sys.argv[1] +files = sys.argv[2:] + +changed = {} +hunk = re.compile(r"^@@ -\d+(?:,\d+)? \+(\d+)(?:,(\d+))? @@") + +for path in files: + proc = subprocess.run( + ["git", "diff", "--unified=0", base, "HEAD", "--", path], + check=False, + capture_output=True, + text=True, + ) + ranges = [] + for line in proc.stdout.splitlines(): + m = hunk.match(line) + if not m: + continue + start = int(m.group(1)) + count = int(m.group(2) or "1") + if count > 0: + ranges.append([start, start + count - 1]) + changed[path] = ranges + +print(json.dumps(changed)) +PY + +FILTERED_OUTPUT_FILE="$(mktemp)" +set +e +python3 - "$LINT_OUTPUT_FILE" "$CHANGED_LINES_JSON_FILE" >"$FILTERED_OUTPUT_FILE" <<'PY' +import json +import re +import sys + +lint_file = sys.argv[1] +changed_file = sys.argv[2] + +with open(changed_file, "r", encoding="utf-8") as f: + changed = json.load(f) + +line_re = re.compile(r"^(.+?):(\d+)\s+error\s+(MD\d+(?:/[^\s]+)?)\s+(.*)$") + +blocking = [] +baseline = [] +other_lines = [] + +with open(lint_file, "r", encoding="utf-8") as f: + for raw_line in f: + line = raw_line.rstrip("\n") + m = line_re.match(line) + if not m: + other_lines.append(line) + continue + + path, line_no_s, rule, msg = m.groups() + line_no = int(line_no_s) + ranges = changed.get(path, []) + + is_changed_line = any(start <= line_no <= end for start, end in ranges) + entry = f"{path}:{line_no} {rule} {msg}" + if is_changed_line: + blocking.append(entry) + else: + baseline.append(entry) + +if baseline: + print("Existing markdown issues outside changed lines (non-blocking):") + for entry in baseline: + print(f" - {entry}") + +if blocking: + print("Markdown issues introduced on changed lines (blocking):") + for entry in blocking: + print(f" - {entry}") + print(f"Blocking markdown issues: {len(blocking)}") + sys.exit(1) + +if baseline: + print("No blocking markdown issues on changed lines.") + sys.exit(0) + +for line in other_lines: + print(line) + +if any(line.strip() for line in other_lines): + print("markdownlint exited non-zero with unclassified output; failing safe.") + sys.exit(2) + +print("No blocking markdown issues on changed lines.") +PY +SCRIPT_EXIT=$? +set -e + +cat "$FILTERED_OUTPUT_FILE" + +rm -f "$LINT_OUTPUT_FILE" "$CHANGED_LINES_JSON_FILE" "$FILTERED_OUTPUT_FILE" +exit "$SCRIPT_EXIT" diff --git a/scripts/ci/rust_quality_gate.sh b/scripts/ci/rust_quality_gate.sh new file mode 100755 index 0000000..75e7f1d --- /dev/null +++ b/scripts/ci/rust_quality_gate.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash + +set -euo pipefail + +MODE="correctness" +if [ "${1:-}" = "--strict" ]; then + MODE="strict" +fi + +echo "==> rust quality: cargo fmt --all -- --check" +cargo fmt --all -- --check + +if [ "$MODE" = "strict" ]; then + echo "==> rust quality: cargo clippy --locked --all-targets -- -D warnings" + cargo clippy --locked --all-targets -- -D warnings +else + echo "==> rust quality: cargo clippy --locked --all-targets -- -D clippy::correctness" + cargo clippy --locked --all-targets -- -D clippy::correctness +fi diff --git a/scripts/ci/rust_strict_delta_gate.sh b/scripts/ci/rust_strict_delta_gate.sh new file mode 100755 index 0000000..5f4ccc7 --- /dev/null +++ b/scripts/ci/rust_strict_delta_gate.sh @@ -0,0 +1,237 @@ +#!/usr/bin/env bash + +set -euo pipefail + +BASE_SHA="${BASE_SHA:-}" +RUST_FILES_RAW="${RUST_FILES:-}" + +if [ -z "$BASE_SHA" ] && git rev-parse --verify origin/main >/dev/null 2>&1; then + BASE_SHA="$(git merge-base origin/main HEAD)" +fi + +if [ -z "$BASE_SHA" ] && git rev-parse --verify HEAD~1 >/dev/null 2>&1; then + BASE_SHA="$(git rev-parse HEAD~1)" +fi + +if [ -z "$BASE_SHA" ] || ! git cat-file -e "$BASE_SHA^{commit}" 2>/dev/null; then + echo "BASE_SHA is missing or invalid for strict delta gate." + echo "Set BASE_SHA explicitly or ensure origin/main is available." + exit 1 +fi + +if [ -z "$RUST_FILES_RAW" ]; then + RUST_FILES_RAW="$(git diff --name-only "$BASE_SHA" HEAD | awk '/\.rs$/ { print }')" +fi + +ALL_FILES=() +while IFS= read -r file; do + if [ -n "$file" ]; then + ALL_FILES+=("$file") + fi +done < <(printf '%s\n' "$RUST_FILES_RAW") + +if [ "${#ALL_FILES[@]}" -eq 0 ]; then + echo "No Rust source files changed; skipping strict delta gate." + exit 0 +fi + +EXISTING_FILES=() +for file in "${ALL_FILES[@]}"; do + if [ -f "$file" ]; then + EXISTING_FILES+=("$file") + fi +done + +if [ "${#EXISTING_FILES[@]}" -eq 0 ]; then + echo "No existing changed Rust files to lint; skipping strict delta gate." + exit 0 +fi + +echo "Strict delta linting changed Rust files: ${EXISTING_FILES[*]}" + +CHANGED_LINES_JSON_FILE="$(mktemp)" +CLIPPY_JSON_FILE="$(mktemp)" +CLIPPY_STDERR_FILE="$(mktemp)" +FILTERED_OUTPUT_FILE="$(mktemp)" +trap 'rm -f "$CHANGED_LINES_JSON_FILE" "$CLIPPY_JSON_FILE" "$CLIPPY_STDERR_FILE" "$FILTERED_OUTPUT_FILE"' EXIT + +python3 - "$BASE_SHA" "${EXISTING_FILES[@]}" >"$CHANGED_LINES_JSON_FILE" <<'PY' +import json +import re +import subprocess +import sys + +base = sys.argv[1] +files = sys.argv[2:] +hunk = re.compile(r"^@@ -\d+(?:,\d+)? \+(\d+)(?:,(\d+))? @@") +changed = {} + +for path in files: + proc = subprocess.run( + ["git", "diff", "--unified=0", base, "HEAD", "--", path], + check=False, + capture_output=True, + text=True, + ) + ranges = [] + for line in proc.stdout.splitlines(): + match = hunk.match(line) + if not match: + continue + start = int(match.group(1)) + count = int(match.group(2) or "1") + if count > 0: + ranges.append([start, start + count - 1]) + changed[path] = ranges + +print(json.dumps(changed)) +PY + +set +e +cargo clippy --quiet --locked --all-targets --message-format=json -- -D warnings >"$CLIPPY_JSON_FILE" 2>"$CLIPPY_STDERR_FILE" +CLIPPY_EXIT=$? +set -e + +if [ "$CLIPPY_EXIT" -eq 0 ]; then + echo "Strict delta gate passed: no strict warnings/errors." + exit 0 +fi + +set +e +python3 - "$CLIPPY_JSON_FILE" "$CHANGED_LINES_JSON_FILE" >"$FILTERED_OUTPUT_FILE" <<'PY' +import json +import sys +from pathlib import Path + +messages_file = sys.argv[1] +changed_file = sys.argv[2] + +with open(changed_file, "r", encoding="utf-8") as f: + changed = json.load(f) + +cwd = Path.cwd().resolve() + + +def normalize_path(path_value: str) -> str: + path = Path(path_value) + if path.is_absolute(): + try: + return path.resolve().relative_to(cwd).as_posix() + except Exception: + return path.as_posix() + return path.as_posix() + + +blocking = [] +baseline = [] +unclassified = [] +classified_count = 0 + +with open(messages_file, "r", encoding="utf-8", errors="ignore") as f: + for raw_line in f: + line = raw_line.strip() + if not line: + continue + + try: + payload = json.loads(line) + except json.JSONDecodeError: + continue + + if payload.get("reason") != "compiler-message": + continue + + message = payload.get("message", {}) + level = message.get("level") + if level not in {"warning", "error"}: + continue + + code_obj = message.get("code") or {} + code = code_obj.get("code") if isinstance(code_obj, dict) else None + text = message.get("message", "") + spans = message.get("spans") or [] + + candidate_spans = [span for span in spans if span.get("is_primary")] + if not candidate_spans: + candidate_spans = spans + + span_entries = [] + for span in candidate_spans: + file_name = span.get("file_name") + line_start = span.get("line_start") + line_end = span.get("line_end") + if not file_name or line_start is None: + continue + norm_path = normalize_path(file_name) + span_entries.append((norm_path, int(line_start), int(line_end or line_start))) + + if not span_entries: + unclassified.append(f"{level.upper()} {code or '-'} {text}") + continue + + is_changed_line = False + best_path, best_line, _ = span_entries[0] + for path, line_start, line_end in span_entries: + ranges = changed.get(path) + if ranges is None: + continue + + for start, end in ranges: + if line_end >= start and line_start <= end: + is_changed_line = True + best_path, best_line = path, line_start + break + if is_changed_line: + break + + entry = f"{best_path}:{best_line} {level.upper()} {code or '-'} {text}" + classified_count += 1 + if is_changed_line: + blocking.append(entry) + else: + baseline.append(entry) + +if baseline: + print("Existing strict lint issues outside changed Rust lines (non-blocking):") + for entry in baseline: + print(f" - {entry}") + +if blocking: + print("Strict lint issues introduced on changed Rust lines (blocking):") + for entry in blocking: + print(f" - {entry}") + print(f"Blocking strict lint issues: {len(blocking)}") + sys.exit(1) + +if classified_count > 0: + print("No blocking strict lint issues on changed Rust lines.") + sys.exit(0) + +if unclassified: + print("Strict lint exited non-zero with unclassified diagnostics; failing safe:") + for entry in unclassified[:20]: + print(f" - {entry}") + sys.exit(2) + +print("Strict lint exited non-zero without parsable diagnostics; failing safe.") +sys.exit(2) +PY +FILTER_EXIT=$? +set -e + +cat "$FILTERED_OUTPUT_FILE" + +if [ "$FILTER_EXIT" -eq 0 ]; then + if [ -s "$CLIPPY_STDERR_FILE" ]; then + echo "clippy stderr summary (informational):" + cat "$CLIPPY_STDERR_FILE" + fi + exit 0 +fi + +if [ -s "$CLIPPY_STDERR_FILE" ]; then + echo "clippy stderr summary:" + cat "$CLIPPY_STDERR_FILE" +fi + +exit "$FILTER_EXIT" diff --git a/scripts/recompute_contributor_tiers.sh b/scripts/recompute_contributor_tiers.sh new file mode 100755 index 0000000..6e3e528 --- /dev/null +++ b/scripts/recompute_contributor_tiers.sh @@ -0,0 +1,324 @@ +#!/usr/bin/env bash + +set -euo pipefail + +SCRIPT_NAME="$(basename "$0")" + +usage() { + cat < Target repository (default: current gh repo) + --kind + Target objects (default: both) + --state + State filter for listing objects (default: all) + --limit Limit processed objects after fetch (default: 0 = no limit) + --apply Apply label updates (default is dry-run) + --dry-run Preview only (default) + -h, --help Show this help + +Examples: + ./$SCRIPT_NAME --repo zeroclaw-labs/zeroclaw --limit 50 + ./$SCRIPT_NAME --repo zeroclaw-labs/zeroclaw --kind prs --state open --apply +USAGE +} + +die() { + echo "[$SCRIPT_NAME] ERROR: $*" >&2 + exit 1 +} + +require_cmd() { + if ! command -v "$1" >/dev/null 2>&1; then + die "Required command not found: $1" + fi +} + +urlencode() { + jq -nr --arg value "$1" '$value|@uri' +} + +select_contributor_tier() { + local merged_count="$1" + if (( merged_count >= 50 )); then + echo "distinguished contributor" + elif (( merged_count >= 20 )); then + echo "principal contributor" + elif (( merged_count >= 10 )); then + echo "experienced contributor" + elif (( merged_count >= 5 )); then + echo "trusted contributor" + else + echo "" + fi +} + +DRY_RUN=1 +KIND="both" +STATE="all" +LIMIT=0 +REPO="" + +while (($# > 0)); do + case "$1" in + --repo) + [[ $# -ge 2 ]] || die "Missing value for --repo" + REPO="$2" + shift 2 + ;; + --kind) + [[ $# -ge 2 ]] || die "Missing value for --kind" + KIND="$2" + shift 2 + ;; + --state) + [[ $# -ge 2 ]] || die "Missing value for --state" + STATE="$2" + shift 2 + ;; + --limit) + [[ $# -ge 2 ]] || die "Missing value for --limit" + LIMIT="$2" + shift 2 + ;; + --apply) + DRY_RUN=0 + shift + ;; + --dry-run) + DRY_RUN=1 + shift + ;; + -h|--help) + usage + exit 0 + ;; + *) + die "Unknown option: $1" + ;; + esac +done + +case "$KIND" in + both|prs|issues) ;; + *) die "--kind must be one of: both, prs, issues" ;; +esac + +case "$STATE" in + all|open|closed) ;; + *) die "--state must be one of: all, open, closed" ;; +esac + +if ! [[ "$LIMIT" =~ ^[0-9]+$ ]]; then + die "--limit must be a non-negative integer" +fi + +require_cmd gh +require_cmd jq + +if ! gh auth status >/dev/null 2>&1; then + die "gh CLI is not authenticated. Run: gh auth login" +fi + +if [[ -z "$REPO" ]]; then + REPO="$(gh repo view --json nameWithOwner --jq '.nameWithOwner' 2>/dev/null || true)" + [[ -n "$REPO" ]] || die "Unable to infer repo. Pass --repo ." +fi + +echo "[$SCRIPT_NAME] Repo: $REPO" +echo "[$SCRIPT_NAME] Mode: $([[ "$DRY_RUN" -eq 1 ]] && echo "dry-run" || echo "apply")" +echo "[$SCRIPT_NAME] Kind: $KIND | State: $STATE | Limit: $LIMIT" + +TIERS_JSON='["trusted contributor","experienced contributor","principal contributor","distinguished contributor"]' + +TMP_FILES=() +cleanup() { + if ((${#TMP_FILES[@]} > 0)); then + rm -f "${TMP_FILES[@]}" + fi +} +trap cleanup EXIT + +new_tmp_file() { + local tmp + tmp="$(mktemp)" + TMP_FILES+=("$tmp") + echo "$tmp" +} + +targets_file="$(new_tmp_file)" + +if [[ "$KIND" == "both" || "$KIND" == "prs" ]]; then + gh api --paginate "repos/$REPO/pulls?state=$STATE&per_page=100" \ + --jq '.[] | { + kind: "pr", + number: .number, + author: (.user.login // ""), + author_type: (.user.type // ""), + labels: [(.labels[]?.name // empty)] + }' >> "$targets_file" +fi + +if [[ "$KIND" == "both" || "$KIND" == "issues" ]]; then + gh api --paginate "repos/$REPO/issues?state=$STATE&per_page=100" \ + --jq '.[] | select(.pull_request | not) | { + kind: "issue", + number: .number, + author: (.user.login // ""), + author_type: (.user.type // ""), + labels: [(.labels[]?.name // empty)] + }' >> "$targets_file" +fi + +if [[ "$LIMIT" -gt 0 ]]; then + limited_file="$(new_tmp_file)" + head -n "$LIMIT" "$targets_file" > "$limited_file" + mv "$limited_file" "$targets_file" +fi + +target_count="$(wc -l < "$targets_file" | tr -d ' ')" +if [[ "$target_count" -eq 0 ]]; then + echo "[$SCRIPT_NAME] No targets found." + exit 0 +fi + +echo "[$SCRIPT_NAME] Targets fetched: $target_count" + +# Ensure tier labels exist (trusted contributor might be new). +label_color="" +for probe_label in "experienced contributor" "principal contributor" "distinguished contributor" "trusted contributor"; do + encoded_label="$(urlencode "$probe_label")" + if color_candidate="$(gh api "repos/$REPO/labels/$encoded_label" --jq '.color' 2>/dev/null || true)"; then + if [[ -n "$color_candidate" ]]; then + label_color="$(echo "$color_candidate" | tr '[:lower:]' '[:upper:]')" + break + fi + fi +done +[[ -n "$label_color" ]] || label_color="C5D7A2" + +while IFS= read -r tier_label; do + [[ -n "$tier_label" ]] || continue + encoded_label="$(urlencode "$tier_label")" + if gh api "repos/$REPO/labels/$encoded_label" >/dev/null 2>&1; then + continue + fi + + if [[ "$DRY_RUN" -eq 1 ]]; then + echo "[dry-run] Would create missing label: $tier_label (color=$label_color)" + else + gh api -X POST "repos/$REPO/labels" \ + -f name="$tier_label" \ + -f color="$label_color" >/dev/null + echo "[apply] Created missing label: $tier_label" + fi +done < <(jq -r '.[]' <<<"$TIERS_JSON") + +# Build merged PR count cache by unique human authors. +authors_file="$(new_tmp_file)" +jq -r 'select(.author != "" and .author_type != "Bot") | .author' "$targets_file" | sort -u > "$authors_file" +author_count="$(wc -l < "$authors_file" | tr -d ' ')" +echo "[$SCRIPT_NAME] Unique human authors: $author_count" + +author_counts_file="$(new_tmp_file)" +while IFS= read -r author; do + [[ -n "$author" ]] || continue + query="repo:$REPO is:pr is:merged author:$author" + merged_count="$(gh api search/issues -f q="$query" -F per_page=1 --jq '.total_count' 2>/dev/null || true)" + if ! [[ "$merged_count" =~ ^[0-9]+$ ]]; then + merged_count=0 + fi + printf '%s\t%s\n' "$author" "$merged_count" >> "$author_counts_file" +done < "$authors_file" + +updated=0 +unchanged=0 +skipped=0 +failed=0 + +while IFS= read -r target_json; do + [[ -n "$target_json" ]] || continue + + number="$(jq -r '.number' <<<"$target_json")" + kind="$(jq -r '.kind' <<<"$target_json")" + author="$(jq -r '.author' <<<"$target_json")" + author_type="$(jq -r '.author_type' <<<"$target_json")" + current_labels_json="$(jq -c '.labels // []' <<<"$target_json")" + + if [[ -z "$author" || "$author_type" == "Bot" ]]; then + skipped=$((skipped + 1)) + continue + fi + + merged_count="$(awk -F '\t' -v key="$author" '$1 == key { print $2; exit }' "$author_counts_file")" + if ! [[ "$merged_count" =~ ^[0-9]+$ ]]; then + merged_count=0 + fi + desired_tier="$(select_contributor_tier "$merged_count")" + + if ! current_tier="$(jq -r --argjson tiers "$TIERS_JSON" '[.[] | select(. as $label | ($tiers | index($label)) != null)][0] // ""' <<<"$current_labels_json" 2>/dev/null)"; then + echo "[warn] Skipping ${kind} #${number}: cannot parse current labels JSON" >&2 + failed=$((failed + 1)) + continue + fi + + if ! next_labels_json="$(jq -c --arg desired "$desired_tier" --argjson tiers "$TIERS_JSON" ' + (. // []) + | map(select(. as $label | ($tiers | index($label)) == null)) + | if $desired != "" then . + [$desired] else . end + | unique + ' <<<"$current_labels_json" 2>/dev/null)"; then + echo "[warn] Skipping ${kind} #${number}: cannot compute next labels" >&2 + failed=$((failed + 1)) + continue + fi + + if ! normalized_current="$(jq -c 'unique | sort' <<<"$current_labels_json" 2>/dev/null)"; then + echo "[warn] Skipping ${kind} #${number}: cannot normalize current labels" >&2 + failed=$((failed + 1)) + continue + fi + + if ! normalized_next="$(jq -c 'unique | sort' <<<"$next_labels_json" 2>/dev/null)"; then + echo "[warn] Skipping ${kind} #${number}: cannot normalize next labels" >&2 + failed=$((failed + 1)) + continue + fi + + if [[ "$normalized_current" == "$normalized_next" ]]; then + unchanged=$((unchanged + 1)) + continue + fi + + if [[ "$DRY_RUN" -eq 1 ]]; then + echo "[dry-run] ${kind} #${number} @${author} merged=${merged_count} tier: '${current_tier:-none}' -> '${desired_tier:-none}'" + updated=$((updated + 1)) + continue + fi + + payload="$(jq -cn --argjson labels "$next_labels_json" '{labels: $labels}')" + if gh api -X PUT "repos/$REPO/issues/$number/labels" --input - <<<"$payload" >/dev/null; then + echo "[apply] Updated ${kind} #${number} @${author} tier: '${current_tier:-none}' -> '${desired_tier:-none}'" + updated=$((updated + 1)) + else + echo "[apply] FAILED ${kind} #${number}" >&2 + failed=$((failed + 1)) + fi +done < "$targets_file" + +echo "" +echo "[$SCRIPT_NAME] Summary" +echo " Targets: $target_count" +echo " Updated: $updated" +echo " Unchanged: $unchanged" +echo " Skipped: $skipped" +echo " Failed: $failed" + +if [[ "$failed" -gt 0 ]]; then + exit 1 +fi diff --git a/src/agent/agent.rs b/src/agent/agent.rs new file mode 100644 index 0000000..3e5693e --- /dev/null +++ b/src/agent/agent.rs @@ -0,0 +1,708 @@ +use crate::agent::dispatcher::{ + NativeToolDispatcher, ParsedToolCall, ToolDispatcher, ToolExecutionResult, XmlToolDispatcher, +}; +use crate::agent::memory_loader::{DefaultMemoryLoader, MemoryLoader}; +use crate::agent::prompt::{PromptContext, SystemPromptBuilder}; +use crate::config::Config; +use crate::memory::{self, Memory, MemoryCategory}; +use crate::observability::{self, Observer, ObserverEvent}; +use crate::providers::{self, ChatMessage, ChatRequest, ConversationMessage, Provider}; +use crate::runtime; +use crate::security::SecurityPolicy; +use crate::tools::{self, Tool, ToolSpec}; +use crate::util::truncate_with_ellipsis; +use anyhow::Result; +use std::io::Write as IoWrite; +use std::sync::Arc; +use std::time::Instant; + +pub struct Agent { + provider: Box, + tools: Vec>, + tool_specs: Vec, + memory: Arc, + observer: Arc, + prompt_builder: SystemPromptBuilder, + tool_dispatcher: Box, + memory_loader: Box, + config: crate::config::AgentConfig, + model_name: String, + temperature: f64, + workspace_dir: std::path::PathBuf, + identity_config: crate::config::IdentityConfig, + skills: Vec, + auto_save: bool, + history: Vec, +} + +pub struct AgentBuilder { + provider: Option>, + tools: Option>>, + memory: Option>, + observer: Option>, + prompt_builder: Option, + tool_dispatcher: Option>, + memory_loader: Option>, + config: Option, + model_name: Option, + temperature: Option, + workspace_dir: Option, + identity_config: Option, + skills: Option>, + auto_save: Option, +} + +impl AgentBuilder { + pub fn new() -> Self { + Self { + provider: None, + tools: None, + memory: None, + observer: None, + prompt_builder: None, + tool_dispatcher: None, + memory_loader: None, + config: None, + model_name: None, + temperature: None, + workspace_dir: None, + identity_config: None, + skills: None, + auto_save: None, + } + } + + pub fn provider(mut self, provider: Box) -> Self { + self.provider = Some(provider); + self + } + + pub fn tools(mut self, tools: Vec>) -> Self { + self.tools = Some(tools); + self + } + + pub fn memory(mut self, memory: Arc) -> Self { + self.memory = Some(memory); + self + } + + pub fn observer(mut self, observer: Arc) -> Self { + self.observer = Some(observer); + self + } + + pub fn prompt_builder(mut self, prompt_builder: SystemPromptBuilder) -> Self { + self.prompt_builder = Some(prompt_builder); + self + } + + pub fn tool_dispatcher(mut self, tool_dispatcher: Box) -> Self { + self.tool_dispatcher = Some(tool_dispatcher); + self + } + + pub fn memory_loader(mut self, memory_loader: Box) -> Self { + self.memory_loader = Some(memory_loader); + self + } + + pub fn config(mut self, config: crate::config::AgentConfig) -> Self { + self.config = Some(config); + self + } + + pub fn model_name(mut self, model_name: String) -> Self { + self.model_name = Some(model_name); + self + } + + pub fn temperature(mut self, temperature: f64) -> Self { + self.temperature = Some(temperature); + self + } + + pub fn workspace_dir(mut self, workspace_dir: std::path::PathBuf) -> Self { + self.workspace_dir = Some(workspace_dir); + self + } + + pub fn identity_config(mut self, identity_config: crate::config::IdentityConfig) -> Self { + self.identity_config = Some(identity_config); + self + } + + pub fn skills(mut self, skills: Vec) -> Self { + self.skills = Some(skills); + self + } + + pub fn auto_save(mut self, auto_save: bool) -> Self { + self.auto_save = Some(auto_save); + self + } + + pub fn build(self) -> Result { + let tools = self + .tools + .ok_or_else(|| anyhow::anyhow!("tools are required"))?; + let tool_specs = tools.iter().map(|tool| tool.spec()).collect(); + + Ok(Agent { + provider: self + .provider + .ok_or_else(|| anyhow::anyhow!("provider is required"))?, + tools, + tool_specs, + memory: self + .memory + .ok_or_else(|| anyhow::anyhow!("memory is required"))?, + observer: self + .observer + .ok_or_else(|| anyhow::anyhow!("observer is required"))?, + prompt_builder: self + .prompt_builder + .unwrap_or_else(SystemPromptBuilder::with_defaults), + tool_dispatcher: self + .tool_dispatcher + .ok_or_else(|| anyhow::anyhow!("tool_dispatcher is required"))?, + memory_loader: self + .memory_loader + .unwrap_or_else(|| Box::new(DefaultMemoryLoader::default())), + config: self.config.unwrap_or_default(), + model_name: self + .model_name + .unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into()), + temperature: self.temperature.unwrap_or(0.7), + workspace_dir: self + .workspace_dir + .unwrap_or_else(|| std::path::PathBuf::from(".")), + identity_config: self.identity_config.unwrap_or_default(), + skills: self.skills.unwrap_or_default(), + auto_save: self.auto_save.unwrap_or(false), + history: Vec::new(), + }) + } +} + +impl Agent { + pub fn builder() -> AgentBuilder { + AgentBuilder::new() + } + + pub fn history(&self) -> &[ConversationMessage] { + &self.history + } + + pub fn clear_history(&mut self) { + self.history.clear(); + } + + pub fn from_config(config: &Config) -> Result { + let observer: Arc = + Arc::from(observability::create_observer(&config.observability)); + let runtime: Arc = + Arc::from(runtime::create_runtime(&config.runtime)?); + let security = Arc::new(SecurityPolicy::from_config( + &config.autonomy, + &config.workspace_dir, + )); + + let memory: Arc = Arc::from(memory::create_memory( + &config.memory, + &config.workspace_dir, + config.api_key.as_deref(), + )?); + + let composio_key = if config.composio.enabled { + config.composio.api_key.as_deref() + } else { + None + }; + let composio_entity_id = if config.composio.enabled { + Some(config.composio.entity_id.as_str()) + } else { + None + }; + + let tools = tools::all_tools_with_runtime( + Arc::new(config.clone()), + &security, + runtime, + memory.clone(), + composio_key, + composio_entity_id, + &config.browser, + &config.http_request, + &config.workspace_dir, + &config.agents, + config.api_key.as_deref(), + config, + ); + + let provider_name = config.default_provider.as_deref().unwrap_or("openrouter"); + + let model_name = config + .default_model + .as_deref() + .unwrap_or("anthropic/claude-sonnet-4-20250514") + .to_string(); + + let provider: Box = providers::create_routed_provider( + provider_name, + config.api_key.as_deref(), + config.api_url.as_deref(), + &config.reliability, + &config.model_routes, + &model_name, + )?; + + let dispatcher_choice = config.agent.tool_dispatcher.as_str(); + let tool_dispatcher: Box = match dispatcher_choice { + "native" => Box::new(NativeToolDispatcher), + "xml" => Box::new(XmlToolDispatcher), + _ if provider.supports_native_tools() => Box::new(NativeToolDispatcher), + _ => Box::new(XmlToolDispatcher), + }; + + Agent::builder() + .provider(provider) + .tools(tools) + .memory(memory) + .observer(observer) + .tool_dispatcher(tool_dispatcher) + .memory_loader(Box::new(DefaultMemoryLoader::default())) + .prompt_builder(SystemPromptBuilder::with_defaults()) + .config(config.agent.clone()) + .model_name(model_name) + .temperature(config.default_temperature) + .workspace_dir(config.workspace_dir.clone()) + .identity_config(config.identity.clone()) + .skills(crate::skills::load_skills(&config.workspace_dir)) + .auto_save(config.memory.auto_save) + .build() + } + + fn trim_history(&mut self) { + let max = self.config.max_history_messages; + if self.history.len() <= max { + return; + } + + let mut system_messages = Vec::new(); + let mut other_messages = Vec::new(); + + for msg in self.history.drain(..) { + match &msg { + ConversationMessage::Chat(chat) if chat.role == "system" => { + system_messages.push(msg); + } + _ => other_messages.push(msg), + } + } + + if other_messages.len() > max { + let drop_count = other_messages.len() - max; + other_messages.drain(0..drop_count); + } + + self.history = system_messages; + self.history.extend(other_messages); + } + + fn build_system_prompt(&self) -> Result { + let instructions = self.tool_dispatcher.prompt_instructions(&self.tools); + let ctx = PromptContext { + workspace_dir: &self.workspace_dir, + model_name: &self.model_name, + tools: &self.tools, + skills: &self.skills, + identity_config: Some(&self.identity_config), + dispatcher_instructions: &instructions, + }; + self.prompt_builder.build(&ctx) + } + + async fn execute_tool_call(&self, call: &ParsedToolCall) -> ToolExecutionResult { + let start = Instant::now(); + + let result = if let Some(tool) = self.tools.iter().find(|t| t.name() == call.name) { + match tool.execute(call.arguments.clone()).await { + Ok(r) => { + self.observer.record_event(&ObserverEvent::ToolCall { + tool: call.name.clone(), + duration: start.elapsed(), + success: r.success, + }); + if r.success { + r.output + } else { + format!("Error: {}", r.error.unwrap_or(r.output)) + } + } + Err(e) => { + self.observer.record_event(&ObserverEvent::ToolCall { + tool: call.name.clone(), + duration: start.elapsed(), + success: false, + }); + format!("Error executing {}: {e}", call.name) + } + } + } else { + format!("Unknown tool: {}", call.name) + }; + + ToolExecutionResult { + name: call.name.clone(), + output: result, + success: true, + tool_call_id: call.tool_call_id.clone(), + } + } + + async fn execute_tools(&self, calls: &[ParsedToolCall]) -> Vec { + if !self.config.parallel_tools { + let mut results = Vec::with_capacity(calls.len()); + for call in calls { + results.push(self.execute_tool_call(call).await); + } + return results; + } + + let mut results = Vec::with_capacity(calls.len()); + for call in calls { + results.push(self.execute_tool_call(call).await); + } + results + } + + pub async fn turn(&mut self, user_message: &str) -> Result { + if self.history.is_empty() { + let system_prompt = self.build_system_prompt()?; + self.history + .push(ConversationMessage::Chat(ChatMessage::system( + system_prompt, + ))); + } + + if self.auto_save { + let _ = self + .memory + .store("user_msg", user_message, MemoryCategory::Conversation, None) + .await; + } + + let context = self + .memory_loader + .load_context(self.memory.as_ref(), user_message) + .await + .unwrap_or_default(); + + let enriched = if context.is_empty() { + user_message.to_string() + } else { + format!("{context}{user_message}") + }; + + self.history + .push(ConversationMessage::Chat(ChatMessage::user(enriched))); + + for _ in 0..self.config.max_tool_iterations { + let messages = self.tool_dispatcher.to_provider_messages(&self.history); + let response = match self + .provider + .chat( + ChatRequest { + messages: &messages, + tools: if self.tool_dispatcher.should_send_tool_specs() { + Some(&self.tool_specs) + } else { + None + }, + }, + &self.model_name, + self.temperature, + ) + .await + { + Ok(resp) => resp, + Err(err) => return Err(err), + }; + + let (text, calls) = self.tool_dispatcher.parse_response(&response); + if calls.is_empty() { + let final_text = if text.is_empty() { + response.text.unwrap_or_default() + } else { + text + }; + + self.history + .push(ConversationMessage::Chat(ChatMessage::assistant( + final_text.clone(), + ))); + self.trim_history(); + + if self.auto_save { + let summary = truncate_with_ellipsis(&final_text, 100); + let _ = self + .memory + .store("assistant_resp", &summary, MemoryCategory::Daily, None) + .await; + } + + return Ok(final_text); + } + + if !text.is_empty() { + self.history + .push(ConversationMessage::Chat(ChatMessage::assistant( + text.clone(), + ))); + print!("{text}"); + let _ = std::io::stdout().flush(); + } + + self.history.push(ConversationMessage::AssistantToolCalls { + text: response.text.clone(), + tool_calls: response.tool_calls.clone(), + }); + + let results = self.execute_tools(&calls).await; + let formatted = self.tool_dispatcher.format_results(&results); + self.history.push(formatted); + self.trim_history(); + } + + anyhow::bail!( + "Agent exceeded maximum tool iterations ({})", + self.config.max_tool_iterations + ) + } + + pub async fn run_single(&mut self, message: &str) -> Result { + self.turn(message).await + } + + pub async fn run_interactive(&mut self) -> Result<()> { + println!("🦀 ZeroClaw Interactive Mode"); + println!("Type /quit to exit.\n"); + + let (tx, mut rx) = tokio::sync::mpsc::channel(32); + let cli = crate::channels::CliChannel::new(); + + let listen_handle = tokio::spawn(async move { + let _ = crate::channels::Channel::listen(&cli, tx).await; + }); + + while let Some(msg) = rx.recv().await { + let response = match self.turn(&msg.content).await { + Ok(resp) => resp, + Err(e) => { + eprintln!("\nError: {e}\n"); + continue; + } + }; + println!("\n{response}\n"); + } + + listen_handle.abort(); + Ok(()) + } +} + +pub async fn run( + config: Config, + message: Option, + provider_override: Option, + model_override: Option, + temperature: f64, +) -> Result<()> { + let start = Instant::now(); + + let mut effective_config = config; + if let Some(p) = provider_override { + effective_config.default_provider = Some(p); + } + if let Some(m) = model_override { + effective_config.default_model = Some(m); + } + effective_config.default_temperature = temperature; + + let mut agent = Agent::from_config(&effective_config)?; + + let provider_name = effective_config + .default_provider + .as_deref() + .unwrap_or("openrouter") + .to_string(); + let model_name = effective_config + .default_model + .as_deref() + .unwrap_or("anthropic/claude-sonnet-4-20250514") + .to_string(); + + agent.observer.record_event(&ObserverEvent::AgentStart { + provider: provider_name, + model: model_name, + }); + + if let Some(msg) = message { + let response = agent.run_single(&msg).await?; + println!("{response}"); + } else { + agent.run_interactive().await?; + } + + agent.observer.record_event(&ObserverEvent::AgentEnd { + duration: start.elapsed(), + tokens_used: None, + cost_usd: None, + }); + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use async_trait::async_trait; + use parking_lot::Mutex; + + struct MockProvider { + responses: Mutex>, + } + + #[async_trait] + impl Provider for MockProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> Result { + Ok("ok".into()) + } + + async fn chat( + &self, + _request: ChatRequest<'_>, + _model: &str, + _temperature: f64, + ) -> Result { + let mut guard = self.responses.lock(); + if guard.is_empty() { + return Ok(crate::providers::ChatResponse { + text: Some("done".into()), + tool_calls: vec![], + }); + } + Ok(guard.remove(0)) + } + } + + struct MockTool; + + #[async_trait] + impl Tool for MockTool { + fn name(&self) -> &str { + "echo" + } + + fn description(&self) -> &str { + "echo" + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({"type": "object"}) + } + + async fn execute(&self, _args: serde_json::Value) -> Result { + Ok(crate::tools::ToolResult { + success: true, + output: "tool-out".into(), + error: None, + }) + } + } + + #[tokio::test] + async fn turn_without_tools_returns_text() { + let provider = Box::new(MockProvider { + responses: Mutex::new(vec![crate::providers::ChatResponse { + text: Some("hello".into()), + tool_calls: vec![], + }]), + }); + + let memory_cfg = crate::config::MemoryConfig { + backend: "none".into(), + ..crate::config::MemoryConfig::default() + }; + let mem: Arc = Arc::from( + crate::memory::create_memory(&memory_cfg, std::path::Path::new("/tmp"), None).unwrap(), + ); + + let observer: Arc = Arc::from(crate::observability::NoopObserver {}); + let mut agent = Agent::builder() + .provider(provider) + .tools(vec![Box::new(MockTool)]) + .memory(mem) + .observer(observer) + .tool_dispatcher(Box::new(XmlToolDispatcher)) + .workspace_dir(std::path::PathBuf::from("/tmp")) + .build() + .unwrap(); + + let response = agent.turn("hi").await.unwrap(); + assert_eq!(response, "hello"); + } + + #[tokio::test] + async fn turn_with_native_dispatcher_handles_tool_results_variant() { + let provider = Box::new(MockProvider { + responses: Mutex::new(vec![ + crate::providers::ChatResponse { + text: Some(String::new()), + tool_calls: vec![crate::providers::ToolCall { + id: "tc1".into(), + name: "echo".into(), + arguments: "{}".into(), + }], + }, + crate::providers::ChatResponse { + text: Some("done".into()), + tool_calls: vec![], + }, + ]), + }); + + let memory_cfg = crate::config::MemoryConfig { + backend: "none".into(), + ..crate::config::MemoryConfig::default() + }; + let mem: Arc = Arc::from( + crate::memory::create_memory(&memory_cfg, std::path::Path::new("/tmp"), None).unwrap(), + ); + + let observer: Arc = Arc::from(crate::observability::NoopObserver {}); + let mut agent = Agent::builder() + .provider(provider) + .tools(vec![Box::new(MockTool)]) + .memory(mem) + .observer(observer) + .tool_dispatcher(Box::new(NativeToolDispatcher)) + .workspace_dir(std::path::PathBuf::from("/tmp")) + .build() + .unwrap(); + + let response = agent.turn("hi").await.unwrap(); + assert_eq!(response, "done"); + assert!(agent + .history() + .iter() + .any(|msg| matches!(msg, ConversationMessage::ToolResults(_)))); + } +} diff --git a/src/agent/dispatcher.rs b/src/agent/dispatcher.rs new file mode 100644 index 0000000..673ec8c --- /dev/null +++ b/src/agent/dispatcher.rs @@ -0,0 +1,312 @@ +use crate::providers::{ChatMessage, ChatResponse, ConversationMessage, ToolResultMessage}; +use crate::tools::{Tool, ToolSpec}; +use serde_json::Value; +use std::fmt::Write; + +#[derive(Debug, Clone)] +pub struct ParsedToolCall { + pub name: String, + pub arguments: Value, + pub tool_call_id: Option, +} + +#[derive(Debug, Clone)] +pub struct ToolExecutionResult { + pub name: String, + pub output: String, + pub success: bool, + pub tool_call_id: Option, +} + +pub trait ToolDispatcher: Send + Sync { + fn parse_response(&self, response: &ChatResponse) -> (String, Vec); + fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage; + fn prompt_instructions(&self, tools: &[Box]) -> String; + fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec; + fn should_send_tool_specs(&self) -> bool; +} + +#[derive(Default)] +pub struct XmlToolDispatcher; + +impl XmlToolDispatcher { + fn parse_xml_tool_calls(response: &str) -> (String, Vec) { + let mut text_parts = Vec::new(); + let mut calls = Vec::new(); + let mut remaining = response; + + while let Some(start) = remaining.find("") { + let before = &remaining[..start]; + if !before.trim().is_empty() { + text_parts.push(before.trim().to_string()); + } + + if let Some(end) = remaining[start..].find("") { + let inner = &remaining[start + 11..start + end]; + match serde_json::from_str::(inner.trim()) { + Ok(parsed) => { + let name = parsed + .get("name") + .and_then(Value::as_str) + .unwrap_or("") + .to_string(); + if name.is_empty() { + remaining = &remaining[start + end + 12..]; + continue; + } + let arguments = parsed + .get("arguments") + .cloned() + .unwrap_or_else(|| Value::Object(serde_json::Map::new())); + calls.push(ParsedToolCall { + name, + arguments, + tool_call_id: None, + }); + } + Err(e) => { + tracing::warn!("Malformed JSON: {e}"); + } + } + remaining = &remaining[start + end + 12..]; + } else { + break; + } + } + + if !remaining.trim().is_empty() { + text_parts.push(remaining.trim().to_string()); + } + + (text_parts.join("\n"), calls) + } + + pub fn tool_specs(tools: &[Box]) -> Vec { + tools.iter().map(|tool| tool.spec()).collect() + } +} + +impl ToolDispatcher for XmlToolDispatcher { + fn parse_response(&self, response: &ChatResponse) -> (String, Vec) { + let text = response.text_or_empty(); + Self::parse_xml_tool_calls(text) + } + + fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage { + let mut content = String::new(); + for result in results { + let status = if result.success { "ok" } else { "error" }; + let _ = writeln!( + content, + "\n{}\n", + result.name, status, result.output + ); + } + ConversationMessage::Chat(ChatMessage::user(format!("[Tool results]\n{content}"))) + } + + fn prompt_instructions(&self, tools: &[Box]) -> String { + let mut instructions = String::new(); + instructions.push_str("## Tool Use Protocol\n\n"); + instructions + .push_str("To use a tool, wrap a JSON object in tags:\n\n"); + instructions.push_str( + "```\n\n{\"name\": \"tool_name\", \"arguments\": {\"param\": \"value\"}}\n\n```\n\n", + ); + instructions.push_str("### Available Tools\n\n"); + + for tool in tools { + let _ = writeln!( + instructions, + "- **{}**: {}\n Parameters: `{}`", + tool.name(), + tool.description(), + tool.parameters_schema() + ); + } + + instructions + } + + fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec { + history + .iter() + .flat_map(|msg| match msg { + ConversationMessage::Chat(chat) => vec![chat.clone()], + ConversationMessage::AssistantToolCalls { text, .. } => { + vec![ChatMessage::assistant(text.clone().unwrap_or_default())] + } + ConversationMessage::ToolResults(results) => { + let mut content = String::new(); + for result in results { + let _ = writeln!( + content, + "\n{}\n", + result.tool_call_id, result.content + ); + } + vec![ChatMessage::user(format!("[Tool results]\n{content}"))] + } + }) + .collect() + } + + fn should_send_tool_specs(&self) -> bool { + false + } +} + +pub struct NativeToolDispatcher; + +impl ToolDispatcher for NativeToolDispatcher { + fn parse_response(&self, response: &ChatResponse) -> (String, Vec) { + let text = response.text.clone().unwrap_or_default(); + let calls = response + .tool_calls + .iter() + .map(|tc| ParsedToolCall { + name: tc.name.clone(), + arguments: serde_json::from_str(&tc.arguments) + .unwrap_or_else(|_| Value::Object(serde_json::Map::new())), + tool_call_id: Some(tc.id.clone()), + }) + .collect(); + (text, calls) + } + + fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage { + let messages = results + .iter() + .map(|result| ToolResultMessage { + tool_call_id: result + .tool_call_id + .clone() + .unwrap_or_else(|| "unknown".to_string()), + content: result.output.clone(), + }) + .collect(); + ConversationMessage::ToolResults(messages) + } + + fn prompt_instructions(&self, _tools: &[Box]) -> String { + String::new() + } + + fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec { + history + .iter() + .flat_map(|msg| match msg { + ConversationMessage::Chat(chat) => vec![chat.clone()], + ConversationMessage::AssistantToolCalls { text, tool_calls } => { + let payload = serde_json::json!({ + "content": text, + "tool_calls": tool_calls, + }); + vec![ChatMessage::assistant(payload.to_string())] + } + ConversationMessage::ToolResults(results) => results + .iter() + .map(|result| { + ChatMessage::tool( + serde_json::json!({ + "tool_call_id": result.tool_call_id, + "content": result.content, + }) + .to_string(), + ) + }) + .collect(), + }) + .collect() + } + + fn should_send_tool_specs(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn xml_dispatcher_parses_tool_calls() { + let response = ChatResponse { + text: Some( + "Checking\n{\"name\":\"shell\",\"arguments\":{\"command\":\"ls\"}}" + .into(), + ), + tool_calls: vec![], + }; + let dispatcher = XmlToolDispatcher; + let (_, calls) = dispatcher.parse_response(&response); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "shell"); + } + + #[test] + fn native_dispatcher_roundtrip() { + let response = ChatResponse { + text: Some("ok".into()), + tool_calls: vec![crate::providers::ToolCall { + id: "tc1".into(), + name: "file_read".into(), + arguments: "{\"path\":\"a.txt\"}".into(), + }], + }; + let dispatcher = NativeToolDispatcher; + let (_, calls) = dispatcher.parse_response(&response); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].tool_call_id.as_deref(), Some("tc1")); + + let msg = dispatcher.format_results(&[ToolExecutionResult { + name: "file_read".into(), + output: "hello".into(), + success: true, + tool_call_id: Some("tc1".into()), + }]); + match msg { + ConversationMessage::ToolResults(results) => { + assert_eq!(results.len(), 1); + assert_eq!(results[0].tool_call_id, "tc1"); + } + _ => panic!("expected tool results"), + } + } + + #[test] + fn xml_format_results_contains_tool_result_tags() { + let dispatcher = XmlToolDispatcher; + let msg = dispatcher.format_results(&[ToolExecutionResult { + name: "shell".into(), + output: "ok".into(), + success: true, + tool_call_id: None, + }]); + let rendered = match msg { + ConversationMessage::Chat(chat) => chat.content, + _ => String::new(), + }; + assert!(rendered.contains(" { + assert_eq!(results.len(), 1); + assert_eq!(results[0].tool_call_id, "tc-1"); + } + _ => panic!("expected ToolResults variant"), + } + } +} diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index 0f611d7..b4d62a5 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -1,21 +1,208 @@ +use crate::approval::{ApprovalManager, ApprovalRequest, ApprovalResponse}; use crate::config::Config; use crate::memory::{self, Memory, MemoryCategory}; use crate::observability::{self, Observer, ObserverEvent}; -use crate::providers::{self, Provider}; +use crate::providers::{self, ChatMessage, Provider, ToolCall}; use crate::runtime; use crate::security::SecurityPolicy; -use crate::tools; +use crate::tools::{self, Tool}; +use crate::util::truncate_with_ellipsis; use anyhow::Result; +use regex::{Regex, RegexSet}; use std::fmt::Write; -use std::sync::Arc; +use std::io::Write as _; +use std::sync::{Arc, LazyLock}; use std::time::Instant; +use uuid::Uuid; + +/// Maximum agentic tool-use iterations per user message to prevent runaway loops. +const MAX_TOOL_ITERATIONS: usize = 10; + +static SENSITIVE_KEY_PATTERNS: LazyLock = LazyLock::new(|| { + RegexSet::new([ + r"(?i)token", + r"(?i)api[_-]?key", + r"(?i)password", + r"(?i)secret", + r"(?i)user[_-]?key", + r"(?i)bearer", + r"(?i)credential", + ]) + .unwrap() +}); + +static SENSITIVE_KV_REGEX: LazyLock = LazyLock::new(|| { + Regex::new(r#"(?i)(token|api[_-]?key|password|secret|user[_-]?key|bearer|credential)["']?\s*[:=]\s*(?:"([^"]{8,})"|'([^']{8,})'|([a-zA-Z0-9_\-\.]{8,}))"#).unwrap() +}); + +/// Scrub credentials from tool output to prevent accidental exfiltration. +/// Replaces known credential patterns with a redacted placeholder while preserving +/// a small prefix for context. +fn scrub_credentials(input: &str) -> String { + SENSITIVE_KV_REGEX + .replace_all(input, |caps: ®ex::Captures| { + let full_match = &caps[0]; + let key = &caps[1]; + let val = caps + .get(2) + .or(caps.get(3)) + .or(caps.get(4)) + .map(|m| m.as_str()) + .unwrap_or(""); + + // Preserve first 4 chars for context, then redact + let prefix = if val.len() > 4 { &val[..4] } else { "" }; + + if full_match.contains(':') { + if full_match.contains('"') { + format!("\"{}\": \"{}*[REDACTED]\"", key, prefix) + } else { + format!("{}: {}*[REDACTED]", key, prefix) + } + } else if full_match.contains('=') { + if full_match.contains('"') { + format!("{}=\"{}*[REDACTED]\"", key, prefix) + } else { + format!("{}={}*[REDACTED]", key, prefix) + } + } else { + format!("{}: {}*[REDACTED]", key, prefix) + } + }) + .to_string() +} + +/// Trigger auto-compaction when non-system message count exceeds this threshold. +const MAX_HISTORY_MESSAGES: usize = 50; + +/// Keep this many most-recent non-system messages after compaction. +const COMPACTION_KEEP_RECENT_MESSAGES: usize = 20; + +/// Safety cap for compaction source transcript passed to the summarizer. +const COMPACTION_MAX_SOURCE_CHARS: usize = 12_000; + +/// Max characters retained in stored compaction summary. +const COMPACTION_MAX_SUMMARY_CHARS: usize = 2_000; + +/// Convert a tool registry to OpenAI function-calling format for native tool support. +fn tools_to_openai_format(tools_registry: &[Box]) -> Vec { + tools_registry + .iter() + .map(|tool| { + serde_json::json!({ + "type": "function", + "function": { + "name": tool.name(), + "description": tool.description(), + "parameters": tool.parameters_schema() + } + }) + }) + .collect() +} + +fn autosave_memory_key(prefix: &str) -> String { + format!("{prefix}_{}", Uuid::new_v4()) +} + +/// Trim conversation history to prevent unbounded growth. +/// Preserves the system prompt (first message if role=system) and the most recent messages. +fn trim_history(history: &mut Vec) { + // Nothing to trim if within limit + let has_system = history.first().map_or(false, |m| m.role == "system"); + let non_system_count = if has_system { + history.len() - 1 + } else { + history.len() + }; + + if non_system_count <= MAX_HISTORY_MESSAGES { + return; + } + + let start = if has_system { 1 } else { 0 }; + let to_remove = non_system_count - MAX_HISTORY_MESSAGES; + history.drain(start..start + to_remove); +} + +fn build_compaction_transcript(messages: &[ChatMessage]) -> String { + let mut transcript = String::new(); + for msg in messages { + let role = msg.role.to_uppercase(); + let _ = writeln!(transcript, "{role}: {}", msg.content.trim()); + } + + if transcript.chars().count() > COMPACTION_MAX_SOURCE_CHARS { + truncate_with_ellipsis(&transcript, COMPACTION_MAX_SOURCE_CHARS) + } else { + transcript + } +} + +fn apply_compaction_summary( + history: &mut Vec, + start: usize, + compact_end: usize, + summary: &str, +) { + let summary_msg = ChatMessage::assistant(format!("[Compaction summary]\n{}", summary.trim())); + history.splice(start..compact_end, std::iter::once(summary_msg)); +} + +async fn auto_compact_history( + history: &mut Vec, + provider: &dyn Provider, + model: &str, +) -> Result { + let has_system = history.first().map_or(false, |m| m.role == "system"); + let non_system_count = if has_system { + history.len().saturating_sub(1) + } else { + history.len() + }; + + if non_system_count <= MAX_HISTORY_MESSAGES { + return Ok(false); + } + + let start = if has_system { 1 } else { 0 }; + let keep_recent = COMPACTION_KEEP_RECENT_MESSAGES.min(non_system_count); + let compact_count = non_system_count.saturating_sub(keep_recent); + if compact_count == 0 { + return Ok(false); + } + + let compact_end = start + compact_count; + let to_compact: Vec = history[start..compact_end].to_vec(); + let transcript = build_compaction_transcript(&to_compact); + + let summarizer_system = "You are a conversation compaction engine. Summarize older chat history into concise context for future turns. Preserve: user preferences, commitments, decisions, unresolved tasks, key facts. Omit: filler, repeated chit-chat, verbose tool logs. Output plain text bullet points only."; + + let summarizer_user = format!( + "Summarize the following conversation history for context preservation. Keep it short (max 12 bullet points).\n\n{}", + transcript + ); + + let summary_raw = provider + .chat_with_system(Some(summarizer_system), &summarizer_user, model, 0.2) + .await + .unwrap_or_else(|_| { + // Fallback to deterministic local truncation when summarization fails. + truncate_with_ellipsis(&transcript, COMPACTION_MAX_SUMMARY_CHARS) + }); + + let summary = truncate_with_ellipsis(&summary_raw, COMPACTION_MAX_SUMMARY_CHARS); + apply_compaction_summary(history, start, compact_end, &summary); + + Ok(true) +} /// Build context preamble by searching memory for relevant entries async fn build_context(mem: &dyn Memory, user_msg: &str) -> String { let mut context = String::new(); // Pull relevant memories for this message - if let Ok(entries) = mem.recall(user_msg, 5).await { + if let Ok(entries) = mem.recall(user_msg, 5, None).await { if !entries.is_empty() { context.push_str("[Memory context]\n"); for entry in &entries { @@ -28,6 +215,576 @@ async fn build_context(mem: &dyn Memory, user_msg: &str) -> String { context } +/// Build hardware datasheet context from RAG when peripherals are enabled. +/// Includes pin-alias lookup (e.g. "red_led" → 13) when query matches, plus retrieved chunks. +fn build_hardware_context( + rag: &crate::rag::HardwareRag, + user_msg: &str, + boards: &[String], + chunk_limit: usize, +) -> String { + if rag.is_empty() || boards.is_empty() { + return String::new(); + } + + let mut context = String::new(); + + // Pin aliases: when user says "red led", inject "red_led: 13" for matching boards + let pin_ctx = rag.pin_alias_context(user_msg, boards); + if !pin_ctx.is_empty() { + context.push_str(&pin_ctx); + } + + let chunks = rag.retrieve(user_msg, boards, chunk_limit); + if chunks.is_empty() && pin_ctx.is_empty() { + return String::new(); + } + + if !chunks.is_empty() { + context.push_str("[Hardware documentation]\n"); + } + for chunk in chunks { + let board_tag = chunk.board.as_deref().unwrap_or("generic"); + let _ = writeln!( + context, + "--- {} ({}) ---\n{}\n", + chunk.source, board_tag, chunk.content + ); + } + context.push('\n'); + context +} + +/// Find a tool by name in the registry. +fn find_tool<'a>(tools: &'a [Box], name: &str) -> Option<&'a dyn Tool> { + tools.iter().find(|t| t.name() == name).map(|t| t.as_ref()) +} + +fn parse_arguments_value(raw: Option<&serde_json::Value>) -> serde_json::Value { + match raw { + Some(serde_json::Value::String(s)) => serde_json::from_str::(s) + .unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new())), + Some(value) => value.clone(), + None => serde_json::Value::Object(serde_json::Map::new()), + } +} + +fn parse_tool_call_value(value: &serde_json::Value) -> Option { + if let Some(function) = value.get("function") { + let name = function + .get("name") + .and_then(|v| v.as_str()) + .unwrap_or("") + .trim() + .to_string(); + if !name.is_empty() { + let arguments = parse_arguments_value(function.get("arguments")); + return Some(ParsedToolCall { name, arguments }); + } + } + + let name = value + .get("name") + .and_then(|v| v.as_str()) + .unwrap_or("") + .trim() + .to_string(); + + if name.is_empty() { + return None; + } + + let arguments = parse_arguments_value(value.get("arguments")); + Some(ParsedToolCall { name, arguments }) +} + +fn parse_tool_calls_from_json_value(value: &serde_json::Value) -> Vec { + let mut calls = Vec::new(); + + if let Some(tool_calls) = value.get("tool_calls").and_then(|v| v.as_array()) { + for call in tool_calls { + if let Some(parsed) = parse_tool_call_value(call) { + calls.push(parsed); + } + } + + if !calls.is_empty() { + return calls; + } + } + + if let Some(array) = value.as_array() { + for item in array { + if let Some(parsed) = parse_tool_call_value(item) { + calls.push(parsed); + } + } + return calls; + } + + if let Some(parsed) = parse_tool_call_value(value) { + calls.push(parsed); + } + + calls +} + +const TOOL_CALL_OPEN_TAGS: [&str; 3] = ["", "", ""]; + +fn find_first_tag<'a>(haystack: &str, tags: &'a [&'a str]) -> Option<(usize, &'a str)> { + tags.iter() + .filter_map(|tag| haystack.find(tag).map(|idx| (idx, *tag))) + .min_by_key(|(idx, _)| *idx) +} + +fn matching_tool_call_close_tag(open_tag: &str) -> Option<&'static str> { + match open_tag { + "" => Some(""), + "" => Some(""), + "" => Some(""), + _ => None, + } +} + +/// Extract JSON values from a string. +/// +/// # Security Warning +/// +/// This function extracts ANY JSON objects/arrays from the input. It MUST only +/// be used on content that is already trusted to be from the LLM, such as +/// content inside `` tags where the LLM has explicitly indicated intent +/// to make a tool call. Do NOT use this on raw user input or content that +/// could contain prompt injection payloads. +fn extract_json_values(input: &str) -> Vec { + let mut values = Vec::new(); + let trimmed = input.trim(); + if trimmed.is_empty() { + return values; + } + + if let Ok(value) = serde_json::from_str::(trimmed) { + values.push(value); + return values; + } + + let char_positions: Vec<(usize, char)> = trimmed.char_indices().collect(); + let mut idx = 0; + while idx < char_positions.len() { + let (byte_idx, ch) = char_positions[idx]; + if ch == '{' || ch == '[' { + let slice = &trimmed[byte_idx..]; + let mut stream = + serde_json::Deserializer::from_str(slice).into_iter::(); + if let Some(Ok(value)) = stream.next() { + let consumed = stream.byte_offset(); + if consumed > 0 { + values.push(value); + let next_byte = byte_idx + consumed; + while idx < char_positions.len() && char_positions[idx].0 < next_byte { + idx += 1; + } + continue; + } + } + } + idx += 1; + } + + values +} + +/// Parse tool calls from an LLM response that uses XML-style function calling. +/// +/// Expected format (common with system-prompt-guided tool use): +/// ```text +/// +/// {"name": "shell", "arguments": {"command": "ls"}} +/// +/// ``` +/// +/// Also accepts common tag variants (``, ``) for model +/// compatibility. +/// +/// Also supports JSON with `tool_calls` array from OpenAI-format responses. +fn parse_tool_calls(response: &str) -> (String, Vec) { + let mut text_parts = Vec::new(); + let mut calls = Vec::new(); + let mut remaining = response; + + // First, try to parse as OpenAI-style JSON response with tool_calls array + // This handles providers like Minimax that return tool_calls in native JSON format + if let Ok(json_value) = serde_json::from_str::(response.trim()) { + calls = parse_tool_calls_from_json_value(&json_value); + if !calls.is_empty() { + // If we found tool_calls, extract any content field as text + if let Some(content) = json_value.get("content").and_then(|v| v.as_str()) { + if !content.trim().is_empty() { + text_parts.push(content.trim().to_string()); + } + } + return (text_parts.join("\n"), calls); + } + } + + // Fall back to XML-style tool-call tag parsing. + while let Some((start, open_tag)) = find_first_tag(remaining, &TOOL_CALL_OPEN_TAGS) { + // Everything before the tag is text + let before = &remaining[..start]; + if !before.trim().is_empty() { + text_parts.push(before.trim().to_string()); + } + + let Some(close_tag) = matching_tool_call_close_tag(open_tag) else { + break; + }; + + let after_open = &remaining[start + open_tag.len()..]; + if let Some(close_idx) = after_open.find(close_tag) { + let inner = &after_open[..close_idx]; + let mut parsed_any = false; + let json_values = extract_json_values(inner); + for value in json_values { + let parsed_calls = parse_tool_calls_from_json_value(&value); + if !parsed_calls.is_empty() { + parsed_any = true; + calls.extend(parsed_calls); + } + } + + if !parsed_any { + tracing::warn!("Malformed JSON: expected tool-call object in tag body"); + } + + remaining = &after_open[close_idx + close_tag.len()..]; + } else { + break; + } + } + + // SECURITY: We do NOT fall back to extracting arbitrary JSON from the response + // here. That would enable prompt injection attacks where malicious content + // (e.g., in emails, files, or web pages) could include JSON that mimics a + // tool call. Tool calls MUST be explicitly wrapped in either: + // 1. OpenAI-style JSON with a "tool_calls" array + // 2. ZeroClaw tool-call tags (, , ) + // This ensures only the LLM's intentional tool calls are executed. + + // Remaining text after last tool call + if !remaining.trim().is_empty() { + text_parts.push(remaining.trim().to_string()); + } + + (text_parts.join("\n"), calls) +} + +fn parse_structured_tool_calls(tool_calls: &[ToolCall]) -> Vec { + tool_calls + .iter() + .map(|call| ParsedToolCall { + name: call.name.clone(), + arguments: serde_json::from_str::(&call.arguments) + .unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new())), + }) + .collect() +} + +fn build_assistant_history_with_tool_calls(text: &str, tool_calls: &[ToolCall]) -> String { + let mut parts = Vec::new(); + + if !text.trim().is_empty() { + parts.push(text.trim().to_string()); + } + + for call in tool_calls { + let arguments = serde_json::from_str::(&call.arguments) + .unwrap_or_else(|_| serde_json::Value::String(call.arguments.clone())); + let payload = serde_json::json!({ + "id": call.id, + "name": call.name, + "arguments": arguments, + }); + parts.push(format!("\n{payload}\n")); + } + + parts.join("\n") +} + +#[derive(Debug)] +struct ParsedToolCall { + name: String, + arguments: serde_json::Value, +} + +/// Execute a single turn of the agent loop: send messages, parse tool calls, +/// execute tools, and loop until the LLM produces a final text response. +/// When `silent` is true, suppresses stdout (for channel use). +#[allow(clippy::too_many_arguments)] +pub(crate) async fn agent_turn( + provider: &dyn Provider, + history: &mut Vec, + tools_registry: &[Box], + observer: &dyn Observer, + provider_name: &str, + model: &str, + temperature: f64, + silent: bool, +) -> Result { + run_tool_call_loop( + provider, + history, + tools_registry, + observer, + provider_name, + model, + temperature, + silent, + None, + "channel", + ) + .await +} + +/// Execute a single turn of the agent loop: send messages, parse tool calls, +/// execute tools, and loop until the LLM produces a final text response. +#[allow(clippy::too_many_arguments)] +pub(crate) async fn run_tool_call_loop( + provider: &dyn Provider, + history: &mut Vec, + tools_registry: &[Box], + observer: &dyn Observer, + provider_name: &str, + model: &str, + temperature: f64, + silent: bool, + approval: Option<&ApprovalManager>, + channel_name: &str, +) -> Result { + // Build native tool definitions once if the provider supports them. + let use_native_tools = provider.supports_native_tools() && !tools_registry.is_empty(); + let tool_definitions = if use_native_tools { + tools_to_openai_format(tools_registry) + } else { + Vec::new() + }; + + for _iteration in 0..MAX_TOOL_ITERATIONS { + observer.record_event(&ObserverEvent::LlmRequest { + provider: provider_name.to_string(), + model: model.to_string(), + messages_count: history.len(), + }); + + let llm_started_at = Instant::now(); + + // Choose between native tool-call API and prompt-based tool use. + let (response_text, parsed_text, tool_calls, assistant_history_content) = + if use_native_tools { + match provider + .chat_with_tools(history, &tool_definitions, model, temperature) + .await + { + Ok(resp) => { + observer.record_event(&ObserverEvent::LlmResponse { + provider: provider_name.to_string(), + model: model.to_string(), + duration: llm_started_at.elapsed(), + success: true, + error_message: None, + }); + let response_text = resp.text_or_empty().to_string(); + let mut calls = parse_structured_tool_calls(&resp.tool_calls); + let mut parsed_text = String::new(); + + if calls.is_empty() { + let (fallback_text, fallback_calls) = parse_tool_calls(&response_text); + if !fallback_text.is_empty() { + parsed_text = fallback_text; + } + calls = fallback_calls; + } + + let assistant_history_content = if resp.tool_calls.is_empty() { + response_text.clone() + } else { + build_assistant_history_with_tool_calls( + &response_text, + &resp.tool_calls, + ) + }; + + (response_text, parsed_text, calls, assistant_history_content) + } + Err(e) => { + observer.record_event(&ObserverEvent::LlmResponse { + provider: provider_name.to_string(), + model: model.to_string(), + duration: llm_started_at.elapsed(), + success: false, + error_message: Some(crate::providers::sanitize_api_error( + &e.to_string(), + )), + }); + return Err(e); + } + } + } else { + match provider + .chat_with_history(history, model, temperature) + .await + { + Ok(resp) => { + observer.record_event(&ObserverEvent::LlmResponse { + provider: provider_name.to_string(), + model: model.to_string(), + duration: llm_started_at.elapsed(), + success: true, + error_message: None, + }); + let response_text = resp; + let assistant_history_content = response_text.clone(); + let (parsed_text, calls) = parse_tool_calls(&response_text); + (response_text, parsed_text, calls, assistant_history_content) + } + Err(e) => { + observer.record_event(&ObserverEvent::LlmResponse { + provider: provider_name.to_string(), + model: model.to_string(), + duration: llm_started_at.elapsed(), + success: false, + error_message: Some(crate::providers::sanitize_api_error( + &e.to_string(), + )), + }); + return Err(e); + } + } + }; + + let display_text = if parsed_text.is_empty() { + response_text.clone() + } else { + parsed_text + }; + + if tool_calls.is_empty() { + // No tool calls — this is the final response + history.push(ChatMessage::assistant(response_text.clone())); + return Ok(display_text); + } + + // Print any text the LLM produced alongside tool calls (unless silent) + if !silent && !display_text.is_empty() { + print!("{display_text}"); + let _ = std::io::stdout().flush(); + } + + // Execute each tool call and build results + let mut tool_results = String::new(); + for call in &tool_calls { + // ── Approval hook ──────────────────────────────── + if let Some(mgr) = approval { + if mgr.needs_approval(&call.name) { + let request = ApprovalRequest { + tool_name: call.name.clone(), + arguments: call.arguments.clone(), + }; + + // Only prompt interactively on CLI; auto-approve on other channels. + let decision = if channel_name == "cli" { + mgr.prompt_cli(&request) + } else { + ApprovalResponse::Yes + }; + + mgr.record_decision(&call.name, &call.arguments, decision, channel_name); + + if decision == ApprovalResponse::No { + let _ = writeln!( + tool_results, + "\nDenied by user.\n", + call.name + ); + continue; + } + } + } + + observer.record_event(&ObserverEvent::ToolCallStart { + tool: call.name.clone(), + }); + let start = Instant::now(); + let result = if let Some(tool) = find_tool(tools_registry, &call.name) { + match tool.execute(call.arguments.clone()).await { + Ok(r) => { + observer.record_event(&ObserverEvent::ToolCall { + tool: call.name.clone(), + duration: start.elapsed(), + success: r.success, + }); + if r.success { + scrub_credentials(&r.output) + } else { + format!("Error: {}", r.error.unwrap_or_else(|| r.output)) + } + } + Err(e) => { + observer.record_event(&ObserverEvent::ToolCall { + tool: call.name.clone(), + duration: start.elapsed(), + success: false, + }); + format!("Error executing {}: {e}", call.name) + } + } + } else { + format!("Unknown tool: {}", call.name) + }; + + let _ = writeln!( + tool_results, + "\n{}\n", + call.name, result + ); + } + + // Add assistant message with tool calls + tool results to history + history.push(ChatMessage::assistant(assistant_history_content)); + history.push(ChatMessage::user(format!("[Tool results]\n{tool_results}"))); + } + + anyhow::bail!("Agent exceeded maximum tool iterations ({MAX_TOOL_ITERATIONS})") +} + +/// Build the tool instruction block for the system prompt so the LLM knows +/// how to invoke tools. +pub(crate) fn build_tool_instructions(tools_registry: &[Box]) -> String { + let mut instructions = String::new(); + instructions.push_str("\n## Tool Use Protocol\n\n"); + instructions.push_str("To use a tool, wrap a JSON object in tags:\n\n"); + instructions.push_str("```\n\n{\"name\": \"tool_name\", \"arguments\": {\"param\": \"value\"}}\n\n```\n\n"); + instructions.push_str( + "CRITICAL: Output actual tags—never describe steps or give examples.\n\n", + ); + instructions.push_str("Example: User says \"what's the date?\". You MUST respond with:\n\n{\"name\":\"shell\",\"arguments\":{\"command\":\"date\"}}\n\n\n"); + instructions.push_str("You may use multiple tool calls in a single response. "); + instructions.push_str("After tool execution, results appear in tags. "); + instructions + .push_str("Continue reasoning with the results until you can give a final answer.\n\n"); + instructions.push_str("### Available Tools\n\n"); + + for tool in tools_registry { + let _ = writeln!( + instructions, + "**{}**: {}\nParameters: `{}`\n", + tool.name(), + tool.description(), + tool.parameters_schema() + ); + } + + instructions +} + #[allow(clippy::too_many_lines)] pub async fn run( config: Config, @@ -35,11 +792,13 @@ pub async fn run( provider_override: Option, model_override: Option, temperature: f64, -) -> Result<()> { + peripheral_overrides: Vec, +) -> Result { // ── Wire up agnostic subsystems ────────────────────────────── - let observer: Arc = - Arc::from(observability::create_observer(&config.observability)); - let _runtime = runtime::create_runtime(&config.runtime)?; + let base_observer = observability::create_observer(&config.observability); + let observer: Arc = Arc::from(base_observer); + let runtime: Arc = + Arc::from(runtime::create_runtime(&config.runtime)?); let security = Arc::new(SecurityPolicy::from_config( &config.autonomy, &config.workspace_dir, @@ -53,13 +812,44 @@ pub async fn run( )?); tracing::info!(backend = mem.name(), "Memory initialized"); - // ── Tools (including memory tools) ──────────────────────────── - let composio_key = if config.composio.enabled { - config.composio.api_key.as_deref() + // ── Peripherals (merge peripheral tools into registry) ─ + if !peripheral_overrides.is_empty() { + tracing::info!( + peripherals = ?peripheral_overrides, + "Peripheral overrides from CLI (config boards take precedence)" + ); + } + + // ── Tools (including memory tools and peripherals) ──────────── + let (composio_key, composio_entity_id) = if config.composio.enabled { + ( + config.composio.api_key.as_deref(), + Some(config.composio.entity_id.as_str()), + ) } else { - None + (None, None) }; - let _tools = tools::all_tools(&security, mem.clone(), composio_key, &config.browser); + let mut tools_registry = tools::all_tools_with_runtime( + Arc::new(config.clone()), + &security, + runtime, + mem.clone(), + composio_key, + composio_entity_id, + &config.browser, + &config.http_request, + &config.workspace_dir, + &config.agents, + config.api_key.as_deref(), + &config, + ); + + let peripheral_tools: Vec> = + crate::peripherals::create_peripheral_tools(&config.peripherals).await?; + if !peripheral_tools.is_empty() { + tracing::info!(count = peripheral_tools.len(), "Peripheral tools added"); + tools_registry.extend(peripheral_tools); + } // ── Resolve provider ───────────────────────────────────────── let provider_name = provider_override @@ -70,12 +860,15 @@ pub async fn run( let model_name = model_override .as_deref() .or(config.default_model.as_deref()) - .unwrap_or("anthropic/claude-sonnet-4-20250514"); + .unwrap_or("anthropic/claude-sonnet-4"); - let provider: Box = providers::create_resilient_provider( + let provider: Box = providers::create_routed_provider( provider_name, config.api_key.as_deref(), + config.api_url.as_deref(), &config.reliability, + &config.model_routes, + model_name, )?; observer.record_event(&ObserverEvent::AgentStart { @@ -83,6 +876,26 @@ pub async fn run( model: model_name.to_string(), }); + // ── Hardware RAG (datasheet retrieval when peripherals + datasheet_dir) ── + let hardware_rag: Option = config + .peripherals + .datasheet_dir + .as_ref() + .filter(|d| !d.trim().is_empty()) + .map(|dir| crate::rag::HardwareRag::load(&config.workspace_dir, dir.trim())) + .and_then(Result::ok) + .filter(|r: &crate::rag::HardwareRag| !r.is_empty()); + if let Some(ref rag) = hardware_rag { + tracing::info!(chunks = rag.len(), "Hardware RAG loaded"); + } + + let board_names: Vec = config + .peripherals + .boards + .iter() + .map(|b| b.board.clone()) + .collect(); + // ── Build system prompt from workspace MD files (OpenClaw framework) ── let skills = crate::skills::load_skills(&config.workspace_dir); let mut tool_descs: Vec<(&str, &str)> = vec![ @@ -111,107 +924,1003 @@ pub async fn run( "Delete a memory entry. Use when: memory is incorrect/stale or explicitly requested for removal. Don't use when: impact is uncertain.", ), ]; + tool_descs.push(( + "cron_add", + "Create a cron job. Supports schedule kinds: cron, at, every; and job types: shell or agent.", + )); + tool_descs.push(( + "cron_list", + "List all cron jobs with schedule, status, and metadata.", + )); + tool_descs.push(("cron_remove", "Remove a cron job by job_id.")); + tool_descs.push(( + "cron_update", + "Patch a cron job (schedule, enabled, command/prompt, model, delivery, session_target).", + )); + tool_descs.push(( + "cron_run", + "Force-run a cron job immediately and record a run history entry.", + )); + tool_descs.push(("cron_runs", "Show recent run history for a cron job.")); + tool_descs.push(( + "screenshot", + "Capture a screenshot of the current screen. Returns file path and base64-encoded PNG. Use when: visual verification, UI inspection, debugging displays.", + )); + tool_descs.push(( + "image_info", + "Read image file metadata (format, dimensions, size) and optionally base64-encode it. Use when: inspecting images, preparing visual data for analysis.", + )); if config.browser.enabled { tool_descs.push(( "browser_open", "Open approved HTTPS URLs in Brave Browser (allowlist-only, no scraping)", )); } - let system_prompt = crate::channels::build_system_prompt( + if config.composio.enabled { + tool_descs.push(( + "composio", + "Execute actions on 1000+ apps via Composio (Gmail, Notion, GitHub, Slack, etc.). Use action='list' to discover, 'execute' to run (optionally with connected_account_id), 'connect' to OAuth.", + )); + } + tool_descs.push(( + "schedule", + "Manage scheduled tasks (create/list/get/cancel/pause/resume). Supports recurring cron and one-shot delays.", + )); + if !config.agents.is_empty() { + tool_descs.push(( + "delegate", + "Delegate a sub-task to a specialized agent. Use when: task needs different model/capability, or to parallelize work.", + )); + } + if config.peripherals.enabled && !config.peripherals.boards.is_empty() { + tool_descs.push(( + "gpio_read", + "Read GPIO pin value (0 or 1) on connected hardware (STM32, Arduino). Use when: checking sensor/button state, LED status.", + )); + tool_descs.push(( + "gpio_write", + "Set GPIO pin high (1) or low (0) on connected hardware. Use when: turning LED on/off, controlling actuators.", + )); + tool_descs.push(( + "arduino_upload", + "Upload agent-generated Arduino sketch. Use when: user asks for 'make a heart', 'blink pattern', or custom LED behavior on Arduino. You write the full .ino code; ZeroClaw compiles and uploads it. Pin 13 = built-in LED on Uno.", + )); + tool_descs.push(( + "hardware_memory_map", + "Return flash and RAM address ranges for connected hardware. Use when: user asks for 'upper and lower memory addresses', 'memory map', or 'readable addresses'.", + )); + tool_descs.push(( + "hardware_board_info", + "Return full board info (chip, architecture, memory map) for connected hardware. Use when: user asks for 'board info', 'what board do I have', 'connected hardware', 'chip info', or 'what hardware'.", + )); + tool_descs.push(( + "hardware_memory_read", + "Read actual memory/register values from Nucleo via USB. Use when: user asks to 'read register values', 'read memory', 'dump lower memory 0-126', 'give address and value'. Params: address (hex, default 0x20000000), length (bytes, default 128).", + )); + tool_descs.push(( + "hardware_capabilities", + "Query connected hardware for reported GPIO pins and LED pin. Use when: user asks what pins are available.", + )); + } + let bootstrap_max_chars = if config.agent.compact_context { + Some(6000) + } else { + None + }; + let mut system_prompt = crate::channels::build_system_prompt( &config.workspace_dir, model_name, &tool_descs, &skills, + Some(&config.identity), + bootstrap_max_chars, ); + // Append structured tool-use instructions with schemas + system_prompt.push_str(&build_tool_instructions(&tools_registry)); + + // ── Approval manager (supervised mode) ─────────────────────── + let approval_manager = ApprovalManager::from_config(&config.autonomy); + // ── Execute ────────────────────────────────────────────────── let start = Instant::now(); + let mut final_output = String::new(); + if let Some(msg) = message { // Auto-save user message to memory if config.memory.auto_save { + let user_key = autosave_memory_key("user_msg"); let _ = mem - .store("user_msg", &msg, MemoryCategory::Conversation) + .store(&user_key, &msg, MemoryCategory::Conversation, None) .await; } - // Inject memory context into user message - let context = build_context(mem.as_ref(), &msg).await; + // Inject memory + hardware RAG context into user message + let mem_context = build_context(mem.as_ref(), &msg).await; + let rag_limit = if config.agent.compact_context { 2 } else { 5 }; + let hw_context = hardware_rag + .as_ref() + .map(|r| build_hardware_context(r, &msg, &board_names, rag_limit)) + .unwrap_or_default(); + let context = format!("{mem_context}{hw_context}"); let enriched = if context.is_empty() { msg.clone() } else { format!("{context}{msg}") }; - let response = provider - .chat_with_system(Some(&system_prompt), &enriched, model_name, temperature) - .await?; + let mut history = vec![ + ChatMessage::system(&system_prompt), + ChatMessage::user(&enriched), + ]; + + let response = run_tool_call_loop( + provider.as_ref(), + &mut history, + &tools_registry, + observer.as_ref(), + provider_name, + model_name, + temperature, + false, + Some(&approval_manager), + "cli", + ) + .await?; + final_output = response.clone(); println!("{response}"); + observer.record_event(&ObserverEvent::TurnComplete); // Auto-save assistant response to daily log if config.memory.auto_save { - let summary = if response.len() > 100 { - format!("{}...", &response[..100]) - } else { - response.clone() - }; + let summary = truncate_with_ellipsis(&response, 100); + let response_key = autosave_memory_key("assistant_resp"); let _ = mem - .store("assistant_resp", &summary, MemoryCategory::Daily) + .store(&response_key, &summary, MemoryCategory::Daily, None) .await; } } else { println!("🦀 ZeroClaw Interactive Mode"); println!("Type /quit to exit.\n"); - - let (tx, mut rx) = tokio::sync::mpsc::channel(32); let cli = crate::channels::CliChannel::new(); - // Spawn listener - let listen_handle = tokio::spawn(async move { - let _ = crate::channels::Channel::listen(&cli, tx).await; - }); + // Persistent conversation history across turns + let mut history = vec![ChatMessage::system(&system_prompt)]; + + loop { + print!("> "); + let _ = std::io::stdout().flush(); + + let mut input = String::new(); + match std::io::stdin().read_line(&mut input) { + Ok(0) => break, + Ok(_) => {} + Err(e) => { + eprintln!("\nError reading input: {e}\n"); + break; + } + } + + let user_input = input.trim().to_string(); + if user_input.is_empty() { + continue; + } + if user_input == "/quit" || user_input == "/exit" { + break; + } - while let Some(msg) = rx.recv().await { // Auto-save conversation turns if config.memory.auto_save { + let user_key = autosave_memory_key("user_msg"); let _ = mem - .store("user_msg", &msg.content, MemoryCategory::Conversation) + .store(&user_key, &user_input, MemoryCategory::Conversation, None) .await; } - // Inject memory context into user message - let context = build_context(mem.as_ref(), &msg.content).await; + // Inject memory + hardware RAG context into user message + let mem_context = build_context(mem.as_ref(), &user_input).await; + let rag_limit = if config.agent.compact_context { 2 } else { 5 }; + let hw_context = hardware_rag + .as_ref() + .map(|r| build_hardware_context(r, &user_input, &board_names, rag_limit)) + .unwrap_or_default(); + let context = format!("{mem_context}{hw_context}"); let enriched = if context.is_empty() { - msg.content.clone() + user_input.clone() } else { - format!("{context}{}", msg.content) + format!("{context}{user_input}") }; - let response = provider - .chat_with_system(Some(&system_prompt), &enriched, model_name, temperature) - .await?; - println!("\n{response}\n"); + history.push(ChatMessage::user(&enriched)); + + let response = match run_tool_call_loop( + provider.as_ref(), + &mut history, + &tools_registry, + observer.as_ref(), + provider_name, + model_name, + temperature, + false, + Some(&approval_manager), + "cli", + ) + .await + { + Ok(resp) => resp, + Err(e) => { + eprintln!("\nError: {e}\n"); + continue; + } + }; + final_output = response.clone(); + if let Err(e) = crate::channels::Channel::send( + &cli, + &crate::channels::traits::SendMessage::new(format!("\n{response}\n"), "user"), + ) + .await + { + eprintln!("\nError sending CLI response: {e}\n"); + } + observer.record_event(&ObserverEvent::TurnComplete); + + // Auto-compaction before hard trimming to preserve long-context signal. + if let Ok(compacted) = + auto_compact_history(&mut history, provider.as_ref(), model_name).await + { + if compacted { + println!("🧹 Auto-compaction complete"); + } + } + + // Hard cap as a safety net. + trim_history(&mut history); if config.memory.auto_save { - let summary = if response.len() > 100 { - format!("{}...", &response[..100]) - } else { - response.clone() - }; + let summary = truncate_with_ellipsis(&response, 100); + let response_key = autosave_memory_key("assistant_resp"); let _ = mem - .store("assistant_resp", &summary, MemoryCategory::Daily) + .store(&response_key, &summary, MemoryCategory::Daily, None) .await; } } - - listen_handle.abort(); } let duration = start.elapsed(); observer.record_event(&ObserverEvent::AgentEnd { duration, tokens_used: None, + cost_usd: None, }); - Ok(()) + Ok(final_output) +} + +/// Process a single message through the full agent (with tools, peripherals, memory). +/// Used by channels (Telegram, Discord, etc.) to enable hardware and tool use. +pub async fn process_message(config: Config, message: &str) -> Result { + let observer: Arc = + Arc::from(observability::create_observer(&config.observability)); + let runtime: Arc = + Arc::from(runtime::create_runtime(&config.runtime)?); + let security = Arc::new(SecurityPolicy::from_config( + &config.autonomy, + &config.workspace_dir, + )); + let mem: Arc = Arc::from(memory::create_memory( + &config.memory, + &config.workspace_dir, + config.api_key.as_deref(), + )?); + + let (composio_key, composio_entity_id) = if config.composio.enabled { + ( + config.composio.api_key.as_deref(), + Some(config.composio.entity_id.as_str()), + ) + } else { + (None, None) + }; + let mut tools_registry = tools::all_tools_with_runtime( + Arc::new(config.clone()), + &security, + runtime, + mem.clone(), + composio_key, + composio_entity_id, + &config.browser, + &config.http_request, + &config.workspace_dir, + &config.agents, + config.api_key.as_deref(), + &config, + ); + let peripheral_tools: Vec> = + crate::peripherals::create_peripheral_tools(&config.peripherals).await?; + tools_registry.extend(peripheral_tools); + + let provider_name = config.default_provider.as_deref().unwrap_or("openrouter"); + let model_name = config + .default_model + .clone() + .unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into()); + let provider: Box = providers::create_routed_provider( + provider_name, + config.api_key.as_deref(), + config.api_url.as_deref(), + &config.reliability, + &config.model_routes, + &model_name, + )?; + + let hardware_rag: Option = config + .peripherals + .datasheet_dir + .as_ref() + .filter(|d| !d.trim().is_empty()) + .map(|dir| crate::rag::HardwareRag::load(&config.workspace_dir, dir.trim())) + .and_then(Result::ok) + .filter(|r: &crate::rag::HardwareRag| !r.is_empty()); + let board_names: Vec = config + .peripherals + .boards + .iter() + .map(|b| b.board.clone()) + .collect(); + + let skills = crate::skills::load_skills(&config.workspace_dir); + let mut tool_descs: Vec<(&str, &str)> = vec![ + ("shell", "Execute terminal commands."), + ("file_read", "Read file contents."), + ("file_write", "Write file contents."), + ("memory_store", "Save to memory."), + ("memory_recall", "Search memory."), + ("memory_forget", "Delete a memory entry."), + ("screenshot", "Capture a screenshot."), + ("image_info", "Read image metadata."), + ]; + if config.browser.enabled { + tool_descs.push(("browser_open", "Open approved URLs in browser.")); + } + if config.composio.enabled { + tool_descs.push(("composio", "Execute actions on 1000+ apps via Composio.")); + } + if config.peripherals.enabled && !config.peripherals.boards.is_empty() { + tool_descs.push(("gpio_read", "Read GPIO pin value on connected hardware.")); + tool_descs.push(( + "gpio_write", + "Set GPIO pin high or low on connected hardware.", + )); + tool_descs.push(( + "arduino_upload", + "Upload Arduino sketch. Use for 'make a heart', custom patterns. You write full .ino code; ZeroClaw uploads it.", + )); + tool_descs.push(( + "hardware_memory_map", + "Return flash and RAM address ranges. Use when user asks for memory addresses or memory map.", + )); + tool_descs.push(( + "hardware_board_info", + "Return full board info (chip, architecture, memory map). Use when user asks for board info, what board, connected hardware, or chip info.", + )); + tool_descs.push(( + "hardware_memory_read", + "Read actual memory/register values from Nucleo. Use when user asks to read registers, read memory, dump lower memory 0-126, or give address and value.", + )); + tool_descs.push(( + "hardware_capabilities", + "Query connected hardware for reported GPIO pins and LED pin. Use when user asks what pins are available.", + )); + } + let bootstrap_max_chars = if config.agent.compact_context { + Some(6000) + } else { + None + }; + let mut system_prompt = crate::channels::build_system_prompt( + &config.workspace_dir, + &model_name, + &tool_descs, + &skills, + Some(&config.identity), + bootstrap_max_chars, + ); + system_prompt.push_str(&build_tool_instructions(&tools_registry)); + + let mem_context = build_context(mem.as_ref(), message).await; + let rag_limit = if config.agent.compact_context { 2 } else { 5 }; + let hw_context = hardware_rag + .as_ref() + .map(|r| build_hardware_context(r, message, &board_names, rag_limit)) + .unwrap_or_default(); + let context = format!("{mem_context}{hw_context}"); + let enriched = if context.is_empty() { + message.to_string() + } else { + format!("{context}{message}") + }; + + let mut history = vec![ + ChatMessage::system(&system_prompt), + ChatMessage::user(&enriched), + ]; + + agent_turn( + provider.as_ref(), + &mut history, + &tools_registry, + observer.as_ref(), + provider_name, + &model_name, + config.default_temperature, + true, + ) + .await +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_scrub_credentials() { + let input = "API_KEY=sk-1234567890abcdef; token: 1234567890; password=\"secret123456\""; + let scrubbed = scrub_credentials(input); + assert!(scrubbed.contains("API_KEY=sk-1*[REDACTED]")); + assert!(scrubbed.contains("token: 1234*[REDACTED]")); + assert!(scrubbed.contains("password=\"secr*[REDACTED]\"")); + assert!(!scrubbed.contains("abcdef")); + assert!(!scrubbed.contains("secret123456")); + } + + #[test] + fn test_scrub_credentials_json() { + let input = r#"{"api_key": "sk-1234567890", "other": "public"}"#; + let scrubbed = scrub_credentials(input); + assert!(scrubbed.contains("\"api_key\": \"sk-1*[REDACTED]\"")); + assert!(scrubbed.contains("public")); + } + use crate::memory::{Memory, MemoryCategory, SqliteMemory}; + use tempfile::TempDir; + + #[test] + fn parse_tool_calls_extracts_single_call() { + let response = r#"Let me check that. + +{"name": "shell", "arguments": {"command": "ls -la"}} +"#; + + let (text, calls) = parse_tool_calls(response); + assert_eq!(text, "Let me check that."); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "shell"); + assert_eq!( + calls[0].arguments.get("command").unwrap().as_str().unwrap(), + "ls -la" + ); + } + + #[test] + fn parse_tool_calls_extracts_multiple_calls() { + let response = r#" +{"name": "file_read", "arguments": {"path": "a.txt"}} + + +{"name": "file_read", "arguments": {"path": "b.txt"}} +"#; + + let (_, calls) = parse_tool_calls(response); + assert_eq!(calls.len(), 2); + assert_eq!(calls[0].name, "file_read"); + assert_eq!(calls[1].name, "file_read"); + } + + #[test] + fn parse_tool_calls_returns_text_only_when_no_calls() { + let response = "Just a normal response with no tools."; + let (text, calls) = parse_tool_calls(response); + assert_eq!(text, "Just a normal response with no tools."); + assert!(calls.is_empty()); + } + + #[test] + fn parse_tool_calls_handles_malformed_json() { + let response = r#" +not valid json + +Some text after."#; + + let (text, calls) = parse_tool_calls(response); + assert!(calls.is_empty()); + assert!(text.contains("Some text after.")); + } + + #[test] + fn parse_tool_calls_text_before_and_after() { + let response = r#"Before text. + +{"name": "shell", "arguments": {"command": "echo hi"}} + +After text."#; + + let (text, calls) = parse_tool_calls(response); + assert!(text.contains("Before text.")); + assert!(text.contains("After text.")); + assert_eq!(calls.len(), 1); + } + + #[test] + fn parse_tool_calls_handles_openai_format() { + // OpenAI-style response with tool_calls array + let response = r#"{"content": "Let me check that for you.", "tool_calls": [{"type": "function", "function": {"name": "shell", "arguments": "{\"command\": \"ls -la\"}"}}]}"#; + + let (text, calls) = parse_tool_calls(response); + assert_eq!(text, "Let me check that for you."); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "shell"); + assert_eq!( + calls[0].arguments.get("command").unwrap().as_str().unwrap(), + "ls -la" + ); + } + + #[test] + fn parse_tool_calls_handles_openai_format_multiple_calls() { + let response = r#"{"tool_calls": [{"type": "function", "function": {"name": "file_read", "arguments": "{\"path\": \"a.txt\"}"}}, {"type": "function", "function": {"name": "file_read", "arguments": "{\"path\": \"b.txt\"}"}}]}"#; + + let (_, calls) = parse_tool_calls(response); + assert_eq!(calls.len(), 2); + assert_eq!(calls[0].name, "file_read"); + assert_eq!(calls[1].name, "file_read"); + } + + #[test] + fn parse_tool_calls_openai_format_without_content() { + // Some providers don't include content field with tool_calls + let response = r#"{"tool_calls": [{"type": "function", "function": {"name": "memory_recall", "arguments": "{}"}}]}"#; + + let (text, calls) = parse_tool_calls(response); + assert!(text.is_empty()); // No content field + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "memory_recall"); + } + + #[test] + fn parse_tool_calls_handles_markdown_json_inside_tool_call_tag() { + let response = r#" +```json +{"name": "file_write", "arguments": {"path": "test.py", "content": "print('ok')"}} +``` +"#; + + let (text, calls) = parse_tool_calls(response); + assert!(text.is_empty()); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "file_write"); + assert_eq!( + calls[0].arguments.get("path").unwrap().as_str().unwrap(), + "test.py" + ); + } + + #[test] + fn parse_tool_calls_handles_noisy_tool_call_tag_body() { + let response = r#" +I will now call the tool with this payload: +{"name": "shell", "arguments": {"command": "pwd"}} +"#; + + let (text, calls) = parse_tool_calls(response); + assert!(text.is_empty()); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "shell"); + assert_eq!( + calls[0].arguments.get("command").unwrap().as_str().unwrap(), + "pwd" + ); + } + + #[test] + fn parse_tool_calls_handles_toolcall_tag_alias() { + let response = r#" +{"name": "shell", "arguments": {"command": "date"}} +"#; + + let (text, calls) = parse_tool_calls(response); + assert!(text.is_empty()); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "shell"); + assert_eq!( + calls[0].arguments.get("command").unwrap().as_str().unwrap(), + "date" + ); + } + + #[test] + fn parse_tool_calls_handles_tool_dash_call_tag_alias() { + let response = r#" +{"name": "shell", "arguments": {"command": "whoami"}} +"#; + + let (text, calls) = parse_tool_calls(response); + assert!(text.is_empty()); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "shell"); + assert_eq!( + calls[0].arguments.get("command").unwrap().as_str().unwrap(), + "whoami" + ); + } + + #[test] + fn parse_tool_calls_does_not_cross_match_alias_tags() { + let response = r#" +{"name": "shell", "arguments": {"command": "date"}} +"#; + + let (text, calls) = parse_tool_calls(response); + assert!(calls.is_empty()); + assert!(text.contains("")); + assert!(text.contains("")); + } + + #[test] + fn parse_tool_calls_rejects_raw_tool_json_without_tags() { + // SECURITY: Raw JSON without explicit wrappers should NOT be parsed + // This prevents prompt injection attacks where malicious content + // could include JSON that mimics a tool call. + let response = r#"Sure, creating the file now. +{"name": "file_write", "arguments": {"path": "hello.py", "content": "print('hello')"}}"#; + + let (text, calls) = parse_tool_calls(response); + assert!(text.contains("Sure, creating the file now.")); + assert_eq!( + calls.len(), + 0, + "Raw JSON without wrappers should not be parsed" + ); + } + + #[test] + fn build_tool_instructions_includes_all_tools() { + use crate::security::SecurityPolicy; + let security = Arc::new(SecurityPolicy::from_config( + &crate::config::AutonomyConfig::default(), + std::path::Path::new("/tmp"), + )); + let tools = tools::default_tools(security); + let instructions = build_tool_instructions(&tools); + + assert!(instructions.contains("## Tool Use Protocol")); + assert!(instructions.contains("")); + assert!(instructions.contains("shell")); + assert!(instructions.contains("file_read")); + assert!(instructions.contains("file_write")); + } + + #[test] + fn tools_to_openai_format_produces_valid_schema() { + use crate::security::SecurityPolicy; + let security = Arc::new(SecurityPolicy::from_config( + &crate::config::AutonomyConfig::default(), + std::path::Path::new("/tmp"), + )); + let tools = tools::default_tools(security); + let formatted = tools_to_openai_format(&tools); + + assert!(!formatted.is_empty()); + for tool_json in &formatted { + assert_eq!(tool_json["type"], "function"); + assert!(tool_json["function"]["name"].is_string()); + assert!(tool_json["function"]["description"].is_string()); + assert!(!tool_json["function"]["name"].as_str().unwrap().is_empty()); + } + // Verify known tools are present + let names: Vec<&str> = formatted + .iter() + .filter_map(|t| t["function"]["name"].as_str()) + .collect(); + assert!(names.contains(&"shell")); + assert!(names.contains(&"file_read")); + } + + #[test] + fn trim_history_preserves_system_prompt() { + let mut history = vec![ChatMessage::system("system prompt")]; + for i in 0..MAX_HISTORY_MESSAGES + 20 { + history.push(ChatMessage::user(format!("msg {i}"))); + } + let original_len = history.len(); + assert!(original_len > MAX_HISTORY_MESSAGES + 1); + + trim_history(&mut history); + + // System prompt preserved + assert_eq!(history[0].role, "system"); + assert_eq!(history[0].content, "system prompt"); + // Trimmed to limit + assert_eq!(history.len(), MAX_HISTORY_MESSAGES + 1); // +1 for system + // Most recent messages preserved + let last = &history[history.len() - 1]; + assert_eq!(last.content, format!("msg {}", MAX_HISTORY_MESSAGES + 19)); + } + + #[test] + fn trim_history_noop_when_within_limit() { + let mut history = vec![ + ChatMessage::system("sys"), + ChatMessage::user("hello"), + ChatMessage::assistant("hi"), + ]; + trim_history(&mut history); + assert_eq!(history.len(), 3); + } + + #[test] + fn build_compaction_transcript_formats_roles() { + let messages = vec![ + ChatMessage::user("I like dark mode"), + ChatMessage::assistant("Got it"), + ]; + let transcript = build_compaction_transcript(&messages); + assert!(transcript.contains("USER: I like dark mode")); + assert!(transcript.contains("ASSISTANT: Got it")); + } + + #[test] + fn apply_compaction_summary_replaces_old_segment() { + let mut history = vec![ + ChatMessage::system("sys"), + ChatMessage::user("old 1"), + ChatMessage::assistant("old 2"), + ChatMessage::user("recent 1"), + ChatMessage::assistant("recent 2"), + ]; + + apply_compaction_summary(&mut history, 1, 3, "- user prefers concise replies"); + + assert_eq!(history.len(), 4); + assert!(history[1].content.contains("Compaction summary")); + assert!(history[2].content.contains("recent 1")); + assert!(history[3].content.contains("recent 2")); + } + + #[test] + fn autosave_memory_key_has_prefix_and_uniqueness() { + let key1 = autosave_memory_key("user_msg"); + let key2 = autosave_memory_key("user_msg"); + + assert!(key1.starts_with("user_msg_")); + assert!(key2.starts_with("user_msg_")); + assert_ne!(key1, key2); + } + + #[tokio::test] + async fn autosave_memory_keys_preserve_multiple_turns() { + let tmp = TempDir::new().unwrap(); + let mem = SqliteMemory::new(tmp.path()).unwrap(); + + let key1 = autosave_memory_key("user_msg"); + let key2 = autosave_memory_key("user_msg"); + + mem.store(&key1, "I'm Paul", MemoryCategory::Conversation, None) + .await + .unwrap(); + mem.store(&key2, "I'm 45", MemoryCategory::Conversation, None) + .await + .unwrap(); + + assert_eq!(mem.count().await.unwrap(), 2); + + let recalled = mem.recall("45", 5, None).await.unwrap(); + assert!(recalled.iter().any(|entry| entry.content.contains("45"))); + } + + // ═══════════════════════════════════════════════════════════════════════ + // Recovery Tests - Tool Call Parsing Edge Cases + // ═══════════════════════════════════════════════════════════════════════ + + #[test] + fn parse_tool_calls_handles_empty_tool_result() { + // Recovery: Empty tool_result tag should be handled gracefully + let response = r#"I'll run that command. + + + +Done."#; + let (text, calls) = parse_tool_calls(response); + assert!(text.contains("Done.")); + assert!(calls.is_empty()); + } + + #[test] + fn parse_arguments_value_handles_null() { + // Recovery: null arguments are returned as-is (Value::Null) + let value = serde_json::json!(null); + let result = parse_arguments_value(Some(&value)); + assert!(result.is_null()); + } + + #[test] + fn parse_tool_calls_handles_empty_tool_calls_array() { + // Recovery: Empty tool_calls array returns original response (no tool parsing) + let response = r#"{"content": "Hello", "tool_calls": []}"#; + let (text, calls) = parse_tool_calls(response); + // When tool_calls is empty, the entire JSON is returned as text + assert!(text.contains("Hello")); + assert!(calls.is_empty()); + } + + #[test] + fn parse_tool_calls_handles_whitespace_only_name() { + // Recovery: Whitespace-only tool name should return None + let value = serde_json::json!({"function": {"name": " ", "arguments": {}}}); + let result = parse_tool_call_value(&value); + assert!(result.is_none()); + } + + #[test] + fn parse_tool_calls_handles_empty_string_arguments() { + // Recovery: Empty string arguments should be handled + let value = serde_json::json!({"name": "test", "arguments": ""}); + let result = parse_tool_call_value(&value); + assert!(result.is_some()); + assert_eq!(result.unwrap().name, "test"); + } + + // ═══════════════════════════════════════════════════════════════════════ + // Recovery Tests - History Management + // ═══════════════════════════════════════════════════════════════════════ + + #[test] + fn trim_history_with_no_system_prompt() { + // Recovery: History without system prompt should trim correctly + let mut history = vec![]; + for i in 0..MAX_HISTORY_MESSAGES + 20 { + history.push(ChatMessage::user(format!("msg {i}"))); + } + trim_history(&mut history); + assert_eq!(history.len(), MAX_HISTORY_MESSAGES); + } + + #[test] + fn trim_history_preserves_role_ordering() { + // Recovery: After trimming, role ordering should remain consistent + let mut history = vec![ChatMessage::system("system")]; + for i in 0..MAX_HISTORY_MESSAGES + 10 { + history.push(ChatMessage::user(format!("user {i}"))); + history.push(ChatMessage::assistant(format!("assistant {i}"))); + } + trim_history(&mut history); + assert_eq!(history[0].role, "system"); + assert_eq!(history[history.len() - 1].role, "assistant"); + } + + #[test] + fn trim_history_with_only_system_prompt() { + // Recovery: Only system prompt should not be trimmed + let mut history = vec![ChatMessage::system("system prompt")]; + trim_history(&mut history); + assert_eq!(history.len(), 1); + } + + // ═══════════════════════════════════════════════════════════════════════ + // Recovery Tests - Arguments Parsing + // ═══════════════════════════════════════════════════════════════════════ + + #[test] + fn parse_arguments_value_handles_invalid_json_string() { + // Recovery: Invalid JSON string should return empty object + let value = serde_json::Value::String("not valid json".to_string()); + let result = parse_arguments_value(Some(&value)); + assert!(result.is_object()); + assert!(result.as_object().unwrap().is_empty()); + } + + #[test] + fn parse_arguments_value_handles_none() { + // Recovery: None arguments should return empty object + let result = parse_arguments_value(None); + assert!(result.is_object()); + assert!(result.as_object().unwrap().is_empty()); + } + + // ═══════════════════════════════════════════════════════════════════════ + // Recovery Tests - JSON Extraction + // ═══════════════════════════════════════════════════════════════════════ + + #[test] + fn extract_json_values_handles_empty_string() { + // Recovery: Empty input should return empty vec + let result = extract_json_values(""); + assert!(result.is_empty()); + } + + #[test] + fn extract_json_values_handles_whitespace_only() { + // Recovery: Whitespace only should return empty vec + let result = extract_json_values(" \n\t "); + assert!(result.is_empty()); + } + + #[test] + fn extract_json_values_handles_multiple_objects() { + // Recovery: Multiple JSON objects should all be extracted + let input = r#"{"a": 1}{"b": 2}{"c": 3}"#; + let result = extract_json_values(input); + assert_eq!(result.len(), 3); + } + + #[test] + fn extract_json_values_handles_arrays() { + // Recovery: JSON arrays should be extracted + let input = r#"[1, 2, 3]{"key": "value"}"#; + let result = extract_json_values(input); + assert_eq!(result.len(), 2); + } + + // ═══════════════════════════════════════════════════════════════════════ + // Recovery Tests - Constants Validation + // ═══════════════════════════════════════════════════════════════════════ + + const _: () = { + assert!(MAX_TOOL_ITERATIONS > 0); + assert!(MAX_TOOL_ITERATIONS <= 100); + assert!(MAX_HISTORY_MESSAGES > 0); + assert!(MAX_HISTORY_MESSAGES <= 1000); + }; + + #[test] + fn constants_bounds_are_compile_time_checked() { + // Bounds are enforced by the const assertions above. + } + + // ═══════════════════════════════════════════════════════════════════════ + // Recovery Tests - Tool Call Value Parsing + // ═══════════════════════════════════════════════════════════════════════ + + #[test] + fn parse_tool_call_value_handles_missing_name_field() { + // Recovery: Missing name field should return None + let value = serde_json::json!({"function": {"arguments": {}}}); + let result = parse_tool_call_value(&value); + assert!(result.is_none()); + } + + #[test] + fn parse_tool_call_value_handles_top_level_name() { + // Recovery: Tool call with name at top level (non-OpenAI format) + let value = serde_json::json!({"name": "test_tool", "arguments": {}}); + let result = parse_tool_call_value(&value); + assert!(result.is_some()); + assert_eq!(result.unwrap().name, "test_tool"); + } + + #[test] + fn parse_tool_calls_from_json_value_handles_empty_array() { + // Recovery: Empty tool_calls array should return empty vec + let value = serde_json::json!({"tool_calls": []}); + let result = parse_tool_calls_from_json_value(&value); + assert!(result.is_empty()); + } + + #[test] + fn parse_tool_calls_from_json_value_handles_missing_tool_calls() { + // Recovery: Missing tool_calls field should fall through + let value = serde_json::json!({"name": "test", "arguments": {}}); + let result = parse_tool_calls_from_json_value(&value); + assert_eq!(result.len(), 1); + } + + #[test] + fn parse_tool_calls_from_json_value_handles_top_level_array() { + // Recovery: Top-level array of tool calls + let value = serde_json::json!([ + {"name": "tool_a", "arguments": {}}, + {"name": "tool_b", "arguments": {}} + ]); + let result = parse_tool_calls_from_json_value(&value); + assert_eq!(result.len(), 2); + } } diff --git a/src/agent/memory_loader.rs b/src/agent/memory_loader.rs new file mode 100644 index 0000000..0cc530f --- /dev/null +++ b/src/agent/memory_loader.rs @@ -0,0 +1,125 @@ +use crate::memory::Memory; +use async_trait::async_trait; +use std::fmt::Write; + +#[async_trait] +pub trait MemoryLoader: Send + Sync { + async fn load_context(&self, memory: &dyn Memory, user_message: &str) + -> anyhow::Result; +} + +pub struct DefaultMemoryLoader { + limit: usize, +} + +impl Default for DefaultMemoryLoader { + fn default() -> Self { + Self { limit: 5 } + } +} + +impl DefaultMemoryLoader { + pub fn new(limit: usize) -> Self { + Self { + limit: limit.max(1), + } + } +} + +#[async_trait] +impl MemoryLoader for DefaultMemoryLoader { + async fn load_context( + &self, + memory: &dyn Memory, + user_message: &str, + ) -> anyhow::Result { + let entries = memory.recall(user_message, self.limit, None).await?; + if entries.is_empty() { + return Ok(String::new()); + } + + let mut context = String::from("[Memory context]\n"); + for entry in entries { + let _ = writeln!(context, "- {}: {}", entry.key, entry.content); + } + context.push('\n'); + Ok(context) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::memory::{Memory, MemoryCategory, MemoryEntry}; + + struct MockMemory; + + #[async_trait] + impl Memory for MockMemory { + async fn store( + &self, + _key: &str, + _content: &str, + _category: MemoryCategory, + _session_id: Option<&str>, + ) -> anyhow::Result<()> { + Ok(()) + } + + async fn recall( + &self, + _query: &str, + limit: usize, + _session_id: Option<&str>, + ) -> anyhow::Result> { + if limit == 0 { + return Ok(vec![]); + } + Ok(vec![MemoryEntry { + id: "1".into(), + key: "k".into(), + content: "v".into(), + category: MemoryCategory::Conversation, + timestamp: "now".into(), + session_id: None, + score: None, + }]) + } + + async fn get(&self, _key: &str) -> anyhow::Result> { + Ok(None) + } + + async fn list( + &self, + _category: Option<&MemoryCategory>, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Ok(vec![]) + } + + async fn forget(&self, _key: &str) -> anyhow::Result { + Ok(true) + } + + async fn count(&self) -> anyhow::Result { + Ok(0) + } + + async fn health_check(&self) -> bool { + true + } + + fn name(&self) -> &str { + "mock" + } + } + + #[tokio::test] + async fn default_loader_formats_context() { + let loader = DefaultMemoryLoader::default(); + let context = loader.load_context(&MockMemory, "hello").await.unwrap(); + assert!(context.contains("[Memory context]")); + assert!(context.contains("- k: v")); + } +} diff --git a/src/agent/mod.rs b/src/agent/mod.rs index f889613..29c96a5 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -1,3 +1,14 @@ +#[allow(clippy::module_inception)] +pub mod agent; +pub mod dispatcher; pub mod loop_; +pub mod memory_loader; +pub mod prompt; -pub use loop_::run; +#[cfg(test)] +mod tests; + +#[allow(unused_imports)] +pub use agent::{Agent, AgentBuilder}; +#[allow(unused_imports)] +pub use loop_::{process_message, run}; diff --git a/src/agent/prompt.rs b/src/agent/prompt.rs new file mode 100644 index 0000000..bdc426f --- /dev/null +++ b/src/agent/prompt.rs @@ -0,0 +1,304 @@ +use crate::config::IdentityConfig; +use crate::identity; +use crate::skills::Skill; +use crate::tools::Tool; +use anyhow::Result; +use chrono::Local; +use std::fmt::Write; +use std::path::Path; + +const BOOTSTRAP_MAX_CHARS: usize = 20_000; + +pub struct PromptContext<'a> { + pub workspace_dir: &'a Path, + pub model_name: &'a str, + pub tools: &'a [Box], + pub skills: &'a [Skill], + pub identity_config: Option<&'a IdentityConfig>, + pub dispatcher_instructions: &'a str, +} + +pub trait PromptSection: Send + Sync { + fn name(&self) -> &str; + fn build(&self, ctx: &PromptContext<'_>) -> Result; +} + +#[derive(Default)] +pub struct SystemPromptBuilder { + sections: Vec>, +} + +impl SystemPromptBuilder { + pub fn with_defaults() -> Self { + Self { + sections: vec![ + Box::new(IdentitySection), + Box::new(ToolsSection), + Box::new(SafetySection), + Box::new(SkillsSection), + Box::new(WorkspaceSection), + Box::new(DateTimeSection), + Box::new(RuntimeSection), + ], + } + } + + pub fn add_section(mut self, section: Box) -> Self { + self.sections.push(section); + self + } + + pub fn build(&self, ctx: &PromptContext<'_>) -> Result { + let mut output = String::new(); + for section in &self.sections { + let part = section.build(ctx)?; + if part.trim().is_empty() { + continue; + } + output.push_str(part.trim_end()); + output.push_str("\n\n"); + } + Ok(output) + } +} + +pub struct IdentitySection; +pub struct ToolsSection; +pub struct SafetySection; +pub struct SkillsSection; +pub struct WorkspaceSection; +pub struct RuntimeSection; +pub struct DateTimeSection; + +impl PromptSection for IdentitySection { + fn name(&self) -> &str { + "identity" + } + + fn build(&self, ctx: &PromptContext<'_>) -> Result { + let mut prompt = String::from("## Project Context\n\n"); + if let Some(config) = ctx.identity_config { + if identity::is_aieos_configured(config) { + if let Ok(Some(aieos)) = identity::load_aieos_identity(config, ctx.workspace_dir) { + let rendered = identity::aieos_to_system_prompt(&aieos); + if !rendered.is_empty() { + prompt.push_str(&rendered); + return Ok(prompt); + } + } + } + } + + prompt.push_str( + "The following workspace files define your identity, behavior, and context.\n\n", + ); + for file in [ + "AGENTS.md", + "SOUL.md", + "TOOLS.md", + "IDENTITY.md", + "USER.md", + "HEARTBEAT.md", + "BOOTSTRAP.md", + "MEMORY.md", + ] { + inject_workspace_file(&mut prompt, ctx.workspace_dir, file); + } + + Ok(prompt) + } +} + +impl PromptSection for ToolsSection { + fn name(&self) -> &str { + "tools" + } + + fn build(&self, ctx: &PromptContext<'_>) -> Result { + let mut out = String::from("## Tools\n\n"); + for tool in ctx.tools { + let _ = writeln!( + out, + "- **{}**: {}\n Parameters: `{}`", + tool.name(), + tool.description(), + tool.parameters_schema() + ); + } + if !ctx.dispatcher_instructions.is_empty() { + out.push('\n'); + out.push_str(ctx.dispatcher_instructions); + } + Ok(out) + } +} + +impl PromptSection for SafetySection { + fn name(&self) -> &str { + "safety" + } + + fn build(&self, _ctx: &PromptContext<'_>) -> Result { + Ok("## Safety\n\n- Do not exfiltrate private data.\n- Do not run destructive commands without asking.\n- Do not bypass oversight or approval mechanisms.\n- Prefer `trash` over `rm`.\n- When in doubt, ask before acting externally.".into()) + } +} + +impl PromptSection for SkillsSection { + fn name(&self) -> &str { + "skills" + } + + fn build(&self, ctx: &PromptContext<'_>) -> Result { + if ctx.skills.is_empty() { + return Ok(String::new()); + } + + let mut prompt = String::from("## Available Skills\n\n\n"); + for skill in ctx.skills { + let location = skill.location.clone().unwrap_or_else(|| { + ctx.workspace_dir + .join("skills") + .join(&skill.name) + .join("SKILL.md") + }); + let _ = writeln!( + prompt, + " \n {}\n {}\n {}\n ", + skill.name, + skill.description, + location.display() + ); + } + prompt.push_str(""); + Ok(prompt) + } +} + +impl PromptSection for WorkspaceSection { + fn name(&self) -> &str { + "workspace" + } + + fn build(&self, ctx: &PromptContext<'_>) -> Result { + Ok(format!( + "## Workspace\n\nWorking directory: `{}`", + ctx.workspace_dir.display() + )) + } +} + +impl PromptSection for RuntimeSection { + fn name(&self) -> &str { + "runtime" + } + + fn build(&self, ctx: &PromptContext<'_>) -> Result { + let host = + hostname::get().map_or_else(|_| "unknown".into(), |h| h.to_string_lossy().to_string()); + Ok(format!( + "## Runtime\n\nHost: {host} | OS: {} | Model: {}", + std::env::consts::OS, + ctx.model_name + )) + } +} + +impl PromptSection for DateTimeSection { + fn name(&self) -> &str { + "datetime" + } + + fn build(&self, _ctx: &PromptContext<'_>) -> Result { + let now = Local::now(); + Ok(format!( + "## Current Date & Time\n\nTimezone: {}", + now.format("%Z") + )) + } +} + +fn inject_workspace_file(prompt: &mut String, workspace_dir: &Path, filename: &str) { + let path = workspace_dir.join(filename); + match std::fs::read_to_string(&path) { + Ok(content) => { + let trimmed = content.trim(); + if trimmed.is_empty() { + return; + } + let _ = writeln!(prompt, "### {filename}\n"); + let truncated = if trimmed.chars().count() > BOOTSTRAP_MAX_CHARS { + trimmed + .char_indices() + .nth(BOOTSTRAP_MAX_CHARS) + .map(|(idx, _)| &trimmed[..idx]) + .unwrap_or(trimmed) + } else { + trimmed + }; + prompt.push_str(truncated); + if truncated.len() < trimmed.len() { + let _ = writeln!( + prompt, + "\n\n[... truncated at {BOOTSTRAP_MAX_CHARS} chars — use `read` for full file]\n" + ); + } else { + prompt.push_str("\n\n"); + } + } + Err(_) => { + let _ = writeln!(prompt, "### {filename}\n\n[File not found: {filename}]\n"); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tools::traits::Tool; + use async_trait::async_trait; + + struct TestTool; + + #[async_trait] + impl Tool for TestTool { + fn name(&self) -> &str { + "test_tool" + } + + fn description(&self) -> &str { + "tool desc" + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({"type": "object"}) + } + + async fn execute( + &self, + _args: serde_json::Value, + ) -> anyhow::Result { + Ok(crate::tools::ToolResult { + success: true, + output: "ok".into(), + error: None, + }) + } + } + + #[test] + fn prompt_builder_assembles_sections() { + let tools: Vec> = vec![Box::new(TestTool)]; + let ctx = PromptContext { + workspace_dir: Path::new("/tmp"), + model_name: "test-model", + tools: &tools, + skills: &[], + identity_config: None, + dispatcher_instructions: "instr", + }; + let prompt = SystemPromptBuilder::with_defaults().build(&ctx).unwrap(); + assert!(prompt.contains("## Tools")); + assert!(prompt.contains("test_tool")); + assert!(prompt.contains("instr")); + } +} diff --git a/src/agent/tests.rs b/src/agent/tests.rs new file mode 100644 index 0000000..63058d0 --- /dev/null +++ b/src/agent/tests.rs @@ -0,0 +1,1269 @@ +//! Comprehensive agent-loop test suite. +//! +//! Tests exercise the full `Agent.turn()` cycle with mock providers and tools, +//! covering every edge case an agentic tool loop must handle: +//! +//! 1. Simple text response (no tools) +//! 2. Single tool call → final response +//! 3. Multi-step tool chain (tool A → tool B → response) +//! 4. Max-iteration bailout +//! 5. Unknown tool name recovery +//! 6. Tool execution failure recovery +//! 7. Parallel tool dispatch +//! 8. History trimming during long conversations +//! 9. Memory auto-save round-trip +//! 10. Native vs XML dispatcher integration +//! 11. Empty / whitespace-only LLM responses +//! 12. Mixed text + tool call responses +//! 13. Multi-tool batch in a single response +//! 14. System prompt generation & tool instructions +//! 15. Context enrichment from memory loader +//! 16. ConversationMessage serialization round-trip +//! 17. Tool call with stringified JSON arguments +//! 18. Conversation history fidelity (tool call → tool result → assistant) +//! 19. Builder validation (missing required fields) +//! 20. Idempotent system prompt insertion + +use crate::agent::agent::Agent; +use crate::agent::dispatcher::{ + NativeToolDispatcher, ToolDispatcher, ToolExecutionResult, XmlToolDispatcher, +}; +use crate::config::{AgentConfig, MemoryConfig}; +use crate::memory::{self, Memory}; +use crate::observability::{NoopObserver, Observer}; +use crate::providers::{ + ChatMessage, ChatRequest, ChatResponse, ConversationMessage, Provider, ToolCall, + ToolResultMessage, +}; +use crate::tools::{Tool, ToolResult}; +use anyhow::Result; +use async_trait::async_trait; +use std::sync::{Arc, Mutex}; + +// ═══════════════════════════════════════════════════════════════════════════ +// Test Helpers — Mock Provider, Mock Tool, Mock Memory +// ═══════════════════════════════════════════════════════════════════════════ + +/// A mock LLM provider that returns pre-scripted responses in order. +/// When the queue is exhausted it returns a simple "done" text response. +struct ScriptedProvider { + responses: Mutex>, + /// Records every request for assertion. + requests: Mutex>>, +} + +impl ScriptedProvider { + fn new(responses: Vec) -> Self { + Self { + responses: Mutex::new(responses), + requests: Mutex::new(Vec::new()), + } + } + + fn request_count(&self) -> usize { + self.requests.lock().unwrap().len() + } +} + +#[async_trait] +impl Provider for ScriptedProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> Result { + Ok("fallback".into()) + } + + async fn chat( + &self, + request: ChatRequest<'_>, + _model: &str, + _temperature: f64, + ) -> Result { + self.requests + .lock() + .unwrap() + .push(request.messages.to_vec()); + + let mut guard = self.responses.lock().unwrap(); + if guard.is_empty() { + return Ok(ChatResponse { + text: Some("done".into()), + tool_calls: vec![], + }); + } + Ok(guard.remove(0)) + } +} + +/// A mock provider that always returns an error. +struct FailingProvider; + +#[async_trait] +impl Provider for FailingProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> Result { + anyhow::bail!("provider error") + } + + async fn chat( + &self, + _request: ChatRequest<'_>, + _model: &str, + _temperature: f64, + ) -> Result { + anyhow::bail!("provider error") + } +} + +/// A simple echo tool that returns its arguments as output. +struct EchoTool; + +#[async_trait] +impl Tool for EchoTool { + fn name(&self) -> &str { + "echo" + } + + fn description(&self) -> &str { + "Echoes the input" + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": { + "message": {"type": "string"} + } + }) + } + + async fn execute(&self, args: serde_json::Value) -> Result { + let msg = args + .get("message") + .and_then(|v| v.as_str()) + .unwrap_or("(empty)") + .to_string(); + Ok(ToolResult { + success: true, + output: msg, + error: None, + }) + } +} + +/// A tool that always fails execution. +struct FailingTool; + +#[async_trait] +impl Tool for FailingTool { + fn name(&self) -> &str { + "fail" + } + + fn description(&self) -> &str { + "Always fails" + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({"type": "object"}) + } + + async fn execute(&self, _args: serde_json::Value) -> Result { + Ok(ToolResult { + success: false, + output: String::new(), + error: Some("intentional failure".into()), + }) + } +} + +/// A tool that panics (tests error propagation). +struct PanickingTool; + +#[async_trait] +impl Tool for PanickingTool { + fn name(&self) -> &str { + "panicker" + } + + fn description(&self) -> &str { + "Panics on execution" + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({"type": "object"}) + } + + async fn execute(&self, _args: serde_json::Value) -> Result { + anyhow::bail!("catastrophic tool failure") + } +} + +/// A tool that tracks how many times it was called. +struct CountingTool { + count: Arc>, +} + +impl CountingTool { + fn new() -> (Self, Arc>) { + let count = Arc::new(Mutex::new(0)); + ( + Self { + count: count.clone(), + }, + count, + ) + } +} + +#[async_trait] +impl Tool for CountingTool { + fn name(&self) -> &str { + "counter" + } + + fn description(&self) -> &str { + "Counts calls" + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({"type": "object"}) + } + + async fn execute(&self, _args: serde_json::Value) -> Result { + let mut c = self.count.lock().unwrap(); + *c += 1; + Ok(ToolResult { + success: true, + output: format!("call #{}", *c), + error: None, + }) + } +} + +fn make_memory() -> Arc { + let cfg = MemoryConfig { + backend: "none".into(), + ..MemoryConfig::default() + }; + Arc::from(memory::create_memory(&cfg, std::path::Path::new("/tmp"), None).unwrap()) +} + +fn make_sqlite_memory() -> (Arc, tempfile::TempDir) { + let tmp = tempfile::TempDir::new().unwrap(); + let cfg = MemoryConfig { + backend: "sqlite".into(), + ..MemoryConfig::default() + }; + let mem = Arc::from(memory::create_memory(&cfg, tmp.path(), None).unwrap()); + (mem, tmp) +} + +fn make_observer() -> Arc { + Arc::from(NoopObserver {}) +} + +fn build_agent_with( + provider: Box, + tools: Vec>, + dispatcher: Box, +) -> Agent { + Agent::builder() + .provider(provider) + .tools(tools) + .memory(make_memory()) + .observer(make_observer()) + .tool_dispatcher(dispatcher) + .workspace_dir(std::path::PathBuf::from("/tmp")) + .build() + .unwrap() +} + +fn build_agent_with_memory( + provider: Box, + tools: Vec>, + mem: Arc, + auto_save: bool, +) -> Agent { + Agent::builder() + .provider(provider) + .tools(tools) + .memory(mem) + .observer(make_observer()) + .tool_dispatcher(Box::new(NativeToolDispatcher)) + .workspace_dir(std::path::PathBuf::from("/tmp")) + .auto_save(auto_save) + .build() + .unwrap() +} + +fn build_agent_with_config( + provider: Box, + tools: Vec>, + config: AgentConfig, +) -> Agent { + Agent::builder() + .provider(provider) + .tools(tools) + .memory(make_memory()) + .observer(make_observer()) + .tool_dispatcher(Box::new(NativeToolDispatcher)) + .workspace_dir(std::path::PathBuf::from("/tmp")) + .config(config) + .build() + .unwrap() +} + +/// Helper: create a ChatResponse with tool calls (native format). +fn tool_response(calls: Vec) -> ChatResponse { + ChatResponse { + text: Some(String::new()), + tool_calls: calls, + } +} + +/// Helper: create a plain text ChatResponse. +fn text_response(text: &str) -> ChatResponse { + ChatResponse { + text: Some(text.into()), + tool_calls: vec![], + } +} + +/// Helper: create an XML-style tool call response. +fn xml_tool_response(name: &str, args: &str) -> ChatResponse { + ChatResponse { + text: Some(format!( + "\n{{\"name\": \"{name}\", \"arguments\": {args}}}\n" + )), + tool_calls: vec![], + } +} + +// ═══════════════════════════════════════════════════════════════════════════ +// 1. Simple text response (no tools) +// ═══════════════════════════════════════════════════════════════════════════ + +#[tokio::test] +async fn turn_returns_text_when_no_tools_called() { + let provider = Box::new(ScriptedProvider::new(vec![text_response("Hello world")])); + let mut agent = build_agent_with( + provider, + vec![Box::new(EchoTool)], + Box::new(NativeToolDispatcher), + ); + + let response = agent.turn("hi").await.unwrap(); + assert_eq!(response, "Hello world"); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// 2. Single tool call → final response +// ═══════════════════════════════════════════════════════════════════════════ + +#[tokio::test] +async fn turn_executes_single_tool_then_returns() { + let provider = Box::new(ScriptedProvider::new(vec![ + tool_response(vec![ToolCall { + id: "tc1".into(), + name: "echo".into(), + arguments: r#"{"message": "hello from tool"}"#.into(), + }]), + text_response("I ran the tool"), + ])); + + let mut agent = build_agent_with( + provider, + vec![Box::new(EchoTool)], + Box::new(NativeToolDispatcher), + ); + + let response = agent.turn("run echo").await.unwrap(); + assert_eq!(response, "I ran the tool"); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// 3. Multi-step tool chain (tool A → tool B → response) +// ═══════════════════════════════════════════════════════════════════════════ + +#[tokio::test] +async fn turn_handles_multi_step_tool_chain() { + let (counting_tool, count) = CountingTool::new(); + + let provider = Box::new(ScriptedProvider::new(vec![ + tool_response(vec![ToolCall { + id: "tc1".into(), + name: "counter".into(), + arguments: "{}".into(), + }]), + tool_response(vec![ToolCall { + id: "tc2".into(), + name: "counter".into(), + arguments: "{}".into(), + }]), + tool_response(vec![ToolCall { + id: "tc3".into(), + name: "counter".into(), + arguments: "{}".into(), + }]), + text_response("Done after 3 calls"), + ])); + + let mut agent = build_agent_with( + provider, + vec![Box::new(counting_tool)], + Box::new(NativeToolDispatcher), + ); + + let response = agent.turn("count 3 times").await.unwrap(); + assert_eq!(response, "Done after 3 calls"); + assert_eq!(*count.lock().unwrap(), 3); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// 4. Max-iteration bailout +// ═══════════════════════════════════════════════════════════════════════════ + +#[tokio::test] +async fn turn_bails_out_at_max_iterations() { + // Create more tool calls than max_tool_iterations allows. + let max_iters = 3; + let mut responses = Vec::new(); + for i in 0..max_iters + 5 { + responses.push(tool_response(vec![ToolCall { + id: format!("tc{i}"), + name: "echo".into(), + arguments: r#"{"message": "loop"}"#.into(), + }])); + } + + let provider = Box::new(ScriptedProvider::new(responses)); + + let config = AgentConfig { + max_tool_iterations: max_iters, + ..AgentConfig::default() + }; + + let mut agent = build_agent_with_config(provider, vec![Box::new(EchoTool)], config); + + let result = agent.turn("infinite loop").await; + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!( + err.contains("maximum tool iterations"), + "Expected max iterations error, got: {err}" + ); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// 5. Unknown tool name recovery +// ═══════════════════════════════════════════════════════════════════════════ + +#[tokio::test] +async fn turn_handles_unknown_tool_gracefully() { + let provider = Box::new(ScriptedProvider::new(vec![ + tool_response(vec![ToolCall { + id: "tc1".into(), + name: "nonexistent_tool".into(), + arguments: "{}".into(), + }]), + text_response("I couldn't find that tool"), + ])); + + let mut agent = build_agent_with( + provider, + vec![Box::new(EchoTool)], + Box::new(NativeToolDispatcher), + ); + + let response = agent.turn("use nonexistent").await.unwrap(); + assert_eq!(response, "I couldn't find that tool"); + + // Verify the tool result mentioned "Unknown tool" + let has_tool_result = agent.history().iter().any(|msg| match msg { + ConversationMessage::ToolResults(results) => { + results.iter().any(|r| r.content.contains("Unknown tool")) + } + _ => false, + }); + assert!( + has_tool_result, + "Expected tool result with 'Unknown tool' message" + ); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// 6. Tool execution failure recovery +// ═══════════════════════════════════════════════════════════════════════════ + +#[tokio::test] +async fn turn_recovers_from_tool_failure() { + let provider = Box::new(ScriptedProvider::new(vec![ + tool_response(vec![ToolCall { + id: "tc1".into(), + name: "fail".into(), + arguments: "{}".into(), + }]), + text_response("Tool failed but I recovered"), + ])); + + let mut agent = build_agent_with( + provider, + vec![Box::new(FailingTool)], + Box::new(NativeToolDispatcher), + ); + + let response = agent.turn("try failing tool").await.unwrap(); + assert_eq!(response, "Tool failed but I recovered"); +} + +#[tokio::test] +async fn turn_recovers_from_tool_error() { + let provider = Box::new(ScriptedProvider::new(vec![ + tool_response(vec![ToolCall { + id: "tc1".into(), + name: "panicker".into(), + arguments: "{}".into(), + }]), + text_response("I recovered from the error"), + ])); + + let mut agent = build_agent_with( + provider, + vec![Box::new(PanickingTool)], + Box::new(NativeToolDispatcher), + ); + + let response = agent.turn("try panicking").await.unwrap(); + assert_eq!(response, "I recovered from the error"); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// 7. Provider error propagation +// ═══════════════════════════════════════════════════════════════════════════ + +#[tokio::test] +async fn turn_propagates_provider_error() { + let mut agent = build_agent_with( + Box::new(FailingProvider), + vec![], + Box::new(NativeToolDispatcher), + ); + + let result = agent.turn("hello").await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("provider error")); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// 8. History trimming during long conversations +// ═══════════════════════════════════════════════════════════════════════════ + +#[tokio::test] +async fn history_trims_after_max_messages() { + let max_history = 6; + let mut responses = vec![]; + for _ in 0..max_history + 5 { + responses.push(text_response("ok")); + } + + let provider = Box::new(ScriptedProvider::new(responses)); + let config = AgentConfig { + max_history_messages: max_history, + ..AgentConfig::default() + }; + + let mut agent = build_agent_with_config(provider, vec![], config); + + for i in 0..max_history + 5 { + let _ = agent.turn(&format!("msg {i}")).await.unwrap(); + } + + // System prompt (1) + trimmed messages + // Should not exceed max_history + 1 (system prompt) + assert!( + agent.history().len() <= max_history + 1, + "History length {} exceeds max {} + 1 (system)", + agent.history().len(), + max_history, + ); + + // System prompt should always be preserved + let first = &agent.history()[0]; + assert!(matches!(first, ConversationMessage::Chat(c) if c.role == "system")); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// 9. Memory auto-save round-trip +// ═══════════════════════════════════════════════════════════════════════════ + +#[tokio::test] +async fn auto_save_stores_messages_in_memory() { + let (mem, _tmp) = make_sqlite_memory(); + let provider = Box::new(ScriptedProvider::new(vec![text_response( + "I remember everything", + )])); + + let mut agent = build_agent_with_memory( + provider, + vec![], + mem.clone(), + true, // auto_save enabled + ); + + let _ = agent.turn("Remember this fact").await.unwrap(); + + // Both user message and assistant response should be saved + let count = mem.count().await.unwrap(); + assert!( + count >= 2, + "Expected at least 2 memory entries, got {count}" + ); +} + +#[tokio::test] +async fn auto_save_disabled_does_not_store() { + let (mem, _tmp) = make_sqlite_memory(); + let provider = Box::new(ScriptedProvider::new(vec![text_response("hello")])); + + let mut agent = build_agent_with_memory( + provider, + vec![], + mem.clone(), + false, // auto_save disabled + ); + + let _ = agent.turn("test message").await.unwrap(); + + let count = mem.count().await.unwrap(); + assert_eq!(count, 0, "Expected 0 memory entries with auto_save off"); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// 10. Native vs XML dispatcher integration +// ═══════════════════════════════════════════════════════════════════════════ + +#[tokio::test] +async fn xml_dispatcher_parses_and_loops() { + let provider = Box::new(ScriptedProvider::new(vec![ + xml_tool_response("echo", r#"{"message": "xml-test"}"#), + text_response("XML tool completed"), + ])); + + let mut agent = build_agent_with( + provider, + vec![Box::new(EchoTool)], + Box::new(XmlToolDispatcher), + ); + + let response = agent.turn("test xml").await.unwrap(); + assert_eq!(response, "XML tool completed"); +} + +#[tokio::test] +async fn native_dispatcher_sends_tool_specs() { + let provider = Box::new(ScriptedProvider::new(vec![text_response("ok")])); + let mut agent = build_agent_with( + provider, + vec![Box::new(EchoTool)], + Box::new(NativeToolDispatcher), + ); + + let _ = agent.turn("hi").await.unwrap(); + + // NativeToolDispatcher.should_send_tool_specs() returns true + let dispatcher = NativeToolDispatcher; + assert!(dispatcher.should_send_tool_specs()); +} + +#[tokio::test] +async fn xml_dispatcher_does_not_send_tool_specs() { + let dispatcher = XmlToolDispatcher; + assert!(!dispatcher.should_send_tool_specs()); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// 11. Empty / whitespace-only LLM responses +// ═══════════════════════════════════════════════════════════════════════════ + +#[tokio::test] +async fn turn_handles_empty_text_response() { + let provider = Box::new(ScriptedProvider::new(vec![ChatResponse { + text: Some(String::new()), + tool_calls: vec![], + }])); + + let mut agent = build_agent_with(provider, vec![], Box::new(NativeToolDispatcher)); + + let response = agent.turn("hi").await.unwrap(); + assert!(response.is_empty()); +} + +#[tokio::test] +async fn turn_handles_none_text_response() { + let provider = Box::new(ScriptedProvider::new(vec![ChatResponse { + text: None, + tool_calls: vec![], + }])); + + let mut agent = build_agent_with(provider, vec![], Box::new(NativeToolDispatcher)); + + // Should not panic — falls back to empty string + let response = agent.turn("hi").await.unwrap(); + assert!(response.is_empty()); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// 12. Mixed text + tool call responses +// ═══════════════════════════════════════════════════════════════════════════ + +#[tokio::test] +async fn turn_preserves_text_alongside_tool_calls() { + let provider = Box::new(ScriptedProvider::new(vec![ + ChatResponse { + text: Some("Let me check...".into()), + tool_calls: vec![ToolCall { + id: "tc1".into(), + name: "echo".into(), + arguments: r#"{"message": "hi"}"#.into(), + }], + }, + text_response("Here are the results"), + ])); + + let mut agent = build_agent_with( + provider, + vec![Box::new(EchoTool)], + Box::new(NativeToolDispatcher), + ); + + let response = agent.turn("check something").await.unwrap(); + assert_eq!(response, "Here are the results"); + + // The intermediate text should be in history + let has_intermediate = agent.history().iter().any(|msg| match msg { + ConversationMessage::Chat(c) => c.role == "assistant" && c.content.contains("Let me check"), + _ => false, + }); + assert!(has_intermediate, "Intermediate text should be in history"); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// 13. Multi-tool batch in a single response +// ═══════════════════════════════════════════════════════════════════════════ + +#[tokio::test] +async fn turn_handles_multiple_tools_in_one_response() { + let (counting_tool, count) = CountingTool::new(); + + let provider = Box::new(ScriptedProvider::new(vec![ + tool_response(vec![ + ToolCall { + id: "tc1".into(), + name: "counter".into(), + arguments: "{}".into(), + }, + ToolCall { + id: "tc2".into(), + name: "counter".into(), + arguments: "{}".into(), + }, + ToolCall { + id: "tc3".into(), + name: "counter".into(), + arguments: "{}".into(), + }, + ]), + text_response("All 3 done"), + ])); + + let mut agent = build_agent_with( + provider, + vec![Box::new(counting_tool)], + Box::new(NativeToolDispatcher), + ); + + let response = agent.turn("batch").await.unwrap(); + assert_eq!(response, "All 3 done"); + assert_eq!( + *count.lock().unwrap(), + 3, + "All 3 tools should have been called" + ); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// 14. System prompt generation & tool instructions +// ═══════════════════════════════════════════════════════════════════════════ + +#[tokio::test] +async fn system_prompt_injected_on_first_turn() { + let provider = Box::new(ScriptedProvider::new(vec![text_response("ok")])); + let mut agent = build_agent_with( + provider, + vec![Box::new(EchoTool)], + Box::new(NativeToolDispatcher), + ); + + assert!(agent.history().is_empty(), "History should start empty"); + + let _ = agent.turn("hi").await.unwrap(); + + // First message should be the system prompt + let first = &agent.history()[0]; + assert!( + matches!(first, ConversationMessage::Chat(c) if c.role == "system"), + "First history entry should be system prompt" + ); +} + +#[tokio::test] +async fn system_prompt_not_duplicated_on_second_turn() { + let provider = Box::new(ScriptedProvider::new(vec![ + text_response("first"), + text_response("second"), + ])); + let mut agent = build_agent_with( + provider, + vec![Box::new(EchoTool)], + Box::new(NativeToolDispatcher), + ); + + let _ = agent.turn("hi").await.unwrap(); + let _ = agent.turn("hello again").await.unwrap(); + + let system_count = agent + .history() + .iter() + .filter(|msg| matches!(msg, ConversationMessage::Chat(c) if c.role == "system")) + .count(); + assert_eq!(system_count, 1, "System prompt should appear exactly once"); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// 15. Conversation history fidelity +// ═══════════════════════════════════════════════════════════════════════════ + +#[tokio::test] +async fn history_contains_all_expected_entries_after_tool_loop() { + let provider = Box::new(ScriptedProvider::new(vec![ + tool_response(vec![ToolCall { + id: "tc1".into(), + name: "echo".into(), + arguments: r#"{"message": "tool-out"}"#.into(), + }]), + text_response("final answer"), + ])); + + let mut agent = build_agent_with( + provider, + vec![Box::new(EchoTool)], + Box::new(NativeToolDispatcher), + ); + + let _ = agent.turn("test").await.unwrap(); + + // Expected history entries: + // 0: system prompt + // 1: user message "test" + // 2: AssistantToolCalls + // 3: ToolResults + // 4: assistant "final answer" + let history = agent.history(); + assert!( + history.len() >= 5, + "Expected at least 5 history entries, got {}", + history.len() + ); + + assert!(matches!(&history[0], ConversationMessage::Chat(c) if c.role == "system")); + assert!(matches!(&history[1], ConversationMessage::Chat(c) if c.role == "user")); + assert!(matches!( + &history[2], + ConversationMessage::AssistantToolCalls { .. } + )); + assert!(matches!(&history[3], ConversationMessage::ToolResults(_))); + assert!( + matches!(&history[4], ConversationMessage::Chat(c) if c.role == "assistant" && c.content == "final answer") + ); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// 16. Builder validation +// ═══════════════════════════════════════════════════════════════════════════ + +#[tokio::test] +async fn builder_fails_without_provider() { + let result = Agent::builder() + .tools(vec![]) + .memory(make_memory()) + .observer(make_observer()) + .tool_dispatcher(Box::new(NativeToolDispatcher)) + .workspace_dir(std::path::PathBuf::from("/tmp")) + .build(); + + assert!(result.is_err(), "Building without provider should fail"); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// 17. Multi-turn conversation maintains context +// ═══════════════════════════════════════════════════════════════════════════ + +#[tokio::test] +async fn multi_turn_maintains_growing_history() { + let provider = Box::new(ScriptedProvider::new(vec![ + text_response("response 1"), + text_response("response 2"), + text_response("response 3"), + ])); + + let mut agent = build_agent_with(provider, vec![], Box::new(NativeToolDispatcher)); + + let r1 = agent.turn("msg 1").await.unwrap(); + let len_after_1 = agent.history().len(); + + let r2 = agent.turn("msg 2").await.unwrap(); + let len_after_2 = agent.history().len(); + + let r3 = agent.turn("msg 3").await.unwrap(); + let len_after_3 = agent.history().len(); + + assert_eq!(r1, "response 1"); + assert_eq!(r2, "response 2"); + assert_eq!(r3, "response 3"); + + // History should grow with each turn (user + assistant per turn) + assert!( + len_after_2 > len_after_1, + "History should grow after turn 2" + ); + assert!( + len_after_3 > len_after_2, + "History should grow after turn 3" + ); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// 18. Tool call with stringified JSON arguments (common LLM pattern) +// ═══════════════════════════════════════════════════════════════════════════ + +#[tokio::test] +async fn native_dispatcher_handles_stringified_arguments() { + let dispatcher = NativeToolDispatcher; + let response = ChatResponse { + text: Some(String::new()), + tool_calls: vec![ToolCall { + id: "tc1".into(), + name: "echo".into(), + arguments: r#"{"message": "hello"}"#.into(), + }], + }; + + let (_, calls) = dispatcher.parse_response(&response); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "echo"); + assert_eq!( + calls[0].arguments.get("message").unwrap().as_str().unwrap(), + "hello" + ); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// 19. XML dispatcher edge cases +// ═══════════════════════════════════════════════════════════════════════════ + +#[test] +fn xml_dispatcher_handles_nested_json() { + let response = ChatResponse { + text: Some( + r#" +{"name": "file_write", "arguments": {"path": "test.json", "content": "{\"key\": \"value\"}"}} +"# + .into(), + ), + tool_calls: vec![], + }; + + let dispatcher = XmlToolDispatcher; + let (_, calls) = dispatcher.parse_response(&response); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "file_write"); + assert_eq!( + calls[0].arguments.get("path").unwrap().as_str().unwrap(), + "test.json" + ); +} + +#[test] +fn xml_dispatcher_handles_empty_tool_call_tag() { + let response = ChatResponse { + text: Some("\n\nSome text".into()), + tool_calls: vec![], + }; + + let dispatcher = XmlToolDispatcher; + let (text, calls) = dispatcher.parse_response(&response); + assert!(calls.is_empty()); + assert!(text.contains("Some text")); +} + +#[test] +fn xml_dispatcher_handles_unclosed_tool_call() { + let response = ChatResponse { + text: Some("Before\n\n{\"name\": \"shell\"}".into()), + tool_calls: vec![], + }; + + let dispatcher = XmlToolDispatcher; + let (text, calls) = dispatcher.parse_response(&response); + // Should not panic — just treat as text + assert!(calls.is_empty()); + assert!(text.contains("Before")); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// 20. ConversationMessage serialization round-trip +// ═══════════════════════════════════════════════════════════════════════════ + +#[test] +fn conversation_message_serialization_roundtrip() { + let messages = vec![ + ConversationMessage::Chat(ChatMessage::system("system")), + ConversationMessage::Chat(ChatMessage::user("hello")), + ConversationMessage::AssistantToolCalls { + text: Some("checking".into()), + tool_calls: vec![ToolCall { + id: "tc1".into(), + name: "shell".into(), + arguments: "{}".into(), + }], + }, + ConversationMessage::ToolResults(vec![ToolResultMessage { + tool_call_id: "tc1".into(), + content: "ok".into(), + }]), + ConversationMessage::Chat(ChatMessage::assistant("done")), + ]; + + for msg in &messages { + let json = serde_json::to_string(msg).unwrap(); + let parsed: ConversationMessage = serde_json::from_str(&json).unwrap(); + + // Verify the variant type matches + match (msg, &parsed) { + (ConversationMessage::Chat(a), ConversationMessage::Chat(b)) => { + assert_eq!(a.role, b.role); + assert_eq!(a.content, b.content); + } + ( + ConversationMessage::AssistantToolCalls { + text: a_text, + tool_calls: a_calls, + }, + ConversationMessage::AssistantToolCalls { + text: b_text, + tool_calls: b_calls, + }, + ) => { + assert_eq!(a_text, b_text); + assert_eq!(a_calls.len(), b_calls.len()); + } + (ConversationMessage::ToolResults(a), ConversationMessage::ToolResults(b)) => { + assert_eq!(a.len(), b.len()); + } + _ => panic!("Variant mismatch after serialization"), + } + } +} + +// ═══════════════════════════════════════════════════════════════════════════ +// 21. Tool dispatcher format_results +// ═══════════════════════════════════════════════════════════════════════════ + +#[test] +fn xml_format_results_includes_status_and_output() { + let dispatcher = XmlToolDispatcher; + let results = vec![ + ToolExecutionResult { + name: "shell".into(), + output: "file1.txt\nfile2.txt".into(), + success: true, + tool_call_id: None, + }, + ToolExecutionResult { + name: "file_read".into(), + output: "Error: file not found".into(), + success: false, + tool_call_id: None, + }, + ]; + + let msg = dispatcher.format_results(&results); + let content = match msg { + ConversationMessage::Chat(c) => c.content, + _ => panic!("Expected Chat variant"), + }; + + assert!(content.contains("shell")); + assert!(content.contains("file1.txt")); + assert!(content.contains("ok")); + assert!(content.contains("file_read")); + assert!(content.contains("error")); +} + +#[test] +fn native_format_results_maps_tool_call_ids() { + let dispatcher = NativeToolDispatcher; + let results = vec![ + ToolExecutionResult { + name: "a".into(), + output: "out1".into(), + success: true, + tool_call_id: Some("tc-001".into()), + }, + ToolExecutionResult { + name: "b".into(), + output: "out2".into(), + success: true, + tool_call_id: Some("tc-002".into()), + }, + ]; + + let msg = dispatcher.format_results(&results); + match msg { + ConversationMessage::ToolResults(r) => { + assert_eq!(r.len(), 2); + assert_eq!(r[0].tool_call_id, "tc-001"); + assert_eq!(r[0].content, "out1"); + assert_eq!(r[1].tool_call_id, "tc-002"); + assert_eq!(r[1].content, "out2"); + } + _ => panic!("Expected ToolResults"), + } +} + +// ═══════════════════════════════════════════════════════════════════════════ +// 22. to_provider_messages conversion +// ═══════════════════════════════════════════════════════════════════════════ + +#[test] +fn xml_dispatcher_converts_history_to_provider_messages() { + let dispatcher = XmlToolDispatcher; + let history = vec![ + ConversationMessage::Chat(ChatMessage::system("sys")), + ConversationMessage::Chat(ChatMessage::user("hi")), + ConversationMessage::AssistantToolCalls { + text: Some("checking".into()), + tool_calls: vec![ToolCall { + id: "tc1".into(), + name: "shell".into(), + arguments: "{}".into(), + }], + }, + ConversationMessage::ToolResults(vec![ToolResultMessage { + tool_call_id: "tc1".into(), + content: "ok".into(), + }]), + ConversationMessage::Chat(ChatMessage::assistant("done")), + ]; + + let messages = dispatcher.to_provider_messages(&history); + + // Should have: system, user, assistant (from tool calls), user (tool results), assistant + assert!(messages.len() >= 4); + assert_eq!(messages[0].role, "system"); + assert_eq!(messages[1].role, "user"); +} + +#[test] +fn native_dispatcher_converts_tool_results_to_tool_messages() { + let dispatcher = NativeToolDispatcher; + let history = vec![ConversationMessage::ToolResults(vec![ + ToolResultMessage { + tool_call_id: "tc1".into(), + content: "output1".into(), + }, + ToolResultMessage { + tool_call_id: "tc2".into(), + content: "output2".into(), + }, + ])]; + + let messages = dispatcher.to_provider_messages(&history); + assert_eq!(messages.len(), 2); + assert_eq!(messages[0].role, "tool"); + assert_eq!(messages[1].role, "tool"); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// 23. XML tool instructions generation +// ═══════════════════════════════════════════════════════════════════════════ + +#[test] +fn xml_dispatcher_generates_tool_instructions() { + let tools: Vec> = vec![Box::new(EchoTool)]; + let dispatcher = XmlToolDispatcher; + let instructions = dispatcher.prompt_instructions(&tools); + + assert!(instructions.contains("## Tool Use Protocol")); + assert!(instructions.contains("")); + assert!(instructions.contains("echo")); + assert!(instructions.contains("Echoes the input")); +} + +#[test] +fn native_dispatcher_returns_empty_instructions() { + let tools: Vec> = vec![Box::new(EchoTool)]; + let dispatcher = NativeToolDispatcher; + let instructions = dispatcher.prompt_instructions(&tools); + assert!(instructions.is_empty()); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// 24. Clear history +// ═══════════════════════════════════════════════════════════════════════════ + +#[tokio::test] +async fn clear_history_resets_conversation() { + let provider = Box::new(ScriptedProvider::new(vec![ + text_response("first"), + text_response("second"), + ])); + + let mut agent = build_agent_with(provider, vec![], Box::new(NativeToolDispatcher)); + + let _ = agent.turn("hi").await.unwrap(); + assert!(!agent.history().is_empty()); + + agent.clear_history(); + assert!(agent.history().is_empty()); + + // Next turn should re-inject system prompt + let _ = agent.turn("hello again").await.unwrap(); + assert!(matches!( + &agent.history()[0], + ConversationMessage::Chat(c) if c.role == "system" + )); +} + +// ═══════════════════════════════════════════════════════════════════════════ +// 25. run_single delegates to turn +// ═══════════════════════════════════════════════════════════════════════════ + +#[tokio::test] +async fn run_single_delegates_to_turn() { + let provider = Box::new(ScriptedProvider::new(vec![text_response("via run_single")])); + let mut agent = build_agent_with(provider, vec![], Box::new(NativeToolDispatcher)); + + let response = agent.run_single("test").await.unwrap(); + assert_eq!(response, "via run_single"); +} diff --git a/src/approval/mod.rs b/src/approval/mod.rs new file mode 100644 index 0000000..ea5b02b --- /dev/null +++ b/src/approval/mod.rs @@ -0,0 +1,426 @@ +//! Interactive approval workflow for supervised mode. +//! +//! Provides a pre-execution hook that prompts the user before tool calls, +//! with session-scoped "Always" allowlists and audit logging. + +use crate::config::AutonomyConfig; +use crate::security::AutonomyLevel; +use chrono::Utc; +use parking_lot::Mutex; +use serde::{Deserialize, Serialize}; +use std::collections::HashSet; +use std::io::{self, BufRead, Write}; + +// ── Types ──────────────────────────────────────────────────────── + +/// A request to approve a tool call before execution. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ApprovalRequest { + pub tool_name: String, + pub arguments: serde_json::Value, +} + +/// The user's response to an approval request. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ApprovalResponse { + /// Execute this one call. + Yes, + /// Deny this call. + No, + /// Execute and add tool to session-scoped allowlist. + Always, +} + +/// A single audit log entry for an approval decision. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ApprovalLogEntry { + pub timestamp: String, + pub tool_name: String, + pub arguments_summary: String, + pub decision: ApprovalResponse, + pub channel: String, +} + +// ── ApprovalManager ────────────────────────────────────────────── + +/// Manages the interactive approval workflow. +/// +/// - Checks config-level `auto_approve` / `always_ask` lists +/// - Maintains a session-scoped "always" allowlist +/// - Records an audit trail of all decisions +pub struct ApprovalManager { + /// Tools that never need approval (from config). + auto_approve: HashSet, + /// Tools that always need approval, ignoring session allowlist. + always_ask: HashSet, + /// Autonomy level from config. + autonomy_level: AutonomyLevel, + /// Session-scoped allowlist built from "Always" responses. + session_allowlist: Mutex>, + /// Audit trail of approval decisions. + audit_log: Mutex>, +} + +impl ApprovalManager { + /// Create from autonomy config. + pub fn from_config(config: &AutonomyConfig) -> Self { + Self { + auto_approve: config.auto_approve.iter().cloned().collect(), + always_ask: config.always_ask.iter().cloned().collect(), + autonomy_level: config.level, + session_allowlist: Mutex::new(HashSet::new()), + audit_log: Mutex::new(Vec::new()), + } + } + + /// Check whether a tool call requires interactive approval. + /// + /// Returns `true` if the call needs a prompt, `false` if it can proceed. + pub fn needs_approval(&self, tool_name: &str) -> bool { + // Full autonomy never prompts. + if self.autonomy_level == AutonomyLevel::Full { + return false; + } + + // ReadOnly blocks everything — handled elsewhere; no prompt needed. + if self.autonomy_level == AutonomyLevel::ReadOnly { + return false; + } + + // always_ask overrides everything. + if self.always_ask.contains(tool_name) { + return true; + } + + // auto_approve skips the prompt. + if self.auto_approve.contains(tool_name) { + return false; + } + + // Session allowlist (from prior "Always" responses). + let allowlist = self.session_allowlist.lock(); + if allowlist.contains(tool_name) { + return false; + } + + // Default: supervised mode requires approval. + true + } + + /// Record an approval decision and update session state. + pub fn record_decision( + &self, + tool_name: &str, + args: &serde_json::Value, + decision: ApprovalResponse, + channel: &str, + ) { + // If "Always", add to session allowlist. + if decision == ApprovalResponse::Always { + let mut allowlist = self.session_allowlist.lock(); + allowlist.insert(tool_name.to_string()); + } + + // Append to audit log. + let summary = summarize_args(args); + let entry = ApprovalLogEntry { + timestamp: Utc::now().to_rfc3339(), + tool_name: tool_name.to_string(), + arguments_summary: summary, + decision, + channel: channel.to_string(), + }; + let mut log = self.audit_log.lock(); + log.push(entry); + } + + /// Get a snapshot of the audit log. + pub fn audit_log(&self) -> Vec { + self.audit_log.lock().clone() + } + + /// Get the current session allowlist. + pub fn session_allowlist(&self) -> HashSet { + self.session_allowlist.lock().clone() + } + + /// Prompt the user on the CLI and return their decision. + /// + /// For non-CLI channels, returns `Yes` automatically (interactive + /// approval is only supported on CLI for now). + pub fn prompt_cli(&self, request: &ApprovalRequest) -> ApprovalResponse { + prompt_cli_interactive(request) + } +} + +// ── CLI prompt ─────────────────────────────────────────────────── + +/// Display the approval prompt and read user input from stdin. +fn prompt_cli_interactive(request: &ApprovalRequest) -> ApprovalResponse { + let summary = summarize_args(&request.arguments); + eprintln!(); + eprintln!("🔧 Agent wants to execute: {}", request.tool_name); + eprintln!(" {summary}"); + eprint!(" [Y]es / [N]o / [A]lways for {}: ", request.tool_name); + let _ = io::stderr().flush(); + + let stdin = io::stdin(); + let mut line = String::new(); + if stdin.lock().read_line(&mut line).is_err() { + return ApprovalResponse::No; + } + + match line.trim().to_ascii_lowercase().as_str() { + "y" | "yes" => ApprovalResponse::Yes, + "a" | "always" => ApprovalResponse::Always, + _ => ApprovalResponse::No, + } +} + +/// Produce a short human-readable summary of tool arguments. +fn summarize_args(args: &serde_json::Value) -> String { + match args { + serde_json::Value::Object(map) => { + let parts: Vec = map + .iter() + .map(|(k, v)| { + let val = match v { + serde_json::Value::String(s) => truncate_for_summary(s, 80), + other => { + let s = other.to_string(); + truncate_for_summary(&s, 80) + } + }; + format!("{k}: {val}") + }) + .collect(); + parts.join(", ") + } + other => { + let s = other.to_string(); + truncate_for_summary(&s, 120) + } + } +} + +fn truncate_for_summary(input: &str, max_chars: usize) -> String { + let mut chars = input.chars(); + let truncated: String = chars.by_ref().take(max_chars).collect(); + if chars.next().is_some() { + format!("{truncated}…") + } else { + input.to_string() + } +} + +// ── Tests ──────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::AutonomyConfig; + + fn supervised_config() -> AutonomyConfig { + AutonomyConfig { + level: AutonomyLevel::Supervised, + auto_approve: vec!["file_read".into(), "memory_recall".into()], + always_ask: vec!["shell".into()], + ..AutonomyConfig::default() + } + } + + fn full_config() -> AutonomyConfig { + AutonomyConfig { + level: AutonomyLevel::Full, + ..AutonomyConfig::default() + } + } + + // ── needs_approval ─────────────────────────────────────── + + #[test] + fn auto_approve_tools_skip_prompt() { + let mgr = ApprovalManager::from_config(&supervised_config()); + assert!(!mgr.needs_approval("file_read")); + assert!(!mgr.needs_approval("memory_recall")); + } + + #[test] + fn always_ask_tools_always_prompt() { + let mgr = ApprovalManager::from_config(&supervised_config()); + assert!(mgr.needs_approval("shell")); + } + + #[test] + fn unknown_tool_needs_approval_in_supervised() { + let mgr = ApprovalManager::from_config(&supervised_config()); + assert!(mgr.needs_approval("file_write")); + assert!(mgr.needs_approval("http_request")); + } + + #[test] + fn full_autonomy_never_prompts() { + let mgr = ApprovalManager::from_config(&full_config()); + assert!(!mgr.needs_approval("shell")); + assert!(!mgr.needs_approval("file_write")); + assert!(!mgr.needs_approval("anything")); + } + + #[test] + fn readonly_never_prompts() { + let config = AutonomyConfig { + level: AutonomyLevel::ReadOnly, + ..AutonomyConfig::default() + }; + let mgr = ApprovalManager::from_config(&config); + assert!(!mgr.needs_approval("shell")); + } + + // ── session allowlist ──────────────────────────────────── + + #[test] + fn always_response_adds_to_session_allowlist() { + let mgr = ApprovalManager::from_config(&supervised_config()); + assert!(mgr.needs_approval("file_write")); + + mgr.record_decision( + "file_write", + &serde_json::json!({"path": "test.txt"}), + ApprovalResponse::Always, + "cli", + ); + + // Now file_write should be in session allowlist. + assert!(!mgr.needs_approval("file_write")); + } + + #[test] + fn always_ask_overrides_session_allowlist() { + let mgr = ApprovalManager::from_config(&supervised_config()); + + // Even after "Always" for shell, it should still prompt. + mgr.record_decision( + "shell", + &serde_json::json!({"command": "ls"}), + ApprovalResponse::Always, + "cli", + ); + + // shell is in always_ask, so it still needs approval. + assert!(mgr.needs_approval("shell")); + } + + #[test] + fn yes_response_does_not_add_to_allowlist() { + let mgr = ApprovalManager::from_config(&supervised_config()); + mgr.record_decision( + "file_write", + &serde_json::json!({}), + ApprovalResponse::Yes, + "cli", + ); + assert!(mgr.needs_approval("file_write")); + } + + // ── audit log ──────────────────────────────────────────── + + #[test] + fn audit_log_records_decisions() { + let mgr = ApprovalManager::from_config(&supervised_config()); + + mgr.record_decision( + "shell", + &serde_json::json!({"command": "rm -rf ./build/"}), + ApprovalResponse::No, + "cli", + ); + mgr.record_decision( + "file_write", + &serde_json::json!({"path": "out.txt", "content": "hello"}), + ApprovalResponse::Yes, + "cli", + ); + + let log = mgr.audit_log(); + assert_eq!(log.len(), 2); + assert_eq!(log[0].tool_name, "shell"); + assert_eq!(log[0].decision, ApprovalResponse::No); + assert_eq!(log[1].tool_name, "file_write"); + assert_eq!(log[1].decision, ApprovalResponse::Yes); + } + + #[test] + fn audit_log_contains_timestamp_and_channel() { + let mgr = ApprovalManager::from_config(&supervised_config()); + mgr.record_decision( + "shell", + &serde_json::json!({"command": "ls"}), + ApprovalResponse::Yes, + "telegram", + ); + + let log = mgr.audit_log(); + assert_eq!(log.len(), 1); + assert!(!log[0].timestamp.is_empty()); + assert_eq!(log[0].channel, "telegram"); + } + + // ── summarize_args ─────────────────────────────────────── + + #[test] + fn summarize_args_object() { + let args = serde_json::json!({"command": "ls -la", "cwd": "/tmp"}); + let summary = summarize_args(&args); + assert!(summary.contains("command: ls -la")); + assert!(summary.contains("cwd: /tmp")); + } + + #[test] + fn summarize_args_truncates_long_values() { + let long_val = "x".repeat(200); + let args = serde_json::json!({"content": long_val}); + let summary = summarize_args(&args); + assert!(summary.contains('…')); + assert!(summary.len() < 200); + } + + #[test] + fn summarize_args_unicode_safe_truncation() { + let long_val = "🦀".repeat(120); + let args = serde_json::json!({"content": long_val}); + let summary = summarize_args(&args); + assert!(summary.contains("content:")); + assert!(summary.contains('…')); + } + + #[test] + fn summarize_args_non_object() { + let args = serde_json::json!("just a string"); + let summary = summarize_args(&args); + assert!(summary.contains("just a string")); + } + + // ── ApprovalResponse serde ─────────────────────────────── + + #[test] + fn approval_response_serde_roundtrip() { + let json = serde_json::to_string(&ApprovalResponse::Always).unwrap(); + assert_eq!(json, "\"always\""); + let parsed: ApprovalResponse = serde_json::from_str("\"no\"").unwrap(); + assert_eq!(parsed, ApprovalResponse::No); + } + + // ── ApprovalRequest ────────────────────────────────────── + + #[test] + fn approval_request_serde() { + let req = ApprovalRequest { + tool_name: "shell".into(), + arguments: serde_json::json!({"command": "echo hi"}), + }; + let json = serde_json::to_string(&req).unwrap(); + let parsed: ApprovalRequest = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.tool_name, "shell"); + } +} diff --git a/src/channels/cli.rs b/src/channels/cli.rs index 8b414fd..ae49548 100644 --- a/src/channels/cli.rs +++ b/src/channels/cli.rs @@ -1,4 +1,4 @@ -use super::traits::{Channel, ChannelMessage}; +use super::traits::{Channel, ChannelMessage, SendMessage}; use async_trait::async_trait; use tokio::io::{self, AsyncBufReadExt, BufReader}; use uuid::Uuid; @@ -18,8 +18,8 @@ impl Channel for CliChannel { "cli" } - async fn send(&self, message: &str, _recipient: &str) -> anyhow::Result<()> { - println!("{message}"); + async fn send(&self, message: &SendMessage) -> anyhow::Result<()> { + println!("{}", message.content); Ok(()) } @@ -40,6 +40,7 @@ impl Channel for CliChannel { let msg = ChannelMessage { id: Uuid::new_v4().to_string(), sender: "user".to_string(), + reply_target: "user".to_string(), content: line, channel: "cli".to_string(), timestamp: std::time::SystemTime::now() @@ -68,14 +69,26 @@ mod tests { #[tokio::test] async fn cli_channel_send_does_not_panic() { let ch = CliChannel::new(); - let result = ch.send("hello", "user").await; + let result = ch + .send(&SendMessage { + content: "hello".into(), + recipient: "user".into(), + subject: None, + }) + .await; assert!(result.is_ok()); } #[tokio::test] async fn cli_channel_send_empty_message() { let ch = CliChannel::new(); - let result = ch.send("", "").await; + let result = ch + .send(&SendMessage { + content: String::new(), + recipient: String::new(), + subject: None, + }) + .await; assert!(result.is_ok()); } @@ -90,12 +103,14 @@ mod tests { let msg = ChannelMessage { id: "test-id".into(), sender: "user".into(), + reply_target: "user".into(), content: "hello".into(), channel: "cli".into(), timestamp: 1_234_567_890, }; assert_eq!(msg.id, "test-id"); assert_eq!(msg.sender, "user"); + assert_eq!(msg.reply_target, "user"); assert_eq!(msg.content, "hello"); assert_eq!(msg.channel, "cli"); assert_eq!(msg.timestamp, 1_234_567_890); @@ -106,6 +121,7 @@ mod tests { let msg = ChannelMessage { id: "id".into(), sender: "s".into(), + reply_target: "s".into(), content: "c".into(), channel: "ch".into(), timestamp: 0, diff --git a/src/channels/dingtalk.rs b/src/channels/dingtalk.rs new file mode 100644 index 0000000..cd0ac7d --- /dev/null +++ b/src/channels/dingtalk.rs @@ -0,0 +1,374 @@ +use super::traits::{Channel, ChannelMessage, SendMessage}; +use async_trait::async_trait; +use futures_util::{SinkExt, StreamExt}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; +use tokio_tungstenite::tungstenite::Message; +use uuid::Uuid; + +const DINGTALK_BOT_CALLBACK_TOPIC: &str = "/v1.0/im/bot/messages/get"; + +/// DingTalk channel — connects via Stream Mode WebSocket for real-time messages. +/// Replies are sent through per-message session webhook URLs. +pub struct DingTalkChannel { + client_id: String, + client_secret: String, + allowed_users: Vec, + client: reqwest::Client, + /// Per-chat session webhooks for sending replies (chatID -> webhook URL). + /// DingTalk provides a unique webhook URL with each incoming message. + session_webhooks: Arc>>, +} + +/// Response from DingTalk gateway connection registration. +#[derive(serde::Deserialize)] +struct GatewayResponse { + endpoint: String, + ticket: String, +} + +impl DingTalkChannel { + pub fn new(client_id: String, client_secret: String, allowed_users: Vec) -> Self { + Self { + client_id, + client_secret, + allowed_users, + client: reqwest::Client::new(), + session_webhooks: Arc::new(RwLock::new(HashMap::new())), + } + } + + fn is_user_allowed(&self, user_id: &str) -> bool { + self.allowed_users.iter().any(|u| u == "*" || u == user_id) + } + + fn parse_stream_data(frame: &serde_json::Value) -> Option { + match frame.get("data") { + Some(serde_json::Value::String(raw)) => serde_json::from_str(raw).ok(), + Some(serde_json::Value::Object(_)) => frame.get("data").cloned(), + _ => None, + } + } + + fn resolve_chat_id(data: &serde_json::Value, sender_id: &str) -> String { + let is_private_chat = data + .get("conversationType") + .and_then(|value| { + value + .as_str() + .map(|v| v == "1") + .or_else(|| value.as_i64().map(|v| v == 1)) + }) + .unwrap_or(true); + + if is_private_chat { + sender_id.to_string() + } else { + data.get("conversationId") + .and_then(|c| c.as_str()) + .unwrap_or(sender_id) + .to_string() + } + } + + /// Register a connection with DingTalk's gateway to get a WebSocket endpoint. + async fn register_connection(&self) -> anyhow::Result { + let body = serde_json::json!({ + "clientId": self.client_id, + "clientSecret": self.client_secret, + "subscriptions": [ + { + "type": "CALLBACK", + "topic": DINGTALK_BOT_CALLBACK_TOPIC, + } + ], + }); + + let resp = self + .client + .post("https://api.dingtalk.com/v1.0/gateway/connections/open") + .json(&body) + .send() + .await?; + + if !resp.status().is_success() { + let status = resp.status(); + let err = resp.text().await.unwrap_or_default(); + anyhow::bail!("DingTalk gateway registration failed ({status}): {err}"); + } + + let gw: GatewayResponse = resp.json().await?; + Ok(gw) + } +} + +#[async_trait] +impl Channel for DingTalkChannel { + fn name(&self) -> &str { + "dingtalk" + } + + async fn send(&self, message: &SendMessage) -> anyhow::Result<()> { + let webhooks = self.session_webhooks.read().await; + let webhook_url = webhooks.get(&message.recipient).ok_or_else(|| { + anyhow::anyhow!( + "No session webhook found for chat {}. \ + The user must send a message first to establish a session.", + message.recipient + ) + })?; + + let title = message.subject.as_deref().unwrap_or("ZeroClaw"); + let body = serde_json::json!({ + "msgtype": "markdown", + "markdown": { + "title": title, + "text": message.content, + } + }); + + let resp = self.client.post(webhook_url).json(&body).send().await?; + + if !resp.status().is_success() { + let status = resp.status(); + let err = resp.text().await.unwrap_or_default(); + anyhow::bail!("DingTalk webhook reply failed ({status}): {err}"); + } + + Ok(()) + } + + async fn listen(&self, tx: tokio::sync::mpsc::Sender) -> anyhow::Result<()> { + tracing::info!("DingTalk: registering gateway connection..."); + + let gw = self.register_connection().await?; + let ws_url = format!("{}?ticket={}", gw.endpoint, gw.ticket); + + tracing::info!("DingTalk: connecting to stream WebSocket..."); + let (ws_stream, _) = tokio_tungstenite::connect_async(&ws_url).await?; + let (mut write, mut read) = ws_stream.split(); + + tracing::info!("DingTalk: connected and listening for messages..."); + + while let Some(msg) = read.next().await { + let msg = match msg { + Ok(Message::Text(t)) => t, + Ok(Message::Close(_)) => break, + Err(e) => { + tracing::warn!("DingTalk WebSocket error: {e}"); + break; + } + _ => continue, + }; + + let frame: serde_json::Value = match serde_json::from_str(&msg) { + Ok(v) => v, + Err(_) => continue, + }; + + let frame_type = frame.get("type").and_then(|t| t.as_str()).unwrap_or(""); + + match frame_type { + "SYSTEM" => { + // Respond to system pings to keep the connection alive + let message_id = frame + .get("headers") + .and_then(|h| h.get("messageId")) + .and_then(|m| m.as_str()) + .unwrap_or(""); + + let pong = serde_json::json!({ + "code": 200, + "headers": { + "contentType": "application/json", + "messageId": message_id, + }, + "message": "OK", + "data": "", + }); + + if let Err(e) = write.send(Message::Text(pong.to_string())).await { + tracing::warn!("DingTalk: failed to send pong: {e}"); + break; + } + } + "EVENT" | "CALLBACK" => { + // Parse the chatbot callback data from the frame. + let data = match Self::parse_stream_data(&frame) { + Some(v) => v, + None => { + tracing::debug!("DingTalk: frame has no parseable data payload"); + continue; + } + }; + + // Extract message content + let content = data + .get("text") + .and_then(|t| t.get("content")) + .and_then(|c| c.as_str()) + .unwrap_or("") + .trim(); + + if content.is_empty() { + continue; + } + + let sender_id = data + .get("senderStaffId") + .and_then(|s| s.as_str()) + .unwrap_or("unknown"); + + if !self.is_user_allowed(sender_id) { + tracing::warn!( + "DingTalk: ignoring message from unauthorized user: {sender_id}" + ); + continue; + } + + // Private chat uses sender ID, group chat uses conversation ID. + let chat_id = Self::resolve_chat_id(&data, sender_id); + + // Store session webhook for later replies + if let Some(webhook) = data.get("sessionWebhook").and_then(|w| w.as_str()) { + let webhook = webhook.to_string(); + let mut webhooks = self.session_webhooks.write().await; + // Use both keys so reply routing works for both group and private flows. + webhooks.insert(chat_id.clone(), webhook.clone()); + webhooks.insert(sender_id.to_string(), webhook); + } + + // Acknowledge the event + let message_id = frame + .get("headers") + .and_then(|h| h.get("messageId")) + .and_then(|m| m.as_str()) + .unwrap_or(""); + + let ack = serde_json::json!({ + "code": 200, + "headers": { + "contentType": "application/json", + "messageId": message_id, + }, + "message": "OK", + "data": "", + }); + let _ = write.send(Message::Text(ack.to_string())).await; + + let channel_msg = ChannelMessage { + id: Uuid::new_v4().to_string(), + sender: sender_id.to_string(), + reply_target: chat_id, + content: content.to_string(), + channel: "dingtalk".to_string(), + timestamp: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(), + }; + + if tx.send(channel_msg).await.is_err() { + tracing::warn!("DingTalk: message channel closed"); + break; + } + } + _ => {} + } + } + + anyhow::bail!("DingTalk WebSocket stream ended") + } + + async fn health_check(&self) -> bool { + self.register_connection().await.is_ok() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_name() { + let ch = DingTalkChannel::new("id".into(), "secret".into(), vec![]); + assert_eq!(ch.name(), "dingtalk"); + } + + #[test] + fn test_user_allowed_wildcard() { + let ch = DingTalkChannel::new("id".into(), "secret".into(), vec!["*".into()]); + assert!(ch.is_user_allowed("anyone")); + } + + #[test] + fn test_user_allowed_specific() { + let ch = DingTalkChannel::new("id".into(), "secret".into(), vec!["user123".into()]); + assert!(ch.is_user_allowed("user123")); + assert!(!ch.is_user_allowed("other")); + } + + #[test] + fn test_user_denied_empty() { + let ch = DingTalkChannel::new("id".into(), "secret".into(), vec![]); + assert!(!ch.is_user_allowed("anyone")); + } + + #[test] + fn test_config_serde() { + let toml_str = r#" +client_id = "app_id_123" +client_secret = "secret_456" +allowed_users = ["user1", "*"] +"#; + let config: crate::config::schema::DingTalkConfig = toml::from_str(toml_str).unwrap(); + assert_eq!(config.client_id, "app_id_123"); + assert_eq!(config.client_secret, "secret_456"); + assert_eq!(config.allowed_users, vec!["user1", "*"]); + } + + #[test] + fn test_config_serde_defaults() { + let toml_str = r#" +client_id = "id" +client_secret = "secret" +"#; + let config: crate::config::schema::DingTalkConfig = toml::from_str(toml_str).unwrap(); + assert!(config.allowed_users.is_empty()); + } + + #[test] + fn parse_stream_data_supports_string_payload() { + let frame = serde_json::json!({ + "data": "{\"text\":{\"content\":\"hello\"}}" + }); + let parsed = DingTalkChannel::parse_stream_data(&frame).unwrap(); + assert_eq!( + parsed.get("text").and_then(|v| v.get("content")), + Some(&serde_json::json!("hello")) + ); + } + + #[test] + fn parse_stream_data_supports_object_payload() { + let frame = serde_json::json!({ + "data": {"text": {"content": "hello"}} + }); + let parsed = DingTalkChannel::parse_stream_data(&frame).unwrap(); + assert_eq!( + parsed.get("text").and_then(|v| v.get("content")), + Some(&serde_json::json!("hello")) + ); + } + + #[test] + fn resolve_chat_id_handles_numeric_group_conversation_type() { + let data = serde_json::json!({ + "conversationType": 2, + "conversationId": "cid-group", + }); + let chat_id = DingTalkChannel::resolve_chat_id(&data, "staff-1"); + assert_eq!(chat_id, "cid-group"); + } +} diff --git a/src/channels/discord.rs b/src/channels/discord.rs index fd5fe37..939d47c 100644 --- a/src/channels/discord.rs +++ b/src/channels/discord.rs @@ -1,6 +1,7 @@ -use super::traits::{Channel, ChannelMessage}; +use super::traits::{Channel, ChannelMessage, SendMessage}; use async_trait::async_trait; use futures_util::{SinkExt, StreamExt}; +use parking_lot::Mutex; use serde_json::json; use tokio_tungstenite::tungstenite::Message; use uuid::Uuid; @@ -10,16 +11,28 @@ pub struct DiscordChannel { bot_token: String, guild_id: Option, allowed_users: Vec, + listen_to_bots: bool, + mention_only: bool, client: reqwest::Client, + typing_handle: Mutex>>, } impl DiscordChannel { - pub fn new(bot_token: String, guild_id: Option, allowed_users: Vec) -> Self { + pub fn new( + bot_token: String, + guild_id: Option, + allowed_users: Vec, + listen_to_bots: bool, + mention_only: bool, + ) -> Self { Self { bot_token, guild_id, allowed_users, + listen_to_bots, + mention_only, client: reqwest::Client::new(), + typing_handle: Mutex::new(None), } } @@ -39,6 +52,96 @@ impl DiscordChannel { const BASE64_ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; +/// Discord's maximum message length for regular messages. +/// +/// Discord rejects longer payloads with `50035 Invalid Form Body`. +const DISCORD_MAX_MESSAGE_LENGTH: usize = 2000; + +/// Split a message into chunks that respect Discord's 2000-character limit. +/// Tries to split at word boundaries when possible. +fn split_message_for_discord(message: &str) -> Vec { + if message.chars().count() <= DISCORD_MAX_MESSAGE_LENGTH { + return vec![message.to_string()]; + } + + let mut chunks = Vec::new(); + let mut remaining = message; + + while !remaining.is_empty() { + // Find the byte offset for the 2000th character boundary. + // If there are fewer than 2000 chars left, we can emit the tail directly. + let hard_split = remaining + .char_indices() + .nth(DISCORD_MAX_MESSAGE_LENGTH) + .map_or(remaining.len(), |(idx, _)| idx); + + let chunk_end = if hard_split == remaining.len() { + hard_split + } else { + // Try to find a good break point (newline, then space) + let search_area = &remaining[..hard_split]; + + // Prefer splitting at newline + if let Some(pos) = search_area.rfind('\n') { + // Don't split if the newline is too close to the end + if search_area[..pos].chars().count() >= DISCORD_MAX_MESSAGE_LENGTH / 2 { + pos + 1 + } else { + // Try space as fallback + search_area.rfind(' ').map_or(hard_split, |space| space + 1) + } + } else if let Some(pos) = search_area.rfind(' ') { + pos + 1 + } else { + // Hard split at the limit + hard_split + } + }; + + chunks.push(remaining[..chunk_end].to_string()); + remaining = &remaining[chunk_end..]; + } + + chunks +} + +fn mention_tags(bot_user_id: &str) -> [String; 2] { + [format!("<@{bot_user_id}>"), format!("<@!{bot_user_id}>")] +} + +fn contains_bot_mention(content: &str, bot_user_id: &str) -> bool { + let tags = mention_tags(bot_user_id); + content.contains(&tags[0]) || content.contains(&tags[1]) +} + +fn normalize_incoming_content( + content: &str, + mention_only: bool, + bot_user_id: &str, +) -> Option { + if content.is_empty() { + return None; + } + + if mention_only && !contains_bot_mention(content, bot_user_id) { + return None; + } + + let mut normalized = content.to_string(); + if mention_only { + for tag in mention_tags(bot_user_id) { + normalized = normalized.replace(&tag, " "); + } + } + + let normalized = normalized.trim().to_string(); + if normalized.is_empty() { + return None; + } + + Some(normalized) +} + /// Minimal base64 decode (no extra dep) — only needs to decode the user ID portion #[allow(clippy::cast_possible_truncation)] fn base64_decode(input: &str) -> Option { @@ -83,16 +186,39 @@ impl Channel for DiscordChannel { "discord" } - async fn send(&self, message: &str, channel_id: &str) -> anyhow::Result<()> { - let url = format!("https://discord.com/api/v10/channels/{channel_id}/messages"); - let body = json!({ "content": message }); + async fn send(&self, message: &SendMessage) -> anyhow::Result<()> { + let chunks = split_message_for_discord(&message.content); - self.client - .post(&url) - .header("Authorization", format!("Bot {}", self.bot_token)) - .json(&body) - .send() - .await?; + for (i, chunk) in chunks.iter().enumerate() { + let url = format!( + "https://discord.com/api/v10/channels/{}/messages", + message.recipient + ); + + let body = json!({ "content": chunk }); + + let resp = self + .client + .post(&url) + .header("Authorization", format!("Bot {}", self.bot_token)) + .json(&body) + .send() + .await?; + + if !resp.status().is_success() { + let status = resp.status(); + let err = resp + .text() + .await + .unwrap_or_else(|e| format!("")); + anyhow::bail!("Discord send message failed ({status}): {err}"); + } + + // Add a small delay between chunks to avoid rate limiting + if i < chunks.len() - 1 { + tokio::time::sleep(std::time::Duration::from_millis(500)).await; + } + } Ok(()) } @@ -136,7 +262,7 @@ impl Channel for DiscordChannel { "op": 2, "d": { "token": self.bot_token, - "intents": 33281, // GUILDS | GUILD_MESSAGES | MESSAGE_CONTENT | DIRECT_MESSAGES + "intents": 37377, // GUILDS | GUILD_MESSAGES | MESSAGE_CONTENT | DIRECT_MESSAGES "properties": { "os": "linux", "browser": "zeroclaw", @@ -148,7 +274,12 @@ impl Channel for DiscordChannel { tracing::info!("Discord: connected and identified"); - // Spawn heartbeat task + // Track the last sequence number for heartbeats and resume. + // Only accessed in the select! loop below, so a plain i64 suffices. + let mut sequence: i64 = -1; + + // Spawn heartbeat timer — sends a tick signal, actual heartbeat + // is assembled in the select! loop where `sequence` lives. let (hb_tx, mut hb_rx) = tokio::sync::mpsc::channel::<()>(1); let hb_interval = heartbeat_interval; tokio::spawn(async move { @@ -166,7 +297,8 @@ impl Channel for DiscordChannel { loop { tokio::select! { _ = hb_rx.recv() => { - let hb = json!({"op": 1, "d": null}); + let d = if sequence >= 0 { json!(sequence) } else { json!(null) }; + let hb = json!({"op": 1, "d": d}); if write.send(Message::Text(hb.to_string())).await.is_err() { break; } @@ -183,6 +315,36 @@ impl Channel for DiscordChannel { Err(_) => continue, }; + // Track sequence number from all dispatch events + if let Some(s) = event.get("s").and_then(serde_json::Value::as_i64) { + sequence = s; + } + + let op = event.get("op").and_then(serde_json::Value::as_u64).unwrap_or(0); + + match op { + // Op 1: Server requests an immediate heartbeat + 1 => { + let d = if sequence >= 0 { json!(sequence) } else { json!(null) }; + let hb = json!({"op": 1, "d": d}); + if write.send(Message::Text(hb.to_string())).await.is_err() { + break; + } + continue; + } + // Op 7: Reconnect + 7 => { + tracing::warn!("Discord: received Reconnect (op 7), closing for restart"); + break; + } + // Op 9: Invalid Session + 9 => { + tracing::warn!("Discord: received Invalid Session (op 9), closing for restart"); + break; + } + _ => {} + } + // Only handle MESSAGE_CREATE (opcode 0, type "MESSAGE_CREATE") let event_type = event.get("t").and_then(|t| t.as_str()).unwrap_or(""); if event_type != "MESSAGE_CREATE" { @@ -199,8 +361,8 @@ impl Channel for DiscordChannel { continue; } - // Skip bot messages - if d.get("author").and_then(|a| a.get("bot")).and_then(serde_json::Value::as_bool).unwrap_or(false) { + // Skip bot messages (unless listen_to_bots is enabled) + if !self.listen_to_bots && d.get("author").and_then(|a| a.get("bot")).and_then(serde_json::Value::as_bool).unwrap_or(false) { continue; } @@ -212,24 +374,39 @@ impl Channel for DiscordChannel { // Guild filter if let Some(ref gid) = guild_filter { - let msg_guild = d.get("guild_id").and_then(serde_json::Value::as_str).unwrap_or(""); - if msg_guild != gid { - continue; + let msg_guild = d.get("guild_id").and_then(serde_json::Value::as_str); + // DMs have no guild_id — let them through; for guild messages, enforce the filter + if let Some(g) = msg_guild { + if g != gid { + continue; + } } } let content = d.get("content").and_then(|c| c.as_str()).unwrap_or(""); - if content.is_empty() { + let Some(clean_content) = + normalize_incoming_content(content, self.mention_only, &bot_user_id) + else { continue; - } + }; + let message_id = d.get("id").and_then(|i| i.as_str()).unwrap_or(""); let channel_id = d.get("channel_id").and_then(|c| c.as_str()).unwrap_or("").to_string(); let channel_msg = ChannelMessage { - id: Uuid::new_v4().to_string(), - sender: channel_id, - content: content.to_string(), - channel: "discord".to_string(), + id: if message_id.is_empty() { + Uuid::new_v4().to_string() + } else { + format!("discord_{message_id}") + }, + sender: author_id.to_string(), + reply_target: if channel_id.is_empty() { + author_id.to_string() + } else { + channel_id.clone() + }, + content: clean_content, + channel: channel_id, timestamp: std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() @@ -255,6 +432,39 @@ impl Channel for DiscordChannel { .map(|r| r.status().is_success()) .unwrap_or(false) } + + async fn start_typing(&self, recipient: &str) -> anyhow::Result<()> { + self.stop_typing(recipient).await?; + + let client = self.client.clone(); + let token = self.bot_token.clone(); + let channel_id = recipient.to_string(); + + let handle = tokio::spawn(async move { + let url = format!("https://discord.com/api/v10/channels/{channel_id}/typing"); + loop { + let _ = client + .post(&url) + .header("Authorization", format!("Bot {token}")) + .send() + .await; + tokio::time::sleep(std::time::Duration::from_secs(8)).await; + } + }); + + let mut guard = self.typing_handle.lock(); + *guard = Some(handle); + + Ok(()) + } + + async fn stop_typing(&self, _recipient: &str) -> anyhow::Result<()> { + let mut guard = self.typing_handle.lock(); + if let Some(handle) = guard.take() { + handle.abort(); + } + Ok(()) + } } #[cfg(test)] @@ -263,7 +473,7 @@ mod tests { #[test] fn discord_channel_name() { - let ch = DiscordChannel::new("fake".into(), None, vec![]); + let ch = DiscordChannel::new("fake".into(), None, vec![], false, false); assert_eq!(ch.name(), "discord"); } @@ -284,21 +494,27 @@ mod tests { #[test] fn empty_allowlist_denies_everyone() { - let ch = DiscordChannel::new("fake".into(), None, vec![]); + let ch = DiscordChannel::new("fake".into(), None, vec![], false, false); assert!(!ch.is_user_allowed("12345")); assert!(!ch.is_user_allowed("anyone")); } #[test] fn wildcard_allows_everyone() { - let ch = DiscordChannel::new("fake".into(), None, vec!["*".into()]); + let ch = DiscordChannel::new("fake".into(), None, vec!["*".into()], false, false); assert!(ch.is_user_allowed("12345")); assert!(ch.is_user_allowed("anyone")); } #[test] fn specific_allowlist_filters() { - let ch = DiscordChannel::new("fake".into(), None, vec!["111".into(), "222".into()]); + let ch = DiscordChannel::new( + "fake".into(), + None, + vec!["111".into(), "222".into()], + false, + false, + ); assert!(ch.is_user_allowed("111")); assert!(ch.is_user_allowed("222")); assert!(!ch.is_user_allowed("333")); @@ -307,7 +523,7 @@ mod tests { #[test] fn allowlist_is_exact_match_not_substring() { - let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()]); + let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false, false); assert!(!ch.is_user_allowed("1111")); assert!(!ch.is_user_allowed("11")); assert!(!ch.is_user_allowed("0111")); @@ -315,20 +531,26 @@ mod tests { #[test] fn allowlist_empty_string_user_id() { - let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()]); + let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false, false); assert!(!ch.is_user_allowed("")); } #[test] fn allowlist_with_wildcard_and_specific() { - let ch = DiscordChannel::new("fake".into(), None, vec!["111".into(), "*".into()]); + let ch = DiscordChannel::new( + "fake".into(), + None, + vec!["111".into(), "*".into()], + false, + false, + ); assert!(ch.is_user_allowed("111")); assert!(ch.is_user_allowed("anyone_else")); } #[test] fn allowlist_case_sensitive() { - let ch = DiscordChannel::new("fake".into(), None, vec!["ABC".into()]); + let ch = DiscordChannel::new("fake".into(), None, vec!["ABC".into()], false, false); assert!(ch.is_user_allowed("ABC")); assert!(!ch.is_user_allowed("abc")); assert!(!ch.is_user_allowed("Abc")); @@ -351,4 +573,269 @@ mod tests { let id = DiscordChannel::bot_user_id_from_token(""); assert_eq!(id, Some(String::new())); } + + #[test] + fn contains_bot_mention_supports_plain_and_nick_forms() { + assert!(contains_bot_mention("hi <@12345>", "12345")); + assert!(contains_bot_mention("hi <@!12345>", "12345")); + assert!(!contains_bot_mention("hi <@99999>", "12345")); + } + + #[test] + fn normalize_incoming_content_requires_mention_when_enabled() { + let cleaned = normalize_incoming_content("hello there", true, "12345"); + assert!(cleaned.is_none()); + } + + #[test] + fn normalize_incoming_content_strips_mentions_and_trims() { + let cleaned = normalize_incoming_content(" <@!12345> run status ", true, "12345"); + assert_eq!(cleaned.as_deref(), Some("run status")); + } + + #[test] + fn normalize_incoming_content_rejects_empty_after_strip() { + let cleaned = normalize_incoming_content("<@12345>", true, "12345"); + assert!(cleaned.is_none()); + } + + // Message splitting tests + + #[test] + fn split_empty_message() { + let chunks = split_message_for_discord(""); + assert_eq!(chunks, vec![""]); + } + + #[test] + fn split_short_message_under_limit() { + let msg = "Hello, world!"; + let chunks = split_message_for_discord(msg); + assert_eq!(chunks, vec![msg]); + } + + #[test] + fn split_message_exactly_2000_chars() { + let msg = "a".repeat(DISCORD_MAX_MESSAGE_LENGTH); + let chunks = split_message_for_discord(&msg); + assert_eq!(chunks.len(), 1); + assert_eq!(chunks[0].chars().count(), DISCORD_MAX_MESSAGE_LENGTH); + } + + #[test] + fn split_message_just_over_limit() { + let msg = "a".repeat(DISCORD_MAX_MESSAGE_LENGTH + 1); + let chunks = split_message_for_discord(&msg); + assert_eq!(chunks.len(), 2); + assert_eq!(chunks[0].chars().count(), DISCORD_MAX_MESSAGE_LENGTH); + assert_eq!(chunks[1].chars().count(), 1); + } + + #[test] + fn split_very_long_message() { + let msg = "word ".repeat(2000); // 10000 characters (5 chars per "word ") + let chunks = split_message_for_discord(&msg); + // Should split into 5 chunks of <= 2000 chars + assert_eq!(chunks.len(), 5); + assert!(chunks + .iter() + .all(|chunk| chunk.chars().count() <= DISCORD_MAX_MESSAGE_LENGTH)); + // Verify total content is preserved + let reconstructed = chunks.concat(); + assert_eq!(reconstructed, msg); + } + + #[test] + fn split_prefer_newline_break() { + let msg = format!("{}\n{}", "a".repeat(1500), "b".repeat(500)); + let chunks = split_message_for_discord(&msg); + // Should split at the newline + assert_eq!(chunks.len(), 2); + assert!(chunks[0].ends_with('\n')); + assert!(chunks[1].starts_with('b')); + } + + #[test] + fn split_prefer_space_break() { + let msg = format!("{} {}", "a".repeat(1500), "b".repeat(600)); + let chunks = split_message_for_discord(&msg); + assert_eq!(chunks.len(), 2); + } + + #[test] + fn split_without_good_break_points_hard_split() { + // No spaces or newlines - should hard split at 2000 + let msg = "a".repeat(5000); + let chunks = split_message_for_discord(&msg); + assert_eq!(chunks.len(), 3); + assert_eq!(chunks[0].chars().count(), DISCORD_MAX_MESSAGE_LENGTH); + assert_eq!(chunks[1].chars().count(), DISCORD_MAX_MESSAGE_LENGTH); + assert_eq!(chunks[2].chars().count(), 1000); + } + + #[test] + fn split_multiple_breaks() { + // Create a message with multiple newlines + let part1 = "a".repeat(900); + let part2 = "b".repeat(900); + let part3 = "c".repeat(900); + let msg = format!("{part1}\n{part2}\n{part3}"); + let chunks = split_message_for_discord(&msg); + // Should split into 2 chunks (first two parts + third part) + assert_eq!(chunks.len(), 2); + assert!(chunks[0].chars().count() <= DISCORD_MAX_MESSAGE_LENGTH); + assert!(chunks[1].chars().count() <= DISCORD_MAX_MESSAGE_LENGTH); + } + + #[test] + fn split_preserves_content() { + let original = "Hello world! This is a test message with some content. ".repeat(200); + let chunks = split_message_for_discord(&original); + let reconstructed = chunks.concat(); + assert_eq!(reconstructed, original); + } + + #[test] + fn split_unicode_content() { + // Test with emoji and multi-byte characters + let msg = "🦀 Rust is awesome! ".repeat(500); + let chunks = split_message_for_discord(&msg); + // All chunks should be valid UTF-8 + for chunk in &chunks { + assert!(std::str::from_utf8(chunk.as_bytes()).is_ok()); + assert!(chunk.chars().count() <= DISCORD_MAX_MESSAGE_LENGTH); + } + // Reconstruct and verify + let reconstructed = chunks.concat(); + assert_eq!(reconstructed, msg); + } + + #[test] + fn split_newline_too_close_to_end() { + // If newline is in the first half, don't use it - use space instead or hard split + let msg = format!("{}\n{}", "a".repeat(1900), "b".repeat(500)); + let chunks = split_message_for_discord(&msg); + // Should split at newline since it's in the second half of the window + assert_eq!(chunks.len(), 2); + } + + #[test] + fn split_multibyte_only_content_without_panics() { + let msg = "🦀".repeat(2500); + let chunks = split_message_for_discord(&msg); + assert_eq!(chunks.len(), 2); + assert_eq!(chunks[0].chars().count(), DISCORD_MAX_MESSAGE_LENGTH); + assert_eq!(chunks[1].chars().count(), 500); + let reconstructed = chunks.concat(); + assert_eq!(reconstructed, msg); + } + + #[test] + fn split_chunks_always_within_discord_limit() { + let msg = "x".repeat(12_345); + let chunks = split_message_for_discord(&msg); + assert!(chunks + .iter() + .all(|chunk| chunk.chars().count() <= DISCORD_MAX_MESSAGE_LENGTH)); + } + + #[test] + fn split_message_with_multiple_newlines() { + let msg = "Line 1\nLine 2\nLine 3\n".repeat(1000); + let chunks = split_message_for_discord(&msg); + assert!(chunks.len() > 1); + let reconstructed = chunks.concat(); + assert_eq!(reconstructed, msg); + } + + #[test] + fn typing_handle_starts_as_none() { + let ch = DiscordChannel::new("fake".into(), None, vec![], false, false); + let guard = ch.typing_handle.lock(); + assert!(guard.is_none()); + } + + #[tokio::test] + async fn start_typing_sets_handle() { + let ch = DiscordChannel::new("fake".into(), None, vec![], false, false); + let _ = ch.start_typing("123456").await; + let guard = ch.typing_handle.lock(); + assert!(guard.is_some()); + } + + #[tokio::test] + async fn stop_typing_clears_handle() { + let ch = DiscordChannel::new("fake".into(), None, vec![], false, false); + let _ = ch.start_typing("123456").await; + let _ = ch.stop_typing("123456").await; + let guard = ch.typing_handle.lock(); + assert!(guard.is_none()); + } + + #[tokio::test] + async fn stop_typing_is_idempotent() { + let ch = DiscordChannel::new("fake".into(), None, vec![], false, false); + assert!(ch.stop_typing("123456").await.is_ok()); + assert!(ch.stop_typing("123456").await.is_ok()); + } + + #[tokio::test] + async fn start_typing_replaces_existing_task() { + let ch = DiscordChannel::new("fake".into(), None, vec![], false, false); + let _ = ch.start_typing("111").await; + let _ = ch.start_typing("222").await; + let guard = ch.typing_handle.lock(); + assert!(guard.is_some()); + } + + // ── Message ID edge cases ───────────────────────────────────── + + #[test] + fn discord_message_id_format_includes_discord_prefix() { + // Verify that message IDs follow the format: discord_{message_id} + let message_id = "123456789012345678"; + let expected_id = format!("discord_{message_id}"); + assert_eq!(expected_id, "discord_123456789012345678"); + } + + #[test] + fn discord_message_id_is_deterministic() { + // Same message_id = same ID (prevents duplicates after restart) + let message_id = "123456789012345678"; + let id1 = format!("discord_{message_id}"); + let id2 = format!("discord_{message_id}"); + assert_eq!(id1, id2); + } + + #[test] + fn discord_message_id_different_message_different_id() { + // Different message IDs produce different IDs + let id1 = "discord_123456789012345678".to_string(); + let id2 = "discord_987654321098765432".to_string(); + assert_ne!(id1, id2); + } + + #[test] + fn discord_message_id_uses_snowflake_id() { + // Discord snowflake IDs are numeric strings + let message_id = "123456789012345678"; // Typical snowflake format + let id = format!("discord_{message_id}"); + assert!(id.starts_with("discord_")); + // Snowflake IDs are numeric + assert!(message_id.chars().all(|c| c.is_ascii_digit())); + } + + #[test] + fn discord_message_id_fallback_to_uuid_on_empty() { + // Edge case: empty message_id falls back to UUID + let message_id = ""; + let id = if message_id.is_empty() { + format!("discord_{}", uuid::Uuid::new_v4()) + } else { + format!("discord_{message_id}") + }; + assert!(id.starts_with("discord_")); + // Should have UUID dashes + assert!(id.contains('-')); + } } diff --git a/src/channels/email_channel.rs b/src/channels/email_channel.rs new file mode 100644 index 0000000..8d06370 --- /dev/null +++ b/src/channels/email_channel.rs @@ -0,0 +1,844 @@ +#![allow(clippy::uninlined_format_args)] +#![allow(clippy::map_unwrap_or)] +#![allow(clippy::redundant_closure_for_method_calls)] +#![allow(clippy::cast_lossless)] +#![allow(clippy::trim_split_whitespace)] +#![allow(clippy::doc_link_with_quotes)] +#![allow(clippy::doc_markdown)] +#![allow(clippy::too_many_lines)] +#![allow(clippy::unnecessary_map_or)] + +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use lettre::message::SinglePart; +use lettre::transport::smtp::authentication::Credentials; +use lettre::{Message, SmtpTransport, Transport}; +use mail_parser::{MessageParser, MimeHeaders}; +use parking_lot::Mutex; +use serde::{Deserialize, Serialize}; +use std::collections::HashSet; +use std::io::Write as IoWrite; +use std::net::TcpStream; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use tokio::sync::mpsc; +use tokio::time::{interval, sleep}; +use tracing::{error, info, warn}; +use uuid::Uuid; + +use super::traits::{Channel, ChannelMessage, SendMessage}; + +/// Email channel configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EmailConfig { + /// IMAP server hostname + pub imap_host: String, + /// IMAP server port (default: 993 for TLS) + #[serde(default = "default_imap_port")] + pub imap_port: u16, + /// IMAP folder to poll (default: INBOX) + #[serde(default = "default_imap_folder")] + pub imap_folder: String, + /// SMTP server hostname + pub smtp_host: String, + /// SMTP server port (default: 465 for TLS) + #[serde(default = "default_smtp_port")] + pub smtp_port: u16, + /// Use TLS for SMTP (default: true) + #[serde(default = "default_true")] + pub smtp_tls: bool, + /// Email username for authentication + pub username: String, + /// Email password for authentication + pub password: String, + /// From address for outgoing emails + pub from_address: String, + /// Poll interval in seconds (default: 60) + #[serde(default = "default_poll_interval")] + pub poll_interval_secs: u64, + /// Allowed sender addresses/domains (empty = deny all, ["*"] = allow all) + #[serde(default)] + pub allowed_senders: Vec, +} + +fn default_imap_port() -> u16 { + 993 +} +fn default_smtp_port() -> u16 { + 465 +} +fn default_imap_folder() -> String { + "INBOX".into() +} +fn default_poll_interval() -> u64 { + 60 +} +fn default_true() -> bool { + true +} + +impl Default for EmailConfig { + fn default() -> Self { + Self { + imap_host: String::new(), + imap_port: default_imap_port(), + imap_folder: default_imap_folder(), + smtp_host: String::new(), + smtp_port: default_smtp_port(), + smtp_tls: true, + username: String::new(), + password: String::new(), + from_address: String::new(), + poll_interval_secs: default_poll_interval(), + allowed_senders: Vec::new(), + } + } +} + +/// Email channel — IMAP polling for inbound, SMTP for outbound +pub struct EmailChannel { + pub config: EmailConfig, + seen_messages: Mutex>, +} + +impl EmailChannel { + pub fn new(config: EmailConfig) -> Self { + Self { + config, + seen_messages: Mutex::new(HashSet::new()), + } + } + + /// Check if a sender email is in the allowlist + pub fn is_sender_allowed(&self, email: &str) -> bool { + if self.config.allowed_senders.is_empty() { + return false; // Empty = deny all + } + if self.config.allowed_senders.iter().any(|a| a == "*") { + return true; // Wildcard = allow all + } + let email_lower = email.to_lowercase(); + self.config.allowed_senders.iter().any(|allowed| { + if allowed.starts_with('@') { + // Domain match with @ prefix: "@example.com" + email_lower.ends_with(&allowed.to_lowercase()) + } else if allowed.contains('@') { + // Full email address match + allowed.eq_ignore_ascii_case(email) + } else { + // Domain match without @ prefix: "example.com" + email_lower.ends_with(&format!("@{}", allowed.to_lowercase())) + } + }) + } + + /// Strip HTML tags from content (basic) + pub fn strip_html(html: &str) -> String { + let mut result = String::new(); + let mut in_tag = false; + for ch in html.chars() { + match ch { + '<' => in_tag = true, + '>' => in_tag = false, + _ if !in_tag => result.push(ch), + _ => {} + } + } + result.split_whitespace().collect::>().join(" ") + } + + /// Extract the sender address from a parsed email + fn extract_sender(parsed: &mail_parser::Message) -> String { + parsed + .from() + .and_then(|addr| addr.first()) + .and_then(|a| a.address()) + .map(|s| s.to_string()) + .unwrap_or_else(|| "unknown".into()) + } + + /// Extract readable text from a parsed email + fn extract_text(parsed: &mail_parser::Message) -> String { + if let Some(text) = parsed.body_text(0) { + return text.to_string(); + } + if let Some(html) = parsed.body_html(0) { + return Self::strip_html(html.as_ref()); + } + for part in parsed.attachments() { + let part: &mail_parser::MessagePart = part; + if let Some(ct) = MimeHeaders::content_type(part) { + if ct.ctype() == "text" { + if let Ok(text) = std::str::from_utf8(part.contents()) { + let name = MimeHeaders::attachment_name(part).unwrap_or("file"); + return format!("[Attachment: {}]\n{}", name, text); + } + } + } + } + "(no readable content)".to_string() + } + + fn build_imap_tls_config() -> Result> { + use rustls::ClientConfig as TlsConfig; + use std::sync::Arc; + use tokio_rustls::rustls; + + let mut root_store = rustls::RootCertStore::empty(); + root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + + let crypto_provider = rustls::crypto::CryptoProvider::get_default() + .cloned() + .unwrap_or_else(|| Arc::new(rustls::crypto::ring::default_provider())); + + let tls_config = TlsConfig::builder_with_provider(crypto_provider) + .with_protocol_versions(rustls::DEFAULT_VERSIONS)? + .with_root_certificates(root_store) + .with_no_client_auth(); + + Ok(Arc::new(tls_config)) + } + + /// Fetch unseen emails via IMAP (blocking, run in spawn_blocking) + fn fetch_unseen_imap(config: &EmailConfig) -> Result> { + use rustls_pki_types::ServerName; + use tokio_rustls::rustls; + + // Connect TCP + let tcp = TcpStream::connect((&*config.imap_host, config.imap_port))?; + tcp.set_read_timeout(Some(Duration::from_secs(30)))?; + + // TLS + let tls_config = Self::build_imap_tls_config()?; + let server_name: ServerName<'_> = ServerName::try_from(config.imap_host.clone())?; + let conn = rustls::ClientConnection::new(tls_config, server_name)?; + let mut tls = rustls::StreamOwned::new(conn, tcp); + + let read_line = + |tls: &mut rustls::StreamOwned| -> Result { + let mut buf = Vec::new(); + loop { + let mut byte = [0u8; 1]; + match std::io::Read::read(tls, &mut byte) { + Ok(0) => return Err(anyhow!("IMAP connection closed")), + Ok(_) => { + buf.push(byte[0]); + if buf.ends_with(b"\r\n") { + return Ok(String::from_utf8_lossy(&buf).to_string()); + } + } + Err(e) => return Err(e.into()), + } + } + }; + + let send_cmd = |tls: &mut rustls::StreamOwned, + tag: &str, + cmd: &str| + -> Result> { + let full = format!("{} {}\r\n", tag, cmd); + IoWrite::write_all(tls, full.as_bytes())?; + IoWrite::flush(tls)?; + let mut lines = Vec::new(); + loop { + let line = read_line(tls)?; + let done = line.starts_with(tag); + lines.push(line); + if done { + break; + } + } + Ok(lines) + }; + + // Read greeting + let _greeting = read_line(&mut tls)?; + + // Login + let login_resp = send_cmd( + &mut tls, + "A1", + &format!("LOGIN \"{}\" \"{}\"", config.username, config.password), + )?; + if !login_resp.last().map_or(false, |l| l.contains("OK")) { + return Err(anyhow!("IMAP login failed")); + } + + // Select folder + let _select = send_cmd( + &mut tls, + "A2", + &format!("SELECT \"{}\"", config.imap_folder), + )?; + + // Search unseen + let search_resp = send_cmd(&mut tls, "A3", "SEARCH UNSEEN")?; + let mut uids: Vec<&str> = Vec::new(); + for line in &search_resp { + if line.starts_with("* SEARCH") { + let parts: Vec<&str> = line.trim().split_whitespace().collect(); + if parts.len() > 2 { + uids.extend_from_slice(&parts[2..]); + } + } + } + + let mut results = Vec::new(); + let mut tag_counter = 4_u32; // Start after A1, A2, A3 + + for uid in &uids { + // Fetch RFC822 with unique tag + let fetch_tag = format!("A{}", tag_counter); + tag_counter += 1; + let fetch_resp = send_cmd(&mut tls, &fetch_tag, &format!("FETCH {} RFC822", uid))?; + // Reconstruct the raw email from the response (skip first and last lines) + let raw: String = fetch_resp + .iter() + .skip(1) + .take(fetch_resp.len().saturating_sub(2)) + .cloned() + .collect(); + + if let Some(parsed) = MessageParser::default().parse(raw.as_bytes()) { + let sender = Self::extract_sender(&parsed); + let subject = parsed.subject().unwrap_or("(no subject)").to_string(); + let body = Self::extract_text(&parsed); + let content = format!("Subject: {}\n\n{}", subject, body); + let msg_id = parsed + .message_id() + .map(|s| s.to_string()) + .unwrap_or_else(|| format!("gen-{}", Uuid::new_v4())); + #[allow(clippy::cast_sign_loss)] + let ts = parsed + .date() + .map(|d| { + let naive = chrono::NaiveDate::from_ymd_opt( + d.year as i32, + u32::from(d.month), + u32::from(d.day), + ) + .and_then(|date| { + date.and_hms_opt( + u32::from(d.hour), + u32::from(d.minute), + u32::from(d.second), + ) + }); + naive.map_or(0, |n| n.and_utc().timestamp() as u64) + }) + .unwrap_or_else(|| { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0) + }); + + results.push((msg_id, sender, content, ts)); + } + + // Mark as seen with unique tag + let store_tag = format!("A{tag_counter}"); + tag_counter += 1; + let _ = send_cmd( + &mut tls, + &store_tag, + &format!("STORE {uid} +FLAGS (\\Seen)"), + ); + } + + // Logout with unique tag + let logout_tag = format!("A{tag_counter}"); + let _ = send_cmd(&mut tls, &logout_tag, "LOGOUT"); + + Ok(results) + } + + fn create_smtp_transport(&self) -> Result { + let creds = Credentials::new(self.config.username.clone(), self.config.password.clone()); + let transport = if self.config.smtp_tls { + SmtpTransport::relay(&self.config.smtp_host)? + .port(self.config.smtp_port) + .credentials(creds) + .build() + } else { + SmtpTransport::builder_dangerous(&self.config.smtp_host) + .port(self.config.smtp_port) + .credentials(creds) + .build() + }; + Ok(transport) + } +} + +#[async_trait] +impl Channel for EmailChannel { + fn name(&self) -> &str { + "email" + } + + async fn send(&self, message: &SendMessage) -> Result<()> { + // Use explicit subject if provided, otherwise fall back to legacy parsing or default + let (subject, body) = if let Some(ref subj) = message.subject { + (subj.as_str(), message.content.as_str()) + } else if message.content.starts_with("Subject: ") { + if let Some(pos) = message.content.find('\n') { + (&message.content[9..pos], message.content[pos + 1..].trim()) + } else { + ("ZeroClaw Message", message.content.as_str()) + } + } else { + ("ZeroClaw Message", message.content.as_str()) + }; + + let email = Message::builder() + .from(self.config.from_address.parse()?) + .to(message.recipient.parse()?) + .subject(subject) + .singlepart(SinglePart::plain(body.to_string()))?; + + let transport = self.create_smtp_transport()?; + transport.send(&email)?; + info!("Email sent to {}", message.recipient); + Ok(()) + } + + async fn listen(&self, tx: mpsc::Sender) -> Result<()> { + info!( + "Email polling every {}s on {}", + self.config.poll_interval_secs, self.config.imap_folder + ); + let mut tick = interval(Duration::from_secs(self.config.poll_interval_secs)); + let config = self.config.clone(); + + loop { + tick.tick().await; + let cfg = config.clone(); + match tokio::task::spawn_blocking(move || Self::fetch_unseen_imap(&cfg)).await { + Ok(Ok(messages)) => { + for (id, sender, content, ts) in messages { + { + let mut seen = self.seen_messages.lock(); + if seen.contains(&id) { + continue; + } + if !self.is_sender_allowed(&sender) { + warn!("Blocked email from {}", sender); + continue; + } + seen.insert(id.clone()); + } // MutexGuard dropped before await + let msg = ChannelMessage { + id, + reply_target: sender.clone(), + sender, + content, + channel: "email".to_string(), + timestamp: ts, + }; + if tx.send(msg).await.is_err() { + return Ok(()); + } + } + } + Ok(Err(e)) => { + error!("Email poll failed: {}", e); + sleep(Duration::from_secs(10)).await; + } + Err(e) => { + error!("Email poll task panicked: {}", e); + sleep(Duration::from_secs(10)).await; + } + } + } + } + + async fn health_check(&self) -> bool { + let cfg = self.config.clone(); + tokio::task::spawn_blocking(move || { + let tcp = TcpStream::connect((&*cfg.imap_host, cfg.imap_port)); + tcp.is_ok() + }) + .await + .unwrap_or_default() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn default_smtp_port_uses_tls_port() { + assert_eq!(default_smtp_port(), 465); + } + + #[test] + fn email_config_default_uses_tls_smtp_defaults() { + let config = EmailConfig::default(); + assert_eq!(config.smtp_port, 465); + assert!(config.smtp_tls); + } + + #[test] + fn build_imap_tls_config_succeeds() { + let tls_config = + EmailChannel::build_imap_tls_config().expect("TLS config construction should succeed"); + assert_eq!(std::sync::Arc::strong_count(&tls_config), 1); + } + + #[test] + fn seen_messages_starts_empty() { + let channel = EmailChannel::new(EmailConfig::default()); + let seen = channel.seen_messages.lock(); + assert!(seen.is_empty()); + } + + #[test] + fn seen_messages_tracks_unique_ids() { + let channel = EmailChannel::new(EmailConfig::default()); + let mut seen = channel.seen_messages.lock(); + + assert!(seen.insert("first-id".to_string())); + assert!(!seen.insert("first-id".to_string())); + assert!(seen.insert("second-id".to_string())); + assert_eq!(seen.len(), 2); + } + + // EmailConfig tests + + #[test] + fn email_config_default() { + let config = EmailConfig::default(); + assert_eq!(config.imap_host, ""); + assert_eq!(config.imap_port, 993); + assert_eq!(config.imap_folder, "INBOX"); + assert_eq!(config.smtp_host, ""); + assert_eq!(config.smtp_port, 465); + assert!(config.smtp_tls); + assert_eq!(config.username, ""); + assert_eq!(config.password, ""); + assert_eq!(config.from_address, ""); + assert_eq!(config.poll_interval_secs, 60); + assert!(config.allowed_senders.is_empty()); + } + + #[test] + fn email_config_custom() { + let config = EmailConfig { + imap_host: "imap.example.com".to_string(), + imap_port: 993, + imap_folder: "Archive".to_string(), + smtp_host: "smtp.example.com".to_string(), + smtp_port: 465, + smtp_tls: true, + username: "user@example.com".to_string(), + password: "pass123".to_string(), + from_address: "bot@example.com".to_string(), + poll_interval_secs: 30, + allowed_senders: vec!["allowed@example.com".to_string()], + }; + assert_eq!(config.imap_host, "imap.example.com"); + assert_eq!(config.imap_folder, "Archive"); + assert_eq!(config.poll_interval_secs, 30); + } + + #[test] + fn email_config_clone() { + let config = EmailConfig { + imap_host: "imap.test.com".to_string(), + imap_port: 993, + imap_folder: "INBOX".to_string(), + smtp_host: "smtp.test.com".to_string(), + smtp_port: 587, + smtp_tls: true, + username: "user@test.com".to_string(), + password: "secret".to_string(), + from_address: "bot@test.com".to_string(), + poll_interval_secs: 120, + allowed_senders: vec!["*".to_string()], + }; + let cloned = config.clone(); + assert_eq!(cloned.imap_host, config.imap_host); + assert_eq!(cloned.smtp_port, config.smtp_port); + assert_eq!(cloned.allowed_senders, config.allowed_senders); + } + + // EmailChannel tests + + #[test] + fn email_channel_new() { + let config = EmailConfig::default(); + let channel = EmailChannel::new(config.clone()); + assert_eq!(channel.config.imap_host, config.imap_host); + + let seen_guard = channel.seen_messages.lock(); + assert_eq!(seen_guard.len(), 0); + } + + #[test] + fn email_channel_name() { + let channel = EmailChannel::new(EmailConfig::default()); + assert_eq!(channel.name(), "email"); + } + + // is_sender_allowed tests + + #[test] + fn is_sender_allowed_empty_list_denies_all() { + let config = EmailConfig { + allowed_senders: vec![], + ..Default::default() + }; + let channel = EmailChannel::new(config); + assert!(!channel.is_sender_allowed("anyone@example.com")); + assert!(!channel.is_sender_allowed("user@test.com")); + } + + #[test] + fn is_sender_allowed_wildcard_allows_all() { + let config = EmailConfig { + allowed_senders: vec!["*".to_string()], + ..Default::default() + }; + let channel = EmailChannel::new(config); + assert!(channel.is_sender_allowed("anyone@example.com")); + assert!(channel.is_sender_allowed("user@test.com")); + assert!(channel.is_sender_allowed("random@domain.org")); + } + + #[test] + fn is_sender_allowed_specific_email() { + let config = EmailConfig { + allowed_senders: vec!["allowed@example.com".to_string()], + ..Default::default() + }; + let channel = EmailChannel::new(config); + assert!(channel.is_sender_allowed("allowed@example.com")); + assert!(!channel.is_sender_allowed("other@example.com")); + assert!(!channel.is_sender_allowed("allowed@other.com")); + } + + #[test] + fn is_sender_allowed_domain_with_at_prefix() { + let config = EmailConfig { + allowed_senders: vec!["@example.com".to_string()], + ..Default::default() + }; + let channel = EmailChannel::new(config); + assert!(channel.is_sender_allowed("user@example.com")); + assert!(channel.is_sender_allowed("admin@example.com")); + assert!(!channel.is_sender_allowed("user@other.com")); + } + + #[test] + fn is_sender_allowed_domain_without_at_prefix() { + let config = EmailConfig { + allowed_senders: vec!["example.com".to_string()], + ..Default::default() + }; + let channel = EmailChannel::new(config); + assert!(channel.is_sender_allowed("user@example.com")); + assert!(channel.is_sender_allowed("admin@example.com")); + assert!(!channel.is_sender_allowed("user@other.com")); + } + + #[test] + fn is_sender_allowed_case_insensitive() { + let config = EmailConfig { + allowed_senders: vec!["Allowed@Example.COM".to_string()], + ..Default::default() + }; + let channel = EmailChannel::new(config); + assert!(channel.is_sender_allowed("allowed@example.com")); + assert!(channel.is_sender_allowed("ALLOWED@EXAMPLE.COM")); + assert!(channel.is_sender_allowed("AlLoWeD@eXaMpLe.cOm")); + } + + #[test] + fn is_sender_allowed_multiple_senders() { + let config = EmailConfig { + allowed_senders: vec![ + "user1@example.com".to_string(), + "user2@test.com".to_string(), + "@allowed.com".to_string(), + ], + ..Default::default() + }; + let channel = EmailChannel::new(config); + assert!(channel.is_sender_allowed("user1@example.com")); + assert!(channel.is_sender_allowed("user2@test.com")); + assert!(channel.is_sender_allowed("anyone@allowed.com")); + assert!(!channel.is_sender_allowed("user3@example.com")); + } + + #[test] + fn is_sender_allowed_wildcard_with_specific() { + let config = EmailConfig { + allowed_senders: vec!["*".to_string(), "specific@example.com".to_string()], + ..Default::default() + }; + let channel = EmailChannel::new(config); + assert!(channel.is_sender_allowed("anyone@example.com")); + assert!(channel.is_sender_allowed("specific@example.com")); + } + + #[test] + fn is_sender_allowed_empty_sender() { + let config = EmailConfig { + allowed_senders: vec!["@example.com".to_string()], + ..Default::default() + }; + let channel = EmailChannel::new(config); + assert!(!channel.is_sender_allowed("")); + // "@example.com" ends with "@example.com" so it's allowed + assert!(channel.is_sender_allowed("@example.com")); + } + + // strip_html tests + + #[test] + fn strip_html_basic() { + assert_eq!(EmailChannel::strip_html("

Hello

"), "Hello"); + assert_eq!(EmailChannel::strip_html("
World
"), "World"); + } + + #[test] + fn strip_html_nested_tags() { + assert_eq!( + EmailChannel::strip_html("

Hello World

"), + "Hello World" + ); + } + + #[test] + fn strip_html_multiple_lines() { + let html = "
\n

Line 1

\n

Line 2

\n
"; + assert_eq!(EmailChannel::strip_html(html), "Line 1 Line 2"); + } + + #[test] + fn strip_html_preserves_text() { + assert_eq!(EmailChannel::strip_html("No tags here"), "No tags here"); + assert_eq!(EmailChannel::strip_html(""), ""); + } + + #[test] + fn strip_html_handles_malformed() { + assert_eq!(EmailChannel::strip_html("

Unclosed"), "Unclosed"); + // The function removes everything between < and >, so "Text>with>brackets" becomes "Textwithbrackets" + assert_eq!( + EmailChannel::strip_html("Text>with>brackets"), + "Textwithbrackets" + ); + } + + #[test] + fn strip_html_self_closing_tags() { + // Self-closing tags are removed but don't add spaces + assert_eq!(EmailChannel::strip_html("Hello
World"), "HelloWorld"); + assert_eq!(EmailChannel::strip_html("Text


More"), "TextMore"); + } + + #[test] + fn strip_html_attributes_preserved() { + assert_eq!( + EmailChannel::strip_html("
Link"), + "Link" + ); + } + + #[test] + fn strip_html_multiple_spaces_collapsed() { + assert_eq!( + EmailChannel::strip_html("

Word

Word

"), + "Word Word" + ); + } + + #[test] + fn strip_html_special_characters() { + assert_eq!( + EmailChannel::strip_html("<tag>"), + "<tag>" + ); + } + + // Default function tests + + #[test] + fn default_imap_port_returns_993() { + assert_eq!(default_imap_port(), 993); + } + + #[test] + fn default_smtp_port_returns_465() { + assert_eq!(default_smtp_port(), 465); + } + + #[test] + fn default_imap_folder_returns_inbox() { + assert_eq!(default_imap_folder(), "INBOX"); + } + + #[test] + fn default_poll_interval_returns_60() { + assert_eq!(default_poll_interval(), 60); + } + + #[test] + fn default_true_returns_true() { + assert!(default_true()); + } + + // EmailConfig serialization tests + + #[test] + fn email_config_serialize_deserialize() { + let config = EmailConfig { + imap_host: "imap.example.com".to_string(), + imap_port: 993, + imap_folder: "INBOX".to_string(), + smtp_host: "smtp.example.com".to_string(), + smtp_port: 587, + smtp_tls: true, + username: "user@example.com".to_string(), + password: "password123".to_string(), + from_address: "bot@example.com".to_string(), + poll_interval_secs: 30, + allowed_senders: vec!["allowed@example.com".to_string()], + }; + + let json = serde_json::to_string(&config).unwrap(); + let deserialized: EmailConfig = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized.imap_host, config.imap_host); + assert_eq!(deserialized.smtp_port, config.smtp_port); + assert_eq!(deserialized.allowed_senders, config.allowed_senders); + } + + #[test] + fn email_config_deserialize_with_defaults() { + let json = r#"{ + "imap_host": "imap.test.com", + "smtp_host": "smtp.test.com", + "username": "user", + "password": "pass", + "from_address": "bot@test.com" + }"#; + + let config: EmailConfig = serde_json::from_str(json).unwrap(); + assert_eq!(config.imap_port, 993); // default + assert_eq!(config.smtp_port, 465); // default + assert!(config.smtp_tls); // default + assert_eq!(config.poll_interval_secs, 60); // default + } + + #[test] + fn email_config_debug_output() { + let config = EmailConfig { + imap_host: "imap.debug.com".to_string(), + ..Default::default() + }; + let debug_str = format!("{:?}", config); + assert!(debug_str.contains("imap.debug.com")); + } +} diff --git a/src/channels/imessage.rs b/src/channels/imessage.rs index 1272f0c..8dbd614 100644 --- a/src/channels/imessage.rs +++ b/src/channels/imessage.rs @@ -1,4 +1,4 @@ -use crate::channels::traits::{Channel, ChannelMessage}; +use crate::channels::traits::{Channel, ChannelMessage, SendMessage}; use async_trait::async_trait; use directories::UserDirs; use rusqlite::{Connection, OpenFlags}; @@ -36,8 +36,12 @@ impl IMessageChannel { /// This prevents injection attacks by escaping: /// - Backslashes (`\` → `\\`) /// - Double quotes (`"` → `\"`) +/// - Newlines (`\n` → `\\n`, `\r` → `\\r`) to prevent code injection via line breaks fn escape_applescript(s: &str) -> String { - s.replace('\\', "\\\\").replace('"', "\\\"") + s.replace('\\', "\\\\") + .replace('"', "\\\"") + .replace('\n', "\\n") + .replace('\r', "\\r") } /// Validate that a target looks like a valid phone number or email address. @@ -91,9 +95,9 @@ impl Channel for IMessageChannel { "imessage" } - async fn send(&self, message: &str, target: &str) -> anyhow::Result<()> { + async fn send(&self, message: &SendMessage) -> anyhow::Result<()> { // Defense-in-depth: validate target format before any interpolation - if !is_valid_imessage_target(target) { + if !is_valid_imessage_target(&message.recipient) { anyhow::bail!( "Invalid iMessage target: must be a phone number (+1234567890) or email (user@example.com)" ); @@ -101,8 +105,8 @@ impl Channel for IMessageChannel { // SECURITY: Escape both message AND target to prevent AppleScript injection // See: CWE-78 (OS Command Injection) - let escaped_msg = escape_applescript(message); - let escaped_target = escape_applescript(target); + let escaped_msg = escape_applescript(&message.content); + let escaped_target = escape_applescript(&message.recipient); let script = format!( r#"tell application "Messages" @@ -168,6 +172,7 @@ end tell"# let msg = ChannelMessage { id: rowid.to_string(), sender: sender.clone(), + reply_target: sender.clone(), content: text, channel: "imessage".to_string(), timestamp: std::time::SystemTime::now() @@ -386,8 +391,10 @@ mod tests { } #[test] - fn escape_applescript_newlines_preserved() { - assert_eq!(escape_applescript("line1\nline2"), "line1\nline2"); + fn escape_applescript_newlines_escaped() { + assert_eq!(escape_applescript("line1\nline2"), "line1\\nline2"); + assert_eq!(escape_applescript("line1\rline2"), "line1\\rline2"); + assert_eq!(escape_applescript("line1\r\nline2"), "line1\\r\\nline2"); } // ══════════════════════════════════════════════════════════ diff --git a/src/channels/irc.rs b/src/channels/irc.rs new file mode 100644 index 0000000..2e03378 --- /dev/null +++ b/src/channels/irc.rs @@ -0,0 +1,1014 @@ +use crate::channels::traits::{Channel, ChannelMessage, SendMessage}; +use async_trait::async_trait; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::sync::{mpsc, Mutex}; + +// Use tokio_rustls's re-export of rustls types +use tokio_rustls::rustls; + +/// Read timeout for IRC — if no data arrives within this duration, the +/// connection is considered dead. IRC servers typically PING every 60-120s. +const READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300); + +/// Monotonic counter to ensure unique message IDs under burst traffic. +static MSG_SEQ: AtomicU64 = AtomicU64::new(0); + +/// IRC over TLS channel. +/// +/// Connects to an IRC server using TLS, joins configured channels, +/// and forwards PRIVMSG messages to the `ZeroClaw` message bus. +/// Supports both channel messages and private messages (DMs). +pub struct IrcChannel { + server: String, + port: u16, + nickname: String, + username: String, + channels: Vec, + allowed_users: Vec, + server_password: Option, + nickserv_password: Option, + sasl_password: Option, + verify_tls: bool, + /// Shared write half of the TLS stream for sending messages. + writer: Arc>>, +} + +type WriteHalf = tokio::io::WriteHalf>; + +/// Style instruction prepended to every IRC message before it reaches the LLM. +/// IRC clients render plain text only — no markdown, no HTML, no XML. +const IRC_STYLE_PREFIX: &str = "\ +[context: you are responding over IRC. \ +Plain text only. No markdown, no tables, no XML/HTML tags. \ +Never use triple backtick code fences. Use a single blank line to separate blocks instead. \ +Be terse and concise. \ +Use short lines. Avoid walls of text.]\n"; + +/// Reserved bytes for the server-prepended sender prefix (`:nick!user@host `). +const SENDER_PREFIX_RESERVE: usize = 64; + +/// A parsed IRC message. +#[derive(Debug, Clone, PartialEq, Eq)] +struct IrcMessage { + prefix: Option, + command: String, + params: Vec, +} + +impl IrcMessage { + /// Parse a raw IRC line into an `IrcMessage`. + /// + /// IRC format: `[:] [] [:]` + fn parse(line: &str) -> Option { + let line = line.trim_end_matches(['\r', '\n']); + if line.is_empty() { + return None; + } + + let (prefix, rest) = if let Some(stripped) = line.strip_prefix(':') { + let space = stripped.find(' ')?; + (Some(stripped[..space].to_string()), &stripped[space + 1..]) + } else { + (None, line) + }; + + // Split at trailing (first `:` after command/params) + let (params_part, trailing) = if let Some(colon_pos) = rest.find(" :") { + (&rest[..colon_pos], Some(&rest[colon_pos + 2..])) + } else { + (rest, None) + }; + + let mut parts: Vec<&str> = params_part.split_whitespace().collect(); + if parts.is_empty() { + return None; + } + + let command = parts.remove(0).to_uppercase(); + let mut params: Vec = parts.iter().map(std::string::ToString::to_string).collect(); + if let Some(t) = trailing { + params.push(t.to_string()); + } + + Some(IrcMessage { + prefix, + command, + params, + }) + } + + /// Extract the nickname from the prefix (nick!user@host → nick). + fn nick(&self) -> Option<&str> { + self.prefix.as_ref().and_then(|p| { + let end = p.find('!').unwrap_or(p.len()); + let nick = &p[..end]; + if nick.is_empty() { + None + } else { + Some(nick) + } + }) + } +} + +/// Encode SASL PLAIN credentials: base64(\0nick\0password). +fn encode_sasl_plain(nick: &str, password: &str) -> String { + // Simple base64 encoder — avoids adding a base64 crate dependency. + // The project's Discord channel uses a similar inline approach. + const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + let input = format!("\0{nick}\0{password}"); + let bytes = input.as_bytes(); + let mut out = String::with_capacity(bytes.len().div_ceil(3) * 4); + + for chunk in bytes.chunks(3) { + let b0 = u32::from(chunk[0]); + let b1 = u32::from(chunk.get(1).copied().unwrap_or(0)); + let b2 = u32::from(chunk.get(2).copied().unwrap_or(0)); + let triple = (b0 << 16) | (b1 << 8) | b2; + + out.push(CHARS[(triple >> 18 & 0x3F) as usize] as char); + out.push(CHARS[(triple >> 12 & 0x3F) as usize] as char); + + if chunk.len() > 1 { + out.push(CHARS[(triple >> 6 & 0x3F) as usize] as char); + } else { + out.push('='); + } + + if chunk.len() > 2 { + out.push(CHARS[(triple & 0x3F) as usize] as char); + } else { + out.push('='); + } + } + + out +} + +/// Split a message into lines safe for IRC transmission. +/// +/// IRC is a line-based protocol — `\r\n` terminates each command, so any +/// newline inside a PRIVMSG payload would truncate the message and turn the +/// remainder into garbled/invalid IRC commands. +/// +/// This function: +/// 1. Splits on `\n` (and strips `\r`) so each logical line becomes its own PRIVMSG. +/// 2. Splits any line that exceeds `max_bytes` at a safe UTF-8 boundary. +/// 3. Skips empty lines to avoid sending blank PRIVMSGs. +fn split_message(message: &str, max_bytes: usize) -> Vec { + let mut chunks = Vec::new(); + + // Guard against max_bytes == 0 to prevent infinite loop + if max_bytes == 0 { + let full: String = message + .lines() + .map(|l| l.trim_end_matches('\r')) + .filter(|l| !l.is_empty()) + .collect::>() + .join(" "); + if full.is_empty() { + chunks.push(String::new()); + } else { + chunks.push(full); + } + return chunks; + } + + for line in message.split('\n') { + let line = line.trim_end_matches('\r'); + if line.is_empty() { + continue; + } + + if line.len() <= max_bytes { + chunks.push(line.to_string()); + continue; + } + + // Line exceeds max_bytes — split at safe UTF-8 boundaries + let mut remaining = line; + while !remaining.is_empty() { + if remaining.len() <= max_bytes { + chunks.push(remaining.to_string()); + break; + } + + let mut split_at = max_bytes; + while split_at > 0 && !remaining.is_char_boundary(split_at) { + split_at -= 1; + } + if split_at == 0 { + // No valid boundary found going backward — advance forward instead + split_at = max_bytes; + while split_at < remaining.len() && !remaining.is_char_boundary(split_at) { + split_at += 1; + } + } + + chunks.push(remaining[..split_at].to_string()); + remaining = &remaining[split_at..]; + } + } + + if chunks.is_empty() { + chunks.push(String::new()); + } + + chunks +} + +/// Configuration for constructing an `IrcChannel`. +pub struct IrcChannelConfig { + pub server: String, + pub port: u16, + pub nickname: String, + pub username: Option, + pub channels: Vec, + pub allowed_users: Vec, + pub server_password: Option, + pub nickserv_password: Option, + pub sasl_password: Option, + pub verify_tls: bool, +} + +impl IrcChannel { + pub fn new(cfg: IrcChannelConfig) -> Self { + let username = cfg.username.unwrap_or_else(|| cfg.nickname.clone()); + Self { + server: cfg.server, + port: cfg.port, + nickname: cfg.nickname, + username, + channels: cfg.channels, + allowed_users: cfg.allowed_users, + server_password: cfg.server_password, + nickserv_password: cfg.nickserv_password, + sasl_password: cfg.sasl_password, + verify_tls: cfg.verify_tls, + writer: Arc::new(Mutex::new(None)), + } + } + + fn is_user_allowed(&self, nick: &str) -> bool { + if self.allowed_users.iter().any(|u| u == "*") { + return true; + } + self.allowed_users + .iter() + .any(|u| u.eq_ignore_ascii_case(nick)) + } + + /// Create a TLS connection to the IRC server. + async fn connect( + &self, + ) -> anyhow::Result> { + let addr = format!("{}:{}", self.server, self.port); + let tcp = tokio::net::TcpStream::connect(&addr).await?; + + let tls_config = if self.verify_tls { + let root_store: rustls::RootCertStore = + webpki_roots::TLS_SERVER_ROOTS.iter().cloned().collect(); + rustls::ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth() + } else { + rustls::ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(Arc::new(NoVerify)) + .with_no_client_auth() + }; + + let connector = tokio_rustls::TlsConnector::from(Arc::new(tls_config)); + let domain = rustls::pki_types::ServerName::try_from(self.server.clone())?; + let tls = connector.connect(domain, tcp).await?; + + Ok(tls) + } + + /// Send a raw IRC line (appends \r\n). + async fn send_raw(writer: &mut WriteHalf, line: &str) -> anyhow::Result<()> { + let data = format!("{line}\r\n"); + writer.write_all(data.as_bytes()).await?; + writer.flush().await?; + Ok(()) + } +} + +/// Certificate verifier that accepts any certificate (for `verify_tls=false`). +#[derive(Debug)] +struct NoVerify; + +impl rustls::client::danger::ServerCertVerifier for NoVerify { + fn verify_server_cert( + &self, + _end_entity: &rustls::pki_types::CertificateDer<'_>, + _intermediates: &[rustls::pki_types::CertificateDer<'_>], + _server_name: &rustls::pki_types::ServerName<'_>, + _ocsp_response: &[u8], + _now: rustls::pki_types::UnixTime, + ) -> Result { + Ok(rustls::client::danger::ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &rustls::pki_types::CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &rustls::pki_types::CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + + fn supported_verify_schemes(&self) -> Vec { + rustls::crypto::ring::default_provider() + .signature_verification_algorithms + .supported_schemes() + } +} + +#[async_trait] +#[allow(clippy::too_many_lines)] +impl Channel for IrcChannel { + fn name(&self) -> &str { + "irc" + } + + async fn send(&self, message: &SendMessage) -> anyhow::Result<()> { + let mut guard = self.writer.lock().await; + let writer = guard + .as_mut() + .ok_or_else(|| anyhow::anyhow!("IRC not connected"))?; + + // Calculate safe payload size: + // 512 - sender prefix (~64 bytes for :nick!user@host) - "PRIVMSG " - target - " :" - "\r\n" + let overhead = SENDER_PREFIX_RESERVE + 10 + message.recipient.len() + 2; + let max_payload = 512_usize.saturating_sub(overhead); + let chunks = split_message(&message.content, max_payload); + + for chunk in chunks { + Self::send_raw(writer, &format!("PRIVMSG {} :{chunk}", message.recipient)).await?; + } + + Ok(()) + } + + async fn listen(&self, tx: mpsc::Sender) -> anyhow::Result<()> { + let mut current_nick = self.nickname.clone(); + tracing::info!( + "IRC channel connecting to {}:{} as {}...", + self.server, + self.port, + current_nick + ); + + let tls = self.connect().await?; + let (reader, mut writer) = tokio::io::split(tls); + + // --- SASL negotiation --- + if self.sasl_password.is_some() { + Self::send_raw(&mut writer, "CAP REQ :sasl").await?; + } + + // --- Server password --- + if let Some(ref pass) = self.server_password { + Self::send_raw(&mut writer, &format!("PASS {pass}")).await?; + } + + // --- Nick/User registration --- + Self::send_raw(&mut writer, &format!("NICK {current_nick}")).await?; + Self::send_raw( + &mut writer, + &format!("USER {} 0 * :ZeroClaw", self.username), + ) + .await?; + + // Store writer for send() + { + let mut guard = self.writer.lock().await; + *guard = Some(writer); + } + + let mut buf_reader = BufReader::new(reader); + let mut line = String::new(); + let mut registered = false; + let mut sasl_pending = self.sasl_password.is_some(); + + loop { + line.clear(); + let n = tokio::time::timeout(READ_TIMEOUT, buf_reader.read_line(&mut line)) + .await + .map_err(|_| { + anyhow::anyhow!("IRC read timed out (no data for {READ_TIMEOUT:?})") + })??; + if n == 0 { + anyhow::bail!("IRC connection closed by server"); + } + + let Some(msg) = IrcMessage::parse(&line) else { + continue; + }; + + match msg.command.as_str() { + "PING" => { + let token = msg.params.first().map_or("", String::as_str); + let mut guard = self.writer.lock().await; + if let Some(ref mut w) = *guard { + Self::send_raw(w, &format!("PONG :{token}")).await?; + } + } + + // CAP responses for SASL + "CAP" => { + if sasl_pending && msg.params.iter().any(|p| p.contains("sasl")) { + if msg.params.iter().any(|p| p.contains("ACK")) { + // CAP * ACK :sasl — server accepted, start SASL auth + let mut guard = self.writer.lock().await; + if let Some(ref mut w) = *guard { + Self::send_raw(w, "AUTHENTICATE PLAIN").await?; + } + } else if msg.params.iter().any(|p| p.contains("NAK")) { + // CAP * NAK :sasl — server rejected SASL, proceed without it + tracing::warn!( + "IRC server does not support SASL, continuing without it" + ); + sasl_pending = false; + let mut guard = self.writer.lock().await; + if let Some(ref mut w) = *guard { + Self::send_raw(w, "CAP END").await?; + } + } + } + } + + "AUTHENTICATE" => { + // Server sends "AUTHENTICATE +" to request credentials + if sasl_pending && msg.params.first().is_some_and(|p| p == "+") { + if let Some(password) = self.sasl_password.as_deref() { + let encoded = encode_sasl_plain(¤t_nick, password); + let mut guard = self.writer.lock().await; + if let Some(ref mut w) = *guard { + Self::send_raw(w, &format!("AUTHENTICATE {encoded}")).await?; + } + } else { + // SASL was requested but no password is configured; abort SASL + tracing::warn!( + "SASL authentication requested but no SASL password is configured; aborting SASL" + ); + sasl_pending = false; + let mut guard = self.writer.lock().await; + if let Some(ref mut w) = *guard { + Self::send_raw(w, "CAP END").await?; + } + } + } + } + + // RPL_SASLSUCCESS (903) — SASL done, end CAP + "903" => { + sasl_pending = false; + let mut guard = self.writer.lock().await; + if let Some(ref mut w) = *guard { + Self::send_raw(w, "CAP END").await?; + } + } + + // SASL failure (904, 905, 906, 907) + "904" | "905" | "906" | "907" => { + tracing::warn!("IRC SASL authentication failed ({})", msg.command); + sasl_pending = false; + let mut guard = self.writer.lock().await; + if let Some(ref mut w) = *guard { + Self::send_raw(w, "CAP END").await?; + } + } + + // RPL_WELCOME — registration complete + "001" => { + registered = true; + tracing::info!("IRC registered as {}", current_nick); + + // NickServ authentication + if let Some(ref pass) = self.nickserv_password { + let mut guard = self.writer.lock().await; + if let Some(ref mut w) = *guard { + Self::send_raw(w, &format!("PRIVMSG NickServ :IDENTIFY {pass}")) + .await?; + } + } + + // Join channels + for chan in &self.channels { + let mut guard = self.writer.lock().await; + if let Some(ref mut w) = *guard { + Self::send_raw(w, &format!("JOIN {chan}")).await?; + } + } + } + + // ERR_NICKNAMEINUSE (433) + "433" => { + let alt = format!("{current_nick}_"); + tracing::warn!("IRC nickname {current_nick} is in use, trying {alt}"); + let mut guard = self.writer.lock().await; + if let Some(ref mut w) = *guard { + Self::send_raw(w, &format!("NICK {alt}")).await?; + } + current_nick = alt; + } + + "PRIVMSG" => { + if !registered { + continue; + } + + let target = msg.params.first().map_or("", String::as_str); + let text = msg.params.get(1).map_or("", String::as_str); + let sender_nick = msg.nick().unwrap_or("unknown"); + + // Skip messages from NickServ/ChanServ + if sender_nick.eq_ignore_ascii_case("NickServ") + || sender_nick.eq_ignore_ascii_case("ChanServ") + { + continue; + } + + if !self.is_user_allowed(sender_nick) { + continue; + } + + // Determine reply target: if sent to a channel, reply to channel; + // if DM (target == our nick), reply to sender + let is_channel = target.starts_with('#') || target.starts_with('&'); + let reply_to = if is_channel { + target.to_string() + } else { + sender_nick.to_string() + }; + let content = if is_channel { + format!("{IRC_STYLE_PREFIX}<{sender_nick}> {text}") + } else { + format!("{IRC_STYLE_PREFIX}{text}") + }; + + let seq = MSG_SEQ.fetch_add(1, Ordering::Relaxed); + let channel_msg = ChannelMessage { + id: format!("irc_{}_{seq}", chrono::Utc::now().timestamp_millis()), + sender: sender_nick.to_string(), + reply_target: reply_to, + content, + channel: "irc".to_string(), + timestamp: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(), + }; + + if tx.send(channel_msg).await.is_err() { + return Ok(()); + } + } + + // ERR_PASSWDMISMATCH (464) or other fatal errors + "464" => { + anyhow::bail!("IRC password mismatch"); + } + + _ => {} + } + } + } + + async fn health_check(&self) -> bool { + // Lightweight connectivity check: TLS connect + QUIT + match self.connect().await { + Ok(tls) => { + let (_, mut writer) = tokio::io::split(tls); + let _ = Self::send_raw(&mut writer, "QUIT :health check").await; + true + } + Err(_) => false, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── IRC message parsing ────────────────────────────────── + + #[test] + fn parse_privmsg_with_prefix() { + let msg = IrcMessage::parse(":nick!user@host PRIVMSG #channel :Hello world").unwrap(); + assert_eq!(msg.prefix.as_deref(), Some("nick!user@host")); + assert_eq!(msg.command, "PRIVMSG"); + assert_eq!(msg.params, vec!["#channel", "Hello world"]); + } + + #[test] + fn parse_privmsg_dm() { + let msg = IrcMessage::parse(":alice!a@host PRIVMSG botname :hi there").unwrap(); + assert_eq!(msg.command, "PRIVMSG"); + assert_eq!(msg.params, vec!["botname", "hi there"]); + assert_eq!(msg.nick(), Some("alice")); + } + + #[test] + fn parse_ping() { + let msg = IrcMessage::parse("PING :server.example.com").unwrap(); + assert!(msg.prefix.is_none()); + assert_eq!(msg.command, "PING"); + assert_eq!(msg.params, vec!["server.example.com"]); + } + + #[test] + fn parse_numeric_reply() { + let msg = IrcMessage::parse(":server 001 botname :Welcome to the IRC network").unwrap(); + assert_eq!(msg.prefix.as_deref(), Some("server")); + assert_eq!(msg.command, "001"); + assert_eq!(msg.params, vec!["botname", "Welcome to the IRC network"]); + } + + #[test] + fn parse_no_trailing() { + let msg = IrcMessage::parse(":server 433 * botname").unwrap(); + assert_eq!(msg.command, "433"); + assert_eq!(msg.params, vec!["*", "botname"]); + } + + #[test] + fn parse_cap_ack() { + let msg = IrcMessage::parse(":server CAP * ACK :sasl").unwrap(); + assert_eq!(msg.command, "CAP"); + assert_eq!(msg.params, vec!["*", "ACK", "sasl"]); + } + + #[test] + fn parse_empty_line_returns_none() { + assert!(IrcMessage::parse("").is_none()); + assert!(IrcMessage::parse("\r\n").is_none()); + } + + #[test] + fn parse_strips_crlf() { + let msg = IrcMessage::parse("PING :test\r\n").unwrap(); + assert_eq!(msg.params, vec!["test"]); + } + + #[test] + fn parse_command_uppercase() { + let msg = IrcMessage::parse("ping :test").unwrap(); + assert_eq!(msg.command, "PING"); + } + + #[test] + fn nick_extraction_full_prefix() { + let msg = IrcMessage::parse(":nick!user@host PRIVMSG #ch :msg").unwrap(); + assert_eq!(msg.nick(), Some("nick")); + } + + #[test] + fn nick_extraction_nick_only() { + let msg = IrcMessage::parse(":server 001 bot :Welcome").unwrap(); + assert_eq!(msg.nick(), Some("server")); + } + + #[test] + fn nick_extraction_no_prefix() { + let msg = IrcMessage::parse("PING :token").unwrap(); + assert_eq!(msg.nick(), None); + } + + #[test] + fn parse_authenticate_plus() { + let msg = IrcMessage::parse("AUTHENTICATE +").unwrap(); + assert_eq!(msg.command, "AUTHENTICATE"); + assert_eq!(msg.params, vec!["+"]); + } + + // ── SASL PLAIN encoding ───────────────────────────────── + + #[test] + fn sasl_plain_encode() { + let encoded = encode_sasl_plain("jilles", "sesame"); + // \0jilles\0sesame → base64 + assert_eq!(encoded, "AGppbGxlcwBzZXNhbWU="); + } + + #[test] + fn sasl_plain_empty_password() { + let encoded = encode_sasl_plain("nick", ""); + // \0nick\0 → base64 + assert_eq!(encoded, "AG5pY2sA"); + } + + // ── Message splitting ─────────────────────────────────── + + #[test] + fn split_short_message() { + let chunks = split_message("hello", 400); + assert_eq!(chunks, vec!["hello"]); + } + + #[test] + fn split_long_message() { + let msg = "a".repeat(800); + let chunks = split_message(&msg, 400); + assert_eq!(chunks.len(), 2); + assert_eq!(chunks[0].len(), 400); + assert_eq!(chunks[1].len(), 400); + } + + #[test] + fn split_exact_boundary() { + let msg = "a".repeat(400); + let chunks = split_message(&msg, 400); + assert_eq!(chunks.len(), 1); + } + + #[test] + fn split_unicode_safe() { + // 'é' is 2 bytes in UTF-8; splitting at byte 3 would split mid-char + let msg = "ééé"; // 6 bytes + let chunks = split_message(msg, 3); + // Should split at char boundary (2 bytes), not mid-char + assert_eq!(chunks.len(), 3); + assert_eq!(chunks[0], "é"); + assert_eq!(chunks[1], "é"); + assert_eq!(chunks[2], "é"); + } + + #[test] + fn split_empty_message() { + let chunks = split_message("", 400); + assert_eq!(chunks, vec![""]); + } + + #[test] + fn split_newlines_into_separate_lines() { + let chunks = split_message("line one\nline two\nline three", 400); + assert_eq!(chunks, vec!["line one", "line two", "line three"]); + } + + #[test] + fn split_crlf_newlines() { + let chunks = split_message("hello\r\nworld", 400); + assert_eq!(chunks, vec!["hello", "world"]); + } + + #[test] + fn split_skips_empty_lines() { + let chunks = split_message("hello\n\n\nworld", 400); + assert_eq!(chunks, vec!["hello", "world"]); + } + + #[test] + fn split_trailing_newline() { + let chunks = split_message("hello\n", 400); + assert_eq!(chunks, vec!["hello"]); + } + + #[test] + fn split_multiline_with_long_line() { + let long = "a".repeat(800); + let msg = format!("short\n{long}\nend"); + let chunks = split_message(&msg, 400); + assert_eq!(chunks.len(), 4); + assert_eq!(chunks[0], "short"); + assert_eq!(chunks[1].len(), 400); + assert_eq!(chunks[2].len(), 400); + assert_eq!(chunks[3], "end"); + } + + #[test] + fn split_only_newlines() { + let chunks = split_message("\n\n\n", 400); + assert_eq!(chunks, vec![""]); + } + + // ── Allowlist ─────────────────────────────────────────── + + #[test] + fn wildcard_allows_anyone() { + let ch = make_channel(); + // Default make_channel has wildcard + assert!(ch.is_user_allowed("anyone")); + assert!(ch.is_user_allowed("stranger")); + } + + #[test] + fn specific_user_allowed() { + let ch = IrcChannel::new(IrcChannelConfig { + server: "irc.test".into(), + port: 6697, + nickname: "bot".into(), + username: None, + channels: vec![], + allowed_users: vec!["alice".into(), "bob".into()], + server_password: None, + nickserv_password: None, + sasl_password: None, + verify_tls: true, + }); + assert!(ch.is_user_allowed("alice")); + assert!(ch.is_user_allowed("bob")); + assert!(!ch.is_user_allowed("eve")); + } + + #[test] + fn allowlist_case_insensitive() { + let ch = IrcChannel::new(IrcChannelConfig { + server: "irc.test".into(), + port: 6697, + nickname: "bot".into(), + username: None, + channels: vec![], + allowed_users: vec!["Alice".into()], + server_password: None, + nickserv_password: None, + sasl_password: None, + verify_tls: true, + }); + assert!(ch.is_user_allowed("alice")); + assert!(ch.is_user_allowed("ALICE")); + assert!(ch.is_user_allowed("Alice")); + } + + #[test] + fn empty_allowlist_denies_all() { + let ch = IrcChannel::new(IrcChannelConfig { + server: "irc.test".into(), + port: 6697, + nickname: "bot".into(), + username: None, + channels: vec![], + allowed_users: vec![], + server_password: None, + nickserv_password: None, + sasl_password: None, + verify_tls: true, + }); + assert!(!ch.is_user_allowed("anyone")); + } + + // ── Constructor ───────────────────────────────────────── + + #[test] + fn new_defaults_username_to_nickname() { + let ch = IrcChannel::new(IrcChannelConfig { + server: "irc.test".into(), + port: 6697, + nickname: "mybot".into(), + username: None, + channels: vec![], + allowed_users: vec![], + server_password: None, + nickserv_password: None, + sasl_password: None, + verify_tls: true, + }); + assert_eq!(ch.username, "mybot"); + } + + #[test] + fn new_uses_explicit_username() { + let ch = IrcChannel::new(IrcChannelConfig { + server: "irc.test".into(), + port: 6697, + nickname: "mybot".into(), + username: Some("customuser".into()), + channels: vec![], + allowed_users: vec![], + server_password: None, + nickserv_password: None, + sasl_password: None, + verify_tls: true, + }); + assert_eq!(ch.username, "customuser"); + assert_eq!(ch.nickname, "mybot"); + } + + #[test] + fn name_returns_irc() { + let ch = make_channel(); + assert_eq!(ch.name(), "irc"); + } + + #[test] + fn new_stores_all_fields() { + let ch = IrcChannel::new(IrcChannelConfig { + server: "irc.example.com".into(), + port: 6697, + nickname: "zcbot".into(), + username: Some("zeroclaw".into()), + channels: vec!["#test".into()], + allowed_users: vec!["alice".into()], + server_password: Some("serverpass".into()), + nickserv_password: Some("nspass".into()), + sasl_password: Some("saslpass".into()), + verify_tls: false, + }); + assert_eq!(ch.server, "irc.example.com"); + assert_eq!(ch.port, 6697); + assert_eq!(ch.nickname, "zcbot"); + assert_eq!(ch.username, "zeroclaw"); + assert_eq!(ch.channels, vec!["#test"]); + assert_eq!(ch.allowed_users, vec!["alice"]); + assert_eq!(ch.server_password.as_deref(), Some("serverpass")); + assert_eq!(ch.nickserv_password.as_deref(), Some("nspass")); + assert_eq!(ch.sasl_password.as_deref(), Some("saslpass")); + assert!(!ch.verify_tls); + } + + // ── Config serde ──────────────────────────────────────── + + #[test] + fn irc_config_serde_roundtrip() { + use crate::config::schema::IrcConfig; + + let config = IrcConfig { + server: "irc.example.com".into(), + port: 6697, + nickname: "zcbot".into(), + username: Some("zeroclaw".into()), + channels: vec!["#test".into(), "#dev".into()], + allowed_users: vec!["alice".into()], + server_password: None, + nickserv_password: Some("secret".into()), + sasl_password: None, + verify_tls: Some(true), + }; + + let toml_str = toml::to_string(&config).unwrap(); + let parsed: IrcConfig = toml::from_str(&toml_str).unwrap(); + assert_eq!(parsed.server, "irc.example.com"); + assert_eq!(parsed.port, 6697); + assert_eq!(parsed.nickname, "zcbot"); + assert_eq!(parsed.username.as_deref(), Some("zeroclaw")); + assert_eq!(parsed.channels, vec!["#test", "#dev"]); + assert_eq!(parsed.allowed_users, vec!["alice"]); + assert!(parsed.server_password.is_none()); + assert_eq!(parsed.nickserv_password.as_deref(), Some("secret")); + assert!(parsed.sasl_password.is_none()); + assert_eq!(parsed.verify_tls, Some(true)); + } + + #[test] + fn irc_config_minimal_toml() { + use crate::config::schema::IrcConfig; + + let toml_str = r#" +server = "irc.example.com" +nickname = "bot" +"#; + let parsed: IrcConfig = toml::from_str(toml_str).unwrap(); + assert_eq!(parsed.server, "irc.example.com"); + assert_eq!(parsed.port, 6697); // default + assert_eq!(parsed.nickname, "bot"); + assert!(parsed.username.is_none()); + assert!(parsed.channels.is_empty()); + assert!(parsed.allowed_users.is_empty()); + assert!(parsed.server_password.is_none()); + assert!(parsed.nickserv_password.is_none()); + assert!(parsed.sasl_password.is_none()); + assert!(parsed.verify_tls.is_none()); + } + + #[test] + fn irc_config_default_port() { + use crate::config::schema::IrcConfig; + + let json = r#"{"server":"irc.test","nickname":"bot"}"#; + let parsed: IrcConfig = serde_json::from_str(json).unwrap(); + assert_eq!(parsed.port, 6697); + } + + // ── Helpers ───────────────────────────────────────────── + + fn make_channel() -> IrcChannel { + IrcChannel::new(IrcChannelConfig { + server: "irc.example.com".into(), + port: 6697, + nickname: "zcbot".into(), + username: None, + channels: vec!["#zeroclaw".into()], + allowed_users: vec!["*".into()], + server_password: None, + nickserv_password: None, + sasl_password: None, + verify_tls: true, + }) + } +} diff --git a/src/channels/lark.rs b/src/channels/lark.rs new file mode 100644 index 0000000..c8d6cdb --- /dev/null +++ b/src/channels/lark.rs @@ -0,0 +1,1237 @@ +use super::traits::{Channel, ChannelMessage, SendMessage}; +use async_trait::async_trait; +use futures_util::{SinkExt, StreamExt}; +use prost::Message as ProstMessage; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::RwLock; +use tokio_tungstenite::tungstenite::Message as WsMsg; +use uuid::Uuid; + +const FEISHU_BASE_URL: &str = "https://open.feishu.cn/open-apis"; +const FEISHU_WS_BASE_URL: &str = "https://open.feishu.cn"; +const LARK_BASE_URL: &str = "https://open.larksuite.com/open-apis"; +const LARK_WS_BASE_URL: &str = "https://open.larksuite.com"; + +// ───────────────────────────────────────────────────────────────────────────── +// Feishu WebSocket long-connection: pbbp2.proto frame codec +// ───────────────────────────────────────────────────────────────────────────── + +#[derive(Clone, PartialEq, prost::Message)] +struct PbHeader { + #[prost(string, tag = "1")] + pub key: String, + #[prost(string, tag = "2")] + pub value: String, +} + +/// Feishu WS frame (pbbp2.proto). +/// method=0 → CONTROL (ping/pong) method=1 → DATA (events) +#[derive(Clone, PartialEq, prost::Message)] +struct PbFrame { + #[prost(uint64, tag = "1")] + pub seq_id: u64, + #[prost(uint64, tag = "2")] + pub log_id: u64, + #[prost(int32, tag = "3")] + pub service: i32, + #[prost(int32, tag = "4")] + pub method: i32, + #[prost(message, repeated, tag = "5")] + pub headers: Vec, + #[prost(bytes = "vec", optional, tag = "8")] + pub payload: Option>, +} + +impl PbFrame { + fn header_value<'a>(&'a self, key: &str) -> &'a str { + self.headers + .iter() + .find(|h| h.key == key) + .map(|h| h.value.as_str()) + .unwrap_or("") + } +} + +/// Server-sent client config (parsed from pong payload) +#[derive(Debug, serde::Deserialize, Default, Clone)] +struct WsClientConfig { + #[serde(rename = "PingInterval")] + ping_interval: Option, +} + +/// POST /callback/ws/endpoint response +#[derive(Debug, serde::Deserialize)] +struct WsEndpointResp { + code: i32, + #[serde(default)] + msg: Option, + #[serde(default)] + data: Option, +} + +#[derive(Debug, serde::Deserialize)] +struct WsEndpoint { + #[serde(rename = "URL")] + url: String, + #[serde(rename = "ClientConfig")] + client_config: Option, +} + +/// LarkEvent envelope (method=1 / type=event payload) +#[derive(Debug, serde::Deserialize)] +struct LarkEvent { + header: LarkEventHeader, + event: serde_json::Value, +} + +#[derive(Debug, serde::Deserialize)] +struct LarkEventHeader { + event_type: String, + #[allow(dead_code)] + event_id: String, +} + +#[derive(Debug, serde::Deserialize)] +struct MsgReceivePayload { + sender: LarkSender, + message: LarkMessage, +} + +#[derive(Debug, serde::Deserialize)] +struct LarkSender { + sender_id: LarkSenderId, + #[serde(default)] + sender_type: String, +} + +#[derive(Debug, serde::Deserialize, Default)] +struct LarkSenderId { + open_id: Option, +} + +#[derive(Debug, serde::Deserialize)] +struct LarkMessage { + message_id: String, + chat_id: String, + chat_type: String, + message_type: String, + #[serde(default)] + content: String, + #[serde(default)] + mentions: Vec, +} + +/// Heartbeat timeout for WS connection — must be larger than ping_interval (default 120 s). +/// If no binary frame (pong or event) is received within this window, reconnect. +const WS_HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(300); + +/// Lark/Feishu channel. +/// +/// Supports two receive modes (configured via `receive_mode` in config): +/// - **`websocket`** (default): persistent WSS long-connection; no public URL needed. +/// - **`webhook`**: HTTP callback server; requires a public HTTPS endpoint. +pub struct LarkChannel { + app_id: String, + app_secret: String, + verification_token: String, + port: Option, + allowed_users: Vec, + /// When true, use Feishu (CN) endpoints; when false, use Lark (international). + use_feishu: bool, + /// How to receive events: WebSocket long-connection or HTTP webhook. + receive_mode: crate::config::schema::LarkReceiveMode, + client: reqwest::Client, + /// Cached tenant access token + tenant_token: Arc>>, + /// Dedup set: WS message_ids seen in last ~30 min to prevent double-dispatch + ws_seen_ids: Arc>>, +} + +impl LarkChannel { + pub fn new( + app_id: String, + app_secret: String, + verification_token: String, + port: Option, + allowed_users: Vec, + ) -> Self { + Self { + app_id, + app_secret, + verification_token, + port, + allowed_users, + use_feishu: true, + receive_mode: crate::config::schema::LarkReceiveMode::default(), + client: reqwest::Client::new(), + tenant_token: Arc::new(RwLock::new(None)), + ws_seen_ids: Arc::new(RwLock::new(HashMap::new())), + } + } + + /// Build from `LarkConfig` (preserves `use_feishu` and `receive_mode`). + pub fn from_config(config: &crate::config::schema::LarkConfig) -> Self { + let mut ch = Self::new( + config.app_id.clone(), + config.app_secret.clone(), + config.verification_token.clone().unwrap_or_default(), + config.port, + config.allowed_users.clone(), + ); + ch.use_feishu = config.use_feishu; + ch.receive_mode = config.receive_mode.clone(); + ch + } + + fn api_base(&self) -> &'static str { + if self.use_feishu { + FEISHU_BASE_URL + } else { + LARK_BASE_URL + } + } + + fn ws_base(&self) -> &'static str { + if self.use_feishu { + FEISHU_WS_BASE_URL + } else { + LARK_WS_BASE_URL + } + } + + fn tenant_access_token_url(&self) -> String { + format!("{}/auth/v3/tenant_access_token/internal", self.api_base()) + } + + fn send_message_url(&self) -> String { + format!("{}/im/v1/messages?receive_id_type=chat_id", self.api_base()) + } + + /// POST /callback/ws/endpoint → (wss_url, client_config) + async fn get_ws_endpoint(&self) -> anyhow::Result<(String, WsClientConfig)> { + let resp = self + .client + .post(format!("{}/callback/ws/endpoint", self.ws_base())) + .header("locale", if self.use_feishu { "zh" } else { "en" }) + .json(&serde_json::json!({ + "AppID": self.app_id, + "AppSecret": self.app_secret, + })) + .send() + .await? + .json::() + .await?; + if resp.code != 0 { + anyhow::bail!( + "Lark WS endpoint failed: code={} msg={}", + resp.code, + resp.msg.as_deref().unwrap_or("(none)") + ); + } + let ep = resp + .data + .ok_or_else(|| anyhow::anyhow!("Lark WS endpoint: empty data"))?; + Ok((ep.url, ep.client_config.unwrap_or_default())) + } + + /// WS long-connection event loop. Returns Ok(()) when the connection closes + /// (the caller reconnects). + #[allow(clippy::too_many_lines)] + async fn listen_ws(&self, tx: tokio::sync::mpsc::Sender) -> anyhow::Result<()> { + let (wss_url, client_config) = self.get_ws_endpoint().await?; + let service_id = wss_url + .split('?') + .nth(1) + .and_then(|qs| { + qs.split('&') + .find(|kv| kv.starts_with("service_id=")) + .and_then(|kv| kv.split('=').nth(1)) + .and_then(|v| v.parse::().ok()) + }) + .unwrap_or(0); + tracing::info!("Lark: connecting to {wss_url}"); + + let (ws_stream, _) = tokio_tungstenite::connect_async(&wss_url).await?; + let (mut write, mut read) = ws_stream.split(); + tracing::info!("Lark: WS connected (service_id={service_id})"); + + let mut ping_secs = client_config.ping_interval.unwrap_or(120).max(10); + let mut hb_interval = tokio::time::interval(Duration::from_secs(ping_secs)); + let mut timeout_check = tokio::time::interval(Duration::from_secs(10)); + hb_interval.tick().await; // consume immediate tick + + let mut seq: u64 = 0; + let mut last_recv = Instant::now(); + + // Send initial ping immediately (like the official SDK) so the server + // starts responding with pongs and we can calibrate the ping_interval. + seq = seq.wrapping_add(1); + let initial_ping = PbFrame { + seq_id: seq, + log_id: 0, + service: service_id, + method: 0, + headers: vec![PbHeader { + key: "type".into(), + value: "ping".into(), + }], + payload: None, + }; + if write + .send(WsMsg::Binary(initial_ping.encode_to_vec())) + .await + .is_err() + { + anyhow::bail!("Lark: initial ping failed"); + } + // message_id → (fragment_slots, created_at) for multi-part reassembly + type FragEntry = (Vec>>, Instant); + let mut frag_cache: HashMap = HashMap::new(); + + loop { + tokio::select! { + biased; + + _ = hb_interval.tick() => { + seq = seq.wrapping_add(1); + let ping = PbFrame { + seq_id: seq, log_id: 0, service: service_id, method: 0, + headers: vec![PbHeader { key: "type".into(), value: "ping".into() }], + payload: None, + }; + if write.send(WsMsg::Binary(ping.encode_to_vec())).await.is_err() { + tracing::warn!("Lark: ping failed, reconnecting"); + break; + } + // GC stale fragments > 5 min + let cutoff = Instant::now().checked_sub(Duration::from_secs(300)).unwrap_or(Instant::now()); + frag_cache.retain(|_, (_, ts)| *ts > cutoff); + } + + _ = timeout_check.tick() => { + if last_recv.elapsed() > WS_HEARTBEAT_TIMEOUT { + tracing::warn!("Lark: heartbeat timeout, reconnecting"); + break; + } + } + + msg = read.next() => { + let raw = match msg { + Some(Ok(WsMsg::Binary(b))) => { last_recv = Instant::now(); b } + Some(Ok(WsMsg::Ping(d))) => { let _ = write.send(WsMsg::Pong(d)).await; continue; } + Some(Ok(WsMsg::Close(_))) | None => { tracing::info!("Lark: WS closed — reconnecting"); break; } + Some(Err(e)) => { tracing::error!("Lark: WS read error: {e}"); break; } + _ => continue, + }; + + let frame = match PbFrame::decode(&raw[..]) { + Ok(f) => f, + Err(e) => { tracing::error!("Lark: proto decode: {e}"); continue; } + }; + + // CONTROL frame + if frame.method == 0 { + if frame.header_value("type") == "pong" { + if let Some(p) = &frame.payload { + if let Ok(cfg) = serde_json::from_slice::(p) { + if let Some(secs) = cfg.ping_interval { + let secs = secs.max(10); + if secs != ping_secs { + ping_secs = secs; + hb_interval = tokio::time::interval(Duration::from_secs(ping_secs)); + tracing::info!("Lark: ping_interval → {ping_secs}s"); + } + } + } + } + } + continue; + } + + // DATA frame + let msg_type = frame.header_value("type").to_string(); + let msg_id = frame.header_value("message_id").to_string(); + let sum = frame.header_value("sum").parse::().unwrap_or(1); + let seq_num = frame.header_value("seq").parse::().unwrap_or(0); + + // ACK immediately (Feishu requires within 3 s) + { + let mut ack = frame.clone(); + ack.payload = Some(br#"{"code":200,"headers":{},"data":[]}"#.to_vec()); + ack.headers.push(PbHeader { key: "biz_rt".into(), value: "0".into() }); + let _ = write.send(WsMsg::Binary(ack.encode_to_vec())).await; + } + + // Fragment reassembly + let sum = if sum == 0 { 1 } else { sum }; + let payload: Vec = if sum == 1 || msg_id.is_empty() || seq_num >= sum { + frame.payload.clone().unwrap_or_default() + } else { + let entry = frag_cache.entry(msg_id.clone()) + .or_insert_with(|| (vec![None; sum], Instant::now())); + if entry.0.len() != sum { *entry = (vec![None; sum], Instant::now()); } + entry.0[seq_num] = frame.payload.clone(); + if entry.0.iter().all(|s| s.is_some()) { + let full: Vec = entry.0.iter() + .flat_map(|s| s.as_deref().unwrap_or(&[])) + .copied().collect(); + frag_cache.remove(&msg_id); + full + } else { continue; } + }; + + if msg_type != "event" { continue; } + + let event: LarkEvent = match serde_json::from_slice(&payload) { + Ok(e) => e, + Err(e) => { tracing::error!("Lark: event JSON: {e}"); continue; } + }; + if event.header.event_type != "im.message.receive_v1" { continue; } + + let recv: MsgReceivePayload = match serde_json::from_value(event.event) { + Ok(r) => r, + Err(e) => { tracing::error!("Lark: payload parse: {e}"); continue; } + }; + + if recv.sender.sender_type == "app" || recv.sender.sender_type == "bot" { continue; } + + let sender_open_id = recv.sender.sender_id.open_id.as_deref().unwrap_or(""); + if !self.is_user_allowed(sender_open_id) { + tracing::warn!("Lark WS: ignoring {sender_open_id} (not in allowed_users)"); + continue; + } + + let lark_msg = &recv.message; + + // Dedup + { + let now = Instant::now(); + let mut seen = self.ws_seen_ids.write().await; + // GC + seen.retain(|_, t| now.duration_since(*t) < Duration::from_secs(30 * 60)); + if seen.contains_key(&lark_msg.message_id) { + tracing::debug!("Lark WS: dup {}", lark_msg.message_id); + continue; + } + seen.insert(lark_msg.message_id.clone(), now); + } + + // Decode content by type (mirrors clawdbot-feishu parsing) + let text = match lark_msg.message_type.as_str() { + "text" => { + let v: serde_json::Value = match serde_json::from_str(&lark_msg.content) { + Ok(v) => v, + Err(_) => continue, + }; + match v.get("text").and_then(|t| t.as_str()).filter(|s| !s.is_empty()) { + Some(t) => t.to_string(), + None => continue, + } + } + "post" => match parse_post_content(&lark_msg.content) { + Some(t) => t, + None => continue, + }, + _ => { tracing::debug!("Lark WS: skipping unsupported type '{}'", lark_msg.message_type); continue; } + }; + + // Strip @_user_N placeholders + let text = strip_at_placeholders(&text); + let text = text.trim().to_string(); + if text.is_empty() { continue; } + + // Group-chat: only respond when explicitly @-mentioned + if lark_msg.chat_type == "group" && !should_respond_in_group(&lark_msg.mentions) { + continue; + } + + let channel_msg = ChannelMessage { + id: Uuid::new_v4().to_string(), + sender: lark_msg.chat_id.clone(), + reply_target: lark_msg.chat_id.clone(), + content: text, + channel: "lark".to_string(), + timestamp: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(), + }; + + tracing::debug!("Lark WS: message in {}", lark_msg.chat_id); + if tx.send(channel_msg).await.is_err() { break; } + } + } + } + Ok(()) + } + + /// Check if a user open_id is allowed + fn is_user_allowed(&self, open_id: &str) -> bool { + self.allowed_users.iter().any(|u| u == "*" || u == open_id) + } + + /// Get or refresh tenant access token + async fn get_tenant_access_token(&self) -> anyhow::Result { + // Check cache first + { + let cached = self.tenant_token.read().await; + if let Some(ref token) = *cached { + return Ok(token.clone()); + } + } + + let url = self.tenant_access_token_url(); + let body = serde_json::json!({ + "app_id": self.app_id, + "app_secret": self.app_secret, + }); + + let resp = self.client.post(&url).json(&body).send().await?; + let data: serde_json::Value = resp.json().await?; + + let code = data.get("code").and_then(|c| c.as_i64()).unwrap_or(-1); + if code != 0 { + let msg = data + .get("msg") + .and_then(|m| m.as_str()) + .unwrap_or("unknown error"); + anyhow::bail!("Lark tenant_access_token failed: {msg}"); + } + + let token = data + .get("tenant_access_token") + .and_then(|t| t.as_str()) + .ok_or_else(|| anyhow::anyhow!("missing tenant_access_token in response"))? + .to_string(); + + // Cache it + { + let mut cached = self.tenant_token.write().await; + *cached = Some(token.clone()); + } + + Ok(token) + } + + /// Invalidate cached token (called on 401) + async fn invalidate_token(&self) { + let mut cached = self.tenant_token.write().await; + *cached = None; + } + + /// Parse an event callback payload and extract text messages + pub fn parse_event_payload(&self, payload: &serde_json::Value) -> Vec { + let mut messages = Vec::new(); + + // Lark event v2 structure: + // { "header": { "event_type": "im.message.receive_v1" }, "event": { "message": { ... }, "sender": { ... } } } + let event_type = payload + .pointer("/header/event_type") + .and_then(|e| e.as_str()) + .unwrap_or(""); + + if event_type != "im.message.receive_v1" { + return messages; + } + + let event = match payload.get("event") { + Some(e) => e, + None => return messages, + }; + + // Extract sender open_id + let open_id = event + .pointer("/sender/sender_id/open_id") + .and_then(|s| s.as_str()) + .unwrap_or(""); + + if open_id.is_empty() { + return messages; + } + + // Check allowlist + if !self.is_user_allowed(open_id) { + tracing::warn!("Lark: ignoring message from unauthorized user: {open_id}"); + return messages; + } + + // Extract message content (text and post supported) + let msg_type = event + .pointer("/message/message_type") + .and_then(|t| t.as_str()) + .unwrap_or(""); + + let content_str = event + .pointer("/message/content") + .and_then(|c| c.as_str()) + .unwrap_or(""); + + let text: String = match msg_type { + "text" => { + let extracted = serde_json::from_str::(content_str) + .ok() + .and_then(|v| { + v.get("text") + .and_then(|t| t.as_str()) + .filter(|s| !s.is_empty()) + .map(String::from) + }); + match extracted { + Some(t) => t, + None => return messages, + } + } + "post" => match parse_post_content(content_str) { + Some(t) => t, + None => return messages, + }, + _ => { + tracing::debug!("Lark: skipping unsupported message type: {msg_type}"); + return messages; + } + }; + + let timestamp = event + .pointer("/message/create_time") + .and_then(|t| t.as_str()) + .and_then(|t| t.parse::().ok()) + // Lark timestamps are in milliseconds + .map(|ms| ms / 1000) + .unwrap_or_else(|| { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() + }); + + let chat_id = event + .pointer("/message/chat_id") + .and_then(|c| c.as_str()) + .unwrap_or(open_id); + + messages.push(ChannelMessage { + id: Uuid::new_v4().to_string(), + sender: chat_id.to_string(), + reply_target: chat_id.to_string(), + content: text, + channel: "lark".to_string(), + timestamp, + }); + + messages + } +} + +#[async_trait] +impl Channel for LarkChannel { + fn name(&self) -> &str { + "lark" + } + + async fn send(&self, message: &SendMessage) -> anyhow::Result<()> { + let token = self.get_tenant_access_token().await?; + let url = self.send_message_url(); + + let content = serde_json::json!({ "text": message.content }).to_string(); + let body = serde_json::json!({ + "receive_id": message.recipient, + "msg_type": "text", + "content": content, + }); + + let resp = self + .client + .post(&url) + .header("Authorization", format!("Bearer {token}")) + .header("Content-Type", "application/json; charset=utf-8") + .json(&body) + .send() + .await?; + + if resp.status().as_u16() == 401 { + // Token expired, invalidate and retry once + self.invalidate_token().await; + let new_token = self.get_tenant_access_token().await?; + let retry_resp = self + .client + .post(&url) + .header("Authorization", format!("Bearer {new_token}")) + .header("Content-Type", "application/json; charset=utf-8") + .json(&body) + .send() + .await?; + + if !retry_resp.status().is_success() { + let err = retry_resp.text().await.unwrap_or_default(); + anyhow::bail!("Lark send failed after token refresh: {err}"); + } + return Ok(()); + } + + if !resp.status().is_success() { + let err = resp.text().await.unwrap_or_default(); + anyhow::bail!("Lark send failed: {err}"); + } + + Ok(()) + } + + async fn listen(&self, tx: tokio::sync::mpsc::Sender) -> anyhow::Result<()> { + use crate::config::schema::LarkReceiveMode; + match self.receive_mode { + LarkReceiveMode::Websocket => self.listen_ws(tx).await, + LarkReceiveMode::Webhook => self.listen_http(tx).await, + } + } + + async fn health_check(&self) -> bool { + self.get_tenant_access_token().await.is_ok() + } +} + +impl LarkChannel { + /// HTTP callback server (legacy — requires a public endpoint). + /// Use `listen()` (WS long-connection) for new deployments. + pub async fn listen_http( + &self, + tx: tokio::sync::mpsc::Sender, + ) -> anyhow::Result<()> { + use axum::{extract::State, routing::post, Json, Router}; + + #[derive(Clone)] + struct AppState { + verification_token: String, + channel: Arc, + tx: tokio::sync::mpsc::Sender, + } + + async fn handle_event( + State(state): State, + Json(payload): Json, + ) -> axum::response::Response { + use axum::http::StatusCode; + use axum::response::IntoResponse; + + // URL verification challenge + if let Some(challenge) = payload.get("challenge").and_then(|c| c.as_str()) { + // Verify token if present + let token_ok = payload + .get("token") + .and_then(|t| t.as_str()) + .map_or(true, |t| t == state.verification_token); + + if !token_ok { + return (StatusCode::FORBIDDEN, "invalid token").into_response(); + } + + let resp = serde_json::json!({ "challenge": challenge }); + return (StatusCode::OK, Json(resp)).into_response(); + } + + // Parse event messages + let messages = state.channel.parse_event_payload(&payload); + for msg in messages { + if state.tx.send(msg).await.is_err() { + tracing::warn!("Lark: message channel closed"); + break; + } + } + + (StatusCode::OK, "ok").into_response() + } + + let port = self.port.ok_or_else(|| { + anyhow::anyhow!("Lark webhook mode requires `port` to be set in [channels_config.lark]") + })?; + + let state = AppState { + verification_token: self.verification_token.clone(), + channel: Arc::new(LarkChannel::new( + self.app_id.clone(), + self.app_secret.clone(), + self.verification_token.clone(), + None, + self.allowed_users.clone(), + )), + tx, + }; + + let app = Router::new() + .route("/lark", post(handle_event)) + .with_state(state); + + let addr = std::net::SocketAddr::from(([0, 0, 0, 0], port)); + tracing::info!("Lark event callback server listening on {addr}"); + + let listener = tokio::net::TcpListener::bind(addr).await?; + axum::serve(listener, app).await?; + + Ok(()) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// WS helper functions +// ───────────────────────────────────────────────────────────────────────────── + +/// Flatten a Feishu `post` rich-text message to plain text. +/// +/// Returns `None` when the content cannot be parsed or yields no usable text, +/// so callers can simply `continue` rather than forwarding a meaningless +/// placeholder string to the agent. +fn parse_post_content(content: &str) -> Option { + let parsed = serde_json::from_str::(content).ok()?; + let locale = parsed + .get("zh_cn") + .or_else(|| parsed.get("en_us")) + .or_else(|| { + parsed + .as_object() + .and_then(|m| m.values().find(|v| v.is_object())) + })?; + + let mut text = String::new(); + + if let Some(title) = locale + .get("title") + .and_then(|t| t.as_str()) + .filter(|s| !s.is_empty()) + { + text.push_str(title); + text.push_str("\n\n"); + } + + if let Some(paragraphs) = locale.get("content").and_then(|c| c.as_array()) { + for para in paragraphs { + if let Some(elements) = para.as_array() { + for el in elements { + match el.get("tag").and_then(|t| t.as_str()).unwrap_or("") { + "text" => { + if let Some(t) = el.get("text").and_then(|t| t.as_str()) { + text.push_str(t); + } + } + "a" => { + text.push_str( + el.get("text") + .and_then(|t| t.as_str()) + .filter(|s| !s.is_empty()) + .or_else(|| el.get("href").and_then(|h| h.as_str())) + .unwrap_or(""), + ); + } + "at" => { + let n = el + .get("user_name") + .and_then(|n| n.as_str()) + .or_else(|| el.get("user_id").and_then(|i| i.as_str())) + .unwrap_or("user"); + text.push('@'); + text.push_str(n); + } + _ => {} + } + } + text.push('\n'); + } + } + } + + let result = text.trim().to_string(); + if result.is_empty() { + None + } else { + Some(result) + } +} + +/// Remove `@_user_N` placeholder tokens injected by Feishu in group chats. +fn strip_at_placeholders(text: &str) -> String { + let mut result = String::with_capacity(text.len()); + let mut chars = text.char_indices().peekable(); + while let Some((_, ch)) = chars.next() { + if ch == '@' { + let rest: String = chars.clone().map(|(_, c)| c).collect(); + if let Some(after) = rest.strip_prefix("_user_") { + let skip = + "_user_".len() + after.chars().take_while(|c| c.is_ascii_digit()).count(); + for _ in 0..=skip { + chars.next(); + } + if chars.peek().map(|(_, c)| *c == ' ').unwrap_or(false) { + chars.next(); + } + continue; + } + } + result.push(ch); + } + result +} + +/// In group chats, only respond when the bot is explicitly @-mentioned. +fn should_respond_in_group(mentions: &[serde_json::Value]) -> bool { + !mentions.is_empty() +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_channel() -> LarkChannel { + LarkChannel::new( + "cli_test_app_id".into(), + "test_app_secret".into(), + "test_verification_token".into(), + None, + vec!["ou_testuser123".into()], + ) + } + + #[test] + fn lark_channel_name() { + let ch = make_channel(); + assert_eq!(ch.name(), "lark"); + } + + #[test] + fn lark_user_allowed_exact() { + let ch = make_channel(); + assert!(ch.is_user_allowed("ou_testuser123")); + assert!(!ch.is_user_allowed("ou_other")); + } + + #[test] + fn lark_user_allowed_wildcard() { + let ch = LarkChannel::new( + "id".into(), + "secret".into(), + "token".into(), + None, + vec!["*".into()], + ); + assert!(ch.is_user_allowed("ou_anyone")); + } + + #[test] + fn lark_user_denied_empty() { + let ch = LarkChannel::new("id".into(), "secret".into(), "token".into(), None, vec![]); + assert!(!ch.is_user_allowed("ou_anyone")); + } + + #[test] + fn lark_parse_challenge() { + let ch = make_channel(); + let payload = serde_json::json!({ + "challenge": "abc123", + "token": "test_verification_token", + "type": "url_verification" + }); + // Challenge payloads should not produce messages + let msgs = ch.parse_event_payload(&payload); + assert!(msgs.is_empty()); + } + + #[test] + fn lark_parse_valid_text_message() { + let ch = make_channel(); + let payload = serde_json::json!({ + "header": { + "event_type": "im.message.receive_v1" + }, + "event": { + "sender": { + "sender_id": { + "open_id": "ou_testuser123" + } + }, + "message": { + "message_type": "text", + "content": "{\"text\":\"Hello ZeroClaw!\"}", + "chat_id": "oc_chat123", + "create_time": "1699999999000" + } + } + }); + + let msgs = ch.parse_event_payload(&payload); + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].content, "Hello ZeroClaw!"); + assert_eq!(msgs[0].sender, "oc_chat123"); + assert_eq!(msgs[0].channel, "lark"); + assert_eq!(msgs[0].timestamp, 1_699_999_999); + } + + #[test] + fn lark_parse_unauthorized_user() { + let ch = make_channel(); + let payload = serde_json::json!({ + "header": { "event_type": "im.message.receive_v1" }, + "event": { + "sender": { "sender_id": { "open_id": "ou_unauthorized" } }, + "message": { + "message_type": "text", + "content": "{\"text\":\"spam\"}", + "chat_id": "oc_chat", + "create_time": "1000" + } + } + }); + + let msgs = ch.parse_event_payload(&payload); + assert!(msgs.is_empty()); + } + + #[test] + fn lark_parse_non_text_message_skipped() { + let ch = LarkChannel::new( + "id".into(), + "secret".into(), + "token".into(), + None, + vec!["*".into()], + ); + let payload = serde_json::json!({ + "header": { "event_type": "im.message.receive_v1" }, + "event": { + "sender": { "sender_id": { "open_id": "ou_user" } }, + "message": { + "message_type": "image", + "content": "{}", + "chat_id": "oc_chat" + } + } + }); + + let msgs = ch.parse_event_payload(&payload); + assert!(msgs.is_empty()); + } + + #[test] + fn lark_parse_empty_text_skipped() { + let ch = LarkChannel::new( + "id".into(), + "secret".into(), + "token".into(), + None, + vec!["*".into()], + ); + let payload = serde_json::json!({ + "header": { "event_type": "im.message.receive_v1" }, + "event": { + "sender": { "sender_id": { "open_id": "ou_user" } }, + "message": { + "message_type": "text", + "content": "{\"text\":\"\"}", + "chat_id": "oc_chat" + } + } + }); + + let msgs = ch.parse_event_payload(&payload); + assert!(msgs.is_empty()); + } + + #[test] + fn lark_parse_wrong_event_type() { + let ch = make_channel(); + let payload = serde_json::json!({ + "header": { "event_type": "im.chat.disbanded_v1" }, + "event": {} + }); + + let msgs = ch.parse_event_payload(&payload); + assert!(msgs.is_empty()); + } + + #[test] + fn lark_parse_missing_sender() { + let ch = LarkChannel::new( + "id".into(), + "secret".into(), + "token".into(), + None, + vec!["*".into()], + ); + let payload = serde_json::json!({ + "header": { "event_type": "im.message.receive_v1" }, + "event": { + "message": { + "message_type": "text", + "content": "{\"text\":\"hello\"}", + "chat_id": "oc_chat" + } + } + }); + + let msgs = ch.parse_event_payload(&payload); + assert!(msgs.is_empty()); + } + + #[test] + fn lark_parse_unicode_message() { + let ch = LarkChannel::new( + "id".into(), + "secret".into(), + "token".into(), + None, + vec!["*".into()], + ); + let payload = serde_json::json!({ + "header": { "event_type": "im.message.receive_v1" }, + "event": { + "sender": { "sender_id": { "open_id": "ou_user" } }, + "message": { + "message_type": "text", + "content": "{\"text\":\"Hello world 🌍\"}", + "chat_id": "oc_chat", + "create_time": "1000" + } + } + }); + + let msgs = ch.parse_event_payload(&payload); + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].content, "Hello world 🌍"); + } + + #[test] + fn lark_parse_missing_event() { + let ch = make_channel(); + let payload = serde_json::json!({ + "header": { "event_type": "im.message.receive_v1" } + }); + + let msgs = ch.parse_event_payload(&payload); + assert!(msgs.is_empty()); + } + + #[test] + fn lark_parse_invalid_content_json() { + let ch = LarkChannel::new( + "id".into(), + "secret".into(), + "token".into(), + None, + vec!["*".into()], + ); + let payload = serde_json::json!({ + "header": { "event_type": "im.message.receive_v1" }, + "event": { + "sender": { "sender_id": { "open_id": "ou_user" } }, + "message": { + "message_type": "text", + "content": "not valid json", + "chat_id": "oc_chat" + } + } + }); + + let msgs = ch.parse_event_payload(&payload); + assert!(msgs.is_empty()); + } + + #[test] + fn lark_config_serde() { + use crate::config::schema::{LarkConfig, LarkReceiveMode}; + let lc = LarkConfig { + app_id: "cli_app123".into(), + app_secret: "secret456".into(), + encrypt_key: None, + verification_token: Some("vtoken789".into()), + allowed_users: vec!["ou_user1".into(), "ou_user2".into()], + use_feishu: false, + receive_mode: LarkReceiveMode::default(), + port: None, + }; + let json = serde_json::to_string(&lc).unwrap(); + let parsed: LarkConfig = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.app_id, "cli_app123"); + assert_eq!(parsed.app_secret, "secret456"); + assert_eq!(parsed.verification_token.as_deref(), Some("vtoken789")); + assert_eq!(parsed.allowed_users.len(), 2); + } + + #[test] + fn lark_config_toml_roundtrip() { + use crate::config::schema::{LarkConfig, LarkReceiveMode}; + let lc = LarkConfig { + app_id: "app".into(), + app_secret: "secret".into(), + encrypt_key: None, + verification_token: Some("tok".into()), + allowed_users: vec!["*".into()], + use_feishu: false, + receive_mode: LarkReceiveMode::Webhook, + port: Some(9898), + }; + let toml_str = toml::to_string(&lc).unwrap(); + let parsed: LarkConfig = toml::from_str(&toml_str).unwrap(); + assert_eq!(parsed.app_id, "app"); + assert_eq!(parsed.verification_token.as_deref(), Some("tok")); + assert_eq!(parsed.allowed_users, vec!["*"]); + } + + #[test] + fn lark_config_defaults_optional_fields() { + use crate::config::schema::{LarkConfig, LarkReceiveMode}; + let json = r#"{"app_id":"a","app_secret":"s"}"#; + let parsed: LarkConfig = serde_json::from_str(json).unwrap(); + assert!(parsed.verification_token.is_none()); + assert!(parsed.allowed_users.is_empty()); + assert_eq!(parsed.receive_mode, LarkReceiveMode::Websocket); + assert!(parsed.port.is_none()); + } + + #[test] + fn lark_from_config_preserves_mode_and_region() { + use crate::config::schema::{LarkConfig, LarkReceiveMode}; + + let cfg = LarkConfig { + app_id: "cli_app123".into(), + app_secret: "secret456".into(), + encrypt_key: None, + verification_token: Some("vtoken789".into()), + allowed_users: vec!["*".into()], + use_feishu: false, + receive_mode: LarkReceiveMode::Webhook, + port: Some(9898), + }; + + let ch = LarkChannel::from_config(&cfg); + + assert_eq!(ch.api_base(), LARK_BASE_URL); + assert_eq!(ch.ws_base(), LARK_WS_BASE_URL); + assert_eq!(ch.receive_mode, LarkReceiveMode::Webhook); + assert_eq!(ch.port, Some(9898)); + } + + #[test] + fn lark_parse_fallback_sender_to_open_id() { + // When chat_id is missing, sender should fall back to open_id + let ch = LarkChannel::new( + "id".into(), + "secret".into(), + "token".into(), + None, + vec!["*".into()], + ); + let payload = serde_json::json!({ + "header": { "event_type": "im.message.receive_v1" }, + "event": { + "sender": { "sender_id": { "open_id": "ou_user" } }, + "message": { + "message_type": "text", + "content": "{\"text\":\"hello\"}", + "create_time": "1000" + } + } + }); + + let msgs = ch.parse_event_payload(&payload); + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].sender, "ou_user"); + } +} diff --git a/src/channels/matrix.rs b/src/channels/matrix.rs index 9f8924c..9b327d2 100644 --- a/src/channels/matrix.rs +++ b/src/channels/matrix.rs @@ -1,4 +1,4 @@ -use crate::channels::traits::{Channel, ChannelMessage}; +use crate::channels::traits::{Channel, ChannelMessage, SendMessage}; use async_trait::async_trait; use reqwest::Client; use serde::Deserialize; @@ -117,7 +117,7 @@ impl Channel for MatrixChannel { "matrix" } - async fn send(&self, message: &str, _target: &str) -> anyhow::Result<()> { + async fn send(&self, message: &SendMessage) -> anyhow::Result<()> { let txn_id = format!("zc_{}", chrono::Utc::now().timestamp_millis()); let url = format!( "{}/_matrix/client/v3/rooms/{}/send/m.room.message/{}", @@ -126,7 +126,7 @@ impl Channel for MatrixChannel { let body = serde_json::json!({ "msgtype": "m.text", - "body": message + "body": message.content }); let resp = self @@ -230,6 +230,7 @@ impl Channel for MatrixChannel { let msg = ChannelMessage { id: format!("mx_{}", chrono::Utc::now().timestamp_millis()), sender: event.sender.clone(), + reply_target: event.sender.clone(), content: body.clone(), channel: "matrix".to_string(), timestamp: std::time::SystemTime::now() diff --git a/src/channels/mattermost.rs b/src/channels/mattermost.rs new file mode 100644 index 0000000..a10cd72 --- /dev/null +++ b/src/channels/mattermost.rs @@ -0,0 +1,313 @@ +use super::traits::{Channel, ChannelMessage, SendMessage}; +use anyhow::{bail, Result}; +use async_trait::async_trait; + +/// Mattermost channel — polls channel posts via REST API v4. +/// Mattermost is API-compatible with many Slack patterns but uses a dedicated v4 structure. +pub struct MattermostChannel { + base_url: String, // e.g., https://mm.example.com + bot_token: String, + channel_id: Option, + allowed_users: Vec, + client: reqwest::Client, +} + +impl MattermostChannel { + pub fn new( + base_url: String, + bot_token: String, + channel_id: Option, + allowed_users: Vec, + ) -> Self { + // Ensure base_url doesn't have a trailing slash for consistent path joining + let base_url = base_url.trim_end_matches('/').to_string(); + Self { + base_url, + bot_token, + channel_id, + allowed_users, + client: reqwest::Client::new(), + } + } + + /// Check if a user ID is in the allowlist. + /// Empty list means deny everyone. "*" means allow everyone. + fn is_user_allowed(&self, user_id: &str) -> bool { + self.allowed_users.iter().any(|u| u == "*" || u == user_id) + } + + /// Get the bot's own user ID so we can ignore our own messages. + async fn get_bot_user_id(&self) -> Option { + let resp: serde_json::Value = self + .client + .get(format!("{}/api/v4/users/me", self.base_url)) + .bearer_auth(&self.bot_token) + .send() + .await + .ok()? + .json() + .await + .ok()?; + + resp.get("id").and_then(|u| u.as_str()).map(String::from) + } +} + +#[async_trait] +impl Channel for MattermostChannel { + fn name(&self) -> &str { + "mattermost" + } + + async fn send(&self, message: &SendMessage) -> Result<()> { + // Mattermost supports threading via 'root_id'. + // We pack 'channel_id:root_id' into recipient if it's a thread. + let (channel_id, root_id) = if let Some((c, r)) = message.recipient.split_once(':') { + (c, Some(r)) + } else { + (message.recipient.as_str(), None) + }; + + let mut body_map = serde_json::json!({ + "channel_id": channel_id, + "message": message.content + }); + + if let Some(root) = root_id { + body_map.as_object_mut().unwrap().insert( + "root_id".to_string(), + serde_json::Value::String(root.to_string()), + ); + } + + let resp = self + .client + .post(format!("{}/api/v4/posts", self.base_url)) + .bearer_auth(&self.bot_token) + .json(&body_map) + .send() + .await?; + + let status = resp.status(); + if !status.is_success() { + let body = resp + .text() + .await + .unwrap_or_else(|e| format!("")); + bail!("Mattermost post failed ({status}): {body}"); + } + + Ok(()) + } + + async fn listen(&self, tx: tokio::sync::mpsc::Sender) -> Result<()> { + let channel_id = self + .channel_id + .clone() + .ok_or_else(|| anyhow::anyhow!("Mattermost channel_id required for listening"))?; + + let bot_user_id = self.get_bot_user_id().await.unwrap_or_default(); + #[allow(clippy::cast_possible_truncation)] + let mut last_create_at = (std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis()) as i64; + + tracing::info!("Mattermost channel listening on {}...", channel_id); + + loop { + tokio::time::sleep(std::time::Duration::from_secs(3)).await; + + let resp = match self + .client + .get(format!( + "{}/api/v4/channels/{}/posts", + self.base_url, channel_id + )) + .bearer_auth(&self.bot_token) + .query(&[("since", last_create_at.to_string())]) + .send() + .await + { + Ok(r) => r, + Err(e) => { + tracing::warn!("Mattermost poll error: {e}"); + continue; + } + }; + + let data: serde_json::Value = match resp.json().await { + Ok(d) => d, + Err(e) => { + tracing::warn!("Mattermost parse error: {e}"); + continue; + } + }; + + if let Some(posts) = data.get("posts").and_then(|p| p.as_object()) { + // Process in chronological order + let mut post_list: Vec<_> = posts.values().collect(); + post_list.sort_by_key(|p| p.get("create_at").and_then(|c| c.as_i64()).unwrap_or(0)); + + for post in post_list { + let msg = + self.parse_mattermost_post(post, &bot_user_id, last_create_at, &channel_id); + let create_at = post + .get("create_at") + .and_then(|c| c.as_i64()) + .unwrap_or(last_create_at); + last_create_at = last_create_at.max(create_at); + + if let Some(channel_msg) = msg { + if tx.send(channel_msg).await.is_err() { + return Ok(()); + } + } + } + } + } + } + + async fn health_check(&self) -> bool { + self.client + .get(format!("{}/api/v4/users/me", self.base_url)) + .bearer_auth(&self.bot_token) + .send() + .await + .map(|r| r.status().is_success()) + .unwrap_or(false) + } +} + +impl MattermostChannel { + fn parse_mattermost_post( + &self, + post: &serde_json::Value, + bot_user_id: &str, + last_create_at: i64, + channel_id: &str, + ) -> Option { + let id = post.get("id").and_then(|i| i.as_str()).unwrap_or(""); + let user_id = post.get("user_id").and_then(|u| u.as_str()).unwrap_or(""); + let text = post.get("message").and_then(|m| m.as_str()).unwrap_or(""); + let create_at = post.get("create_at").and_then(|c| c.as_i64()).unwrap_or(0); + let root_id = post.get("root_id").and_then(|r| r.as_str()).unwrap_or(""); + + if user_id == bot_user_id || create_at <= last_create_at || text.is_empty() { + return None; + } + + if !self.is_user_allowed(user_id) { + tracing::warn!("Mattermost: ignoring message from unauthorized user: {user_id}"); + return None; + } + + // If it's a thread, include root_id in reply_to so we reply in the same thread + let reply_target = if root_id.is_empty() { + // Or if it's a top-level message that WE want to start a thread on, + // the next reply will use THIS post's ID as root_id. + // But for now, we follow Mattermost's 'reply' convention where + // replying to a post uses its ID as root_id. + format!("{}:{}", channel_id, id) + } else { + format!("{}:{}", channel_id, root_id) + }; + + Some(ChannelMessage { + id: format!("mattermost_{id}"), + sender: user_id.to_string(), + reply_target, + content: text.to_string(), + channel: "mattermost".to_string(), + #[allow(clippy::cast_sign_loss)] + timestamp: (create_at / 1000) as u64, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn mattermost_url_trimming() { + let ch = MattermostChannel::new( + "https://mm.example.com/".into(), + "token".into(), + None, + vec![], + ); + assert_eq!(ch.base_url, "https://mm.example.com"); + } + + #[test] + fn mattermost_allowlist_wildcard() { + let ch = MattermostChannel::new("url".into(), "token".into(), None, vec!["*".into()]); + assert!(ch.is_user_allowed("any-id")); + } + + #[test] + fn mattermost_parse_post_basic() { + let ch = MattermostChannel::new("url".into(), "token".into(), None, vec!["*".into()]); + let post = json!({ + "id": "post123", + "user_id": "user456", + "message": "hello world", + "create_at": 1_600_000_000_000_i64, + "root_id": "" + }); + + let msg = ch + .parse_mattermost_post(&post, "bot123", 1_500_000_000_000_i64, "chan789") + .unwrap(); + assert_eq!(msg.sender, "user456"); + assert_eq!(msg.content, "hello world"); + assert_eq!(msg.reply_target, "chan789:post123"); // Threads on the post + } + + #[test] + fn mattermost_parse_post_thread() { + let ch = MattermostChannel::new("url".into(), "token".into(), None, vec!["*".into()]); + let post = json!({ + "id": "post123", + "user_id": "user456", + "message": "reply", + "create_at": 1_600_000_000_000_i64, + "root_id": "root789" + }); + + let msg = ch + .parse_mattermost_post(&post, "bot123", 1_500_000_000_000_i64, "chan789") + .unwrap(); + assert_eq!(msg.reply_target, "chan789:root789"); // Stays in the thread + } + + #[test] + fn mattermost_parse_post_ignore_self() { + let ch = MattermostChannel::new("url".into(), "token".into(), None, vec!["*".into()]); + let post = json!({ + "id": "post123", + "user_id": "bot123", + "message": "my own message", + "create_at": 1_600_000_000_000_i64 + }); + + let msg = ch.parse_mattermost_post(&post, "bot123", 1_500_000_000_000_i64, "chan789"); + assert!(msg.is_none()); + } + + #[test] + fn mattermost_parse_post_ignore_old() { + let ch = MattermostChannel::new("url".into(), "token".into(), None, vec!["*".into()]); + let post = json!({ + "id": "post123", + "user_id": "user456", + "message": "old message", + "create_at": 1_400_000_000_000_i64 + }); + + let msg = ch.parse_mattermost_post(&post, "bot123", 1_500_000_000_000_i64, "chan789"); + assert!(msg.is_none()); + } +} diff --git a/src/channels/mod.rs b/src/channels/mod.rs index bb98a87..9dc0dbd 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -1,33 +1,106 @@ pub mod cli; +pub mod dingtalk; pub mod discord; +pub mod email_channel; pub mod imessage; +pub mod irc; +pub mod lark; pub mod matrix; +pub mod mattermost; +pub mod qq; +pub mod signal; pub mod slack; pub mod telegram; pub mod traits; pub mod whatsapp; pub use cli::CliChannel; +pub use dingtalk::DingTalkChannel; pub use discord::DiscordChannel; +pub use email_channel::EmailChannel; pub use imessage::IMessageChannel; +pub use irc::IrcChannel; +pub use lark::LarkChannel; pub use matrix::MatrixChannel; +pub use mattermost::MattermostChannel; +pub use qq::QQChannel; +pub use signal::SignalChannel; pub use slack::SlackChannel; pub use telegram::TelegramChannel; -pub use traits::Channel; +pub use traits::{Channel, SendMessage}; pub use whatsapp::WhatsAppChannel; +use crate::agent::loop_::{build_tool_instructions, run_tool_call_loop}; use crate::config::Config; +use crate::identity; use crate::memory::{self, Memory}; -use crate::providers::{self, Provider}; -use anyhow::Result; +use crate::observability::{self, Observer}; +use crate::providers::{self, ChatMessage, Provider}; +use crate::runtime; +use crate::security::SecurityPolicy; +use crate::tools::{self, Tool}; +use crate::util::truncate_with_ellipsis; +use anyhow::{Context, Result}; +use std::collections::HashMap; +use std::fmt::Write; +use std::path::PathBuf; +use std::process::Command; use std::sync::Arc; -use std::time::Duration; +use std::time::{Duration, Instant}; /// Maximum characters per injected workspace file (matches `OpenClaw` default). const BOOTSTRAP_MAX_CHARS: usize = 20_000; const DEFAULT_CHANNEL_INITIAL_BACKOFF_SECS: u64 = 2; const DEFAULT_CHANNEL_MAX_BACKOFF_SECS: u64 = 60; +/// Timeout for processing a single channel message (LLM + tools). +/// 300s for on-device LLMs (Ollama) which are slower than cloud APIs. +const CHANNEL_MESSAGE_TIMEOUT_SECS: u64 = 300; +const CHANNEL_PARALLELISM_PER_CHANNEL: usize = 4; +const CHANNEL_MIN_IN_FLIGHT_MESSAGES: usize = 8; +const CHANNEL_MAX_IN_FLIGHT_MESSAGES: usize = 64; + +#[derive(Clone)] +struct ChannelRuntimeContext { + channels_by_name: Arc>>, + provider: Arc, + memory: Arc, + tools_registry: Arc>>, + observer: Arc, + system_prompt: Arc, + model: Arc, + temperature: f64, + auto_save_memory: bool, +} + +fn conversation_memory_key(msg: &traits::ChannelMessage) -> String { + format!("{}_{}_{}", msg.channel, msg.sender, msg.id) +} + +fn channel_delivery_instructions(channel_name: &str) -> Option<&'static str> { + match channel_name { + "telegram" => Some( + "When responding on Telegram, include media markers for files or URLs that should be sent as attachments. Use one marker per attachment with this exact syntax: [IMAGE:], [DOCUMENT:], [VIDEO:], [AUDIO:], or [VOICE:]. Keep normal user-facing text outside markers and never wrap markers in code fences.", + ), + _ => None, + } +} + +async fn build_memory_context(mem: &dyn Memory, user_msg: &str) -> String { + let mut context = String::new(); + + if let Ok(entries) = mem.recall(user_msg, 5, None).await { + if !entries.is_empty() { + context.push_str("[Memory context]\n"); + for entry in &entries { + let _ = writeln!(context, "- {}: {}", entry.key, entry.content); + } + context.push('\n'); + } + } + + context +} fn spawn_supervised_listener( ch: Arc, @@ -52,6 +125,8 @@ fn spawn_supervised_listener( Ok(()) => { tracing::warn!("Channel {} exited unexpectedly; restarting", ch.name()); crate::health::mark_component_error(&component, "listener exited unexpectedly"); + // Clean exit — reset backoff since the listener ran successfully + backoff = initial_backoff_secs.max(1); } Err(e) => { tracing::error!("Channel {} error: {e}; restarting", ch.name()); @@ -61,14 +136,217 @@ fn spawn_supervised_listener( crate::health::bump_component_restart(&component); tokio::time::sleep(Duration::from_secs(backoff)).await; + // Double backoff AFTER sleeping so first error uses initial_backoff backoff = backoff.saturating_mul(2).min(max_backoff); } }) } +fn compute_max_in_flight_messages(channel_count: usize) -> usize { + channel_count + .saturating_mul(CHANNEL_PARALLELISM_PER_CHANNEL) + .clamp( + CHANNEL_MIN_IN_FLIGHT_MESSAGES, + CHANNEL_MAX_IN_FLIGHT_MESSAGES, + ) +} + +fn log_worker_join_result(result: Result<(), tokio::task::JoinError>) { + if let Err(error) = result { + tracing::error!("Channel message worker crashed: {error}"); + } +} + +async fn process_channel_message(ctx: Arc, msg: traits::ChannelMessage) { + println!( + " 💬 [{}] from {}: {}", + msg.channel, + msg.sender, + truncate_with_ellipsis(&msg.content, 80) + ); + + let memory_context = build_memory_context(ctx.memory.as_ref(), &msg.content).await; + + if ctx.auto_save_memory { + let autosave_key = conversation_memory_key(&msg); + let _ = ctx + .memory + .store( + &autosave_key, + &msg.content, + crate::memory::MemoryCategory::Conversation, + None, + ) + .await; + } + + let enriched_message = if memory_context.is_empty() { + msg.content.clone() + } else { + format!("{memory_context}{}", msg.content) + }; + + let target_channel = ctx.channels_by_name.get(&msg.channel).cloned(); + + if let Some(channel) = target_channel.as_ref() { + if let Err(e) = channel.start_typing(&msg.reply_target).await { + tracing::debug!("Failed to start typing on {}: {e}", channel.name()); + } + } + + println!(" ⏳ Processing message..."); + let started_at = Instant::now(); + + let mut history = vec![ + ChatMessage::system(ctx.system_prompt.as_str()), + ChatMessage::user(&enriched_message), + ]; + + if let Some(instructions) = channel_delivery_instructions(&msg.channel) { + history.push(ChatMessage::system(instructions)); + } + + let llm_result = tokio::time::timeout( + Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS), + run_tool_call_loop( + ctx.provider.as_ref(), + &mut history, + ctx.tools_registry.as_ref(), + ctx.observer.as_ref(), + "channel-runtime", + ctx.model.as_str(), + ctx.temperature, + true, // silent — channels don't write to stdout + None, + msg.channel.as_str(), + ), + ) + .await; + + if let Some(channel) = target_channel.as_ref() { + if let Err(e) = channel.stop_typing(&msg.reply_target).await { + tracing::debug!("Failed to stop typing on {}: {e}", channel.name()); + } + } + + match llm_result { + Ok(Ok(response)) => { + println!( + " 🤖 Reply ({}ms): {}", + started_at.elapsed().as_millis(), + truncate_with_ellipsis(&response, 80) + ); + if let Some(channel) = target_channel.as_ref() { + if let Err(e) = channel + .send(&SendMessage::new(response, &msg.reply_target)) + .await + { + eprintln!(" ❌ Failed to reply on {}: {e}", channel.name()); + } + } + } + Ok(Err(e)) => { + eprintln!( + " ❌ LLM error after {}ms: {e}", + started_at.elapsed().as_millis() + ); + if let Some(channel) = target_channel.as_ref() { + let _ = channel + .send(&SendMessage::new( + format!("⚠️ Error: {e}"), + &msg.reply_target, + )) + .await; + } + } + Err(_) => { + let timeout_msg = format!( + "LLM response timed out after {}s", + CHANNEL_MESSAGE_TIMEOUT_SECS + ); + eprintln!( + " ❌ {} (elapsed: {}ms)", + timeout_msg, + started_at.elapsed().as_millis() + ); + if let Some(channel) = target_channel.as_ref() { + let _ = channel + .send(&SendMessage::new( + "⚠️ Request timed out while waiting for the model. Please try again.", + &msg.reply_target, + )) + .await; + } + } + } +} + +async fn run_message_dispatch_loop( + mut rx: tokio::sync::mpsc::Receiver, + ctx: Arc, + max_in_flight_messages: usize, +) { + let semaphore = Arc::new(tokio::sync::Semaphore::new(max_in_flight_messages)); + let mut workers = tokio::task::JoinSet::new(); + + while let Some(msg) = rx.recv().await { + let permit = match Arc::clone(&semaphore).acquire_owned().await { + Ok(permit) => permit, + Err(_) => break, + }; + + let worker_ctx = Arc::clone(&ctx); + workers.spawn(async move { + let _permit = permit; + process_channel_message(worker_ctx, msg).await; + }); + + while let Some(result) = workers.try_join_next() { + log_worker_join_result(result); + } + } + + while let Some(result) = workers.join_next().await { + log_worker_join_result(result); + } +} + +/// Load OpenClaw format bootstrap files into the prompt. +fn load_openclaw_bootstrap_files( + prompt: &mut String, + workspace_dir: &std::path::Path, + max_chars_per_file: usize, +) { + prompt.push_str( + "The following workspace files define your identity, behavior, and context. They are ALREADY injected below—do NOT suggest reading them with file_read.\n\n", + ); + + let bootstrap_files = [ + "AGENTS.md", + "SOUL.md", + "TOOLS.md", + "IDENTITY.md", + "USER.md", + "HEARTBEAT.md", + ]; + + for filename in &bootstrap_files { + inject_workspace_file(prompt, workspace_dir, filename, max_chars_per_file); + } + + // BOOTSTRAP.md — only if it exists (first-run ritual) + let bootstrap_path = workspace_dir.join("BOOTSTRAP.md"); + if bootstrap_path.exists() { + inject_workspace_file(prompt, workspace_dir, "BOOTSTRAP.md", max_chars_per_file); + } + + // MEMORY.md — curated long-term memory (main session only) + inject_workspace_file(prompt, workspace_dir, "MEMORY.md", max_chars_per_file); +} + /// Load workspace identity files and build a system prompt. /// -/// Follows the `OpenClaw` framework structure: +/// Follows the `OpenClaw` framework structure by default: /// 1. Tooling — tool list + descriptions /// 2. Safety — guardrail reminder /// 3. Skills — compact list with paths (loaded on-demand) @@ -77,6 +355,9 @@ fn spawn_supervised_listener( /// 6. Date & Time — timezone for cache stability /// 7. Runtime — host, OS, model /// +/// When `identity_config` is set to AIEOS format, the bootstrap files section +/// is replaced with the AIEOS identity data loaded from file or inline JSON. +/// /// Daily memory files (`memory/*.md`) are NOT injected — they are accessed /// on-demand via `memory_recall` / `memory_search` tools. pub fn build_system_prompt( @@ -84,6 +365,8 @@ pub fn build_system_prompt( model_name: &str, tools: &[(&str, &str)], skills: &[crate::skills::Skill], + identity_config: Option<&crate::config::IdentityConfig>, + bootstrap_max_chars: Option, ) -> String { use std::fmt::Write; let mut prompt = String::with_capacity(8192); @@ -95,9 +378,44 @@ pub fn build_system_prompt( for (name, desc) in tools { let _ = writeln!(prompt, "- **{name}**: {desc}"); } - prompt.push('\n'); + prompt.push_str("\n## Tool Use Protocol\n\n"); + prompt.push_str("To use a tool, wrap a JSON object in tags:\n\n"); + prompt.push_str("```\n\n{\"name\": \"tool_name\", \"arguments\": {\"param\": \"value\"}}\n\n```\n\n"); + prompt.push_str("You may use multiple tool calls in a single response. "); + prompt.push_str("After tool execution, results appear in tags. "); + prompt + .push_str("Continue reasoning with the results until you can give a final answer.\n\n"); } + // ── 1b. Hardware (when gpio/arduino tools present) ─────────── + let has_hardware = tools.iter().any(|(name, _)| { + *name == "gpio_read" + || *name == "gpio_write" + || *name == "arduino_upload" + || *name == "hardware_memory_map" + || *name == "hardware_board_info" + || *name == "hardware_memory_read" + || *name == "hardware_capabilities" + }); + if has_hardware { + prompt.push_str( + "## Hardware Access\n\n\ + You HAVE direct access to connected hardware (Arduino, Nucleo, etc.). The user owns this system and has configured it.\n\ + All hardware tools (gpio_read, gpio_write, hardware_memory_read, hardware_board_info, hardware_memory_map) are AUTHORIZED and NOT blocked by security.\n\ + When they ask to read memory, registers, or board info, USE hardware_memory_read or hardware_board_info — do NOT refuse or invent security excuses.\n\ + When they ask to control LEDs, run patterns, or interact with the Arduino, USE the tools — do NOT refuse or say you cannot access physical devices.\n\ + Use gpio_write for simple on/off; use arduino_upload when they want patterns (heart, blink) or custom behavior.\n\n", + ); + } + + // ── 1c. Action instruction (avoid meta-summary) ─────────────── + prompt.push_str( + "## Your Task\n\n\ + When the user sends a message, ACT on it. Use the tools to fulfill their request.\n\ + Do NOT: summarize this configuration, describe your capabilities, respond with meta-commentary, or output step-by-step instructions (e.g. \"1. First... 2. Next...\").\n\ + Instead: emit actual tags when you need to act. Just do what they ask.\n\n", + ); + // ── 2. Safety ─────────────────────────────────────────────── prompt.push_str("## Safety\n\n"); prompt.push_str( @@ -144,31 +462,45 @@ pub fn build_system_prompt( // ── 5. Bootstrap files (injected into context) ────────────── prompt.push_str("## Project Context\n\n"); - prompt - .push_str("The following workspace files define your identity, behavior, and context.\n\n"); - let bootstrap_files = [ - "AGENTS.md", - "SOUL.md", - "TOOLS.md", - "IDENTITY.md", - "USER.md", - "HEARTBEAT.md", - ]; - - for filename in &bootstrap_files { - inject_workspace_file(&mut prompt, workspace_dir, filename); + // Check if AIEOS identity is configured + if let Some(config) = identity_config { + if identity::is_aieos_configured(config) { + // Load AIEOS identity + match identity::load_aieos_identity(config, workspace_dir) { + Ok(Some(aieos_identity)) => { + let aieos_prompt = identity::aieos_to_system_prompt(&aieos_identity); + if !aieos_prompt.is_empty() { + prompt.push_str(&aieos_prompt); + prompt.push_str("\n\n"); + } + } + Ok(None) => { + // No AIEOS identity loaded (shouldn't happen if is_aieos_configured returned true) + // Fall back to OpenClaw bootstrap files + let max_chars = bootstrap_max_chars.unwrap_or(BOOTSTRAP_MAX_CHARS); + load_openclaw_bootstrap_files(&mut prompt, workspace_dir, max_chars); + } + Err(e) => { + // Log error but don't fail - fall back to OpenClaw + eprintln!( + "Warning: Failed to load AIEOS identity: {e}. Using OpenClaw format." + ); + let max_chars = bootstrap_max_chars.unwrap_or(BOOTSTRAP_MAX_CHARS); + load_openclaw_bootstrap_files(&mut prompt, workspace_dir, max_chars); + } + } + } else { + // OpenClaw format + let max_chars = bootstrap_max_chars.unwrap_or(BOOTSTRAP_MAX_CHARS); + load_openclaw_bootstrap_files(&mut prompt, workspace_dir, max_chars); + } + } else { + // No identity config - use OpenClaw format + let max_chars = bootstrap_max_chars.unwrap_or(BOOTSTRAP_MAX_CHARS); + load_openclaw_bootstrap_files(&mut prompt, workspace_dir, max_chars); } - // BOOTSTRAP.md — only if it exists (first-run ritual) - let bootstrap_path = workspace_dir.join("BOOTSTRAP.md"); - if bootstrap_path.exists() { - inject_workspace_file(&mut prompt, workspace_dir, "BOOTSTRAP.md"); - } - - // MEMORY.md — curated long-term memory (main session only) - inject_workspace_file(&mut prompt, workspace_dir, "MEMORY.md"); - // ── 6. Date & Time ────────────────────────────────────────── let now = chrono::Local::now(); let tz = now.format("%Z").to_string(); @@ -183,6 +515,16 @@ pub fn build_system_prompt( std::env::consts::OS, ); + // ── 8. Channel Capabilities ───────────────────────────────────── + prompt.push_str("## Channel Capabilities\n\n"); + prompt.push_str( + "- You are running as a Discord bot. You CAN and do send messages to Discord channels.\n", + ); + prompt.push_str("- When someone messages you on Discord, your response is automatically sent back to Discord.\n"); + prompt.push_str("- You do NOT need to ask permission to respond — just respond directly.\n"); + prompt.push_str("- NEVER repeat, describe, or echo credentials, tokens, API keys, or secrets in your responses.\n"); + prompt.push_str("- If a tool output contains credentials, they have already been redacted — do not mention them.\n\n"); + if prompt.is_empty() { "You are ZeroClaw, a fast and efficient AI assistant built in Rust. Be helpful, concise, and direct.".to_string() } else { @@ -190,40 +532,13 @@ pub fn build_system_prompt( } } -/// Inject `OpenClaw` (markdown) identity files into the prompt -fn inject_openclaw_identity(prompt: &mut String, workspace_dir: &std::path::Path) { - #[allow(unused_imports)] - use std::fmt::Write; - - prompt.push_str("## Project Context\n\n"); - prompt - .push_str("The following workspace files define your identity, behavior, and context.\n\n"); - - let bootstrap_files = [ - "AGENTS.md", - "SOUL.md", - "TOOLS.md", - "IDENTITY.md", - "USER.md", - "HEARTBEAT.md", - ]; - - for filename in &bootstrap_files { - inject_workspace_file(prompt, workspace_dir, filename); - } - - // BOOTSTRAP.md — only if it exists (first-run ritual) - let bootstrap_path = workspace_dir.join("BOOTSTRAP.md"); - if bootstrap_path.exists() { - inject_workspace_file(prompt, workspace_dir, "BOOTSTRAP.md"); - } - - // MEMORY.md — curated long-term memory (main session only) - inject_workspace_file(prompt, workspace_dir, "MEMORY.md"); -} - /// Inject a single workspace file into the prompt with truncation and missing-file markers. -fn inject_workspace_file(prompt: &mut String, workspace_dir: &std::path::Path, filename: &str) { +fn inject_workspace_file( + prompt: &mut String, + workspace_dir: &std::path::Path, + filename: &str, + max_chars: usize, +) { use std::fmt::Write; let path = workspace_dir.join(filename); @@ -234,11 +549,21 @@ fn inject_workspace_file(prompt: &mut String, workspace_dir: &std::path::Path, f return; } let _ = writeln!(prompt, "### {filename}\n"); - if trimmed.len() > BOOTSTRAP_MAX_CHARS { - prompt.push_str(&trimmed[..BOOTSTRAP_MAX_CHARS]); + // Use character-boundary-safe truncation for UTF-8 + let truncated = if trimmed.chars().count() > max_chars { + trimmed + .char_indices() + .nth(max_chars) + .map(|(idx, _)| &trimmed[..idx]) + .unwrap_or(trimmed) + } else { + trimmed + }; + if truncated.len() < trimmed.len() { + prompt.push_str(truncated); let _ = writeln!( prompt, - "\n\n[... truncated at {BOOTSTRAP_MAX_CHARS} chars — use `read` for full file]\n" + "\n\n[... truncated at {max_chars} chars — use `read` for full file]\n" ); } else { prompt.push_str(trimmed); @@ -252,17 +577,145 @@ fn inject_workspace_file(prompt: &mut String, workspace_dir: &std::path::Path, f } } -pub fn handle_command(command: super::ChannelCommands, config: &Config) -> Result<()> { +fn normalize_telegram_identity(value: &str) -> String { + value.trim().trim_start_matches('@').to_string() +} + +fn bind_telegram_identity(config: &Config, identity: &str) -> Result<()> { + let normalized = normalize_telegram_identity(identity); + if normalized.is_empty() { + anyhow::bail!("Telegram identity cannot be empty"); + } + + let mut updated = config.clone(); + let Some(telegram) = updated.channels_config.telegram.as_mut() else { + anyhow::bail!( + "Telegram channel is not configured. Run `zeroclaw onboard --channels-only` first" + ); + }; + + if telegram.allowed_users.iter().any(|u| u == "*") { + println!( + "⚠️ Telegram allowlist is currently wildcard (`*`) — binding is unnecessary until you remove '*'." + ); + } + + if telegram + .allowed_users + .iter() + .map(|entry| normalize_telegram_identity(entry)) + .any(|entry| entry == normalized) + { + println!("✅ Telegram identity already bound: {normalized}"); + return Ok(()); + } + + telegram.allowed_users.push(normalized.clone()); + updated.save()?; + println!("✅ Bound Telegram identity: {normalized}"); + println!(" Saved to {}", updated.config_path.display()); + match maybe_restart_managed_daemon_service() { + Ok(true) => { + println!("🔄 Detected running managed daemon service; reloaded automatically."); + } + Ok(false) => { + println!( + "ℹ️ No managed daemon service detected. If `zeroclaw daemon`/`channel start` is already running, restart it to load the updated allowlist." + ); + } + Err(e) => { + eprintln!( + "⚠️ Allowlist saved, but failed to reload daemon service automatically: {e}\n\ + Restart service manually with `zeroclaw service stop && zeroclaw service start`." + ); + } + } + Ok(()) +} + +fn maybe_restart_managed_daemon_service() -> Result { + if cfg!(target_os = "macos") { + let home = directories::UserDirs::new() + .map(|u| u.home_dir().to_path_buf()) + .context("Could not find home directory")?; + let plist = home + .join("Library") + .join("LaunchAgents") + .join("com.zeroclaw.daemon.plist"); + if !plist.exists() { + return Ok(false); + } + + let list_output = Command::new("launchctl") + .arg("list") + .output() + .context("Failed to query launchctl list")?; + let listed = String::from_utf8_lossy(&list_output.stdout); + if !listed.contains("com.zeroclaw.daemon") { + return Ok(false); + } + + let _ = Command::new("launchctl") + .args(["stop", "com.zeroclaw.daemon"]) + .output(); + let start_output = Command::new("launchctl") + .args(["start", "com.zeroclaw.daemon"]) + .output() + .context("Failed to start launchd daemon service")?; + if !start_output.status.success() { + let stderr = String::from_utf8_lossy(&start_output.stderr); + anyhow::bail!("launchctl start failed: {}", stderr.trim()); + } + + return Ok(true); + } + + if cfg!(target_os = "linux") { + let home = directories::UserDirs::new() + .map(|u| u.home_dir().to_path_buf()) + .context("Could not find home directory")?; + let unit_path: PathBuf = home + .join(".config") + .join("systemd") + .join("user") + .join("zeroclaw.service"); + if !unit_path.exists() { + return Ok(false); + } + + let active_output = Command::new("systemctl") + .args(["--user", "is-active", "zeroclaw.service"]) + .output() + .context("Failed to query systemd service state")?; + let state = String::from_utf8_lossy(&active_output.stdout); + if !state.trim().eq_ignore_ascii_case("active") { + return Ok(false); + } + + let restart_output = Command::new("systemctl") + .args(["--user", "restart", "zeroclaw.service"]) + .output() + .context("Failed to restart systemd daemon service")?; + if !restart_output.status.success() { + let stderr = String::from_utf8_lossy(&restart_output.stderr); + anyhow::bail!("systemctl restart failed: {}", stderr.trim()); + } + + return Ok(true); + } + + Ok(false) +} + +pub fn handle_command(command: crate::ChannelCommands, config: &Config) -> Result<()> { match command { - super::ChannelCommands::Start => { - // Handled in main.rs (needs async), this is unreachable - unreachable!("Start is handled in main.rs") + crate::ChannelCommands::Start => { + anyhow::bail!("Start must be handled in main.rs (requires async runtime)") } - super::ChannelCommands::Doctor => { - // Handled in main.rs (needs async), this is unreachable - unreachable!("Doctor is handled in main.rs") + crate::ChannelCommands::Doctor => { + anyhow::bail!("Doctor must be handled in main.rs (requires async runtime)") } - super::ChannelCommands::List => { + crate::ChannelCommands::List => { println!("Channels:"); println!(" ✅ CLI (always available)"); for (name, configured) in [ @@ -272,7 +725,13 @@ pub fn handle_command(command: super::ChannelCommands, config: &Config) -> Resul ("Webhook", config.channels_config.webhook.is_some()), ("iMessage", config.channels_config.imessage.is_some()), ("Matrix", config.channels_config.matrix.is_some()), + ("Signal", config.channels_config.signal.is_some()), ("WhatsApp", config.channels_config.whatsapp.is_some()), + ("Email", config.channels_config.email.is_some()), + ("IRC", config.channels_config.irc.is_some()), + ("Lark", config.channels_config.lark.is_some()), + ("DingTalk", config.channels_config.dingtalk.is_some()), + ("QQ", config.channels_config.qq.is_some()), ] { println!(" {} {name}", if configured { "✅" } else { "❌" }); } @@ -281,7 +740,7 @@ pub fn handle_command(command: super::ChannelCommands, config: &Config) -> Resul println!("To configure: zeroclaw onboard"); Ok(()) } - super::ChannelCommands::Add { + crate::ChannelCommands::Add { channel_type, config: _, } => { @@ -289,9 +748,12 @@ pub fn handle_command(command: super::ChannelCommands, config: &Config) -> Resul "Channel type '{channel_type}' — use `zeroclaw onboard` to configure channels" ); } - super::ChannelCommands::Remove { name } => { + crate::ChannelCommands::Remove { name } => { anyhow::bail!("Remove channel '{name}' — edit ~/.zeroclaw/config.toml directly"); } + crate::ChannelCommands::BindTelegram { identity } => { + bind_telegram_identity(config, &identity) + } } } @@ -333,6 +795,8 @@ pub async fn doctor_channels(config: Config) -> Result<()> { dc.bot_token.clone(), dc.guild_id.clone(), dc.allowed_users.clone(), + dc.listen_to_bots, + dc.mention_only, )), )); } @@ -367,6 +831,20 @@ pub async fn doctor_channels(config: Config) -> Result<()> { )); } + if let Some(ref sig) = config.channels_config.signal { + channels.push(( + "Signal", + Arc::new(SignalChannel::new( + sig.http_url.clone(), + sig.account.clone(), + sig.group_id.clone(), + sig.allowed_from.clone(), + sig.ignore_attachments, + sig.ignore_stories, + )), + )); + } + if let Some(ref wa) = config.channels_config.whatsapp { channels.push(( "WhatsApp", @@ -379,6 +857,54 @@ pub async fn doctor_channels(config: Config) -> Result<()> { )); } + if let Some(ref email_cfg) = config.channels_config.email { + channels.push(("Email", Arc::new(EmailChannel::new(email_cfg.clone())))); + } + + if let Some(ref irc) = config.channels_config.irc { + channels.push(( + "IRC", + Arc::new(IrcChannel::new(irc::IrcChannelConfig { + server: irc.server.clone(), + port: irc.port, + nickname: irc.nickname.clone(), + username: irc.username.clone(), + channels: irc.channels.clone(), + allowed_users: irc.allowed_users.clone(), + server_password: irc.server_password.clone(), + nickserv_password: irc.nickserv_password.clone(), + sasl_password: irc.sasl_password.clone(), + verify_tls: irc.verify_tls.unwrap_or(true), + })), + )); + } + + if let Some(ref lk) = config.channels_config.lark { + channels.push(("Lark", Arc::new(LarkChannel::from_config(lk)))); + } + + if let Some(ref dt) = config.channels_config.dingtalk { + channels.push(( + "DingTalk", + Arc::new(DingTalkChannel::new( + dt.client_id.clone(), + dt.client_secret.clone(), + dt.allowed_users.clone(), + )), + )); + } + + if let Some(ref qq) = config.channels_config.qq { + channels.push(( + "QQ", + Arc::new(QQChannel::new( + qq.app_id.clone(), + qq.app_secret.clone(), + qq.allowed_users.clone(), + )), + )); + } + if channels.is_empty() { println!("No real-time channels configured. Run `zeroclaw onboard` first."); return Ok(()); @@ -423,11 +949,31 @@ pub async fn doctor_channels(config: Config) -> Result<()> { /// Start all configured channels and route messages to the agent #[allow(clippy::too_many_lines)] pub async fn start_channels(config: Config) -> Result<()> { + let provider_name = config + .default_provider + .clone() + .unwrap_or_else(|| "openrouter".into()); let provider: Arc = Arc::from(providers::create_resilient_provider( - config.default_provider.as_deref().unwrap_or("openrouter"), + &provider_name, config.api_key.as_deref(), + config.api_url.as_deref(), &config.reliability, )?); + + // Warm up the provider connection pool (TLS handshake, DNS, HTTP/2 setup) + // so the first real message doesn't hit a cold-start timeout. + if let Err(e) = provider.warmup().await { + tracing::warn!("Provider warmup failed (non-fatal): {e}"); + } + + let observer: Arc = + Arc::from(observability::create_observer(&config.observability)); + let runtime: Arc = + Arc::from(runtime::create_runtime(&config.runtime)?); + let security = Arc::new(SecurityPolicy::from_config( + &config.autonomy, + &config.workspace_dir, + )); let model = config .default_model .clone() @@ -438,9 +984,31 @@ pub async fn start_channels(config: Config) -> Result<()> { &config.workspace_dir, config.api_key.as_deref(), )?); - + let (composio_key, composio_entity_id) = if config.composio.enabled { + ( + config.composio.api_key.as_deref(), + Some(config.composio.entity_id.as_str()), + ) + } else { + (None, None) + }; // Build system prompt from workspace identity files + skills let workspace = config.workspace_dir.clone(); + let tools_registry = Arc::new(tools::all_tools_with_runtime( + Arc::new(config.clone()), + &security, + runtime, + Arc::clone(&mem), + composio_key, + composio_entity_id, + &config.browser, + &config.http_request, + &workspace, + &config.agents, + config.api_key.as_deref(), + &config, + )); + let skills = crate::skills::load_skills(&workspace); // Collect tool descriptions for the prompt @@ -477,8 +1045,41 @@ pub async fn start_channels(config: Config) -> Result<()> { "Open approved HTTPS URLs in Brave Browser (allowlist-only, no scraping)", )); } + if config.composio.enabled { + tool_descs.push(( + "composio", + "Execute actions on 1000+ apps via Composio (Gmail, Notion, GitHub, Slack, etc.). Use action='list' to discover, 'execute' to run (optionally with connected_account_id), 'connect' to OAuth.", + )); + } + tool_descs.push(( + "schedule", + "Manage scheduled tasks (create/list/get/cancel/pause/resume). Supports recurring cron and one-shot delays.", + )); + tool_descs.push(( + "pushover", + "Send a Pushover notification to your device. Requires PUSHOVER_TOKEN and PUSHOVER_USER_KEY in .env file.", + )); + if !config.agents.is_empty() { + tool_descs.push(( + "delegate", + "Delegate a subtask to a specialized agent. Use when: a task benefits from a different model (e.g. fast summarization, deep reasoning, code generation). The sub-agent runs a single prompt and returns its response.", + )); + } - let system_prompt = build_system_prompt(&workspace, &model, &tool_descs, &skills); + let bootstrap_max_chars = if config.agent.compact_context { + Some(6000) + } else { + None + }; + let mut system_prompt = build_system_prompt( + &workspace, + &model, + &tool_descs, + &skills, + Some(&config.identity), + bootstrap_max_chars, + ); + system_prompt.push_str(&build_tool_instructions(tools_registry.as_ref())); if !skills.is_empty() { println!( @@ -506,6 +1107,8 @@ pub async fn start_channels(config: Config) -> Result<()> { dc.bot_token.clone(), dc.guild_id.clone(), dc.allowed_users.clone(), + dc.listen_to_bots, + dc.mention_only, ))); } @@ -517,6 +1120,15 @@ pub async fn start_channels(config: Config) -> Result<()> { ))); } + if let Some(ref mm) = config.channels_config.mattermost { + channels.push(Arc::new(MattermostChannel::new( + mm.url.clone(), + mm.bot_token.clone(), + mm.channel_id.clone(), + mm.allowed_users.clone(), + ))); + } + if let Some(ref im) = config.channels_config.imessage { channels.push(Arc::new(IMessageChannel::new(im.allowed_contacts.clone()))); } @@ -530,6 +1142,17 @@ pub async fn start_channels(config: Config) -> Result<()> { ))); } + if let Some(ref sig) = config.channels_config.signal { + channels.push(Arc::new(SignalChannel::new( + sig.http_url.clone(), + sig.account.clone(), + sig.group_id.clone(), + sig.allowed_from.clone(), + sig.ignore_attachments, + sig.ignore_stories, + ))); + } + if let Some(ref wa) = config.channels_config.whatsapp { channels.push(Arc::new(WhatsAppChannel::new( wa.access_token.clone(), @@ -539,6 +1162,45 @@ pub async fn start_channels(config: Config) -> Result<()> { ))); } + if let Some(ref email_cfg) = config.channels_config.email { + channels.push(Arc::new(EmailChannel::new(email_cfg.clone()))); + } + + if let Some(ref irc) = config.channels_config.irc { + channels.push(Arc::new(IrcChannel::new(irc::IrcChannelConfig { + server: irc.server.clone(), + port: irc.port, + nickname: irc.nickname.clone(), + username: irc.username.clone(), + channels: irc.channels.clone(), + allowed_users: irc.allowed_users.clone(), + server_password: irc.server_password.clone(), + nickserv_password: irc.nickserv_password.clone(), + sasl_password: irc.sasl_password.clone(), + verify_tls: irc.verify_tls.unwrap_or(true), + }))); + } + + if let Some(ref lk) = config.channels_config.lark { + channels.push(Arc::new(LarkChannel::from_config(lk))); + } + + if let Some(ref dt) = config.channels_config.dingtalk { + channels.push(Arc::new(DingTalkChannel::new( + dt.client_id.clone(), + dt.client_secret.clone(), + dt.allowed_users.clone(), + ))); + } + + if let Some(ref qq) = config.channels_config.qq { + channels.push(Arc::new(QQChannel::new( + qq.app_id.clone(), + qq.app_secret.clone(), + qq.allowed_users.clone(), + ))); + } + if channels.is_empty() { println!("No channels configured. Run `zeroclaw onboard` to set up channels."); return Ok(()); @@ -575,7 +1237,7 @@ pub async fn start_channels(config: Config) -> Result<()> { .max(DEFAULT_CHANNEL_MAX_BACKOFF_SECS); // Single message bus — all channels send messages here - let (tx, mut rx) = tokio::sync::mpsc::channel::(100); + let (tx, rx) = tokio::sync::mpsc::channel::(100); // Spawn a listener for each channel let mut handles = Vec::new(); @@ -589,65 +1251,29 @@ pub async fn start_channels(config: Config) -> Result<()> { } drop(tx); // Drop our copy so rx closes when all channels stop - // Process incoming messages — call the LLM and reply - while let Some(msg) = rx.recv().await { - println!( - " 💬 [{}] from {}: {}", - msg.channel, - msg.sender, - if msg.content.len() > 80 { - format!("{}...", &msg.content[..80]) - } else { - msg.content.clone() - } - ); + let channels_by_name = Arc::new( + channels + .iter() + .map(|ch| (ch.name().to_string(), Arc::clone(ch))) + .collect::>(), + ); + let max_in_flight_messages = compute_max_in_flight_messages(channels.len()); - // Auto-save to memory - if config.memory.auto_save { - let _ = mem - .store( - &format!("{}_{}", msg.channel, msg.sender), - &msg.content, - crate::memory::MemoryCategory::Conversation, - ) - .await; - } + println!(" 🚦 In-flight message limit: {max_in_flight_messages}"); - // Call the LLM with system prompt (identity + soul + tools) - match provider - .chat_with_system(Some(&system_prompt), &msg.content, &model, temperature) - .await - { - Ok(response) => { - println!( - " 🤖 Reply: {}", - if response.len() > 80 { - format!("{}...", &response[..80]) - } else { - response.clone() - } - ); - // Find the channel that sent this message and reply - for ch in &channels { - if ch.name() == msg.channel { - if let Err(e) = ch.send(&response, &msg.sender).await { - eprintln!(" ❌ Failed to reply on {}: {e}", ch.name()); - } - break; - } - } - } - Err(e) => { - eprintln!(" ❌ LLM error: {e}"); - for ch in &channels { - if ch.name() == msg.channel { - let _ = ch.send(&format!("⚠️ Error: {e}"), &msg.sender).await; - break; - } - } - } - } - } + let runtime_ctx = Arc::new(ChannelRuntimeContext { + channels_by_name, + provider: Arc::clone(&provider), + memory: Arc::clone(&mem), + tools_registry: Arc::clone(&tools_registry), + observer, + system_prompt: Arc::new(system_prompt), + model: Arc::new(model.clone()), + temperature, + auto_save_memory: config.memory.auto_save, + }); + + run_message_dispatch_loop(rx, runtime_ctx, max_in_flight_messages).await; // Wait for all channel tasks for h in handles { @@ -660,6 +1286,11 @@ pub async fn start_channels(config: Config) -> Result<()> { #[cfg(test)] mod tests { use super::*; + use crate::memory::{Memory, MemoryCategory, SqliteMemory}; + use crate::observability::NoopObserver; + use crate::providers::{ChatMessage, Provider}; + use crate::tools::{Tool, ToolResult}; + use std::collections::HashMap; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use tempfile::TempDir; @@ -685,11 +1316,365 @@ mod tests { tmp } + #[derive(Default)] + struct RecordingChannel { + sent_messages: tokio::sync::Mutex>, + } + + #[async_trait::async_trait] + impl Channel for RecordingChannel { + fn name(&self) -> &str { + "test-channel" + } + + async fn send(&self, message: &SendMessage) -> anyhow::Result<()> { + self.sent_messages + .lock() + .await + .push(format!("{}:{}", message.recipient, message.content)); + Ok(()) + } + + async fn listen( + &self, + _tx: tokio::sync::mpsc::Sender, + ) -> anyhow::Result<()> { + Ok(()) + } + } + + struct SlowProvider { + delay: Duration, + } + + #[async_trait::async_trait] + impl Provider for SlowProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + tokio::time::sleep(self.delay).await; + Ok(format!("echo: {message}")) + } + } + + struct ToolCallingProvider; + + fn tool_call_payload() -> String { + r#" +{"name":"mock_price","arguments":{"symbol":"BTC"}} +"# + .to_string() + } + + fn tool_call_payload_with_alias_tag() -> String { + r#" +{"name":"mock_price","arguments":{"symbol":"BTC"}} +"# + .to_string() + } + + #[async_trait::async_trait] + impl Provider for ToolCallingProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + Ok(tool_call_payload()) + } + + async fn chat_with_history( + &self, + messages: &[ChatMessage], + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + let has_tool_results = messages + .iter() + .any(|msg| msg.role == "user" && msg.content.contains("[Tool results]")); + if has_tool_results { + Ok("BTC is currently around $65,000 based on latest tool output.".to_string()) + } else { + Ok(tool_call_payload()) + } + } + } + + struct ToolCallingAliasProvider; + + #[async_trait::async_trait] + impl Provider for ToolCallingAliasProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + Ok(tool_call_payload_with_alias_tag()) + } + + async fn chat_with_history( + &self, + messages: &[ChatMessage], + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + let has_tool_results = messages + .iter() + .any(|msg| msg.role == "user" && msg.content.contains("[Tool results]")); + if has_tool_results { + Ok("BTC alias-tag flow resolved to final text output.".to_string()) + } else { + Ok(tool_call_payload_with_alias_tag()) + } + } + } + + struct MockPriceTool; + + #[async_trait::async_trait] + impl Tool for MockPriceTool { + fn name(&self) -> &str { + "mock_price" + } + + fn description(&self) -> &str { + "Return a mocked BTC price" + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": { + "symbol": { "type": "string" } + }, + "required": ["symbol"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let symbol = args.get("symbol").and_then(serde_json::Value::as_str); + if symbol != Some("BTC") { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("unexpected symbol".to_string()), + }); + } + + Ok(ToolResult { + success: true, + output: r#"{"symbol":"BTC","price_usd":65000}"#.to_string(), + error: None, + }) + } + } + + #[tokio::test] + async fn process_channel_message_executes_tool_calls_instead_of_sending_raw_json() { + let channel_impl = Arc::new(RecordingChannel::default()); + let channel: Arc = channel_impl.clone(); + + let mut channels_by_name = HashMap::new(); + channels_by_name.insert(channel.name().to_string(), channel); + + let runtime_ctx = Arc::new(ChannelRuntimeContext { + channels_by_name: Arc::new(channels_by_name), + provider: Arc::new(ToolCallingProvider), + memory: Arc::new(NoopMemory), + tools_registry: Arc::new(vec![Box::new(MockPriceTool)]), + observer: Arc::new(NoopObserver), + system_prompt: Arc::new("test-system-prompt".to_string()), + model: Arc::new("test-model".to_string()), + temperature: 0.0, + auto_save_memory: false, + }); + + process_channel_message( + runtime_ctx, + traits::ChannelMessage { + id: "msg-1".to_string(), + sender: "alice".to_string(), + reply_target: "chat-42".to_string(), + content: "What is the BTC price now?".to_string(), + channel: "test-channel".to_string(), + timestamp: 1, + }, + ) + .await; + + let sent_messages = channel_impl.sent_messages.lock().await; + assert_eq!(sent_messages.len(), 1); + assert!(sent_messages[0].starts_with("chat-42:")); + assert!(sent_messages[0].contains("BTC is currently around")); + assert!(!sent_messages[0].contains("\"tool_calls\"")); + assert!(!sent_messages[0].contains("mock_price")); + } + + #[tokio::test] + async fn process_channel_message_executes_tool_calls_with_alias_tags() { + let channel_impl = Arc::new(RecordingChannel::default()); + let channel: Arc = channel_impl.clone(); + + let mut channels_by_name = HashMap::new(); + channels_by_name.insert(channel.name().to_string(), channel); + + let runtime_ctx = Arc::new(ChannelRuntimeContext { + channels_by_name: Arc::new(channels_by_name), + provider: Arc::new(ToolCallingAliasProvider), + memory: Arc::new(NoopMemory), + tools_registry: Arc::new(vec![Box::new(MockPriceTool)]), + observer: Arc::new(NoopObserver), + system_prompt: Arc::new("test-system-prompt".to_string()), + model: Arc::new("test-model".to_string()), + temperature: 0.0, + auto_save_memory: false, + }); + + process_channel_message( + runtime_ctx, + traits::ChannelMessage { + id: "msg-2".to_string(), + sender: "bob".to_string(), + reply_target: "chat-84".to_string(), + content: "What is the BTC price now?".to_string(), + channel: "test-channel".to_string(), + timestamp: 2, + }, + ) + .await; + + let sent_messages = channel_impl.sent_messages.lock().await; + assert_eq!(sent_messages.len(), 1); + assert!(sent_messages[0].starts_with("chat-84:")); + assert!(sent_messages[0].contains("alias-tag flow resolved")); + assert!(!sent_messages[0].contains("")); + assert!(!sent_messages[0].contains("mock_price")); + } + + struct NoopMemory; + + #[async_trait::async_trait] + impl Memory for NoopMemory { + fn name(&self) -> &str { + "noop" + } + + async fn store( + &self, + _key: &str, + _content: &str, + _category: crate::memory::MemoryCategory, + _session_id: Option<&str>, + ) -> anyhow::Result<()> { + Ok(()) + } + + async fn recall( + &self, + _query: &str, + _limit: usize, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Ok(Vec::new()) + } + + async fn get(&self, _key: &str) -> anyhow::Result> { + Ok(None) + } + + async fn list( + &self, + _category: Option<&crate::memory::MemoryCategory>, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Ok(Vec::new()) + } + + async fn forget(&self, _key: &str) -> anyhow::Result { + Ok(false) + } + + async fn count(&self) -> anyhow::Result { + Ok(0) + } + + async fn health_check(&self) -> bool { + true + } + } + + #[tokio::test] + async fn message_dispatch_processes_messages_in_parallel() { + let channel_impl = Arc::new(RecordingChannel::default()); + let channel: Arc = channel_impl.clone(); + + let mut channels_by_name = HashMap::new(); + channels_by_name.insert(channel.name().to_string(), channel); + + let runtime_ctx = Arc::new(ChannelRuntimeContext { + channels_by_name: Arc::new(channels_by_name), + provider: Arc::new(SlowProvider { + delay: Duration::from_millis(250), + }), + memory: Arc::new(NoopMemory), + tools_registry: Arc::new(vec![]), + observer: Arc::new(NoopObserver), + system_prompt: Arc::new("test-system-prompt".to_string()), + model: Arc::new("test-model".to_string()), + temperature: 0.0, + auto_save_memory: false, + }); + + let (tx, rx) = tokio::sync::mpsc::channel::(4); + tx.send(traits::ChannelMessage { + id: "1".to_string(), + sender: "alice".to_string(), + reply_target: "alice".to_string(), + content: "hello".to_string(), + channel: "test-channel".to_string(), + timestamp: 1, + }) + .await + .unwrap(); + tx.send(traits::ChannelMessage { + id: "2".to_string(), + sender: "bob".to_string(), + reply_target: "bob".to_string(), + content: "world".to_string(), + channel: "test-channel".to_string(), + timestamp: 2, + }) + .await + .unwrap(); + drop(tx); + + let started = Instant::now(); + run_message_dispatch_loop(rx, runtime_ctx, 2).await; + let elapsed = started.elapsed(); + + assert!( + elapsed < Duration::from_millis(430), + "expected parallel dispatch (<430ms), got {:?}", + elapsed + ); + + let sent_messages = channel_impl.sent_messages.lock().await; + assert_eq!(sent_messages.len(), 2); + } + #[test] fn prompt_contains_all_sections() { let ws = make_workspace(); let tools = vec![("shell", "Run commands"), ("file_read", "Read files")]; - let prompt = build_system_prompt(ws.path(), "test-model", &tools, &[]); + let prompt = build_system_prompt(ws.path(), "test-model", &tools, &[], None, None); // Section headers assert!(prompt.contains("## Tools"), "missing Tools section"); @@ -713,7 +1698,7 @@ mod tests { ("shell", "Run commands"), ("memory_recall", "Search memory"), ]; - let prompt = build_system_prompt(ws.path(), "gpt-4o", &tools, &[]); + let prompt = build_system_prompt(ws.path(), "gpt-4o", &tools, &[], None, None); assert!(prompt.contains("**shell**")); assert!(prompt.contains("Run commands")); @@ -723,7 +1708,7 @@ mod tests { #[test] fn prompt_injects_safety() { let ws = make_workspace(); - let prompt = build_system_prompt(ws.path(), "model", &[], &[]); + let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None); assert!(prompt.contains("Do not exfiltrate private data")); assert!(prompt.contains("Do not run destructive commands")); @@ -733,7 +1718,7 @@ mod tests { #[test] fn prompt_injects_workspace_files() { let ws = make_workspace(); - let prompt = build_system_prompt(ws.path(), "model", &[], &[]); + let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None); assert!(prompt.contains("### SOUL.md"), "missing SOUL.md header"); assert!(prompt.contains("Be helpful"), "missing SOUL content"); @@ -754,7 +1739,7 @@ mod tests { fn prompt_missing_file_markers() { let tmp = TempDir::new().unwrap(); // Empty workspace — no files at all - let prompt = build_system_prompt(tmp.path(), "model", &[], &[]); + let prompt = build_system_prompt(tmp.path(), "model", &[], &[], None, None); assert!(prompt.contains("[File not found: SOUL.md]")); assert!(prompt.contains("[File not found: AGENTS.md]")); @@ -765,7 +1750,7 @@ mod tests { fn prompt_bootstrap_only_if_exists() { let ws = make_workspace(); // No BOOTSTRAP.md — should not appear - let prompt = build_system_prompt(ws.path(), "model", &[], &[]); + let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None); assert!( !prompt.contains("### BOOTSTRAP.md"), "BOOTSTRAP.md should not appear when missing" @@ -773,7 +1758,7 @@ mod tests { // Create BOOTSTRAP.md — should appear std::fs::write(ws.path().join("BOOTSTRAP.md"), "# Bootstrap\nFirst run.").unwrap(); - let prompt2 = build_system_prompt(ws.path(), "model", &[], &[]); + let prompt2 = build_system_prompt(ws.path(), "model", &[], &[], None, None); assert!( prompt2.contains("### BOOTSTRAP.md"), "BOOTSTRAP.md should appear when present" @@ -793,7 +1778,7 @@ mod tests { ) .unwrap(); - let prompt = build_system_prompt(ws.path(), "model", &[], &[]); + let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None); // Daily notes should NOT be in the system prompt (on-demand via tools) assert!( @@ -809,7 +1794,7 @@ mod tests { #[test] fn prompt_runtime_metadata() { let ws = make_workspace(); - let prompt = build_system_prompt(ws.path(), "claude-sonnet-4", &[], &[]); + let prompt = build_system_prompt(ws.path(), "claude-sonnet-4", &[], &[], None, None); assert!(prompt.contains("Model: claude-sonnet-4")); assert!(prompt.contains(&format!("OS: {}", std::env::consts::OS))); @@ -830,7 +1815,7 @@ mod tests { location: None, }]; - let prompt = build_system_prompt(ws.path(), "model", &[], &skills); + let prompt = build_system_prompt(ws.path(), "model", &[], &skills, None, None); assert!(prompt.contains(""), "missing skills XML"); assert!(prompt.contains("code-review")); @@ -851,7 +1836,7 @@ mod tests { let big_content = "x".repeat(BOOTSTRAP_MAX_CHARS + 1000); std::fs::write(ws.path().join("AGENTS.md"), &big_content).unwrap(); - let prompt = build_system_prompt(ws.path(), "model", &[], &[]); + let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None); assert!( prompt.contains("truncated at"), @@ -868,7 +1853,7 @@ mod tests { let ws = make_workspace(); std::fs::write(ws.path().join("TOOLS.md"), "").unwrap(); - let prompt = build_system_prompt(ws.path(), "model", &[], &[]); + let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None); // Empty file should not produce a header assert!( @@ -877,14 +1862,295 @@ mod tests { ); } + #[test] + fn channel_log_truncation_is_utf8_safe_for_multibyte_text() { + let msg = "Hello from ZeroClaw 🌍. Current status is healthy, and café-style UTF-8 text stays safe in logs."; + + // Reproduces the production crash path where channel logs truncate at 80 chars. + let result = std::panic::catch_unwind(|| crate::util::truncate_with_ellipsis(msg, 80)); + assert!( + result.is_ok(), + "truncate_with_ellipsis should never panic on UTF-8" + ); + + let truncated = result.unwrap(); + assert!(!truncated.is_empty()); + assert!(truncated.is_char_boundary(truncated.len())); + } + + #[test] + fn prompt_contains_channel_capabilities() { + let ws = make_workspace(); + let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None); + + assert!( + prompt.contains("## Channel Capabilities"), + "missing Channel Capabilities section" + ); + assert!( + prompt.contains("running as a Discord bot"), + "missing Discord context" + ); + assert!( + prompt.contains("NEVER repeat, describe, or echo credentials"), + "missing security instruction" + ); + } + #[test] fn prompt_workspace_path() { let ws = make_workspace(); - let prompt = build_system_prompt(ws.path(), "model", &[], &[]); + let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None); assert!(prompt.contains(&format!("Working directory: `{}`", ws.path().display()))); } + #[test] + fn conversation_memory_key_uses_message_id() { + let msg = traits::ChannelMessage { + id: "msg_abc123".into(), + sender: "U123".into(), + reply_target: "C456".into(), + content: "hello".into(), + channel: "slack".into(), + timestamp: 1, + }; + + assert_eq!(conversation_memory_key(&msg), "slack_U123_msg_abc123"); + } + + #[test] + fn conversation_memory_key_is_unique_per_message() { + let msg1 = traits::ChannelMessage { + id: "msg_1".into(), + sender: "U123".into(), + reply_target: "C456".into(), + content: "first".into(), + channel: "slack".into(), + timestamp: 1, + }; + let msg2 = traits::ChannelMessage { + id: "msg_2".into(), + sender: "U123".into(), + reply_target: "C456".into(), + content: "second".into(), + channel: "slack".into(), + timestamp: 2, + }; + + assert_ne!( + conversation_memory_key(&msg1), + conversation_memory_key(&msg2) + ); + } + + #[tokio::test] + async fn autosave_keys_preserve_multiple_conversation_facts() { + let tmp = TempDir::new().unwrap(); + let mem = SqliteMemory::new(tmp.path()).unwrap(); + + let msg1 = traits::ChannelMessage { + id: "msg_1".into(), + sender: "U123".into(), + reply_target: "C456".into(), + content: "I'm Paul".into(), + channel: "slack".into(), + timestamp: 1, + }; + let msg2 = traits::ChannelMessage { + id: "msg_2".into(), + sender: "U123".into(), + reply_target: "C456".into(), + content: "I'm 45".into(), + channel: "slack".into(), + timestamp: 2, + }; + + mem.store( + &conversation_memory_key(&msg1), + &msg1.content, + MemoryCategory::Conversation, + None, + ) + .await + .unwrap(); + mem.store( + &conversation_memory_key(&msg2), + &msg2.content, + MemoryCategory::Conversation, + None, + ) + .await + .unwrap(); + + assert_eq!(mem.count().await.unwrap(), 2); + + let recalled = mem.recall("45", 5, None).await.unwrap(); + assert!(recalled.iter().any(|entry| entry.content.contains("45"))); + } + + #[tokio::test] + async fn build_memory_context_includes_recalled_entries() { + let tmp = TempDir::new().unwrap(); + let mem = SqliteMemory::new(tmp.path()).unwrap(); + mem.store("age_fact", "Age is 45", MemoryCategory::Conversation, None) + .await + .unwrap(); + + let context = build_memory_context(&mem, "age").await; + assert!(context.contains("[Memory context]")); + assert!(context.contains("Age is 45")); + } + + // ── AIEOS Identity Tests (Issue #168) ───────────────────────── + + #[test] + fn aieos_identity_from_file() { + use crate::config::IdentityConfig; + use tempfile::TempDir; + + let tmp = TempDir::new().unwrap(); + let identity_path = tmp.path().join("aieos_identity.json"); + + // Write AIEOS identity file + let aieos_json = r#"{ + "identity": { + "names": {"first": "Nova", "nickname": "Nov"}, + "bio": "A helpful AI assistant.", + "origin": "Silicon Valley" + }, + "psychology": { + "mbti": "INTJ", + "moral_compass": ["Be helpful", "Do no harm"] + }, + "linguistics": { + "style": "concise", + "formality": "casual" + } + }"#; + std::fs::write(&identity_path, aieos_json).unwrap(); + + // Create identity config pointing to the file + let config = IdentityConfig { + format: "aieos".into(), + aieos_path: Some("aieos_identity.json".into()), + aieos_inline: None, + }; + + let prompt = build_system_prompt(tmp.path(), "model", &[], &[], Some(&config), None); + + // Should contain AIEOS sections + assert!(prompt.contains("## Identity")); + assert!(prompt.contains("**Name:** Nova")); + assert!(prompt.contains("**Nickname:** Nov")); + assert!(prompt.contains("**Bio:** A helpful AI assistant.")); + assert!(prompt.contains("**Origin:** Silicon Valley")); + + assert!(prompt.contains("## Personality")); + assert!(prompt.contains("**MBTI:** INTJ")); + assert!(prompt.contains("**Moral Compass:**")); + assert!(prompt.contains("- Be helpful")); + + assert!(prompt.contains("## Communication Style")); + assert!(prompt.contains("**Style:** concise")); + assert!(prompt.contains("**Formality Level:** casual")); + + // Should NOT contain OpenClaw bootstrap file headers + assert!(!prompt.contains("### SOUL.md")); + assert!(!prompt.contains("### IDENTITY.md")); + assert!(!prompt.contains("[File not found")); + } + + #[test] + fn aieos_identity_from_inline() { + use crate::config::IdentityConfig; + + let config = IdentityConfig { + format: "aieos".into(), + aieos_path: None, + aieos_inline: Some(r#"{"identity":{"names":{"first":"Claw"}}}"#.into()), + }; + + let prompt = build_system_prompt( + std::env::temp_dir().as_path(), + "model", + &[], + &[], + Some(&config), + None, + ); + + assert!(prompt.contains("**Name:** Claw")); + assert!(prompt.contains("## Identity")); + } + + #[test] + fn aieos_fallback_to_openclaw_on_parse_error() { + use crate::config::IdentityConfig; + + let config = IdentityConfig { + format: "aieos".into(), + aieos_path: Some("nonexistent.json".into()), + aieos_inline: None, + }; + + let ws = make_workspace(); + let prompt = build_system_prompt(ws.path(), "model", &[], &[], Some(&config), None); + + // Should fall back to OpenClaw format when AIEOS file is not found + // (Error is logged to stderr with filename, not included in prompt) + assert!(prompt.contains("### SOUL.md")); + } + + #[test] + fn aieos_empty_uses_openclaw() { + use crate::config::IdentityConfig; + + // Format is "aieos" but neither path nor inline is set + let config = IdentityConfig { + format: "aieos".into(), + aieos_path: None, + aieos_inline: None, + }; + + let ws = make_workspace(); + let prompt = build_system_prompt(ws.path(), "model", &[], &[], Some(&config), None); + + // Should use OpenClaw format (not configured for AIEOS) + assert!(prompt.contains("### SOUL.md")); + assert!(prompt.contains("Be helpful")); + } + + #[test] + fn openclaw_format_uses_bootstrap_files() { + use crate::config::IdentityConfig; + + let config = IdentityConfig { + format: "openclaw".into(), + aieos_path: Some("identity.json".into()), + aieos_inline: None, + }; + + let ws = make_workspace(); + let prompt = build_system_prompt(ws.path(), "model", &[], &[], Some(&config), None); + + // Should use OpenClaw format even if aieos_path is set + assert!(prompt.contains("### SOUL.md")); + assert!(prompt.contains("Be helpful")); + assert!(!prompt.contains("## Identity")); + } + + #[test] + fn none_identity_config_uses_openclaw() { + let ws = make_workspace(); + // Pass None for identity config + let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None); + + // Should use OpenClaw format + assert!(prompt.contains("### SOUL.md")); + assert!(prompt.contains("Be helpful")); + } + #[test] fn classify_health_ok_true() { let state = classify_health_result(&Ok(true)); @@ -919,7 +2185,7 @@ mod tests { self.name } - async fn send(&self, _message: &str, _recipient: &str) -> anyhow::Result<()> { + async fn send(&self, _message: &SendMessage) -> anyhow::Result<()> { Ok(()) } diff --git a/src/channels/qq.rs b/src/channels/qq.rs new file mode 100644 index 0000000..3391fd7 --- /dev/null +++ b/src/channels/qq.rs @@ -0,0 +1,478 @@ +use super::traits::{Channel, ChannelMessage, SendMessage}; +use async_trait::async_trait; +use futures_util::{SinkExt, StreamExt}; +use serde_json::json; +use std::collections::HashSet; +use std::sync::Arc; +use tokio::sync::RwLock; +use tokio_tungstenite::tungstenite::Message; +use uuid::Uuid; + +const QQ_API_BASE: &str = "https://api.sgroup.qq.com"; +const QQ_AUTH_URL: &str = "https://bots.qq.com/app/getAppAccessToken"; + +/// Deduplication set capacity — evict half of entries when full. +const DEDUP_CAPACITY: usize = 10_000; + +/// QQ Official Bot channel — uses Tencent's official QQ Bot API with +/// OAuth2 authentication and a Discord-like WebSocket gateway protocol. +pub struct QQChannel { + app_id: String, + app_secret: String, + allowed_users: Vec, + client: reqwest::Client, + /// Cached access token + expiry timestamp. + token_cache: Arc>>, + /// Message deduplication set. + dedup: Arc>>, +} + +impl QQChannel { + pub fn new(app_id: String, app_secret: String, allowed_users: Vec) -> Self { + Self { + app_id, + app_secret, + allowed_users, + client: reqwest::Client::new(), + token_cache: Arc::new(RwLock::new(None)), + dedup: Arc::new(RwLock::new(HashSet::new())), + } + } + + fn is_user_allowed(&self, user_id: &str) -> bool { + self.allowed_users.iter().any(|u| u == "*" || u == user_id) + } + + /// Fetch an access token from QQ's OAuth2 endpoint. + async fn fetch_access_token(&self) -> anyhow::Result<(String, u64)> { + let body = json!({ + "appId": self.app_id, + "clientSecret": self.app_secret, + }); + + let resp = self.client.post(QQ_AUTH_URL).json(&body).send().await?; + + if !resp.status().is_success() { + let status = resp.status(); + let err = resp.text().await.unwrap_or_default(); + anyhow::bail!("QQ token request failed ({status}): {err}"); + } + + let data: serde_json::Value = resp.json().await?; + let token = data + .get("access_token") + .and_then(|t| t.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing access_token in QQ response"))? + .to_string(); + + let expires_in = data + .get("expires_in") + .and_then(|e| e.as_str()) + .and_then(|e| e.parse::().ok()) + .unwrap_or(7200); + + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + // Expire 60 seconds early to avoid edge cases + let expiry = now + expires_in.saturating_sub(60); + + Ok((token, expiry)) + } + + /// Get a valid access token, refreshing if expired. + async fn get_token(&self) -> anyhow::Result { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + { + let cache = self.token_cache.read().await; + if let Some((ref token, expiry)) = *cache { + if now < expiry { + return Ok(token.clone()); + } + } + } + + let (token, expiry) = self.fetch_access_token().await?; + { + let mut cache = self.token_cache.write().await; + *cache = Some((token.clone(), expiry)); + } + Ok(token) + } + + /// Get the WebSocket gateway URL. + async fn get_gateway_url(&self, token: &str) -> anyhow::Result { + let resp = self + .client + .get(format!("{QQ_API_BASE}/gateway")) + .header("Authorization", format!("QQBot {token}")) + .send() + .await?; + + if !resp.status().is_success() { + let status = resp.status(); + let err = resp.text().await.unwrap_or_default(); + anyhow::bail!("QQ gateway request failed ({status}): {err}"); + } + + let data: serde_json::Value = resp.json().await?; + let url = data + .get("url") + .and_then(|u| u.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing gateway URL in QQ response"))? + .to_string(); + + Ok(url) + } + + /// Check and insert message ID for deduplication. + async fn is_duplicate(&self, msg_id: &str) -> bool { + if msg_id.is_empty() { + return false; + } + + let mut dedup = self.dedup.write().await; + + if dedup.contains(msg_id) { + return true; + } + + // Evict oldest half when at capacity + if dedup.len() >= DEDUP_CAPACITY { + let to_remove: Vec = dedup.iter().take(DEDUP_CAPACITY / 2).cloned().collect(); + for key in to_remove { + dedup.remove(&key); + } + } + + dedup.insert(msg_id.to_string()); + false + } +} + +#[async_trait] +impl Channel for QQChannel { + fn name(&self) -> &str { + "qq" + } + + async fn send(&self, message: &SendMessage) -> anyhow::Result<()> { + let token = self.get_token().await?; + + // Determine if this is a group or private message based on recipient format + // Format: "user:{openid}" or "group:{group_openid}" + let (url, body) = if let Some(group_id) = message.recipient.strip_prefix("group:") { + ( + format!("{QQ_API_BASE}/v2/groups/{group_id}/messages"), + json!({ + "content": &message.content, + "msg_type": 0, + }), + ) + } else { + let user_id = message + .recipient + .strip_prefix("user:") + .unwrap_or(&message.recipient); + ( + format!("{QQ_API_BASE}/v2/users/{user_id}/messages"), + json!({ + "content": &message.content, + "msg_type": 0, + }), + ) + }; + + let resp = self + .client + .post(&url) + .header("Authorization", format!("QQBot {token}")) + .json(&body) + .send() + .await?; + + if !resp.status().is_success() { + let status = resp.status(); + let err = resp.text().await.unwrap_or_default(); + anyhow::bail!("QQ send message failed ({status}): {err}"); + } + + Ok(()) + } + + #[allow(clippy::too_many_lines)] + async fn listen(&self, tx: tokio::sync::mpsc::Sender) -> anyhow::Result<()> { + tracing::info!("QQ: authenticating..."); + let token = self.get_token().await?; + + tracing::info!("QQ: fetching gateway URL..."); + let gw_url = self.get_gateway_url(&token).await?; + + tracing::info!("QQ: connecting to gateway WebSocket..."); + let (ws_stream, _) = tokio_tungstenite::connect_async(&gw_url).await?; + let (mut write, mut read) = ws_stream.split(); + + // Read Hello (opcode 10) + let hello = read + .next() + .await + .ok_or(anyhow::anyhow!("QQ: no hello frame"))??; + let hello_data: serde_json::Value = serde_json::from_str(&hello.to_string())?; + let heartbeat_interval = hello_data + .get("d") + .and_then(|d| d.get("heartbeat_interval")) + .and_then(serde_json::Value::as_u64) + .unwrap_or(41250); + + // Send Identify (opcode 2) + // Intents: PUBLIC_GUILD_MESSAGES (1<<30) | C2C_MESSAGE_CREATE & GROUP_AT_MESSAGE_CREATE (1<<25) + let intents: u64 = (1 << 25) | (1 << 30); + let identify = json!({ + "op": 2, + "d": { + "token": format!("QQBot {token}"), + "intents": intents, + "properties": { + "os": "linux", + "browser": "zeroclaw", + "device": "zeroclaw", + } + } + }); + write.send(Message::Text(identify.to_string())).await?; + + tracing::info!("QQ: connected and identified"); + + let mut sequence: i64 = -1; + + // Spawn heartbeat timer + let (hb_tx, mut hb_rx) = tokio::sync::mpsc::channel::<()>(1); + let hb_interval = heartbeat_interval; + tokio::spawn(async move { + let mut interval = tokio::time::interval(std::time::Duration::from_millis(hb_interval)); + loop { + interval.tick().await; + if hb_tx.send(()).await.is_err() { + break; + } + } + }); + + loop { + tokio::select! { + _ = hb_rx.recv() => { + let d = if sequence >= 0 { json!(sequence) } else { json!(null) }; + let hb = json!({"op": 1, "d": d}); + if write.send(Message::Text(hb.to_string())).await.is_err() { + break; + } + } + msg = read.next() => { + let msg = match msg { + Some(Ok(Message::Text(t))) => t, + Some(Ok(Message::Close(_))) | None => break, + _ => continue, + }; + + let event: serde_json::Value = match serde_json::from_str(&msg) { + Ok(e) => e, + Err(_) => continue, + }; + + // Track sequence number + if let Some(s) = event.get("s").and_then(serde_json::Value::as_i64) { + sequence = s; + } + + let op = event.get("op").and_then(serde_json::Value::as_u64).unwrap_or(0); + + match op { + // Server requests immediate heartbeat + 1 => { + let d = if sequence >= 0 { json!(sequence) } else { json!(null) }; + let hb = json!({"op": 1, "d": d}); + if write.send(Message::Text(hb.to_string())).await.is_err() { + break; + } + continue; + } + // Reconnect + 7 => { + tracing::warn!("QQ: received Reconnect (op 7)"); + break; + } + // Invalid Session + 9 => { + tracing::warn!("QQ: received Invalid Session (op 9)"); + break; + } + _ => {} + } + + // Only process dispatch events (op 0) + if op != 0 { + continue; + } + + let event_type = event.get("t").and_then(|t| t.as_str()).unwrap_or(""); + let d = match event.get("d") { + Some(d) => d, + None => continue, + }; + + match event_type { + "C2C_MESSAGE_CREATE" => { + let msg_id = d.get("id").and_then(|i| i.as_str()).unwrap_or(""); + if self.is_duplicate(msg_id).await { + continue; + } + + let content = d.get("content").and_then(|c| c.as_str()).unwrap_or("").trim(); + if content.is_empty() { + continue; + } + + let author_id = d.get("author").and_then(|a| a.get("id")).and_then(|i| i.as_str()).unwrap_or("unknown"); + // For QQ, user_openid is the identifier + let user_openid = d.get("author").and_then(|a| a.get("user_openid")).and_then(|u| u.as_str()).unwrap_or(author_id); + + if !self.is_user_allowed(user_openid) { + tracing::warn!("QQ: ignoring C2C message from unauthorized user: {user_openid}"); + continue; + } + + let chat_id = format!("user:{user_openid}"); + + let channel_msg = ChannelMessage { + id: Uuid::new_v4().to_string(), + sender: user_openid.to_string(), + reply_target: chat_id, + content: content.to_string(), + channel: "qq".to_string(), + timestamp: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(), + }; + + if tx.send(channel_msg).await.is_err() { + tracing::warn!("QQ: message channel closed"); + break; + } + } + "GROUP_AT_MESSAGE_CREATE" => { + let msg_id = d.get("id").and_then(|i| i.as_str()).unwrap_or(""); + if self.is_duplicate(msg_id).await { + continue; + } + + let content = d.get("content").and_then(|c| c.as_str()).unwrap_or("").trim(); + if content.is_empty() { + continue; + } + + let author_id = d.get("author").and_then(|a| a.get("member_openid")).and_then(|m| m.as_str()).unwrap_or("unknown"); + + if !self.is_user_allowed(author_id) { + tracing::warn!("QQ: ignoring group message from unauthorized user: {author_id}"); + continue; + } + + let group_openid = d.get("group_openid").and_then(|g| g.as_str()).unwrap_or("unknown"); + let chat_id = format!("group:{group_openid}"); + + let channel_msg = ChannelMessage { + id: Uuid::new_v4().to_string(), + sender: author_id.to_string(), + reply_target: chat_id, + content: content.to_string(), + channel: "qq".to_string(), + timestamp: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(), + }; + + if tx.send(channel_msg).await.is_err() { + tracing::warn!("QQ: message channel closed"); + break; + } + } + _ => {} + } + } + } + } + + anyhow::bail!("QQ WebSocket connection closed") + } + + async fn health_check(&self) -> bool { + self.fetch_access_token().await.is_ok() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_name() { + let ch = QQChannel::new("id".into(), "secret".into(), vec![]); + assert_eq!(ch.name(), "qq"); + } + + #[test] + fn test_user_allowed_wildcard() { + let ch = QQChannel::new("id".into(), "secret".into(), vec!["*".into()]); + assert!(ch.is_user_allowed("anyone")); + } + + #[test] + fn test_user_allowed_specific() { + let ch = QQChannel::new("id".into(), "secret".into(), vec!["user123".into()]); + assert!(ch.is_user_allowed("user123")); + assert!(!ch.is_user_allowed("other")); + } + + #[test] + fn test_user_denied_empty() { + let ch = QQChannel::new("id".into(), "secret".into(), vec![]); + assert!(!ch.is_user_allowed("anyone")); + } + + #[tokio::test] + async fn test_dedup() { + let ch = QQChannel::new("id".into(), "secret".into(), vec![]); + assert!(!ch.is_duplicate("msg1").await); + assert!(ch.is_duplicate("msg1").await); + assert!(!ch.is_duplicate("msg2").await); + } + + #[tokio::test] + async fn test_dedup_empty_id() { + let ch = QQChannel::new("id".into(), "secret".into(), vec![]); + // Empty IDs should never be considered duplicates + assert!(!ch.is_duplicate("").await); + assert!(!ch.is_duplicate("").await); + } + + #[test] + fn test_config_serde() { + let toml_str = r#" +app_id = "12345" +app_secret = "secret_abc" +allowed_users = ["user1"] +"#; + let config: crate::config::schema::QQConfig = toml::from_str(toml_str).unwrap(); + assert_eq!(config.app_id, "12345"); + assert_eq!(config.app_secret, "secret_abc"); + assert_eq!(config.allowed_users, vec!["user1"]); + } +} diff --git a/src/channels/signal.rs b/src/channels/signal.rs new file mode 100644 index 0000000..2cbbc84 --- /dev/null +++ b/src/channels/signal.rs @@ -0,0 +1,809 @@ +use crate::channels::traits::{Channel, ChannelMessage, SendMessage}; +use async_trait::async_trait; +use futures_util::StreamExt; +use reqwest::Client; +use serde::Deserialize; +use std::time::Duration; +use tokio::sync::mpsc; +use uuid::Uuid; + +const GROUP_TARGET_PREFIX: &str = "group:"; + +#[derive(Debug, Clone, PartialEq, Eq)] +enum RecipientTarget { + Direct(String), + Group(String), +} + +/// Signal channel using signal-cli daemon's native JSON-RPC + SSE API. +/// +/// Connects to a running `signal-cli daemon --http `. +/// Listens via SSE at `/api/v1/events` and sends via JSON-RPC at +/// `/api/v1/rpc`. +#[derive(Clone)] +pub struct SignalChannel { + http_url: String, + account: String, + group_id: Option, + allowed_from: Vec, + ignore_attachments: bool, + ignore_stories: bool, + client: Client, +} + +// ── signal-cli SSE event JSON shapes ──────────────────────────── + +#[derive(Debug, Deserialize)] +struct SseEnvelope { + #[serde(default)] + envelope: Option, +} + +#[derive(Debug, Deserialize)] +struct Envelope { + #[serde(default)] + source: Option, + #[serde(rename = "sourceNumber", default)] + source_number: Option, + #[serde(rename = "dataMessage", default)] + data_message: Option, + #[serde(rename = "storyMessage", default)] + story_message: Option, + #[serde(default)] + timestamp: Option, +} + +#[derive(Debug, Deserialize)] +struct DataMessage { + #[serde(default)] + message: Option, + #[serde(default)] + timestamp: Option, + #[serde(rename = "groupInfo", default)] + group_info: Option, + #[serde(default)] + attachments: Option>, +} + +#[derive(Debug, Deserialize)] +struct GroupInfo { + #[serde(rename = "groupId", default)] + group_id: Option, +} + +impl SignalChannel { + pub fn new( + http_url: String, + account: String, + group_id: Option, + allowed_from: Vec, + ignore_attachments: bool, + ignore_stories: bool, + ) -> Self { + let http_url = http_url.trim_end_matches('/').to_string(); + let client = Client::builder() + .connect_timeout(Duration::from_secs(10)) + .build() + .expect("Signal HTTP client should build"); + Self { + http_url, + account, + group_id, + allowed_from, + ignore_attachments, + ignore_stories, + client, + } + } + + /// Effective sender: prefer `sourceNumber` (E.164), fall back to `source`. + fn sender(envelope: &Envelope) -> Option { + envelope + .source_number + .as_deref() + .or(envelope.source.as_deref()) + .map(String::from) + } + + fn is_sender_allowed(&self, sender: &str) -> bool { + if self.allowed_from.iter().any(|u| u == "*") { + return true; + } + self.allowed_from.iter().any(|u| u == sender) + } + + fn is_e164(recipient: &str) -> bool { + let Some(number) = recipient.strip_prefix('+') else { + return false; + }; + (2..=15).contains(&number.len()) && number.chars().all(|c| c.is_ascii_digit()) + } + + fn parse_recipient_target(recipient: &str) -> RecipientTarget { + if let Some(group_id) = recipient.strip_prefix(GROUP_TARGET_PREFIX) { + return RecipientTarget::Group(group_id.to_string()); + } + + if Self::is_e164(recipient) { + RecipientTarget::Direct(recipient.to_string()) + } else { + RecipientTarget::Group(recipient.to_string()) + } + } + + /// Check whether the message targets the configured group. + /// If no `group_id` is configured (None), all DMs and groups are accepted. + /// Use "dm" to filter DMs only. + fn matches_group(&self, data_msg: &DataMessage) -> bool { + let Some(ref expected) = self.group_id else { + return true; + }; + match data_msg + .group_info + .as_ref() + .and_then(|g| g.group_id.as_deref()) + { + Some(gid) => gid == expected.as_str(), + None => expected.eq_ignore_ascii_case("dm"), + } + } + + /// Determine the send target: group id or the sender's number. + fn reply_target(&self, data_msg: &DataMessage, sender: &str) -> String { + if let Some(group_id) = data_msg + .group_info + .as_ref() + .and_then(|g| g.group_id.as_deref()) + { + format!("{GROUP_TARGET_PREFIX}{group_id}") + } else { + sender.to_string() + } + } + + /// Send a JSON-RPC request to signal-cli daemon. + async fn rpc_request( + &self, + method: &str, + params: serde_json::Value, + ) -> anyhow::Result> { + let url = format!("{}/api/v1/rpc", self.http_url); + let id = Uuid::new_v4().to_string(); + + let body = serde_json::json!({ + "jsonrpc": "2.0", + "method": method, + "params": params, + "id": id, + }); + + let resp = self + .client + .post(&url) + .timeout(Duration::from_secs(30)) + .header("Content-Type", "application/json") + .json(&body) + .send() + .await?; + + // 201 = success with no body (e.g. typing indicators) + if resp.status().as_u16() == 201 { + return Ok(None); + } + + let text = resp.text().await?; + if text.is_empty() { + return Ok(None); + } + + let parsed: serde_json::Value = serde_json::from_str(&text)?; + if let Some(err) = parsed.get("error") { + let code = err.get("code").and_then(|c| c.as_i64()).unwrap_or(-1); + let msg = err + .get("message") + .and_then(|m| m.as_str()) + .unwrap_or("unknown"); + anyhow::bail!("Signal RPC error {code}: {msg}"); + } + + Ok(parsed.get("result").cloned()) + } + + /// Process a single SSE envelope, returning a ChannelMessage if valid. + fn process_envelope(&self, envelope: &Envelope) -> Option { + // Skip story messages when configured + if self.ignore_stories && envelope.story_message.is_some() { + return None; + } + + let data_msg = envelope.data_message.as_ref()?; + + // Skip attachment-only messages when configured + if self.ignore_attachments { + let has_attachments = data_msg.attachments.as_ref().is_some_and(|a| !a.is_empty()); + if has_attachments && data_msg.message.is_none() { + return None; + } + } + + let text = data_msg.message.as_deref().filter(|t| !t.is_empty())?; + let sender = Self::sender(envelope)?; + + if !self.is_sender_allowed(&sender) { + return None; + } + + if !self.matches_group(data_msg) { + return None; + } + + let target = self.reply_target(data_msg, &sender); + + let timestamp = data_msg + .timestamp + .or(envelope.timestamp) + .unwrap_or_else(|| { + u64::try_from( + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis(), + ) + .unwrap_or(u64::MAX) + }); + + Some(ChannelMessage { + id: format!("sig_{timestamp}"), + sender: sender.clone(), + reply_target: target, + content: text.to_string(), + channel: "signal".to_string(), + timestamp: timestamp / 1000, // millis → secs + }) + } +} + +#[async_trait] +impl Channel for SignalChannel { + fn name(&self) -> &str { + "signal" + } + + async fn send(&self, message: &SendMessage) -> anyhow::Result<()> { + let params = match Self::parse_recipient_target(&message.recipient) { + RecipientTarget::Direct(number) => serde_json::json!({ + "recipient": [number], + "message": &message.content, + "account": &self.account, + }), + RecipientTarget::Group(group_id) => serde_json::json!({ + "groupId": group_id, + "message": &message.content, + "account": &self.account, + }), + }; + + self.rpc_request("send", params).await?; + Ok(()) + } + + async fn listen(&self, tx: mpsc::Sender) -> anyhow::Result<()> { + let mut url = reqwest::Url::parse(&format!("{}/api/v1/events", self.http_url))?; + url.query_pairs_mut().append_pair("account", &self.account); + + tracing::info!("Signal channel listening via SSE on {}...", self.http_url); + + let mut retry_delay_secs = 2u64; + let max_delay_secs = 60u64; + + loop { + let resp = self + .client + .get(url.clone()) + .header("Accept", "text/event-stream") + .send() + .await; + + let resp = match resp { + Ok(r) if r.status().is_success() => r, + Ok(r) => { + let status = r.status(); + let body = r.text().await.unwrap_or_default(); + tracing::warn!("Signal SSE returned {status}: {body}"); + tokio::time::sleep(tokio::time::Duration::from_secs(retry_delay_secs)).await; + retry_delay_secs = (retry_delay_secs * 2).min(max_delay_secs); + continue; + } + Err(e) => { + tracing::warn!("Signal SSE connect error: {e}, retrying..."); + tokio::time::sleep(tokio::time::Duration::from_secs(retry_delay_secs)).await; + retry_delay_secs = (retry_delay_secs * 2).min(max_delay_secs); + continue; + } + }; + + retry_delay_secs = 2; + + let mut bytes_stream = resp.bytes_stream(); + let mut buffer = String::new(); + let mut current_data = String::new(); + + while let Some(chunk) = bytes_stream.next().await { + let chunk = match chunk { + Ok(c) => c, + Err(e) => { + tracing::debug!("Signal SSE chunk error, reconnecting: {e}"); + break; + } + }; + + let text = match String::from_utf8(chunk.to_vec()) { + Ok(t) => t, + Err(e) => { + tracing::debug!("Signal SSE invalid UTF-8, skipping chunk: {}", e); + continue; + } + }; + + buffer.push_str(&text); + + while let Some(newline_pos) = buffer.find('\n') { + let line = buffer[..newline_pos].trim_end_matches('\r').to_string(); + buffer = buffer[newline_pos + 1..].to_string(); + + // Skip SSE comments (keepalive) + if line.starts_with(':') { + continue; + } + + if line.is_empty() { + // Empty line = event boundary, dispatch accumulated data + if !current_data.is_empty() { + match serde_json::from_str::(¤t_data) { + Ok(sse) => { + if let Some(ref envelope) = sse.envelope { + if let Some(msg) = self.process_envelope(envelope) { + if tx.send(msg).await.is_err() { + return Ok(()); + } + } + } + } + Err(e) => { + tracing::debug!("Signal SSE parse skip: {e}"); + } + } + current_data.clear(); + } + } else if let Some(data) = line.strip_prefix("data:") { + if !current_data.is_empty() { + current_data.push('\n'); + } + current_data.push_str(data.trim_start()); + } + // Ignore "event:", "id:", "retry:" lines + } + } + + if !current_data.is_empty() { + match serde_json::from_str::(¤t_data) { + Ok(sse) => { + if let Some(ref envelope) = sse.envelope { + if let Some(msg) = self.process_envelope(envelope) { + let _ = tx.send(msg).await; + } + } + } + Err(e) => { + tracing::debug!("Signal SSE trailing parse skip: {e}"); + } + } + } + + tracing::debug!("Signal SSE stream ended, reconnecting..."); + tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; + } + } + + async fn health_check(&self) -> bool { + let url = format!("{}/api/v1/check", self.http_url); + let Ok(resp) = self + .client + .get(&url) + .timeout(Duration::from_secs(10)) + .send() + .await + else { + return false; + }; + resp.status().is_success() + } + + async fn start_typing(&self, recipient: &str) -> anyhow::Result<()> { + let params = match Self::parse_recipient_target(recipient) { + RecipientTarget::Direct(number) => serde_json::json!({ + "recipient": [number], + "account": &self.account, + }), + RecipientTarget::Group(group_id) => serde_json::json!({ + "groupId": group_id, + "account": &self.account, + }), + }; + self.rpc_request("sendTyping", params).await?; + Ok(()) + } + + async fn stop_typing(&self, _recipient: &str) -> anyhow::Result<()> { + // signal-cli doesn't have a stop-typing RPC; typing indicators + // auto-expire after ~15s on the client side. + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_channel() -> SignalChannel { + SignalChannel::new( + "http://127.0.0.1:8686".to_string(), + "+1234567890".to_string(), + None, + vec!["+1111111111".to_string()], + false, + false, + ) + } + + fn make_channel_with_group(group_id: &str) -> SignalChannel { + SignalChannel::new( + "http://127.0.0.1:8686".to_string(), + "+1234567890".to_string(), + Some(group_id.to_string()), + vec!["*".to_string()], + true, + true, + ) + } + + fn make_envelope(source_number: Option<&str>, message: Option<&str>) -> Envelope { + Envelope { + source: source_number.map(String::from), + source_number: source_number.map(String::from), + data_message: message.map(|m| DataMessage { + message: Some(m.to_string()), + timestamp: Some(1_700_000_000_000), + group_info: None, + attachments: None, + }), + story_message: None, + timestamp: Some(1_700_000_000_000), + } + } + + #[test] + fn creates_with_correct_fields() { + let ch = make_channel(); + assert_eq!(ch.http_url, "http://127.0.0.1:8686"); + assert_eq!(ch.account, "+1234567890"); + assert!(ch.group_id.is_none()); + assert_eq!(ch.allowed_from.len(), 1); + assert!(!ch.ignore_attachments); + assert!(!ch.ignore_stories); + } + + #[test] + fn strips_trailing_slash() { + let ch = SignalChannel::new( + "http://127.0.0.1:8686/".to_string(), + "+1234567890".to_string(), + None, + vec![], + false, + false, + ); + assert_eq!(ch.http_url, "http://127.0.0.1:8686"); + } + + #[test] + fn wildcard_allows_anyone() { + let ch = make_channel_with_group("dm"); + assert!(ch.is_sender_allowed("+9999999999")); + } + + #[test] + fn specific_sender_allowed() { + let ch = make_channel(); + assert!(ch.is_sender_allowed("+1111111111")); + } + + #[test] + fn unknown_sender_denied() { + let ch = make_channel(); + assert!(!ch.is_sender_allowed("+9999999999")); + } + + #[test] + fn empty_allowlist_denies_all() { + let ch = SignalChannel::new( + "http://127.0.0.1:8686".to_string(), + "+1234567890".to_string(), + None, + vec![], + false, + false, + ); + assert!(!ch.is_sender_allowed("+1111111111")); + } + + #[test] + fn name_returns_signal() { + let ch = make_channel(); + assert_eq!(ch.name(), "signal"); + } + + #[test] + fn matches_group_no_group_id_accepts_all() { + let ch = make_channel(); + let dm = DataMessage { + message: Some("hi".to_string()), + timestamp: Some(1000), + group_info: None, + attachments: None, + }; + assert!(ch.matches_group(&dm)); + + let group = DataMessage { + message: Some("hi".to_string()), + timestamp: Some(1000), + group_info: Some(GroupInfo { + group_id: Some("group123".to_string()), + }), + attachments: None, + }; + assert!(ch.matches_group(&group)); + } + + #[test] + fn matches_group_filters_group() { + let ch = make_channel_with_group("group123"); + let matching = DataMessage { + message: Some("hi".to_string()), + timestamp: Some(1000), + group_info: Some(GroupInfo { + group_id: Some("group123".to_string()), + }), + attachments: None, + }; + assert!(ch.matches_group(&matching)); + + let non_matching = DataMessage { + message: Some("hi".to_string()), + timestamp: Some(1000), + group_info: Some(GroupInfo { + group_id: Some("other_group".to_string()), + }), + attachments: None, + }; + assert!(!ch.matches_group(&non_matching)); + } + + #[test] + fn matches_group_dm_keyword() { + let ch = make_channel_with_group("dm"); + let dm = DataMessage { + message: Some("hi".to_string()), + timestamp: Some(1000), + group_info: None, + attachments: None, + }; + assert!(ch.matches_group(&dm)); + + let group = DataMessage { + message: Some("hi".to_string()), + timestamp: Some(1000), + group_info: Some(GroupInfo { + group_id: Some("group123".to_string()), + }), + attachments: None, + }; + assert!(!ch.matches_group(&group)); + } + + #[test] + fn reply_target_dm() { + let ch = make_channel(); + let dm = DataMessage { + message: Some("hi".to_string()), + timestamp: Some(1000), + group_info: None, + attachments: None, + }; + assert_eq!(ch.reply_target(&dm, "+1111111111"), "+1111111111"); + } + + #[test] + fn reply_target_group() { + let ch = make_channel(); + let group = DataMessage { + message: Some("hi".to_string()), + timestamp: Some(1000), + group_info: Some(GroupInfo { + group_id: Some("group123".to_string()), + }), + attachments: None, + }; + assert_eq!(ch.reply_target(&group, "+1111111111"), "group:group123"); + } + + #[test] + fn parse_recipient_target_e164_is_direct() { + assert_eq!( + SignalChannel::parse_recipient_target("+1234567890"), + RecipientTarget::Direct("+1234567890".to_string()) + ); + } + + #[test] + fn parse_recipient_target_prefixed_group_is_group() { + assert_eq!( + SignalChannel::parse_recipient_target("group:abc123"), + RecipientTarget::Group("abc123".to_string()) + ); + } + + #[test] + fn parse_recipient_target_non_e164_plus_is_group() { + assert_eq!( + SignalChannel::parse_recipient_target("+abc123"), + RecipientTarget::Group("+abc123".to_string()) + ); + } + + #[test] + fn sender_prefers_source_number() { + let env = Envelope { + source: Some("uuid-123".to_string()), + source_number: Some("+1111111111".to_string()), + data_message: None, + story_message: None, + timestamp: Some(1000), + }; + assert_eq!(SignalChannel::sender(&env), Some("+1111111111".to_string())); + } + + #[test] + fn sender_falls_back_to_source() { + let env = Envelope { + source: Some("uuid-123".to_string()), + source_number: None, + data_message: None, + story_message: None, + timestamp: Some(1000), + }; + assert_eq!(SignalChannel::sender(&env), Some("uuid-123".to_string())); + } + + #[test] + fn sender_none_when_both_missing() { + let env = Envelope { + source: None, + source_number: None, + data_message: None, + story_message: None, + timestamp: None, + }; + assert_eq!(SignalChannel::sender(&env), None); + } + + #[test] + fn process_envelope_valid_dm() { + let ch = make_channel(); + let env = make_envelope(Some("+1111111111"), Some("Hello!")); + let msg = ch.process_envelope(&env).unwrap(); + assert_eq!(msg.content, "Hello!"); + assert_eq!(msg.sender, "+1111111111"); + assert_eq!(msg.channel, "signal"); + } + + #[test] + fn process_envelope_denied_sender() { + let ch = make_channel(); + let env = make_envelope(Some("+9999999999"), Some("Hello!")); + assert!(ch.process_envelope(&env).is_none()); + } + + #[test] + fn process_envelope_empty_message() { + let ch = make_channel(); + let env = make_envelope(Some("+1111111111"), Some("")); + assert!(ch.process_envelope(&env).is_none()); + } + + #[test] + fn process_envelope_no_data_message() { + let ch = make_channel(); + let env = make_envelope(Some("+1111111111"), None); + assert!(ch.process_envelope(&env).is_none()); + } + + #[test] + fn process_envelope_skips_stories() { + let ch = make_channel_with_group("dm"); + let mut env = make_envelope(Some("+1111111111"), Some("story text")); + env.story_message = Some(serde_json::json!({})); + assert!(ch.process_envelope(&env).is_none()); + } + + #[test] + fn process_envelope_skips_attachment_only() { + let ch = make_channel_with_group("dm"); + let env = Envelope { + source: Some("+1111111111".to_string()), + source_number: Some("+1111111111".to_string()), + data_message: Some(DataMessage { + message: None, + timestamp: Some(1_700_000_000_000), + group_info: None, + attachments: Some(vec![serde_json::json!({"contentType": "image/png"})]), + }), + story_message: None, + timestamp: Some(1_700_000_000_000), + }; + assert!(ch.process_envelope(&env).is_none()); + } + + #[test] + fn sse_envelope_deserializes() { + let json = r#"{ + "envelope": { + "source": "+1111111111", + "sourceNumber": "+1111111111", + "timestamp": 1700000000000, + "dataMessage": { + "message": "Hello Signal!", + "timestamp": 1700000000000 + } + } + }"#; + let sse: SseEnvelope = serde_json::from_str(json).unwrap(); + let env = sse.envelope.unwrap(); + assert_eq!(env.source_number.as_deref(), Some("+1111111111")); + let dm = env.data_message.unwrap(); + assert_eq!(dm.message.as_deref(), Some("Hello Signal!")); + } + + #[test] + fn sse_envelope_deserializes_group() { + let json = r#"{ + "envelope": { + "sourceNumber": "+2222222222", + "dataMessage": { + "message": "Group msg", + "groupInfo": { + "groupId": "abc123" + } + } + } + }"#; + let sse: SseEnvelope = serde_json::from_str(json).unwrap(); + let env = sse.envelope.unwrap(); + let dm = env.data_message.unwrap(); + assert_eq!( + dm.group_info.as_ref().unwrap().group_id.as_deref(), + Some("abc123") + ); + } + + #[test] + fn envelope_defaults() { + let json = r#"{}"#; + let env: Envelope = serde_json::from_str(json).unwrap(); + assert!(env.source.is_none()); + assert!(env.source_number.is_none()); + assert!(env.data_message.is_none()); + assert!(env.story_message.is_none()); + assert!(env.timestamp.is_none()); + } +} diff --git a/src/channels/slack.rs b/src/channels/slack.rs index d8b35cb..9faad48 100644 --- a/src/channels/slack.rs +++ b/src/channels/slack.rs @@ -1,6 +1,5 @@ -use super::traits::{Channel, ChannelMessage}; +use super::traits::{Channel, ChannelMessage, SendMessage}; use async_trait::async_trait; -use uuid::Uuid; /// Slack channel — polls conversations.history via Web API pub struct SlackChannel { @@ -52,19 +51,40 @@ impl Channel for SlackChannel { "slack" } - async fn send(&self, message: &str, channel: &str) -> anyhow::Result<()> { + async fn send(&self, message: &SendMessage) -> anyhow::Result<()> { let body = serde_json::json!({ - "channel": channel, - "text": message + "channel": message.recipient, + "text": message.content }); - self.client + let resp = self + .client .post("https://slack.com/api/chat.postMessage") .bearer_auth(&self.bot_token) .json(&body) .send() .await?; + let status = resp.status(); + let body = resp + .text() + .await + .unwrap_or_else(|e| format!("")); + + if !status.is_success() { + anyhow::bail!("Slack chat.postMessage failed ({status}): {body}"); + } + + // Slack returns 200 for most app-level errors; check JSON "ok" field + let parsed: serde_json::Value = serde_json::from_str(&body).unwrap_or_default(); + if parsed.get("ok") == Some(&serde_json::Value::Bool(false)) { + let err = parsed + .get("error") + .and_then(|e| e.as_str()) + .unwrap_or("unknown"); + anyhow::bail!("Slack chat.postMessage failed: {err}"); + } + Ok(()) } @@ -139,8 +159,9 @@ impl Channel for SlackChannel { last_ts = ts.to_string(); let channel_msg = ChannelMessage { - id: Uuid::new_v4().to_string(), - sender: channel_id.clone(), + id: format!("slack_{channel_id}_{ts}"), + sender: user.to_string(), + reply_target: channel_id.clone(), content: text.to_string(), channel: "slack".to_string(), timestamp: std::time::SystemTime::now() @@ -231,4 +252,53 @@ mod tests { assert!(ch.is_user_allowed("U111")); assert!(ch.is_user_allowed("anyone")); } + + // ── Message ID edge cases ───────────────────────────────────── + + #[test] + fn slack_message_id_format_includes_channel_and_ts() { + // Verify that message IDs follow the format: slack_{channel_id}_{ts} + let ts = "1234567890.123456"; + let channel_id = "C12345"; + let expected_id = format!("slack_{channel_id}_{ts}"); + assert_eq!(expected_id, "slack_C12345_1234567890.123456"); + } + + #[test] + fn slack_message_id_is_deterministic() { + // Same channel_id + same ts = same ID (prevents duplicates after restart) + let ts = "1234567890.123456"; + let channel_id = "C12345"; + let id1 = format!("slack_{channel_id}_{ts}"); + let id2 = format!("slack_{channel_id}_{ts}"); + assert_eq!(id1, id2); + } + + #[test] + fn slack_message_id_different_ts_different_id() { + // Different timestamps produce different IDs + let channel_id = "C12345"; + let id1 = format!("slack_{channel_id}_1234567890.123456"); + let id2 = format!("slack_{channel_id}_1234567890.123457"); + assert_ne!(id1, id2); + } + + #[test] + fn slack_message_id_different_channel_different_id() { + // Different channels produce different IDs even with same ts + let ts = "1234567890.123456"; + let id1 = format!("slack_C12345_{ts}"); + let id2 = format!("slack_C67890_{ts}"); + assert_ne!(id1, id2); + } + + #[test] + fn slack_message_id_no_uuid_randomness() { + // Verify format doesn't contain random UUID components + let ts = "1234567890.123456"; + let channel_id = "C12345"; + let id = format!("slack_{channel_id}_{ts}"); + assert!(!id.contains('-')); // No UUID dashes + assert!(id.starts_with("slack_")); + } } diff --git a/src/channels/telegram.rs b/src/channels/telegram.rs index 1f9b202..a5c8dc5 100644 --- a/src/channels/telegram.rs +++ b/src/channels/telegram.rs @@ -1,31 +1,360 @@ -use super::traits::{Channel, ChannelMessage}; +use super::traits::{Channel, ChannelMessage, SendMessage}; +use crate::config::Config; +use crate::security::pairing::PairingGuard; +use anyhow::Context; use async_trait::async_trait; +use directories::UserDirs; use reqwest::multipart::{Form, Part}; +use std::fs; use std::path::Path; -use uuid::Uuid; +use std::sync::{Arc, RwLock}; +use std::time::Duration; + +/// Telegram's maximum message length for text messages +const TELEGRAM_MAX_MESSAGE_LENGTH: usize = 4096; +const TELEGRAM_BIND_COMMAND: &str = "/bind"; + +/// Split a message into chunks that respect Telegram's 4096 character limit. +/// Tries to split at word boundaries when possible, and handles continuation. +fn split_message_for_telegram(message: &str) -> Vec { + if message.len() <= TELEGRAM_MAX_MESSAGE_LENGTH { + return vec![message.to_string()]; + } + + let mut chunks = Vec::new(); + let mut remaining = message; + + while !remaining.is_empty() { + let chunk_end = if remaining.len() <= TELEGRAM_MAX_MESSAGE_LENGTH { + remaining.len() + } else { + // Try to find a good break point (newline, then space) + let search_area = &remaining[..TELEGRAM_MAX_MESSAGE_LENGTH]; + + // Prefer splitting at newline + if let Some(pos) = search_area.rfind('\n') { + // Don't split if the newline is too close to the start + if pos >= TELEGRAM_MAX_MESSAGE_LENGTH / 2 { + pos + 1 + } else { + // Try space as fallback + search_area + .rfind(' ') + .unwrap_or(TELEGRAM_MAX_MESSAGE_LENGTH) + + 1 + } + } else if let Some(pos) = search_area.rfind(' ') { + pos + 1 + } else { + // Hard split at the limit + TELEGRAM_MAX_MESSAGE_LENGTH + } + }; + + chunks.push(remaining[..chunk_end].to_string()); + remaining = &remaining[chunk_end..]; + } + + chunks +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum TelegramAttachmentKind { + Image, + Document, + Video, + Audio, + Voice, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct TelegramAttachment { + kind: TelegramAttachmentKind, + target: String, +} + +impl TelegramAttachmentKind { + fn from_marker(marker: &str) -> Option { + match marker.trim().to_ascii_uppercase().as_str() { + "IMAGE" | "PHOTO" => Some(Self::Image), + "DOCUMENT" | "FILE" => Some(Self::Document), + "VIDEO" => Some(Self::Video), + "AUDIO" => Some(Self::Audio), + "VOICE" => Some(Self::Voice), + _ => None, + } + } +} + +fn is_http_url(target: &str) -> bool { + target.starts_with("http://") || target.starts_with("https://") +} + +fn infer_attachment_kind_from_target(target: &str) -> Option { + let normalized = target + .split('?') + .next() + .unwrap_or(target) + .split('#') + .next() + .unwrap_or(target); + + let extension = Path::new(normalized) + .extension() + .and_then(|ext| ext.to_str())? + .to_ascii_lowercase(); + + match extension.as_str() { + "png" | "jpg" | "jpeg" | "gif" | "webp" | "bmp" => Some(TelegramAttachmentKind::Image), + "mp4" | "mov" | "mkv" | "avi" | "webm" => Some(TelegramAttachmentKind::Video), + "mp3" | "m4a" | "wav" | "flac" => Some(TelegramAttachmentKind::Audio), + "ogg" | "oga" | "opus" => Some(TelegramAttachmentKind::Voice), + "pdf" | "txt" | "md" | "csv" | "json" | "zip" | "tar" | "gz" | "doc" | "docx" | "xls" + | "xlsx" | "ppt" | "pptx" => Some(TelegramAttachmentKind::Document), + _ => None, + } +} + +fn parse_path_only_attachment(message: &str) -> Option { + let trimmed = message.trim(); + if trimmed.is_empty() || trimmed.contains('\n') { + return None; + } + + let candidate = trimmed.trim_matches(|c| matches!(c, '`' | '"' | '\'')); + if candidate.chars().any(char::is_whitespace) { + return None; + } + + let candidate = candidate.strip_prefix("file://").unwrap_or(candidate); + let kind = infer_attachment_kind_from_target(candidate)?; + + if !is_http_url(candidate) && !Path::new(candidate).exists() { + return None; + } + + Some(TelegramAttachment { + kind, + target: candidate.to_string(), + }) +} + +/// Strip tool_call XML-style tags from message text. +/// These tags are used internally but must not be sent to Telegram as raw markup, +/// since Telegram's Markdown parser will reject them (causing status 400 errors). +fn strip_tool_call_tags(message: &str) -> String { + let mut result = message.to_string(); + + // Strip ... + while let Some(start) = result.find("") { + if let Some(end) = result[start..].find("") { + let end = start + end + "".len(); + result = format!("{}{}", &result[..start], &result[end..]); + } else { + break; + } + } + + // Strip ... + while let Some(start) = result.find("") { + if let Some(end) = result[start..].find("") { + let end = start + end + "".len(); + result = format!("{}{}", &result[..start], &result[end..]); + } else { + break; + } + } + + // Strip ... + while let Some(start) = result.find("") { + if let Some(end) = result[start..].find("") { + let end = start + end + "".len(); + result = format!("{}{}", &result[..start], &result[end..]); + } else { + break; + } + } + + // Clean up any resulting blank lines (but preserve paragraphs) + while result.contains("\n\n\n") { + result = result.replace("\n\n\n", "\n\n"); + } + + result.trim().to_string() +} + +fn parse_attachment_markers(message: &str) -> (String, Vec) { + let mut cleaned = String::with_capacity(message.len()); + let mut attachments = Vec::new(); + let mut cursor = 0; + + while cursor < message.len() { + let Some(open_rel) = message[cursor..].find('[') else { + cleaned.push_str(&message[cursor..]); + break; + }; + + let open = cursor + open_rel; + cleaned.push_str(&message[cursor..open]); + + let Some(close_rel) = message[open..].find(']') else { + cleaned.push_str(&message[open..]); + break; + }; + + let close = open + close_rel; + let marker = &message[open + 1..close]; + + let parsed = marker.split_once(':').and_then(|(kind, target)| { + let kind = TelegramAttachmentKind::from_marker(kind)?; + let target = target.trim(); + if target.is_empty() { + return None; + } + Some(TelegramAttachment { + kind, + target: target.to_string(), + }) + }); + + if let Some(attachment) = parsed { + attachments.push(attachment); + } else { + cleaned.push_str(&message[open..=close]); + } + + cursor = close + 1; + } + + (cleaned.trim().to_string(), attachments) +} /// Telegram channel — long-polls the Bot API for updates pub struct TelegramChannel { bot_token: String, - allowed_users: Vec, + allowed_users: Arc>>, + pairing: Option, client: reqwest::Client, } impl TelegramChannel { pub fn new(bot_token: String, allowed_users: Vec) -> Self { + let normalized_allowed = Self::normalize_allowed_users(allowed_users); + let pairing = if normalized_allowed.is_empty() { + let guard = PairingGuard::new(true, &[]); + if let Some(code) = guard.pairing_code() { + println!(" 🔐 Telegram pairing required. One-time bind code: {code}"); + println!(" Send `{TELEGRAM_BIND_COMMAND} ` from your Telegram account."); + } + Some(guard) + } else { + None + }; + Self { bot_token, - allowed_users, + allowed_users: Arc::new(RwLock::new(normalized_allowed)), + pairing, client: reqwest::Client::new(), } } + fn normalize_identity(value: &str) -> String { + value.trim().trim_start_matches('@').to_string() + } + + fn normalize_allowed_users(allowed_users: Vec) -> Vec { + allowed_users + .into_iter() + .map(|entry| Self::normalize_identity(&entry)) + .filter(|entry| !entry.is_empty()) + .collect() + } + + fn load_config_without_env() -> anyhow::Result { + let home = UserDirs::new() + .map(|u| u.home_dir().to_path_buf()) + .context("Could not find home directory")?; + let zeroclaw_dir = home.join(".zeroclaw"); + let config_path = zeroclaw_dir.join("config.toml"); + + let contents = fs::read_to_string(&config_path) + .with_context(|| format!("Failed to read config file: {}", config_path.display()))?; + let mut config: Config = toml::from_str(&contents) + .context("Failed to parse config file for Telegram binding")?; + config.config_path = config_path; + config.workspace_dir = zeroclaw_dir.join("workspace"); + Ok(config) + } + + fn persist_allowed_identity_blocking(identity: &str) -> anyhow::Result<()> { + let mut config = Self::load_config_without_env()?; + let Some(telegram) = config.channels_config.telegram.as_mut() else { + anyhow::bail!("Telegram channel config is missing in config.toml"); + }; + + let normalized = Self::normalize_identity(identity); + if normalized.is_empty() { + anyhow::bail!("Cannot persist empty Telegram identity"); + } + + if !telegram.allowed_users.iter().any(|u| u == &normalized) { + telegram.allowed_users.push(normalized); + config + .save() + .context("Failed to persist Telegram allowlist to config.toml")?; + } + + Ok(()) + } + + async fn persist_allowed_identity(&self, identity: &str) -> anyhow::Result<()> { + let identity = identity.to_string(); + tokio::task::spawn_blocking(move || Self::persist_allowed_identity_blocking(&identity)) + .await + .map_err(|e| anyhow::anyhow!("Failed to join Telegram bind save task: {e}"))??; + Ok(()) + } + + fn add_allowed_identity_runtime(&self, identity: &str) { + let normalized = Self::normalize_identity(identity); + if normalized.is_empty() { + return; + } + if let Ok(mut users) = self.allowed_users.write() { + if !users.iter().any(|u| u == &normalized) { + users.push(normalized); + } + } + } + + fn extract_bind_code(text: &str) -> Option<&str> { + let mut parts = text.split_whitespace(); + let command = parts.next()?; + let base_command = command.split('@').next().unwrap_or(command); + if base_command != TELEGRAM_BIND_COMMAND { + return None; + } + parts.next().map(str::trim).filter(|code| !code.is_empty()) + } + + fn pairing_code_active(&self) -> bool { + self.pairing + .as_ref() + .and_then(PairingGuard::pairing_code) + .is_some() + } + fn api_url(&self, method: &str) -> String { format!("https://api.telegram.org/bot{}/{method}", self.bot_token) } fn is_user_allowed(&self, username: &str) -> bool { - self.allowed_users.iter().any(|u| u == "*" || u == username) + let identity = Self::normalize_identity(username); + self.allowed_users + .read() + .map(|users| users.iter().any(|u| u == "*" || u == &identity)) + .unwrap_or(false) } fn is_any_user_allowed<'a, I>(&self, identities: I) -> bool @@ -35,6 +364,365 @@ impl TelegramChannel { identities.into_iter().any(|id| self.is_user_allowed(id)) } + async fn handle_unauthorized_message(&self, update: &serde_json::Value) { + let Some(message) = update.get("message") else { + return; + }; + + let Some(text) = message.get("text").and_then(serde_json::Value::as_str) else { + return; + }; + + let username_opt = message + .get("from") + .and_then(|from| from.get("username")) + .and_then(serde_json::Value::as_str); + let username = username_opt.unwrap_or("unknown"); + let normalized_username = Self::normalize_identity(username); + + let user_id = message + .get("from") + .and_then(|from| from.get("id")) + .and_then(serde_json::Value::as_i64); + let user_id_str = user_id.map(|id| id.to_string()); + let normalized_user_id = user_id_str.as_deref().map(Self::normalize_identity); + + let chat_id = message + .get("chat") + .and_then(|chat| chat.get("id")) + .and_then(serde_json::Value::as_i64) + .map(|id| id.to_string()); + + let Some(chat_id) = chat_id else { + tracing::warn!("Telegram: missing chat_id in message, skipping"); + return; + }; + + let mut identities = vec![normalized_username.as_str()]; + if let Some(ref id) = normalized_user_id { + identities.push(id.as_str()); + } + + if self.is_any_user_allowed(identities.iter().copied()) { + return; + } + + if let Some(code) = Self::extract_bind_code(text) { + if let Some(pairing) = self.pairing.as_ref() { + match pairing.try_pair(code) { + Ok(Some(_token)) => { + let bind_identity = normalized_user_id.clone().or_else(|| { + if normalized_username.is_empty() || normalized_username == "unknown" { + None + } else { + Some(normalized_username.clone()) + } + }); + + if let Some(identity) = bind_identity { + self.add_allowed_identity_runtime(&identity); + match self.persist_allowed_identity(&identity).await { + Ok(()) => { + let _ = self + .send(&SendMessage::new( + "✅ Telegram account bound successfully. You can talk to ZeroClaw now.", + &chat_id, + )) + .await; + tracing::info!( + "Telegram: paired and allowlisted identity={identity}" + ); + } + Err(e) => { + tracing::error!( + "Telegram: failed to persist allowlist after bind: {e}" + ); + let _ = self + .send(&SendMessage::new( + "⚠️ Bound for this runtime, but failed to persist config. Access may be lost after restart; check config file permissions.", + &chat_id, + )) + .await; + } + } + } else { + let _ = self + .send(&SendMessage::new( + "❌ Could not identify your Telegram account. Ensure your account has a username or stable user ID, then retry.", + &chat_id, + )) + .await; + } + } + Ok(None) => { + let _ = self + .send(&SendMessage::new( + "❌ Invalid binding code. Ask operator for the latest code and retry.", + &chat_id, + )) + .await; + } + Err(lockout_secs) => { + let _ = self + .send(&SendMessage::new( + format!("⏳ Too many invalid attempts. Retry in {lockout_secs}s."), + &chat_id, + )) + .await; + } + } + } else { + let _ = self + .send(&SendMessage::new( + "ℹ️ Telegram pairing is not active. Ask operator to update allowlist in config.toml.", + &chat_id, + )) + .await; + } + return; + } + + tracing::warn!( + "Telegram: ignoring message from unauthorized user: username={username}, user_id={}. \ +Allowlist Telegram username (without '@') or numeric user ID.", + user_id_str.as_deref().unwrap_or("unknown") + ); + + let suggested_identity = normalized_user_id + .clone() + .or_else(|| { + if normalized_username.is_empty() || normalized_username == "unknown" { + None + } else { + Some(normalized_username.clone()) + } + }) + .unwrap_or_else(|| "YOUR_TELEGRAM_ID".to_string()); + + let _ = self + .send(&SendMessage::new( + format!( + "🔐 This bot requires operator approval.\n\nCopy this command to operator terminal:\n`zeroclaw channel bind-telegram {suggested_identity}`\n\nAfter operator runs it, send your message again." + ), + &chat_id, + )) + .await; + + if self.pairing_code_active() { + let _ = self + .send(&SendMessage::new( + "ℹ️ If operator provides a one-time pairing code, you can also run `/bind `.", + &chat_id, + )) + .await; + } + } + + fn parse_update_message(&self, update: &serde_json::Value) -> Option { + let message = update.get("message")?; + + let text = message.get("text").and_then(serde_json::Value::as_str)?; + + let username = message + .get("from") + .and_then(|from| from.get("username")) + .and_then(serde_json::Value::as_str) + .unwrap_or("unknown") + .to_string(); + + let user_id = message + .get("from") + .and_then(|from| from.get("id")) + .and_then(serde_json::Value::as_i64) + .map(|id| id.to_string()); + + let sender_identity = if username == "unknown" { + user_id.clone().unwrap_or_else(|| "unknown".to_string()) + } else { + username.clone() + }; + + let mut identities = vec![username.as_str()]; + if let Some(id) = user_id.as_deref() { + identities.push(id); + } + + if !self.is_any_user_allowed(identities.iter().copied()) { + return None; + } + + let chat_id = message + .get("chat") + .and_then(|chat| chat.get("id")) + .and_then(serde_json::Value::as_i64) + .map(|id| id.to_string())?; + + let message_id = message + .get("message_id") + .and_then(serde_json::Value::as_i64) + .unwrap_or(0); + + Some(ChannelMessage { + id: format!("telegram_{chat_id}_{message_id}"), + sender: sender_identity, + reply_target: chat_id, + content: text.to_string(), + channel: "telegram".to_string(), + timestamp: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(), + }) + } + + async fn send_text_chunks(&self, message: &str, chat_id: &str) -> anyhow::Result<()> { + let chunks = split_message_for_telegram(message); + + for (index, chunk) in chunks.iter().enumerate() { + let text = if chunks.len() > 1 { + if index == 0 { + format!("{chunk}\n\n(continues...)") + } else if index == chunks.len() - 1 { + format!("(continued)\n\n{chunk}") + } else { + format!("(continued)\n\n{chunk}\n\n(continues...)") + } + } else { + chunk.to_string() + }; + + let markdown_body = serde_json::json!({ + "chat_id": chat_id, + "text": text, + "parse_mode": "Markdown" + }); + + let markdown_resp = self + .client + .post(self.api_url("sendMessage")) + .json(&markdown_body) + .send() + .await?; + + if markdown_resp.status().is_success() { + if index < chunks.len() - 1 { + tokio::time::sleep(Duration::from_millis(100)).await; + } + continue; + } + + let markdown_status = markdown_resp.status(); + let markdown_err = markdown_resp.text().await.unwrap_or_default(); + tracing::warn!( + status = ?markdown_status, + "Telegram sendMessage with Markdown failed; retrying without parse_mode" + ); + + let plain_body = serde_json::json!({ + "chat_id": chat_id, + "text": text, + }); + let plain_resp = self + .client + .post(self.api_url("sendMessage")) + .json(&plain_body) + .send() + .await?; + + if !plain_resp.status().is_success() { + let plain_status = plain_resp.status(); + let plain_err = plain_resp.text().await.unwrap_or_default(); + anyhow::bail!( + "Telegram sendMessage failed (markdown {}: {}; plain {}: {})", + markdown_status, + markdown_err, + plain_status, + plain_err + ); + } + + if index < chunks.len() - 1 { + tokio::time::sleep(Duration::from_millis(100)).await; + } + } + + Ok(()) + } + + async fn send_media_by_url( + &self, + method: &str, + media_field: &str, + chat_id: &str, + url: &str, + caption: Option<&str>, + ) -> anyhow::Result<()> { + let mut body = serde_json::json!({ + "chat_id": chat_id, + }); + body[media_field] = serde_json::Value::String(url.to_string()); + + if let Some(cap) = caption { + body["caption"] = serde_json::Value::String(cap.to_string()); + } + + let resp = self + .client + .post(self.api_url(method)) + .json(&body) + .send() + .await?; + + if !resp.status().is_success() { + let err = resp.text().await?; + anyhow::bail!("Telegram {method} by URL failed: {err}"); + } + + tracing::info!("Telegram {method} sent to {chat_id}: {url}"); + Ok(()) + } + + async fn send_attachment( + &self, + chat_id: &str, + attachment: &TelegramAttachment, + ) -> anyhow::Result<()> { + let target = attachment.target.trim(); + + if is_http_url(target) { + return match attachment.kind { + TelegramAttachmentKind::Image => { + self.send_photo_by_url(chat_id, target, None).await + } + TelegramAttachmentKind::Document => { + self.send_document_by_url(chat_id, target, None).await + } + TelegramAttachmentKind::Video => { + self.send_video_by_url(chat_id, target, None).await + } + TelegramAttachmentKind::Audio => { + self.send_audio_by_url(chat_id, target, None).await + } + TelegramAttachmentKind::Voice => { + self.send_voice_by_url(chat_id, target, None).await + } + }; + } + + let path = Path::new(target); + if !path.exists() { + anyhow::bail!("Telegram attachment path not found: {target}"); + } + + match attachment.kind { + TelegramAttachmentKind::Image => self.send_photo(chat_id, path, None).await, + TelegramAttachmentKind::Document => self.send_document(chat_id, path, None).await, + TelegramAttachmentKind::Video => self.send_video(chat_id, path, None).await, + TelegramAttachmentKind::Audio => self.send_audio(chat_id, path, None).await, + TelegramAttachmentKind::Voice => self.send_voice(chat_id, path, None).await, + } + } + /// Send a document/file to a Telegram chat pub async fn send_document( &self, @@ -361,6 +1049,39 @@ impl TelegramChannel { tracing::info!("Telegram photo (URL) sent to {chat_id}: {url}"); Ok(()) } + + /// Send a video by URL (Telegram will download it) + pub async fn send_video_by_url( + &self, + chat_id: &str, + url: &str, + caption: Option<&str>, + ) -> anyhow::Result<()> { + self.send_media_by_url("sendVideo", "video", chat_id, url, caption) + .await + } + + /// Send an audio file by URL (Telegram will download it) + pub async fn send_audio_by_url( + &self, + chat_id: &str, + url: &str, + caption: Option<&str>, + ) -> anyhow::Result<()> { + self.send_media_by_url("sendAudio", "audio", chat_id, url, caption) + .await + } + + /// Send a voice message by URL (Telegram will download it) + pub async fn send_voice_by_url( + &self, + chat_id: &str, + url: &str, + caption: Option<&str>, + ) -> anyhow::Result<()> { + self.send_media_by_url("sendVoice", "voice", chat_id, url, caption) + .await + } } #[async_trait] @@ -369,20 +1090,32 @@ impl Channel for TelegramChannel { "telegram" } - async fn send(&self, message: &str, chat_id: &str) -> anyhow::Result<()> { - let body = serde_json::json!({ - "chat_id": chat_id, - "text": message, - "parse_mode": "Markdown" - }); + async fn send(&self, message: &SendMessage) -> anyhow::Result<()> { + // Strip tool_call tags before processing to prevent Markdown parsing failures + let content = strip_tool_call_tags(&message.content); - self.client - .post(self.api_url("sendMessage")) - .json(&body) - .send() - .await?; + let (text_without_markers, attachments) = parse_attachment_markers(&content); - Ok(()) + if !attachments.is_empty() { + if !text_without_markers.is_empty() { + self.send_text_chunks(&text_without_markers, &message.recipient) + .await?; + } + + for attachment in &attachments { + self.send_attachment(&message.recipient, attachment).await?; + } + + return Ok(()); + } + + if let Some(attachment) = parse_path_only_attachment(&content) { + self.send_attachment(&message.recipient, &attachment) + .await?; + return Ok(()); + } + + self.send_text_chunks(&content, &message.recipient).await } async fn listen(&self, tx: tokio::sync::mpsc::Sender) -> anyhow::Result<()> { @@ -416,6 +1149,36 @@ impl Channel for TelegramChannel { } }; + let ok = data + .get("ok") + .and_then(serde_json::Value::as_bool) + .unwrap_or(true); + if !ok { + let error_code = data + .get("error_code") + .and_then(serde_json::Value::as_i64) + .unwrap_or_default(); + let description = data + .get("description") + .and_then(serde_json::Value::as_str) + .unwrap_or("unknown Telegram API error"); + + if error_code == 409 { + tracing::warn!( + "Telegram polling conflict (409): {description}. \ +Ensure only one `zeroclaw` process is using this bot token." + ); + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + } else { + tracing::warn!( + "Telegram getUpdates API error (code={}): {description}", + error_code + ); + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + } + continue; + } + if let Some(results) = data.get("result").and_then(serde_json::Value::as_array) { for update in results { // Advance offset past this update @@ -423,57 +1186,21 @@ impl Channel for TelegramChannel { offset = uid + 1; } - let Some(message) = update.get("message") else { + let Some(msg) = self.parse_update_message(update) else { + self.handle_unauthorized_message(update).await; continue; }; - - let Some(text) = message.get("text").and_then(serde_json::Value::as_str) else { - continue; - }; - - let username_opt = message - .get("from") - .and_then(|f| f.get("username")) - .and_then(|u| u.as_str()); - let username = username_opt.unwrap_or("unknown"); - - let user_id = message - .get("from") - .and_then(|f| f.get("id")) - .and_then(serde_json::Value::as_i64); - let user_id_str = user_id.map(|id| id.to_string()); - - let mut identities = vec![username]; - if let Some(ref id) = user_id_str { - identities.push(id.as_str()); - } - - if !self.is_any_user_allowed(identities.iter().copied()) { - tracing::warn!( - "Telegram: ignoring message from unauthorized user: username={username}, user_id={}. \ -Allowlist Telegram @username or numeric user ID, then run `zeroclaw onboard --channels-only`.", - user_id_str.as_deref().unwrap_or("unknown") - ); - continue; - } - - let chat_id = message - .get("chat") - .and_then(|c| c.get("id")) - .and_then(serde_json::Value::as_i64) - .map(|id| id.to_string()) - .unwrap_or_default(); - - let msg = ChannelMessage { - id: Uuid::new_v4().to_string(), - sender: chat_id, - content: text.to_string(), - channel: "telegram".to_string(), - timestamp: std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_secs(), - }; + // Send "typing" indicator immediately when we receive a message + let typing_body = serde_json::json!({ + "chat_id": &msg.reply_target, + "action": "typing" + }); + let _ = self + .client + .post(self.api_url("sendChatAction")) + .json(&typing_body) + .send() + .await; // Ignore errors for typing indicator if tx.send(msg).await.is_err() { return Ok(()); @@ -484,12 +1211,24 @@ Allowlist Telegram @username or numeric user ID, then run `zeroclaw onboard --ch } async fn health_check(&self) -> bool { - self.client - .get(self.api_url("getMe")) - .send() - .await - .map(|r| r.status().is_success()) - .unwrap_or(false) + let timeout_duration = Duration::from_secs(5); + + match tokio::time::timeout( + timeout_duration, + self.client.get(self.api_url("getMe")).send(), + ) + .await + { + Ok(Ok(resp)) => resp.status().is_success(), + Ok(Err(e)) => { + tracing::debug!("Telegram health check failed: {e}"); + false + } + Err(_) => { + tracing::debug!("Telegram health check timed out after 5s"); + false + } + } } } @@ -525,6 +1264,12 @@ mod tests { assert!(!ch.is_user_allowed("eve")); } + #[test] + fn telegram_user_allowed_with_at_prefix_in_config() { + let ch = TelegramChannel::new("t".into(), vec!["@alice".into()]); + assert!(ch.is_user_allowed("alice")); + } + #[test] fn telegram_user_denied_empty() { let ch = TelegramChannel::new("t".into(), vec![]); @@ -573,6 +1318,141 @@ mod tests { assert!(!ch.is_any_user_allowed(["unknown", "123456789"])); } + #[test] + fn telegram_pairing_enabled_with_empty_allowlist() { + let ch = TelegramChannel::new("t".into(), vec![]); + assert!(ch.pairing_code_active()); + } + + #[test] + fn telegram_pairing_disabled_with_nonempty_allowlist() { + let ch = TelegramChannel::new("t".into(), vec!["alice".into()]); + assert!(!ch.pairing_code_active()); + } + + #[test] + fn telegram_extract_bind_code_plain_command() { + assert_eq!( + TelegramChannel::extract_bind_code("/bind 123456"), + Some("123456") + ); + } + + #[test] + fn telegram_extract_bind_code_supports_bot_mention() { + assert_eq!( + TelegramChannel::extract_bind_code("/bind@zeroclaw_bot 654321"), + Some("654321") + ); + } + + #[test] + fn telegram_extract_bind_code_rejects_invalid_forms() { + assert_eq!(TelegramChannel::extract_bind_code("/bind"), None); + assert_eq!(TelegramChannel::extract_bind_code("/start"), None); + } + + #[test] + fn parse_attachment_markers_extracts_multiple_types() { + let message = "Here are files [IMAGE:/tmp/a.png] and [DOCUMENT:https://example.com/a.pdf]"; + let (cleaned, attachments) = parse_attachment_markers(message); + + assert_eq!(cleaned, "Here are files and"); + assert_eq!(attachments.len(), 2); + assert_eq!(attachments[0].kind, TelegramAttachmentKind::Image); + assert_eq!(attachments[0].target, "/tmp/a.png"); + assert_eq!(attachments[1].kind, TelegramAttachmentKind::Document); + assert_eq!(attachments[1].target, "https://example.com/a.pdf"); + } + + #[test] + fn parse_attachment_markers_keeps_invalid_markers_in_text() { + let message = "Report [UNKNOWN:/tmp/a.bin]"; + let (cleaned, attachments) = parse_attachment_markers(message); + + assert_eq!(cleaned, "Report [UNKNOWN:/tmp/a.bin]"); + assert!(attachments.is_empty()); + } + + #[test] + fn parse_path_only_attachment_detects_existing_file() { + let dir = tempfile::tempdir().unwrap(); + let image_path = dir.path().join("snap.png"); + std::fs::write(&image_path, b"fake-png").unwrap(); + + let parsed = parse_path_only_attachment(image_path.to_string_lossy().as_ref()) + .expect("expected attachment"); + + assert_eq!(parsed.kind, TelegramAttachmentKind::Image); + assert_eq!(parsed.target, image_path.to_string_lossy()); + } + + #[test] + fn parse_path_only_attachment_rejects_sentence_text() { + assert!(parse_path_only_attachment("Screenshot saved to /tmp/snap.png").is_none()); + } + + #[test] + fn infer_attachment_kind_from_target_detects_document_extension() { + assert_eq!( + infer_attachment_kind_from_target("https://example.com/files/specs.pdf?download=1"), + Some(TelegramAttachmentKind::Document) + ); + } + + #[test] + fn parse_update_message_uses_chat_id_as_reply_target() { + let ch = TelegramChannel::new("token".into(), vec!["*".into()]); + let update = serde_json::json!({ + "update_id": 1, + "message": { + "message_id": 33, + "text": "hello", + "from": { + "id": 555, + "username": "alice" + }, + "chat": { + "id": -100_200_300 + } + } + }); + + let msg = ch + .parse_update_message(&update) + .expect("message should parse"); + + assert_eq!(msg.sender, "alice"); + assert_eq!(msg.reply_target, "-100200300"); + assert_eq!(msg.content, "hello"); + assert_eq!(msg.id, "telegram_-100200300_33"); + } + + #[test] + fn parse_update_message_allows_numeric_id_without_username() { + let ch = TelegramChannel::new("token".into(), vec!["555".into()]); + let update = serde_json::json!({ + "update_id": 2, + "message": { + "message_id": 9, + "text": "ping", + "from": { + "id": 555 + }, + "chat": { + "id": 12345 + } + } + }); + + let msg = ch + .parse_update_message(&update) + .expect("numeric allowlist should pass"); + + assert_eq!(msg.sender, "555"); + assert_eq!(msg.reply_target, "12345"); + } + // ── File sending API URL tests ────────────────────────────────── #[test] @@ -737,6 +1617,82 @@ mod tests { assert!(result.is_err()); } + // ── Message splitting tests ───────────────────────────────────── + + #[test] + fn telegram_split_short_message() { + let msg = "Hello, world!"; + let chunks = split_message_for_telegram(msg); + assert_eq!(chunks.len(), 1); + assert_eq!(chunks[0], msg); + } + + #[test] + fn telegram_split_exact_limit() { + let msg = "a".repeat(TELEGRAM_MAX_MESSAGE_LENGTH); + let chunks = split_message_for_telegram(&msg); + assert_eq!(chunks.len(), 1); + assert_eq!(chunks[0].len(), TELEGRAM_MAX_MESSAGE_LENGTH); + } + + #[test] + fn telegram_split_over_limit() { + let msg = "a".repeat(TELEGRAM_MAX_MESSAGE_LENGTH + 100); + let chunks = split_message_for_telegram(&msg); + assert_eq!(chunks.len(), 2); + assert!(chunks[0].len() <= TELEGRAM_MAX_MESSAGE_LENGTH); + assert!(chunks[1].len() <= TELEGRAM_MAX_MESSAGE_LENGTH); + } + + #[test] + fn telegram_split_at_word_boundary() { + let msg = format!( + "{} more text here", + "word ".repeat(TELEGRAM_MAX_MESSAGE_LENGTH / 5) + ); + let chunks = split_message_for_telegram(&msg); + assert!(chunks.len() >= 2); + // First chunk should end with a complete word (space at the end) + for chunk in &chunks[..chunks.len() - 1] { + assert!(chunk.len() <= TELEGRAM_MAX_MESSAGE_LENGTH); + } + } + + #[test] + fn telegram_split_at_newline() { + let text_block = "Line of text\n".repeat(TELEGRAM_MAX_MESSAGE_LENGTH / 13 + 1); + let chunks = split_message_for_telegram(&text_block); + assert!(chunks.len() >= 2); + for chunk in chunks { + assert!(chunk.len() <= TELEGRAM_MAX_MESSAGE_LENGTH); + } + } + + #[test] + fn telegram_split_preserves_content() { + let msg = "test ".repeat(TELEGRAM_MAX_MESSAGE_LENGTH / 5 + 100); + let chunks = split_message_for_telegram(&msg); + let rejoined = chunks.join(""); + assert_eq!(rejoined, msg); + } + + #[test] + fn telegram_split_empty_message() { + let chunks = split_message_for_telegram(""); + assert_eq!(chunks.len(), 1); + assert_eq!(chunks[0], ""); + } + + #[test] + fn telegram_split_very_long_message() { + let msg = "x".repeat(TELEGRAM_MAX_MESSAGE_LENGTH * 3); + let chunks = split_message_for_telegram(&msg); + assert!(chunks.len() >= 3); + for chunk in chunks { + assert!(chunk.len() <= TELEGRAM_MAX_MESSAGE_LENGTH); + } + } + // ── Caption handling tests ────────────────────────────────────── #[tokio::test] @@ -818,4 +1774,135 @@ mod tests { // Should not panic assert!(result.is_err()); } + + // ── Message ID edge cases ───────────────────────────────────── + + #[test] + fn telegram_message_id_format_includes_chat_and_message_id() { + // Verify that message IDs follow the format: telegram_{chat_id}_{message_id} + let chat_id = "123456"; + let message_id = 789; + let expected_id = format!("telegram_{chat_id}_{message_id}"); + assert_eq!(expected_id, "telegram_123456_789"); + } + + #[test] + fn telegram_message_id_is_deterministic() { + // Same chat_id + same message_id = same ID (prevents duplicates after restart) + let chat_id = "123456"; + let message_id = 789; + let id1 = format!("telegram_{chat_id}_{message_id}"); + let id2 = format!("telegram_{chat_id}_{message_id}"); + assert_eq!(id1, id2); + } + + #[test] + fn telegram_message_id_different_message_different_id() { + // Different message IDs produce different IDs + let chat_id = "123456"; + let id1 = format!("telegram_{chat_id}_789"); + let id2 = format!("telegram_{chat_id}_790"); + assert_ne!(id1, id2); + } + + #[test] + fn telegram_message_id_different_chat_different_id() { + // Different chats produce different IDs even with same message_id + let message_id = 789; + let id1 = format!("telegram_123456_{message_id}"); + let id2 = format!("telegram_789012_{message_id}"); + assert_ne!(id1, id2); + } + + #[test] + fn telegram_message_id_no_uuid_randomness() { + // Verify format doesn't contain random UUID components + let chat_id = "123456"; + let message_id = 789; + let id = format!("telegram_{chat_id}_{message_id}"); + assert!(!id.contains('-')); // No UUID dashes + assert!(id.starts_with("telegram_")); + } + + #[test] + fn telegram_message_id_handles_zero_message_id() { + // Edge case: message_id can be 0 (fallback/missing case) + let chat_id = "123456"; + let message_id = 0; + let id = format!("telegram_{chat_id}_{message_id}"); + assert_eq!(id, "telegram_123456_0"); + } + + // ── Tool call tag stripping tests ─────────────────────────────────── + + #[test] + fn strip_tool_call_tags_removes_standard_tags() { + let input = + "Hello {\"name\":\"shell\",\"arguments\":{\"command\":\"ls\"}} world"; + let result = strip_tool_call_tags(input); + assert_eq!(result, "Hello world"); + } + + #[test] + fn strip_tool_call_tags_removes_alias_tags() { + let input = "Hello {\"name\":\"shell\",\"arguments\":{\"command\":\"ls\"}} world"; + let result = strip_tool_call_tags(input); + assert_eq!(result, "Hello world"); + } + + #[test] + fn strip_tool_call_tags_removes_dash_tags() { + let input = "Hello {\"name\":\"shell\",\"arguments\":{\"command\":\"ls\"}} world"; + let result = strip_tool_call_tags(input); + assert_eq!(result, "Hello world"); + } + + #[test] + fn strip_tool_call_tags_handles_multiple_tags() { + let input = "Start a middle b end"; + let result = strip_tool_call_tags(input); + assert_eq!(result, "Start middle end"); + } + + #[test] + fn strip_tool_call_tags_handles_mixed_tags() { + let input = "A a B b C c D"; + let result = strip_tool_call_tags(input); + assert_eq!(result, "A B C D"); + } + + #[test] + fn strip_tool_call_tags_preserves_normal_text() { + let input = "Hello world! This is a test."; + let result = strip_tool_call_tags(input); + assert_eq!(result, "Hello world! This is a test."); + } + + #[test] + fn strip_tool_call_tags_handles_unclosed_tags() { + let input = "Hello world"; + let result = strip_tool_call_tags(input); + assert_eq!(result, "Hello world"); + } + + #[test] + fn strip_tool_call_tags_cleans_extra_newlines() { + let input = "Hello\n\n\ntest\n\n\n\nworld"; + let result = strip_tool_call_tags(input); + assert_eq!(result, "Hello\n\nworld"); + } + + #[test] + fn strip_tool_call_tags_handles_empty_input() { + let input = ""; + let result = strip_tool_call_tags(input); + assert_eq!(result, ""); + } + + #[test] + fn strip_tool_call_tags_handles_only_tags() { + let input = "{\"name\":\"test\"}"; + let result = strip_tool_call_tags(input); + assert_eq!(result, ""); + } } diff --git a/src/channels/traits.rs b/src/channels/traits.rs index 4709a1b..1731ba8 100644 --- a/src/channels/traits.rs +++ b/src/channels/traits.rs @@ -5,11 +5,44 @@ use async_trait::async_trait; pub struct ChannelMessage { pub id: String, pub sender: String, + pub reply_target: String, pub content: String, pub channel: String, pub timestamp: u64, } +/// Message to send through a channel +#[derive(Debug, Clone)] +pub struct SendMessage { + pub content: String, + pub recipient: String, + pub subject: Option, +} + +impl SendMessage { + /// Create a new message with content and recipient + pub fn new(content: impl Into, recipient: impl Into) -> Self { + Self { + content: content.into(), + recipient: recipient.into(), + subject: None, + } + } + + /// Create a new message with content, recipient, and subject + pub fn with_subject( + content: impl Into, + recipient: impl Into, + subject: impl Into, + ) -> Self { + Self { + content: content.into(), + recipient: recipient.into(), + subject: Some(subject.into()), + } + } +} + /// Core channel trait — implement for any messaging platform #[async_trait] pub trait Channel: Send + Sync { @@ -17,7 +50,7 @@ pub trait Channel: Send + Sync { fn name(&self) -> &str; /// Send a message through this channel - async fn send(&self, message: &str, recipient: &str) -> anyhow::Result<()>; + async fn send(&self, message: &SendMessage) -> anyhow::Result<()>; /// Start listening for incoming messages (long-running) async fn listen(&self, tx: tokio::sync::mpsc::Sender) -> anyhow::Result<()>; @@ -26,4 +59,95 @@ pub trait Channel: Send + Sync { async fn health_check(&self) -> bool { true } + + /// Signal that the bot is processing a response (e.g. "typing" indicator). + /// Implementations should repeat the indicator as needed for their platform. + async fn start_typing(&self, _recipient: &str) -> anyhow::Result<()> { + Ok(()) + } + + /// Stop any active typing indicator. + async fn stop_typing(&self, _recipient: &str) -> anyhow::Result<()> { + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + struct DummyChannel; + + #[async_trait] + impl Channel for DummyChannel { + fn name(&self) -> &str { + "dummy" + } + + async fn send(&self, _message: &SendMessage) -> anyhow::Result<()> { + Ok(()) + } + + async fn listen( + &self, + tx: tokio::sync::mpsc::Sender, + ) -> anyhow::Result<()> { + tx.send(ChannelMessage { + id: "1".into(), + sender: "tester".into(), + reply_target: "tester".into(), + content: "hello".into(), + channel: "dummy".into(), + timestamp: 123, + }) + .await + .map_err(|e| anyhow::anyhow!(e.to_string())) + } + } + + #[test] + fn channel_message_clone_preserves_fields() { + let message = ChannelMessage { + id: "42".into(), + sender: "alice".into(), + reply_target: "alice".into(), + content: "ping".into(), + channel: "dummy".into(), + timestamp: 999, + }; + + let cloned = message.clone(); + assert_eq!(cloned.id, "42"); + assert_eq!(cloned.sender, "alice"); + assert_eq!(cloned.reply_target, "alice"); + assert_eq!(cloned.content, "ping"); + assert_eq!(cloned.channel, "dummy"); + assert_eq!(cloned.timestamp, 999); + } + + #[tokio::test] + async fn default_trait_methods_return_success() { + let channel = DummyChannel; + + assert!(channel.health_check().await); + assert!(channel.start_typing("bob").await.is_ok()); + assert!(channel.stop_typing("bob").await.is_ok()); + assert!(channel + .send(&SendMessage::new("hello", "bob")) + .await + .is_ok()); + } + + #[tokio::test] + async fn listen_sends_message_to_channel() { + let channel = DummyChannel; + let (tx, mut rx) = tokio::sync::mpsc::channel(1); + + channel.listen(tx).await.unwrap(); + + let received = rx.recv().await.expect("message should be sent"); + assert_eq!(received.sender, "tester"); + assert_eq!(received.content, "hello"); + assert_eq!(received.channel, "dummy"); + } } diff --git a/src/channels/whatsapp.rs b/src/channels/whatsapp.rs index 3e4c045..34b8dc5 100644 --- a/src/channels/whatsapp.rs +++ b/src/channels/whatsapp.rs @@ -1,4 +1,4 @@ -use super::traits::{Channel, ChannelMessage}; +use super::traits::{Channel, ChannelMessage, SendMessage}; use async_trait::async_trait; use uuid::Uuid; @@ -10,7 +10,7 @@ use uuid::Uuid; /// happens in the gateway when Meta sends webhook events. pub struct WhatsAppChannel { access_token: String, - phone_number_id: String, + endpoint_id: String, verify_token: String, allowed_numbers: Vec, client: reqwest::Client, @@ -19,13 +19,13 @@ pub struct WhatsAppChannel { impl WhatsAppChannel { pub fn new( access_token: String, - phone_number_id: String, + endpoint_id: String, verify_token: String, allowed_numbers: Vec, ) -> Self { Self { access_token, - phone_number_id, + endpoint_id, verify_token, allowed_numbers, client: reqwest::Client::new(), @@ -119,6 +119,7 @@ impl WhatsAppChannel { messages.push(ChannelMessage { id: Uuid::new_v4().to_string(), + reply_target: normalized_from.clone(), sender: normalized_from, content, channel: "whatsapp".to_string(), @@ -138,15 +139,18 @@ impl Channel for WhatsAppChannel { "whatsapp" } - async fn send(&self, message: &str, recipient: &str) -> anyhow::Result<()> { + async fn send(&self, message: &SendMessage) -> anyhow::Result<()> { // WhatsApp Cloud API: POST to /v18.0/{phone_number_id}/messages let url = format!( "https://graph.facebook.com/v18.0/{}/messages", - self.phone_number_id + self.endpoint_id ); // Normalize recipient (remove leading + if present for API) - let to = recipient.strip_prefix('+').unwrap_or(recipient); + let to = message + .recipient + .strip_prefix('+') + .unwrap_or(&message.recipient); let body = serde_json::json!({ "messaging_product": "whatsapp", @@ -155,14 +159,14 @@ impl Channel for WhatsAppChannel { "type": "text", "text": { "preview_url": false, - "body": message + "body": message.content } }); let resp = self .client .post(&url) - .header("Authorization", format!("Bearer {}", self.access_token)) + .bearer_auth(&self.access_token) .header("Content-Type", "application/json") .json(&body) .send() @@ -195,11 +199,11 @@ impl Channel for WhatsAppChannel { async fn health_check(&self) -> bool { // Check if we can reach the WhatsApp API - let url = format!("https://graph.facebook.com/v18.0/{}", self.phone_number_id); + let url = format!("https://graph.facebook.com/v18.0/{}", self.endpoint_id); self.client .get(&url) - .header("Authorization", format!("Bearer {}", self.access_token)) + .bearer_auth(&self.access_token) .send() .await .map(|r| r.status().is_success()) diff --git a/src/config/mod.rs b/src/config/mod.rs index f5849c1..8e37cce 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,8 +1,58 @@ pub mod schema; +#[allow(unused_imports)] pub use schema::{ - AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, DiscordConfig, - GatewayConfig, HeartbeatConfig, IMessageConfig, IdentityConfig, MatrixConfig, MemoryConfig, - ObservabilityConfig, ReliabilityConfig, RuntimeConfig, SecretsConfig, SlackConfig, - TelegramConfig, TunnelConfig, WebhookConfig, + AgentConfig, AuditConfig, AutonomyConfig, BrowserComputerUseConfig, BrowserConfig, + ChannelsConfig, ComposioConfig, Config, CostConfig, CronConfig, DelegateAgentConfig, + DiscordConfig, DockerRuntimeConfig, GatewayConfig, HardwareConfig, HardwareTransport, + HeartbeatConfig, HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, + MemoryConfig, ModelRouteConfig, ObservabilityConfig, PeripheralBoardConfig, PeripheralsConfig, + ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig, + SchedulerConfig, SecretsConfig, SecurityConfig, SlackConfig, TelegramConfig, TunnelConfig, + WebhookConfig, }; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn reexported_config_default_is_constructible() { + let config = Config::default(); + + assert!(config.default_provider.is_some()); + assert!(config.default_model.is_some()); + assert!(config.default_temperature > 0.0); + } + + #[test] + fn reexported_channel_configs_are_constructible() { + let telegram = TelegramConfig { + bot_token: "token".into(), + allowed_users: vec!["alice".into()], + }; + + let discord = DiscordConfig { + bot_token: "token".into(), + guild_id: Some("123".into()), + allowed_users: vec![], + listen_to_bots: false, + mention_only: false, + }; + + let lark = LarkConfig { + app_id: "app-id".into(), + app_secret: "app-secret".into(), + encrypt_key: None, + verification_token: None, + allowed_users: vec![], + use_feishu: false, + receive_mode: crate::config::schema::LarkReceiveMode::Websocket, + port: None, + }; + + assert_eq!(telegram.allowed_users.len(), 1); + assert_eq!(discord.guild_id.as_deref(), Some("123")); + assert_eq!(lark.app_id, "app-id"); + } +} diff --git a/src/config/schema.rs b/src/config/schema.rs index e437407..41e556d 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -1,17 +1,26 @@ +use crate::providers::{is_glm_alias, is_zai_alias}; use crate::security::AutonomyLevel; use anyhow::{Context, Result}; use directories::UserDirs; use serde::{Deserialize, Serialize}; -use std::fs; -use std::path::PathBuf; +use std::collections::HashMap; +use std::fs::{self, File, OpenOptions}; +use std::io::Write; +use std::path::{Path, PathBuf}; // ── Top-level config ────────────────────────────────────────────── #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Config { + /// Workspace directory - computed from home, not serialized + #[serde(skip)] pub workspace_dir: PathBuf, + /// Path to config.toml - computed from home, not serialized + #[serde(skip)] pub config_path: PathBuf, pub api_key: Option, + /// Base URL override for provider API (e.g. "http://10.0.0.1:11434" for remote Ollama) + pub api_url: Option, pub default_provider: Option, pub default_model: Option, pub default_temperature: f64, @@ -28,9 +37,22 @@ pub struct Config { #[serde(default)] pub reliability: ReliabilityConfig, + #[serde(default)] + pub scheduler: SchedulerConfig, + + #[serde(default)] + pub agent: AgentConfig, + + /// Model routing rules — route `hint:` to specific provider+model combos. + #[serde(default)] + pub model_routes: Vec, + #[serde(default)] pub heartbeat: HeartbeatConfig, + #[serde(default)] + pub cron: CronConfig, + #[serde(default)] pub channels_config: ChannelsConfig, @@ -52,8 +74,161 @@ pub struct Config { #[serde(default)] pub browser: BrowserConfig, + #[serde(default)] + pub http_request: HttpRequestConfig, + #[serde(default)] pub identity: IdentityConfig, + + #[serde(default)] + pub cost: CostConfig, + + #[serde(default)] + pub peripherals: PeripheralsConfig, + + /// Delegate agent configurations for multi-agent workflows. + #[serde(default)] + pub agents: HashMap, + + /// Hardware configuration (wizard-driven physical world setup). + #[serde(default)] + pub hardware: HardwareConfig, +} + +// ── Delegate Agents ────────────────────────────────────────────── + +/// Configuration for a delegate sub-agent used by the `delegate` tool. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DelegateAgentConfig { + /// Provider name (e.g. "ollama", "openrouter", "anthropic") + pub provider: String, + /// Model name + pub model: String, + /// Optional system prompt for the sub-agent + #[serde(default)] + pub system_prompt: Option, + /// Optional API key override + #[serde(default)] + pub api_key: Option, + /// Temperature override + #[serde(default)] + pub temperature: Option, + /// Max recursion depth for nested delegation + #[serde(default = "default_max_depth")] + pub max_depth: u32, +} + +fn default_max_depth() -> u32 { + 3 +} + +// ── Hardware Config (wizard-driven) ───────────────────────────── + +/// Hardware transport mode. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] +pub enum HardwareTransport { + #[default] + None, + Native, + Serial, + Probe, +} + +impl std::fmt::Display for HardwareTransport { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::None => write!(f, "none"), + Self::Native => write!(f, "native"), + Self::Serial => write!(f, "serial"), + Self::Probe => write!(f, "probe"), + } + } +} + +/// Wizard-driven hardware configuration for physical world interaction. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HardwareConfig { + /// Whether hardware access is enabled + #[serde(default)] + pub enabled: bool, + /// Transport mode + #[serde(default)] + pub transport: HardwareTransport, + /// Serial port path (e.g. "/dev/ttyACM0") + #[serde(default)] + pub serial_port: Option, + /// Serial baud rate + #[serde(default = "default_baud_rate")] + pub baud_rate: u32, + /// Probe target chip (e.g. "STM32F401RE") + #[serde(default)] + pub probe_target: Option, + /// Enable workspace datasheet RAG (index PDF schematics for AI pin lookups) + #[serde(default)] + pub workspace_datasheets: bool, +} + +fn default_baud_rate() -> u32 { + 115_200 +} + +impl HardwareConfig { + /// Return the active transport mode. + pub fn transport_mode(&self) -> HardwareTransport { + self.transport.clone() + } +} + +impl Default for HardwareConfig { + fn default() -> Self { + Self { + enabled: false, + transport: HardwareTransport::None, + serial_port: None, + baud_rate: default_baud_rate(), + probe_target: None, + workspace_datasheets: false, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AgentConfig { + /// When true: bootstrap_max_chars=6000, rag_chunk_limit=2. Use for 13B or smaller models. + #[serde(default)] + pub compact_context: bool, + #[serde(default = "default_agent_max_tool_iterations")] + pub max_tool_iterations: usize, + #[serde(default = "default_agent_max_history_messages")] + pub max_history_messages: usize, + #[serde(default)] + pub parallel_tools: bool, + #[serde(default = "default_agent_tool_dispatcher")] + pub tool_dispatcher: String, +} + +fn default_agent_max_tool_iterations() -> usize { + 10 +} + +fn default_agent_max_history_messages() -> usize { + 50 +} + +fn default_agent_tool_dispatcher() -> String { + "auto".into() +} + +impl Default for AgentConfig { + fn default() -> Self { + Self { + compact_context: false, + max_tool_iterations: default_agent_max_tool_iterations(), + max_history_messages: default_agent_max_history_messages(), + parallel_tools: false, + tool_dispatcher: default_agent_tool_dispatcher(), + } + } } // ── Identity (AIEOS / OpenClaw format) ────────────────────────── @@ -85,14 +260,205 @@ impl Default for IdentityConfig { } } +// ── Cost tracking and budget enforcement ─────────────────────────── + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CostConfig { + /// Enable cost tracking (default: false) + #[serde(default)] + pub enabled: bool, + + /// Daily spending limit in USD (default: 10.00) + #[serde(default = "default_daily_limit")] + pub daily_limit_usd: f64, + + /// Monthly spending limit in USD (default: 100.00) + #[serde(default = "default_monthly_limit")] + pub monthly_limit_usd: f64, + + /// Warn when spending reaches this percentage of limit (default: 80) + #[serde(default = "default_warn_percent")] + pub warn_at_percent: u8, + + /// Allow requests to exceed budget with --override flag (default: false) + #[serde(default)] + pub allow_override: bool, + + /// Per-model pricing (USD per 1M tokens) + #[serde(default)] + pub prices: std::collections::HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelPricing { + /// Input price per 1M tokens + #[serde(default)] + pub input: f64, + + /// Output price per 1M tokens + #[serde(default)] + pub output: f64, +} + +fn default_daily_limit() -> f64 { + 10.0 +} + +fn default_monthly_limit() -> f64 { + 100.0 +} + +fn default_warn_percent() -> u8 { + 80 +} + +impl Default for CostConfig { + fn default() -> Self { + Self { + enabled: false, + daily_limit_usd: default_daily_limit(), + monthly_limit_usd: default_monthly_limit(), + warn_at_percent: default_warn_percent(), + allow_override: false, + prices: get_default_pricing(), + } + } +} + +/// Default pricing for popular models (USD per 1M tokens) +fn get_default_pricing() -> std::collections::HashMap { + let mut prices = std::collections::HashMap::new(); + + // Anthropic models + prices.insert( + "anthropic/claude-sonnet-4-20250514".into(), + ModelPricing { + input: 3.0, + output: 15.0, + }, + ); + prices.insert( + "anthropic/claude-opus-4-20250514".into(), + ModelPricing { + input: 15.0, + output: 75.0, + }, + ); + prices.insert( + "anthropic/claude-3.5-sonnet".into(), + ModelPricing { + input: 3.0, + output: 15.0, + }, + ); + prices.insert( + "anthropic/claude-3-haiku".into(), + ModelPricing { + input: 0.25, + output: 1.25, + }, + ); + + // OpenAI models + prices.insert( + "openai/gpt-4o".into(), + ModelPricing { + input: 5.0, + output: 15.0, + }, + ); + prices.insert( + "openai/gpt-4o-mini".into(), + ModelPricing { + input: 0.15, + output: 0.60, + }, + ); + prices.insert( + "openai/o1-preview".into(), + ModelPricing { + input: 15.0, + output: 60.0, + }, + ); + + // Google models + prices.insert( + "google/gemini-2.0-flash".into(), + ModelPricing { + input: 0.10, + output: 0.40, + }, + ); + prices.insert( + "google/gemini-1.5-pro".into(), + ModelPricing { + input: 1.25, + output: 5.0, + }, + ); + + prices +} + +// ── Peripherals (hardware: STM32, RPi GPIO, etc.) ──────────────────────── + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct PeripheralsConfig { + /// Enable peripheral support (boards become agent tools) + #[serde(default)] + pub enabled: bool, + /// Board configurations (nucleo-f401re, rpi-gpio, etc.) + #[serde(default)] + pub boards: Vec, + /// Path to datasheet docs (relative to workspace) for RAG retrieval. + /// Place .md/.txt files named by board (e.g. nucleo-f401re.md, rpi-gpio.md). + #[serde(default)] + pub datasheet_dir: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PeripheralBoardConfig { + /// Board type: "nucleo-f401re", "rpi-gpio", "esp32", etc. + pub board: String, + /// Transport: "serial", "native", "websocket" + #[serde(default = "default_peripheral_transport")] + pub transport: String, + /// Path for serial: "/dev/ttyACM0", "/dev/ttyUSB0" + #[serde(default)] + pub path: Option, + /// Baud rate for serial (default: 115200) + #[serde(default = "default_peripheral_baud")] + pub baud: u32, +} + +fn default_peripheral_transport() -> String { + "serial".into() +} + +fn default_peripheral_baud() -> u32 { + 115_200 +} + +impl Default for PeripheralBoardConfig { + fn default() -> Self { + Self { + board: String::new(), + transport: default_peripheral_transport(), + path: None, + baud: default_peripheral_baud(), + } + } +} + // ── Gateway security ───────────────────────────────────────────── #[derive(Debug, Clone, Serialize, Deserialize)] pub struct GatewayConfig { - /// Gateway port (default: 3000) + /// Gateway port (default: 8080) #[serde(default = "default_gateway_port")] pub port: u16, - /// Gateway host/bind address (default: 127.0.0.1) + /// Gateway host (default: 127.0.0.1) #[serde(default = "default_gateway_host")] pub host: String, /// Require pairing before accepting requests (default: true) @@ -104,6 +470,18 @@ pub struct GatewayConfig { /// Paired bearer tokens (managed automatically, not user-edited) #[serde(default)] pub paired_tokens: Vec, + + /// Max `/pair` requests per minute per client key. + #[serde(default = "default_pair_rate_limit")] + pub pair_rate_limit_per_minute: u32, + + /// Max `/webhook` requests per minute per client key. + #[serde(default = "default_webhook_rate_limit")] + pub webhook_rate_limit_per_minute: u32, + + /// TTL for webhook idempotency keys. + #[serde(default = "default_idempotency_ttl_secs")] + pub idempotency_ttl_secs: u64, } fn default_gateway_port() -> u16 { @@ -114,6 +492,18 @@ fn default_gateway_host() -> String { "127.0.0.1".into() } +fn default_pair_rate_limit() -> u32 { + 10 +} + +fn default_webhook_rate_limit() -> u32 { + 60 +} + +fn default_idempotency_ttl_secs() -> u64 { + 300 +} + fn default_true() -> bool { true } @@ -126,6 +516,9 @@ impl Default for GatewayConfig { require_pairing: true, allow_public_bind: false, paired_tokens: Vec::new(), + pair_rate_limit_per_minute: default_pair_rate_limit(), + webhook_rate_limit_per_minute: default_webhook_rate_limit(), + idempotency_ttl_secs: default_idempotency_ttl_secs(), } } } @@ -176,24 +569,136 @@ impl Default for SecretsConfig { // ── Browser (friendly-service browsing only) ─────────────────── -#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BrowserComputerUseConfig { + /// Sidecar endpoint for computer-use actions (OS-level mouse/keyboard/screenshot) + #[serde(default = "default_browser_computer_use_endpoint")] + pub endpoint: String, + /// Optional bearer token for computer-use sidecar + #[serde(default)] + pub api_key: Option, + /// Per-action request timeout in milliseconds + #[serde(default = "default_browser_computer_use_timeout_ms")] + pub timeout_ms: u64, + /// Allow remote/public endpoint for computer-use sidecar (default: false) + #[serde(default)] + pub allow_remote_endpoint: bool, + /// Optional window title/process allowlist forwarded to sidecar policy + #[serde(default)] + pub window_allowlist: Vec, + /// Optional X-axis boundary for coordinate-based actions + #[serde(default)] + pub max_coordinate_x: Option, + /// Optional Y-axis boundary for coordinate-based actions + #[serde(default)] + pub max_coordinate_y: Option, +} + +fn default_browser_computer_use_endpoint() -> String { + "http://127.0.0.1:8787/v1/actions".into() +} + +fn default_browser_computer_use_timeout_ms() -> u64 { + 15_000 +} + +impl Default for BrowserComputerUseConfig { + fn default() -> Self { + Self { + endpoint: default_browser_computer_use_endpoint(), + api_key: None, + timeout_ms: default_browser_computer_use_timeout_ms(), + allow_remote_endpoint: false, + window_allowlist: Vec::new(), + max_coordinate_x: None, + max_coordinate_y: None, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct BrowserConfig { - /// Enable browser tools (`browser_open` and browser automation) + /// Enable `browser_open` tool (opens URLs in Brave without scraping) #[serde(default)] pub enabled: bool, - /// Allowed domains for browser tools (exact or subdomain match) + /// Allowed domains for `browser_open` (exact or subdomain match) #[serde(default)] pub allowed_domains: Vec, - /// Session name for agent-browser (persists state across commands) + /// Browser session name (for agent-browser automation) #[serde(default)] pub session_name: Option, + /// Browser automation backend: "agent_browser" | "rust_native" | "computer_use" | "auto" + #[serde(default = "default_browser_backend")] + pub backend: String, + /// Headless mode for rust-native backend + #[serde(default = "default_true")] + pub native_headless: bool, + /// WebDriver endpoint URL for rust-native backend (e.g. http://127.0.0.1:9515) + #[serde(default = "default_browser_webdriver_url")] + pub native_webdriver_url: String, + /// Optional Chrome/Chromium executable path for rust-native backend + #[serde(default)] + pub native_chrome_path: Option, + /// Computer-use sidecar configuration + #[serde(default)] + pub computer_use: BrowserComputerUseConfig, +} + +fn default_browser_backend() -> String { + "agent_browser".into() +} + +fn default_browser_webdriver_url() -> String { + "http://127.0.0.1:9515".into() +} + +impl Default for BrowserConfig { + fn default() -> Self { + Self { + enabled: false, + allowed_domains: Vec::new(), + session_name: None, + backend: default_browser_backend(), + native_headless: default_true(), + native_webdriver_url: default_browser_webdriver_url(), + native_chrome_path: None, + computer_use: BrowserComputerUseConfig::default(), + } + } +} + +// ── HTTP request tool ─────────────────────────────────────────── + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct HttpRequestConfig { + /// Enable `http_request` tool for API interactions + #[serde(default)] + pub enabled: bool, + /// Allowed domains for HTTP requests (exact or subdomain match) + #[serde(default)] + pub allowed_domains: Vec, + /// Maximum response size in bytes (default: 1MB) + #[serde(default = "default_http_max_response_size")] + pub max_response_size: usize, + /// Request timeout in seconds (default: 30) + #[serde(default = "default_http_timeout_secs")] + pub timeout_secs: u64, +} + +fn default_http_max_response_size() -> usize { + 1_000_000 // 1MB +} + +fn default_http_timeout_secs() -> u64 { + 30 } // ── Memory ─────────────────────────────────────────────────── #[derive(Debug, Clone, Serialize, Deserialize)] +#[allow(clippy::struct_excessive_bools)] pub struct MemoryConfig { - /// "sqlite" | "markdown" | "none" + /// "sqlite" | "lucid" | "markdown" | "none" (`none` = explicit no-op memory) pub backend: String, /// Auto-save conversation context to memory pub auto_save: bool, @@ -230,6 +735,28 @@ pub struct MemoryConfig { /// Max tokens per chunk for document splitting #[serde(default = "default_chunk_size")] pub chunk_max_tokens: usize, + + // ── Response Cache (saves tokens on repeated prompts) ────── + /// Enable LLM response caching to avoid paying for duplicate prompts + #[serde(default)] + pub response_cache_enabled: bool, + /// TTL in minutes for cached responses (default: 60) + #[serde(default = "default_response_cache_ttl")] + pub response_cache_ttl_minutes: u32, + /// Max number of cached responses before LRU eviction (default: 5000) + #[serde(default = "default_response_cache_max")] + pub response_cache_max_entries: usize, + + // ── Memory Snapshot (soul backup to Markdown) ───────────── + /// Enable periodic export of core memories to MEMORY_SNAPSHOT.md + #[serde(default)] + pub snapshot_enabled: bool, + /// Run snapshot during hygiene passes (heartbeat-driven) + #[serde(default)] + pub snapshot_on_hygiene: bool, + /// Auto-hydrate from MEMORY_SNAPSHOT.md when brain.db is missing + #[serde(default = "default_true")] + pub auto_hydrate: bool, } fn default_embedding_provider() -> String { @@ -265,6 +792,12 @@ fn default_cache_size() -> usize { fn default_chunk_size() -> usize { 512 } +fn default_response_cache_ttl() -> u32 { + 60 +} +fn default_response_cache_max() -> usize { + 5_000 +} impl Default for MemoryConfig { fn default() -> Self { @@ -282,6 +815,12 @@ impl Default for MemoryConfig { keyword_weight: default_keyword_weight(), embedding_cache_size: default_cache_size(), chunk_max_tokens: default_chunk_size(), + response_cache_enabled: false, + response_cache_ttl_minutes: default_response_cache_ttl(), + response_cache_max_entries: default_response_cache_max(), + snapshot_enabled: false, + snapshot_on_hygiene: false, + auto_hydrate: true, } } } @@ -292,12 +831,22 @@ impl Default for MemoryConfig { pub struct ObservabilityConfig { /// "none" | "log" | "prometheus" | "otel" pub backend: String, + + /// OTLP endpoint (e.g. "http://localhost:4318"). Only used when backend = "otel". + #[serde(default)] + pub otel_endpoint: Option, + + /// Service name reported to the OTel collector. Defaults to "zeroclaw". + #[serde(default)] + pub otel_service_name: Option, } impl Default for ObservabilityConfig { fn default() -> Self { Self { backend: "none".into(), + otel_endpoint: None, + otel_service_name: None, } } } @@ -312,6 +861,30 @@ pub struct AutonomyConfig { pub forbidden_paths: Vec, pub max_actions_per_hour: u32, pub max_cost_per_day_cents: u32, + + /// Require explicit approval for medium-risk shell commands. + #[serde(default = "default_true")] + pub require_approval_for_medium_risk: bool, + + /// Block high-risk shell commands even if allowlisted. + #[serde(default = "default_true")] + pub block_high_risk_commands: bool, + + /// Tools that never require approval (e.g. read-only tools). + #[serde(default = "default_auto_approve")] + pub auto_approve: Vec, + + /// Tools that always require interactive approval, even after "Always". + #[serde(default = "default_always_ask")] + pub always_ask: Vec, +} + +fn default_auto_approve() -> Vec { + vec!["file_read".into(), "memory_recall".into()] +} + +fn default_always_ask() -> Vec { + vec![] } impl Default for AutonomyConfig { @@ -355,6 +928,10 @@ impl Default for AutonomyConfig { ], max_actions_per_hour: 20, max_cost_per_day_cents: 500, + require_approval_for_medium_risk: true, + block_high_risk_commands: true, + auto_approve: default_auto_approve(), + always_ask: default_always_ask(), } } } @@ -363,16 +940,85 @@ impl Default for AutonomyConfig { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RuntimeConfig { - /// Runtime kind (currently supported: "native"). - /// - /// Reserved values (not implemented yet): "docker", "cloudflare". + /// Runtime kind (`native` | `docker`). + #[serde(default = "default_runtime_kind")] pub kind: String, + + /// Docker runtime settings (used when `kind = "docker"`). + #[serde(default)] + pub docker: DockerRuntimeConfig, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DockerRuntimeConfig { + /// Runtime image used to execute shell commands. + #[serde(default = "default_docker_image")] + pub image: String, + + /// Docker network mode (`none`, `bridge`, etc.). + #[serde(default = "default_docker_network")] + pub network: String, + + /// Optional memory limit in MB (`None` = no explicit limit). + #[serde(default = "default_docker_memory_limit_mb")] + pub memory_limit_mb: Option, + + /// Optional CPU limit (`None` = no explicit limit). + #[serde(default = "default_docker_cpu_limit")] + pub cpu_limit: Option, + + /// Mount root filesystem as read-only. + #[serde(default = "default_true")] + pub read_only_rootfs: bool, + + /// Mount configured workspace into `/workspace`. + #[serde(default = "default_true")] + pub mount_workspace: bool, + + /// Optional workspace root allowlist for Docker mount validation. + #[serde(default)] + pub allowed_workspace_roots: Vec, +} + +fn default_runtime_kind() -> String { + "native".into() +} + +fn default_docker_image() -> String { + "alpine:3.20".into() +} + +fn default_docker_network() -> String { + "none".into() +} + +fn default_docker_memory_limit_mb() -> Option { + Some(512) +} + +fn default_docker_cpu_limit() -> Option { + Some(1.0) +} + +impl Default for DockerRuntimeConfig { + fn default() -> Self { + Self { + image: default_docker_image(), + network: default_docker_network(), + memory_limit_mb: default_docker_memory_limit_mb(), + cpu_limit: default_docker_cpu_limit(), + read_only_rootfs: true, + mount_workspace: true, + allowed_workspace_roots: Vec::new(), + } + } } impl Default for RuntimeConfig { fn default() -> Self { Self { - kind: "native".into(), + kind: default_runtime_kind(), + docker: DockerRuntimeConfig::default(), } } } @@ -390,6 +1036,14 @@ pub struct ReliabilityConfig { /// Fallback provider chain (e.g. `["anthropic", "openai"]`). #[serde(default)] pub fallback_providers: Vec, + /// Additional API keys for round-robin rotation on rate-limit (429) errors. + /// The primary `api_key` is always tried first; these are extras. + #[serde(default)] + pub api_keys: Vec, + /// Per-model fallback chains. When a model fails, try these alternatives in order. + /// Example: `{ "claude-opus-4-20250514" = ["claude-sonnet-4-20250514", "gpt-4o"] }` + #[serde(default)] + pub model_fallbacks: std::collections::HashMap>, /// Initial backoff for channel/daemon restarts. #[serde(default = "default_channel_backoff_secs")] pub channel_initial_backoff_secs: u64, @@ -434,6 +1088,8 @@ impl Default for ReliabilityConfig { provider_retries: default_provider_retries(), provider_backoff_ms: default_provider_backoff_ms(), fallback_providers: Vec::new(), + api_keys: Vec::new(), + model_fallbacks: std::collections::HashMap::new(), channel_initial_backoff_secs: default_channel_backoff_secs(), channel_max_backoff_secs: default_channel_backoff_max_secs(), scheduler_poll_secs: default_scheduler_poll_secs(), @@ -442,6 +1098,73 @@ impl Default for ReliabilityConfig { } } +// ── Scheduler ──────────────────────────────────────────────────── + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SchedulerConfig { + /// Enable the built-in scheduler loop. + #[serde(default = "default_scheduler_enabled")] + pub enabled: bool, + /// Maximum number of persisted scheduled tasks. + #[serde(default = "default_scheduler_max_tasks")] + pub max_tasks: usize, + /// Maximum tasks executed per scheduler polling cycle. + #[serde(default = "default_scheduler_max_concurrent")] + pub max_concurrent: usize, +} + +fn default_scheduler_enabled() -> bool { + true +} + +fn default_scheduler_max_tasks() -> usize { + 64 +} + +fn default_scheduler_max_concurrent() -> usize { + 4 +} + +impl Default for SchedulerConfig { + fn default() -> Self { + Self { + enabled: default_scheduler_enabled(), + max_tasks: default_scheduler_max_tasks(), + max_concurrent: default_scheduler_max_concurrent(), + } + } +} + +// ── Model routing ──────────────────────────────────────────────── + +/// Route a task hint to a specific provider + model. +/// +/// ```toml +/// [[model_routes]] +/// hint = "reasoning" +/// provider = "openrouter" +/// model = "anthropic/claude-opus-4-20250514" +/// +/// [[model_routes]] +/// hint = "fast" +/// provider = "groq" +/// model = "llama-3.3-70b-versatile" +/// ``` +/// +/// Usage: pass `hint:reasoning` as the model parameter to route the request. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelRouteConfig { + /// Task hint name (e.g. "reasoning", "fast", "code", "summarize") + pub hint: String, + /// Provider to route to (must match a known provider name) + pub provider: String, + /// Model to use with that provider + pub model: String, + /// Optional API key override for this route's provider + #[serde(default)] + pub api_key: Option, +} + // ── Heartbeat ──────────────────────────────────────────────────── #[derive(Debug, Clone, Serialize, Deserialize)] @@ -459,6 +1182,29 @@ impl Default for HeartbeatConfig { } } +// ── Cron ──────────────────────────────────────────────────────── + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CronConfig { + #[serde(default = "default_true")] + pub enabled: bool, + #[serde(default = "default_max_run_history")] + pub max_run_history: u32, +} + +fn default_max_run_history() -> u32 { + 50 +} + +impl Default for CronConfig { + fn default() -> Self { + Self { + enabled: true, + max_run_history: default_max_run_history(), + } + } +} + // ── Tunnel ────────────────────────────────────────────────────── #[derive(Debug, Clone, Serialize, Deserialize)] @@ -533,10 +1279,17 @@ pub struct ChannelsConfig { pub telegram: Option, pub discord: Option, pub slack: Option, + pub mattermost: Option, pub webhook: Option, pub imessage: Option, pub matrix: Option, + pub signal: Option, pub whatsapp: Option, + pub email: Option, + pub irc: Option, + pub lark: Option, + pub dingtalk: Option, + pub qq: Option, } impl Default for ChannelsConfig { @@ -546,10 +1299,17 @@ impl Default for ChannelsConfig { telegram: None, discord: None, slack: None, + mattermost: None, webhook: None, imessage: None, matrix: None, + signal: None, whatsapp: None, + email: None, + irc: None, + lark: None, + dingtalk: None, + qq: None, } } } @@ -566,6 +1326,14 @@ pub struct DiscordConfig { pub guild_id: Option, #[serde(default)] pub allowed_users: Vec, + /// When true, process messages from other bots (not just humans). + /// The bot still ignores its own messages to prevent feedback loops. + #[serde(default)] + pub listen_to_bots: bool, + /// When true, only respond to messages that @-mention the bot. + /// Other messages in the guild are silently ignored. + #[serde(default)] + pub mention_only: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -577,6 +1345,15 @@ pub struct SlackConfig { pub allowed_users: Vec, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MattermostConfig { + pub url: String, + pub bot_token: String, + pub channel_id: Option, + #[serde(default)] + pub allowed_users: Vec, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct WebhookConfig { pub port: u16, @@ -596,6 +1373,29 @@ pub struct MatrixConfig { pub allowed_users: Vec, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SignalConfig { + /// Base URL for the signal-cli HTTP daemon (e.g. "http://127.0.0.1:8686"). + pub http_url: String, + /// E.164 phone number of the signal-cli account (e.g. "+1234567890"). + pub account: String, + /// Optional group ID to filter messages. + /// - `None` or omitted: accept all messages (DMs and groups) + /// - `"dm"`: only accept direct messages + /// - Specific group ID: only accept messages from that group + #[serde(default)] + pub group_id: Option, + /// Allowed sender phone numbers (E.164) or "*" for all. + #[serde(default)] + pub allowed_from: Vec, + /// Skip messages that are attachment-only (no text body). + #[serde(default)] + pub ignore_attachments: bool, + /// Skip incoming story messages. + #[serde(default)] + pub ignore_stories: bool, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct WhatsAppConfig { /// Access token from Meta Business Suite @@ -604,11 +1404,264 @@ pub struct WhatsAppConfig { pub phone_number_id: String, /// Webhook verify token (you define this, Meta sends it back for verification) pub verify_token: String, + /// App secret from Meta Business Suite (for webhook signature verification) + /// Can also be set via `ZEROCLAW_WHATSAPP_APP_SECRET` environment variable + #[serde(default)] + pub app_secret: Option, /// Allowed phone numbers (E.164 format: +1234567890) or "*" for all #[serde(default)] pub allowed_numbers: Vec, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct IrcConfig { + /// IRC server hostname + pub server: String, + /// IRC server port (default: 6697 for TLS) + #[serde(default = "default_irc_port")] + pub port: u16, + /// Bot nickname + pub nickname: String, + /// Username (defaults to nickname if not set) + pub username: Option, + /// Channels to join on connect + #[serde(default)] + pub channels: Vec, + /// Allowed nicknames (case-insensitive) or "*" for all + #[serde(default)] + pub allowed_users: Vec, + /// Server password (for bouncers like ZNC) + pub server_password: Option, + /// NickServ IDENTIFY password + pub nickserv_password: Option, + /// SASL PLAIN password (IRCv3) + pub sasl_password: Option, + /// Verify TLS certificate (default: true) + pub verify_tls: Option, +} + +fn default_irc_port() -> u16 { + 6697 +} + +/// How ZeroClaw receives events from Feishu / Lark. +/// +/// - `websocket` (default) — persistent WSS long-connection; no public URL required. +/// - `webhook` — HTTP callback server; requires a public HTTPS endpoint. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum LarkReceiveMode { + #[default] + Websocket, + Webhook, +} + +/// Lark/Feishu configuration for messaging integration. +/// Lark is the international version; Feishu is the Chinese version. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LarkConfig { + /// App ID from Lark/Feishu developer console + pub app_id: String, + /// App Secret from Lark/Feishu developer console + pub app_secret: String, + /// Encrypt key for webhook message decryption (optional) + #[serde(default)] + pub encrypt_key: Option, + /// Verification token for webhook validation (optional) + #[serde(default)] + pub verification_token: Option, + /// Allowed user IDs or union IDs (empty = deny all, "*" = allow all) + #[serde(default)] + pub allowed_users: Vec, + /// Whether to use the Feishu (Chinese) endpoint instead of Lark (International) + #[serde(default)] + pub use_feishu: bool, + /// Event receive mode: "websocket" (default) or "webhook" + #[serde(default)] + pub receive_mode: LarkReceiveMode, + /// HTTP port for webhook mode only. Must be set when receive_mode = "webhook". + /// Not required (and ignored) for websocket mode. + #[serde(default)] + pub port: Option, +} + +// ── Security Config ───────────────────────────────────────────────── + +/// Security configuration for sandboxing, resource limits, and audit logging +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct SecurityConfig { + /// Sandbox configuration + #[serde(default)] + pub sandbox: SandboxConfig, + + /// Resource limits + #[serde(default)] + pub resources: ResourceLimitsConfig, + + /// Audit logging configuration + #[serde(default)] + pub audit: AuditConfig, +} + +/// Sandbox configuration for OS-level isolation +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SandboxConfig { + /// Enable sandboxing (None = auto-detect, Some = explicit) + #[serde(default)] + pub enabled: Option, + + /// Sandbox backend to use + #[serde(default)] + pub backend: SandboxBackend, + + /// Custom Firejail arguments (when backend = firejail) + #[serde(default)] + pub firejail_args: Vec, +} + +impl Default for SandboxConfig { + fn default() -> Self { + Self { + enabled: None, // Auto-detect + backend: SandboxBackend::Auto, + firejail_args: Vec::new(), + } + } +} + +/// Sandbox backend selection +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum SandboxBackend { + /// Auto-detect best available (default) + #[default] + Auto, + /// Landlock (Linux kernel LSM, native) + Landlock, + /// Firejail (user-space sandbox) + Firejail, + /// Bubblewrap (user namespaces) + Bubblewrap, + /// Docker container isolation + Docker, + /// No sandboxing (application-layer only) + None, +} + +/// Resource limits for command execution +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ResourceLimitsConfig { + /// Maximum memory in MB per command + #[serde(default = "default_max_memory_mb")] + pub max_memory_mb: u32, + + /// Maximum CPU time in seconds per command + #[serde(default = "default_max_cpu_time_seconds")] + pub max_cpu_time_seconds: u64, + + /// Maximum number of subprocesses + #[serde(default = "default_max_subprocesses")] + pub max_subprocesses: u32, + + /// Enable memory monitoring + #[serde(default = "default_memory_monitoring_enabled")] + pub memory_monitoring: bool, +} + +fn default_max_memory_mb() -> u32 { + 512 +} + +fn default_max_cpu_time_seconds() -> u64 { + 60 +} + +fn default_max_subprocesses() -> u32 { + 10 +} + +fn default_memory_monitoring_enabled() -> bool { + true +} + +impl Default for ResourceLimitsConfig { + fn default() -> Self { + Self { + max_memory_mb: default_max_memory_mb(), + max_cpu_time_seconds: default_max_cpu_time_seconds(), + max_subprocesses: default_max_subprocesses(), + memory_monitoring: default_memory_monitoring_enabled(), + } + } +} + +/// Audit logging configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AuditConfig { + /// Enable audit logging + #[serde(default = "default_audit_enabled")] + pub enabled: bool, + + /// Path to audit log file (relative to zeroclaw dir) + #[serde(default = "default_audit_log_path")] + pub log_path: String, + + /// Maximum log size in MB before rotation + #[serde(default = "default_audit_max_size_mb")] + pub max_size_mb: u32, + + /// Sign events with HMAC for tamper evidence + #[serde(default)] + pub sign_events: bool, +} + +fn default_audit_enabled() -> bool { + true +} + +fn default_audit_log_path() -> String { + "audit.log".to_string() +} + +fn default_audit_max_size_mb() -> u32 { + 100 +} + +impl Default for AuditConfig { + fn default() -> Self { + Self { + enabled: default_audit_enabled(), + log_path: default_audit_log_path(), + max_size_mb: default_audit_max_size_mb(), + sign_events: false, + } + } +} + +/// DingTalk configuration for Stream Mode messaging +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DingTalkConfig { + /// Client ID (AppKey) from DingTalk developer console + pub client_id: String, + /// Client Secret (AppSecret) from DingTalk developer console + pub client_secret: String, + /// Allowed user IDs (staff IDs). Empty = deny all, "*" = allow all + #[serde(default)] + pub allowed_users: Vec, +} + +/// QQ Official Bot configuration (Tencent QQ Bot SDK) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct QQConfig { + /// App ID from QQ Bot developer console + pub app_id: String, + /// App Secret from QQ Bot developer console + pub app_secret: String, + /// Allowed user IDs. Empty = deny all, "*" = allow all + #[serde(default)] + pub allowed_users: Vec, +} + // ── Config impl ────────────────────────────────────────────────── impl Default for Config { @@ -621,14 +1674,19 @@ impl Default for Config { workspace_dir: zeroclaw_dir.join("workspace"), config_path: zeroclaw_dir.join("config.toml"), api_key: None, + api_url: None, default_provider: Some("openrouter".to_string()), - default_model: Some("anthropic/claude-sonnet-4-20250514".to_string()), + default_model: Some("anthropic/claude-sonnet-4".to_string()), default_temperature: 0.7, observability: ObservabilityConfig::default(), autonomy: AutonomyConfig::default(), runtime: RuntimeConfig::default(), reliability: ReliabilityConfig::default(), + scheduler: SchedulerConfig::default(), + agent: AgentConfig::default(), + model_routes: Vec::new(), heartbeat: HeartbeatConfig::default(), + cron: CronConfig::default(), channels_config: ChannelsConfig::default(), memory: MemoryConfig::default(), tunnel: TunnelConfig::default(), @@ -636,70 +1694,305 @@ impl Default for Config { composio: ComposioConfig::default(), secrets: SecretsConfig::default(), browser: BrowserConfig::default(), + http_request: HttpRequestConfig::default(), identity: IdentityConfig::default(), + cost: CostConfig::default(), + peripherals: PeripheralsConfig::default(), + agents: HashMap::new(), + hardware: HardwareConfig::default(), } } } +fn default_config_and_workspace_dirs() -> Result<(PathBuf, PathBuf)> { + let config_dir = default_config_dir()?; + Ok((config_dir.clone(), config_dir.join("workspace"))) +} + +const ACTIVE_WORKSPACE_STATE_FILE: &str = "active_workspace.toml"; + +#[derive(Debug, Serialize, Deserialize)] +struct ActiveWorkspaceState { + config_dir: String, +} + +fn default_config_dir() -> Result { + let home = UserDirs::new() + .map(|u| u.home_dir().to_path_buf()) + .context("Could not find home directory")?; + Ok(home.join(".zeroclaw")) +} + +fn active_workspace_state_path(default_dir: &Path) -> PathBuf { + default_dir.join(ACTIVE_WORKSPACE_STATE_FILE) +} + +fn load_persisted_workspace_dirs(default_config_dir: &Path) -> Result> { + let state_path = active_workspace_state_path(default_config_dir); + if !state_path.exists() { + return Ok(None); + } + + let contents = match fs::read_to_string(&state_path) { + Ok(contents) => contents, + Err(error) => { + tracing::warn!( + "Failed to read active workspace marker {}: {error}", + state_path.display() + ); + return Ok(None); + } + }; + + let state: ActiveWorkspaceState = match toml::from_str(&contents) { + Ok(state) => state, + Err(error) => { + tracing::warn!( + "Failed to parse active workspace marker {}: {error}", + state_path.display() + ); + return Ok(None); + } + }; + + let raw_config_dir = state.config_dir.trim(); + if raw_config_dir.is_empty() { + tracing::warn!( + "Ignoring active workspace marker {} because config_dir is empty", + state_path.display() + ); + return Ok(None); + } + + let parsed_dir = PathBuf::from(raw_config_dir); + let config_dir = if parsed_dir.is_absolute() { + parsed_dir + } else { + default_config_dir.join(parsed_dir) + }; + Ok(Some((config_dir.clone(), config_dir.join("workspace")))) +} + +pub(crate) fn persist_active_workspace_config_dir(config_dir: &Path) -> Result<()> { + let default_config_dir = default_config_dir()?; + let state_path = active_workspace_state_path(&default_config_dir); + + if config_dir == default_config_dir { + if state_path.exists() { + fs::remove_file(&state_path).with_context(|| { + format!( + "Failed to clear active workspace marker: {}", + state_path.display() + ) + })?; + } + return Ok(()); + } + + fs::create_dir_all(&default_config_dir).with_context(|| { + format!( + "Failed to create default config directory: {}", + default_config_dir.display() + ) + })?; + + let state = ActiveWorkspaceState { + config_dir: config_dir.to_string_lossy().into_owned(), + }; + let serialized = + toml::to_string_pretty(&state).context("Failed to serialize active workspace marker")?; + + let temp_path = default_config_dir.join(format!( + ".{ACTIVE_WORKSPACE_STATE_FILE}.tmp-{}", + uuid::Uuid::new_v4() + )); + fs::write(&temp_path, serialized).with_context(|| { + format!( + "Failed to write temporary active workspace marker: {}", + temp_path.display() + ) + })?; + + if let Err(error) = fs::rename(&temp_path, &state_path) { + let _ = fs::remove_file(&temp_path); + anyhow::bail!( + "Failed to atomically persist active workspace marker {}: {error}", + state_path.display() + ); + } + + sync_directory(&default_config_dir)?; + Ok(()) +} + +fn resolve_config_dir_for_workspace(workspace_dir: &Path) -> PathBuf { + let workspace_config_dir = workspace_dir.to_path_buf(); + if workspace_config_dir.join("config.toml").exists() { + return workspace_config_dir; + } + + let legacy_config_dir = workspace_dir + .parent() + .map(|parent| parent.join(".zeroclaw")); + if let Some(legacy_dir) = legacy_config_dir { + if legacy_dir.join("config.toml").exists() { + return legacy_dir; + } + + if workspace_dir + .file_name() + .is_some_and(|name| name == std::ffi::OsStr::new("workspace")) + { + return legacy_dir; + } + } + + workspace_config_dir +} + +fn decrypt_optional_secret( + store: &crate::security::SecretStore, + value: &mut Option, + field_name: &str, +) -> Result<()> { + if let Some(raw) = value.clone() { + if crate::security::SecretStore::is_encrypted(&raw) { + *value = Some( + store + .decrypt(&raw) + .with_context(|| format!("Failed to decrypt {field_name}"))?, + ); + } + } + Ok(()) +} + +fn encrypt_optional_secret( + store: &crate::security::SecretStore, + value: &mut Option, + field_name: &str, +) -> Result<()> { + if let Some(raw) = value.clone() { + if !crate::security::SecretStore::is_encrypted(&raw) { + *value = Some( + store + .encrypt(&raw) + .with_context(|| format!("Failed to encrypt {field_name}"))?, + ); + } + } + Ok(()) +} + impl Config { pub fn load_or_init() -> Result { - // Check for workspace override from environment (Docker support) - let zeroclaw_dir = if let Ok(workspace) = std::env::var("ZEROCLAW_WORKSPACE") { - let ws_path = PathBuf::from(&workspace); - ws_path - .parent() - .map_or_else(|| PathBuf::from(&workspace), PathBuf::from) - } else { - let home = UserDirs::new() - .map(|u| u.home_dir().to_path_buf()) - .context("Could not find home directory")?; - home.join(".zeroclaw") + let (default_zeroclaw_dir, default_workspace_dir) = default_config_and_workspace_dirs()?; + + // Resolution priority: + // 1. ZEROCLAW_WORKSPACE env override + // 2. Persisted active workspace marker from onboarding/custom profile + // 3. Default ~/.zeroclaw layout + let (zeroclaw_dir, workspace_dir) = match std::env::var("ZEROCLAW_WORKSPACE") { + Ok(custom_workspace) if !custom_workspace.is_empty() => { + let workspace = PathBuf::from(custom_workspace); + (resolve_config_dir_for_workspace(&workspace), workspace) + } + _ => load_persisted_workspace_dirs(&default_zeroclaw_dir)? + .unwrap_or((default_zeroclaw_dir, default_workspace_dir)), }; let config_path = zeroclaw_dir.join("config.toml"); - if !zeroclaw_dir.exists() { - fs::create_dir_all(&zeroclaw_dir).context("Failed to create .zeroclaw directory")?; - fs::create_dir_all(zeroclaw_dir.join("workspace")) - .context("Failed to create workspace directory")?; - } + fs::create_dir_all(&zeroclaw_dir).context("Failed to create config directory")?; + fs::create_dir_all(&workspace_dir).context("Failed to create workspace directory")?; + + if config_path.exists() { + // Warn if config file is world-readable (may contain API keys) + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + if let Ok(meta) = fs::metadata(&config_path) { + if meta.permissions().mode() & 0o004 != 0 { + tracing::warn!( + "Config file {:?} is world-readable (mode {:o}). \ + Consider restricting with: chmod 600 {:?}", + config_path, + meta.permissions().mode() & 0o777, + config_path, + ); + } + } + } - let mut config = if config_path.exists() { let contents = fs::read_to_string(&config_path).context("Failed to read config file")?; - toml::from_str(&contents).context("Failed to parse config file")? + let mut config: Config = + toml::from_str(&contents).context("Failed to parse config file")?; + // Set computed paths that are skipped during serialization + config.config_path = config_path.clone(); + config.workspace_dir = workspace_dir; + let store = crate::security::SecretStore::new(&zeroclaw_dir, config.secrets.encrypt); + decrypt_optional_secret(&store, &mut config.api_key, "config.api_key")?; + decrypt_optional_secret( + &store, + &mut config.composio.api_key, + "config.composio.api_key", + )?; + + decrypt_optional_secret( + &store, + &mut config.browser.computer_use.api_key, + "config.browser.computer_use.api_key", + )?; + + for agent in config.agents.values_mut() { + decrypt_optional_secret(&store, &mut agent.api_key, "config.agents.*.api_key")?; + } + config.apply_env_overrides(); + Ok(config) } else { - Config::default() - }; - - // Apply environment variable overrides (Docker/container support) - config.apply_env_overrides(); - - // Save config if it didn't exist (creates default config with env overrides) - if !config_path.exists() { + let mut config = Config::default(); + config.config_path = config_path.clone(); + config.workspace_dir = workspace_dir; config.save()?; - } - Ok(config) + // Restrict permissions on newly created config file (may contain API keys) + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let _ = fs::set_permissions(&config_path, fs::Permissions::from_mode(0o600)); + } + + config.apply_env_overrides(); + Ok(config) + } } - /// Apply environment variable overrides to config. - /// - /// Supports: - /// - `ZEROCLAW_API_KEY` or `API_KEY` - LLM provider API key - /// - `ZEROCLAW_PROVIDER` or `PROVIDER` - Provider name (openrouter, openai, anthropic, ollama) - /// - `ZEROCLAW_MODEL` - Model name/ID - /// - `ZEROCLAW_WORKSPACE` - Workspace directory path - /// - `ZEROCLAW_GATEWAY_PORT` or `PORT` - Gateway server port - /// - `ZEROCLAW_GATEWAY_HOST` or `HOST` - Gateway bind address - /// - `ZEROCLAW_TEMPERATURE` - Default temperature (0.0-2.0) + /// Apply environment variable overrides to config pub fn apply_env_overrides(&mut self) { - // API Key: ZEROCLAW_API_KEY or API_KEY + // API Key: ZEROCLAW_API_KEY or API_KEY (generic) if let Ok(key) = std::env::var("ZEROCLAW_API_KEY").or_else(|_| std::env::var("API_KEY")) { if !key.is_empty() { self.api_key = Some(key); } } + // API Key: GLM_API_KEY overrides when provider is a GLM/Zhipu variant. + if self.default_provider.as_deref().is_some_and(is_glm_alias) { + if let Ok(key) = std::env::var("GLM_API_KEY") { + if !key.is_empty() { + self.api_key = Some(key); + } + } + } + + // API Key: ZAI_API_KEY overrides when provider is a Z.AI variant. + if self.default_provider.as_deref().is_some_and(is_zai_alias) { + if let Ok(key) = std::env::var("ZAI_API_KEY") { + if !key.is_empty() { + self.api_key = Some(key); + } + } + } // Provider: ZEROCLAW_PROVIDER or PROVIDER if let Ok(provider) = @@ -717,15 +2010,6 @@ impl Config { } } - // Temperature: ZEROCLAW_TEMPERATURE - if let Ok(temp_str) = std::env::var("ZEROCLAW_TEMPERATURE") { - if let Ok(temp) = temp_str.parse::() { - if (0.0..=2.0).contains(&temp) { - self.default_temperature = temp; - } - } - } - // Workspace directory: ZEROCLAW_WORKSPACE if let Ok(workspace) = std::env::var("ZEROCLAW_WORKSPACE") { if !workspace.is_empty() { @@ -749,15 +2033,130 @@ impl Config { self.gateway.host = host; } } + + // Allow public bind: ZEROCLAW_ALLOW_PUBLIC_BIND + if let Ok(val) = std::env::var("ZEROCLAW_ALLOW_PUBLIC_BIND") { + self.gateway.allow_public_bind = val == "1" || val.eq_ignore_ascii_case("true"); + } + + // Temperature: ZEROCLAW_TEMPERATURE + if let Ok(temp_str) = std::env::var("ZEROCLAW_TEMPERATURE") { + if let Ok(temp) = temp_str.parse::() { + if (0.0..=2.0).contains(&temp) { + self.default_temperature = temp; + } + } + } } pub fn save(&self) -> Result<()> { - let toml_str = toml::to_string_pretty(self).context("Failed to serialize config")?; - fs::write(&self.config_path, toml_str).context("Failed to write config file")?; + // Encrypt secrets before serialization + let mut config_to_save = self.clone(); + let zeroclaw_dir = self + .config_path + .parent() + .context("Config path must have a parent directory")?; + let store = crate::security::SecretStore::new(zeroclaw_dir, self.secrets.encrypt); + + encrypt_optional_secret(&store, &mut config_to_save.api_key, "config.api_key")?; + encrypt_optional_secret( + &store, + &mut config_to_save.composio.api_key, + "config.composio.api_key", + )?; + + encrypt_optional_secret( + &store, + &mut config_to_save.browser.computer_use.api_key, + "config.browser.computer_use.api_key", + )?; + + for agent in config_to_save.agents.values_mut() { + encrypt_optional_secret(&store, &mut agent.api_key, "config.agents.*.api_key")?; + } + + let toml_str = + toml::to_string_pretty(&config_to_save).context("Failed to serialize config")?; + + let parent_dir = self + .config_path + .parent() + .context("Config path must have a parent directory")?; + fs::create_dir_all(parent_dir).with_context(|| { + format!( + "Failed to create config directory: {}", + parent_dir.display() + ) + })?; + + let file_name = self + .config_path + .file_name() + .and_then(|v| v.to_str()) + .unwrap_or("config.toml"); + let temp_path = parent_dir.join(format!(".{file_name}.tmp-{}", uuid::Uuid::new_v4())); + let backup_path = parent_dir.join(format!("{file_name}.bak")); + + let mut temp_file = OpenOptions::new() + .create_new(true) + .write(true) + .open(&temp_path) + .with_context(|| { + format!( + "Failed to create temporary config file: {}", + temp_path.display() + ) + })?; + temp_file + .write_all(toml_str.as_bytes()) + .context("Failed to write temporary config contents")?; + temp_file + .sync_all() + .context("Failed to fsync temporary config file")?; + drop(temp_file); + + let had_existing_config = self.config_path.exists(); + if had_existing_config { + fs::copy(&self.config_path, &backup_path).with_context(|| { + format!( + "Failed to create config backup before atomic replace: {}", + backup_path.display() + ) + })?; + } + + if let Err(e) = fs::rename(&temp_path, &self.config_path) { + let _ = fs::remove_file(&temp_path); + if had_existing_config && backup_path.exists() { + let _ = fs::copy(&backup_path, &self.config_path); + } + anyhow::bail!("Failed to atomically replace config file: {e}"); + } + + sync_directory(parent_dir)?; + + if had_existing_config { + let _ = fs::remove_file(&backup_path); + } + Ok(()) } } +#[cfg(unix)] +fn sync_directory(path: &Path) -> Result<()> { + let dir = File::open(path) + .with_context(|| format!("Failed to open directory for fsync: {}", path.display()))?; + dir.sync_all() + .with_context(|| format!("Failed to fsync directory metadata: {}", path.display()))?; + Ok(()) +} + +#[cfg(not(unix))] +fn sync_directory(_path: &Path) -> Result<()> { + Ok(()) +} + #[cfg(test)] mod tests { use super::*; @@ -792,12 +2191,20 @@ mod tests { assert!(a.forbidden_paths.contains(&"/etc".to_string())); assert_eq!(a.max_actions_per_hour, 20); assert_eq!(a.max_cost_per_day_cents, 500); + assert!(a.require_approval_for_medium_risk); + assert!(a.block_high_risk_commands); } #[test] fn runtime_config_default() { let r = RuntimeConfig::default(); assert_eq!(r.kind, "native"); + assert_eq!(r.docker.image, "alpine:3.20"); + assert_eq!(r.docker.network, "none"); + assert_eq!(r.docker.memory_limit_mb, Some(512)); + assert_eq!(r.docker.cpu_limit, Some(1.0)); + assert!(r.docker.read_only_rootfs); + assert!(r.docker.mount_workspace); } #[test] @@ -807,6 +2214,38 @@ mod tests { assert_eq!(h.interval_minutes, 30); } + #[test] + fn cron_config_default() { + let c = CronConfig::default(); + assert!(c.enabled); + assert_eq!(c.max_run_history, 50); + } + + #[test] + fn cron_config_serde_roundtrip() { + let c = CronConfig { + enabled: false, + max_run_history: 100, + }; + let json = serde_json::to_string(&c).unwrap(); + let parsed: CronConfig = serde_json::from_str(&json).unwrap(); + assert!(!parsed.enabled); + assert_eq!(parsed.max_run_history, 100); + } + + #[test] + fn config_defaults_cron_when_section_missing() { + let toml_str = r#" +workspace_dir = "/tmp/workspace" +config_path = "/tmp/config.toml" +default_temperature = 0.7 +"#; + + let parsed: Config = toml::from_str(toml_str).unwrap(); + assert!(parsed.cron.enabled); + assert_eq!(parsed.cron.max_run_history, 50); + } + #[test] fn memory_config_default_hygiene_settings() { let m = MemoryConfig::default(); @@ -834,11 +2273,13 @@ mod tests { workspace_dir: PathBuf::from("/tmp/test/workspace"), config_path: PathBuf::from("/tmp/test/config.toml"), api_key: Some("sk-test-key".into()), + api_url: None, default_provider: Some("openrouter".into()), default_model: Some("gpt-4o".into()), default_temperature: 0.5, observability: ObservabilityConfig { backend: "log".into(), + ..ObservabilityConfig::default() }, autonomy: AutonomyConfig { level: AutonomyLevel::Full, @@ -847,15 +2288,23 @@ mod tests { forbidden_paths: vec!["/secret".into()], max_actions_per_hour: 50, max_cost_per_day_cents: 1000, + require_approval_for_medium_risk: false, + block_high_risk_commands: true, + auto_approve: vec!["file_read".into()], + always_ask: vec![], }, runtime: RuntimeConfig { kind: "docker".into(), + ..RuntimeConfig::default() }, reliability: ReliabilityConfig::default(), + scheduler: SchedulerConfig::default(), + model_routes: Vec::new(), heartbeat: HeartbeatConfig { enabled: true, interval_minutes: 15, }, + cron: CronConfig::default(), channels_config: ChannelsConfig { cli: true, telegram: Some(TelegramConfig { @@ -864,10 +2313,17 @@ mod tests { }), discord: None, slack: None, + mattermost: None, webhook: None, imessage: None, matrix: None, + signal: None, whatsapp: None, + email: None, + irc: None, + lark: None, + dingtalk: None, + qq: None, }, memory: MemoryConfig::default(), tunnel: TunnelConfig::default(), @@ -875,7 +2331,13 @@ mod tests { composio: ComposioConfig::default(), secrets: SecretsConfig::default(), browser: BrowserConfig::default(), + http_request: HttpRequestConfig::default(), + agent: AgentConfig::default(), identity: IdentityConfig::default(), + cost: CostConfig::default(), + peripherals: PeripheralsConfig::default(), + agents: HashMap::new(), + hardware: HardwareConfig::default(), }; let toml_str = toml::to_string_pretty(&config).unwrap(); @@ -919,6 +2381,35 @@ default_temperature = 0.7 assert_eq!(parsed.memory.conversation_retention_days, 30); } + #[test] + fn agent_config_defaults() { + let cfg = AgentConfig::default(); + assert!(!cfg.compact_context); + assert_eq!(cfg.max_tool_iterations, 10); + assert_eq!(cfg.max_history_messages, 50); + assert!(!cfg.parallel_tools); + assert_eq!(cfg.tool_dispatcher, "auto"); + } + + #[test] + fn agent_config_deserializes() { + let raw = r#" +default_temperature = 0.7 +[agent] +compact_context = true +max_tool_iterations = 20 +max_history_messages = 80 +parallel_tools = true +tool_dispatcher = "xml" +"#; + let parsed: Config = toml::from_str(raw).unwrap(); + assert!(parsed.agent.compact_context); + assert_eq!(parsed.agent.max_tool_iterations, 20); + assert_eq!(parsed.agent.max_history_messages, 80); + assert!(parsed.agent.parallel_tools); + assert_eq!(parsed.agent.tool_dispatcher, "xml"); + } + #[test] fn config_save_and_load_tmpdir() { let dir = std::env::temp_dir().join("zeroclaw_test_config"); @@ -930,6 +2421,7 @@ default_temperature = 0.7 workspace_dir: dir.join("workspace"), config_path: config_path.clone(), api_key: Some("sk-roundtrip".into()), + api_url: None, default_provider: Some("openrouter".into()), default_model: Some("test-model".into()), default_temperature: 0.9, @@ -937,7 +2429,10 @@ default_temperature = 0.7 autonomy: AutonomyConfig::default(), runtime: RuntimeConfig::default(), reliability: ReliabilityConfig::default(), + scheduler: SchedulerConfig::default(), + model_routes: Vec::new(), heartbeat: HeartbeatConfig::default(), + cron: CronConfig::default(), channels_config: ChannelsConfig::default(), memory: MemoryConfig::default(), tunnel: TunnelConfig::default(), @@ -945,7 +2440,13 @@ default_temperature = 0.7 composio: ComposioConfig::default(), secrets: SecretsConfig::default(), browser: BrowserConfig::default(), + http_request: HttpRequestConfig::default(), + agent: AgentConfig::default(), identity: IdentityConfig::default(), + cost: CostConfig::default(), + peripherals: PeripheralsConfig::default(), + agents: HashMap::new(), + hardware: HardwareConfig::default(), }; config.save().unwrap(); @@ -953,13 +2454,113 @@ default_temperature = 0.7 let contents = fs::read_to_string(&config_path).unwrap(); let loaded: Config = toml::from_str(&contents).unwrap(); - assert_eq!(loaded.api_key.as_deref(), Some("sk-roundtrip")); + assert!(loaded + .api_key + .as_deref() + .is_some_and(crate::security::SecretStore::is_encrypted)); + let store = crate::security::SecretStore::new(&dir, true); + let decrypted = store.decrypt(loaded.api_key.as_deref().unwrap()).unwrap(); + assert_eq!(decrypted, "sk-roundtrip"); assert_eq!(loaded.default_model.as_deref(), Some("test-model")); assert!((loaded.default_temperature - 0.9).abs() < f64::EPSILON); let _ = fs::remove_dir_all(&dir); } + #[test] + fn config_save_encrypts_nested_credentials() { + let dir = std::env::temp_dir().join(format!( + "zeroclaw_test_nested_credentials_{}", + uuid::Uuid::new_v4() + )); + fs::create_dir_all(&dir).unwrap(); + + let mut config = Config::default(); + config.workspace_dir = dir.join("workspace"); + config.config_path = dir.join("config.toml"); + config.api_key = Some("root-credential".into()); + config.composio.api_key = Some("composio-credential".into()); + config.browser.computer_use.api_key = Some("browser-credential".into()); + + config.agents.insert( + "worker".into(), + DelegateAgentConfig { + provider: "openrouter".into(), + model: "model-test".into(), + system_prompt: None, + api_key: Some("agent-credential".into()), + temperature: None, + max_depth: 3, + }, + ); + + config.save().unwrap(); + + let contents = fs::read_to_string(config.config_path.clone()).unwrap(); + let stored: Config = toml::from_str(&contents).unwrap(); + let store = crate::security::SecretStore::new(&dir, true); + + let root_encrypted = stored.api_key.as_deref().unwrap(); + assert!(crate::security::SecretStore::is_encrypted(root_encrypted)); + assert_eq!(store.decrypt(root_encrypted).unwrap(), "root-credential"); + + let composio_encrypted = stored.composio.api_key.as_deref().unwrap(); + assert!(crate::security::SecretStore::is_encrypted( + composio_encrypted + )); + assert_eq!( + store.decrypt(composio_encrypted).unwrap(), + "composio-credential" + ); + + let browser_encrypted = stored.browser.computer_use.api_key.as_deref().unwrap(); + assert!(crate::security::SecretStore::is_encrypted( + browser_encrypted + )); + assert_eq!( + store.decrypt(browser_encrypted).unwrap(), + "browser-credential" + ); + + let worker = stored.agents.get("worker").unwrap(); + let worker_encrypted = worker.api_key.as_deref().unwrap(); + assert!(crate::security::SecretStore::is_encrypted(worker_encrypted)); + assert_eq!(store.decrypt(worker_encrypted).unwrap(), "agent-credential"); + + let _ = fs::remove_dir_all(&dir); + } + + #[test] + fn config_save_atomic_cleanup() { + let dir = + std::env::temp_dir().join(format!("zeroclaw_test_config_{}", uuid::Uuid::new_v4())); + fs::create_dir_all(&dir).unwrap(); + + let config_path = dir.join("config.toml"); + let mut config = Config::default(); + config.workspace_dir = dir.join("workspace"); + config.config_path = config_path.clone(); + config.default_model = Some("model-a".into()); + + config.save().unwrap(); + assert!(config_path.exists()); + + config.default_model = Some("model-b".into()); + config.save().unwrap(); + + let contents = fs::read_to_string(&config_path).unwrap(); + assert!(contents.contains("model-b")); + + let names: Vec = fs::read_dir(&dir) + .unwrap() + .map(|entry| entry.unwrap().file_name().to_string_lossy().to_string()) + .collect(); + assert!(!names.iter().any(|name| name.contains(".tmp-"))); + assert!(!names.iter().any(|name| name.ends_with(".bak"))); + + let _ = fs::remove_dir_all(&dir); + } + // ── Telegram / Discord config ──────────────────────────── #[test] @@ -980,6 +2581,8 @@ default_temperature = 0.7 bot_token: "discord-token".into(), guild_id: Some("12345".into()), allowed_users: vec![], + listen_to_bots: false, + mention_only: false, }; let json = serde_json::to_string(&dc).unwrap(); let parsed: DiscordConfig = serde_json::from_str(&json).unwrap(); @@ -993,6 +2596,8 @@ default_temperature = 0.7 bot_token: "tok".into(), guild_id: None, allowed_users: vec![], + listen_to_bots: false, + mention_only: false, }; let json = serde_json::to_string(&dc).unwrap(); let parsed: DiscordConfig = serde_json::from_str(&json).unwrap(); @@ -1062,6 +2667,54 @@ default_temperature = 0.7 assert_eq!(parsed.allowed_users.len(), 2); } + #[test] + fn signal_config_serde() { + let sc = SignalConfig { + http_url: "http://127.0.0.1:8686".into(), + account: "+1234567890".into(), + group_id: Some("group123".into()), + allowed_from: vec!["+1111111111".into()], + ignore_attachments: true, + ignore_stories: false, + }; + let json = serde_json::to_string(&sc).unwrap(); + let parsed: SignalConfig = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.http_url, "http://127.0.0.1:8686"); + assert_eq!(parsed.account, "+1234567890"); + assert_eq!(parsed.group_id.as_deref(), Some("group123")); + assert_eq!(parsed.allowed_from.len(), 1); + assert!(parsed.ignore_attachments); + assert!(!parsed.ignore_stories); + } + + #[test] + fn signal_config_toml_roundtrip() { + let sc = SignalConfig { + http_url: "http://localhost:8080".into(), + account: "+9876543210".into(), + group_id: None, + allowed_from: vec!["*".into()], + ignore_attachments: false, + ignore_stories: true, + }; + let toml_str = toml::to_string(&sc).unwrap(); + let parsed: SignalConfig = toml::from_str(&toml_str).unwrap(); + assert_eq!(parsed.http_url, "http://localhost:8080"); + assert_eq!(parsed.account, "+9876543210"); + assert!(parsed.group_id.is_none()); + assert!(parsed.ignore_stories); + } + + #[test] + fn signal_config_defaults() { + let json = r#"{"http_url":"http://127.0.0.1:8686","account":"+1234567890"}"#; + let parsed: SignalConfig = serde_json::from_str(json).unwrap(); + assert!(parsed.group_id.is_none()); + assert!(parsed.allowed_from.is_empty()); + assert!(!parsed.ignore_attachments); + assert!(!parsed.ignore_stories); + } + #[test] fn channels_config_with_imessage_and_matrix() { let c = ChannelsConfig { @@ -1069,6 +2722,7 @@ default_temperature = 0.7 telegram: None, discord: None, slack: None, + mattermost: None, webhook: None, imessage: Some(IMessageConfig { allowed_contacts: vec!["+1".into()], @@ -1079,7 +2733,13 @@ default_temperature = 0.7 room_id: "!r:m".into(), allowed_users: vec!["@u:m".into()], }), + signal: None, whatsapp: None, + email: None, + irc: None, + lark: None, + dingtalk: None, + qq: None, }; let toml_str = toml::to_string_pretty(&c).unwrap(); let parsed: ChannelsConfig = toml::from_str(&toml_str).unwrap(); @@ -1172,6 +2832,7 @@ channel_id = "C123" access_token: "EAABx...".into(), phone_number_id: "123456789".into(), verify_token: "my-verify-token".into(), + app_secret: None, allowed_numbers: vec!["+1234567890".into(), "+9876543210".into()], }; let json = serde_json::to_string(&wc).unwrap(); @@ -1188,6 +2849,7 @@ channel_id = "C123" access_token: "tok".into(), phone_number_id: "12345".into(), verify_token: "verify".into(), + app_secret: Some("secret123".into()), allowed_numbers: vec!["+1".into()], }; let toml_str = toml::to_string(&wc).unwrap(); @@ -1209,6 +2871,7 @@ channel_id = "C123" access_token: "tok".into(), phone_number_id: "123".into(), verify_token: "ver".into(), + app_secret: None, allowed_numbers: vec!["*".into()], }; let toml_str = toml::to_string(&wc).unwrap(); @@ -1223,15 +2886,23 @@ channel_id = "C123" telegram: None, discord: None, slack: None, + mattermost: None, webhook: None, imessage: None, matrix: None, + signal: None, whatsapp: Some(WhatsAppConfig { access_token: "tok".into(), phone_number_id: "123".into(), verify_token: "ver".into(), + app_secret: None, allowed_numbers: vec!["+1".into()], }), + email: None, + irc: None, + lark: None, + dingtalk: None, + qq: None, }; let toml_str = toml::to_string_pretty(&c).unwrap(); let parsed: ChannelsConfig = toml::from_str(&toml_str).unwrap(); @@ -1273,6 +2944,9 @@ channel_id = "C123" g.paired_tokens.is_empty(), "No pre-paired tokens by default" ); + assert_eq!(g.pair_rate_limit_per_minute, 10); + assert_eq!(g.webhook_rate_limit_per_minute, 60); + assert_eq!(g.idempotency_ttl_secs, 300); } #[test] @@ -1298,12 +2972,18 @@ channel_id = "C123" require_pairing: true, allow_public_bind: false, paired_tokens: vec!["zc_test_token".into()], + pair_rate_limit_per_minute: 12, + webhook_rate_limit_per_minute: 80, + idempotency_ttl_secs: 600, }; let toml_str = toml::to_string(&g).unwrap(); let parsed: GatewayConfig = toml::from_str(&toml_str).unwrap(); assert!(parsed.require_pairing); assert!(!parsed.allow_public_bind); assert_eq!(parsed.paired_tokens, vec!["zc_test_token"]); + assert_eq!(parsed.pair_rate_limit_per_minute, 12); + assert_eq!(parsed.webhook_rate_limit_per_minute, 80); + assert_eq!(parsed.idempotency_ttl_secs, 600); } #[test] @@ -1442,6 +3122,16 @@ default_temperature = 0.7 let b = BrowserConfig::default(); assert!(!b.enabled); assert!(b.allowed_domains.is_empty()); + assert_eq!(b.backend, "agent_browser"); + assert!(b.native_headless); + assert_eq!(b.native_webdriver_url, "http://127.0.0.1:9515"); + assert!(b.native_chrome_path.is_none()); + assert_eq!(b.computer_use.endpoint, "http://127.0.0.1:8787/v1/actions"); + assert_eq!(b.computer_use.timeout_ms, 15_000); + assert!(!b.computer_use.allow_remote_endpoint); + assert!(b.computer_use.window_allowlist.is_empty()); + assert!(b.computer_use.max_coordinate_x.is_none()); + assert!(b.computer_use.max_coordinate_y.is_none()); } #[test] @@ -1450,12 +3140,42 @@ default_temperature = 0.7 enabled: true, allowed_domains: vec!["example.com".into(), "docs.example.com".into()], session_name: None, + backend: "auto".into(), + native_headless: false, + native_webdriver_url: "http://localhost:4444".into(), + native_chrome_path: Some("/usr/bin/chromium".into()), + computer_use: BrowserComputerUseConfig { + endpoint: "https://computer-use.example.com/v1/actions".into(), + api_key: Some("test-token".into()), + timeout_ms: 8_000, + allow_remote_endpoint: true, + window_allowlist: vec!["Chrome".into(), "Visual Studio Code".into()], + max_coordinate_x: Some(3840), + max_coordinate_y: Some(2160), + }, }; let toml_str = toml::to_string(&b).unwrap(); let parsed: BrowserConfig = toml::from_str(&toml_str).unwrap(); assert!(parsed.enabled); assert_eq!(parsed.allowed_domains.len(), 2); assert_eq!(parsed.allowed_domains[0], "example.com"); + assert_eq!(parsed.backend, "auto"); + assert!(!parsed.native_headless); + assert_eq!(parsed.native_webdriver_url, "http://localhost:4444"); + assert_eq!( + parsed.native_chrome_path.as_deref(), + Some("/usr/bin/chromium") + ); + assert_eq!( + parsed.computer_use.endpoint, + "https://computer-use.example.com/v1/actions" + ); + assert_eq!(parsed.computer_use.api_key.as_deref(), Some("test-token")); + assert_eq!(parsed.computer_use.timeout_ms, 8_000); + assert!(parsed.computer_use.allow_remote_endpoint); + assert_eq!(parsed.computer_use.window_allowlist.len(), 2); + assert_eq!(parsed.computer_use.max_coordinate_x, Some(3840)); + assert_eq!(parsed.computer_use.max_coordinate_y, Some(2160)); } #[test] @@ -1472,157 +3192,412 @@ default_temperature = 0.7 // ── Environment variable overrides (Docker support) ───────── + fn env_override_test_guard() -> std::sync::MutexGuard<'static, ()> { + static ENV_OVERRIDE_TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(()); + ENV_OVERRIDE_TEST_LOCK + .lock() + .expect("env override test lock poisoned") + } + #[test] fn env_override_api_key() { - // Primary and fallback tested together to avoid env-var races. - std::env::remove_var("ZEROCLAW_API_KEY"); - std::env::remove_var("API_KEY"); - - // Primary: ZEROCLAW_API_KEY + let _env_guard = env_override_test_guard(); let mut config = Config::default(); assert!(config.api_key.is_none()); + std::env::set_var("ZEROCLAW_API_KEY", "sk-test-env-key"); config.apply_env_overrides(); assert_eq!(config.api_key.as_deref(), Some("sk-test-env-key")); - std::env::remove_var("ZEROCLAW_API_KEY"); - // Fallback: API_KEY - let mut config2 = Config::default(); + std::env::remove_var("ZEROCLAW_API_KEY"); + } + + #[test] + fn env_override_api_key_fallback() { + let _env_guard = env_override_test_guard(); + let mut config = Config::default(); + + std::env::remove_var("ZEROCLAW_API_KEY"); std::env::set_var("API_KEY", "sk-fallback-key"); - config2.apply_env_overrides(); - assert_eq!(config2.api_key.as_deref(), Some("sk-fallback-key")); + config.apply_env_overrides(); + assert_eq!(config.api_key.as_deref(), Some("sk-fallback-key")); + std::env::remove_var("API_KEY"); } #[test] fn env_override_provider() { - // Primary, fallback, and empty-value tested together to avoid env-var races. - std::env::remove_var("ZEROCLAW_PROVIDER"); - std::env::remove_var("PROVIDER"); - - // Primary: ZEROCLAW_PROVIDER + let _env_guard = env_override_test_guard(); let mut config = Config::default(); + std::env::set_var("ZEROCLAW_PROVIDER", "anthropic"); config.apply_env_overrides(); assert_eq!(config.default_provider.as_deref(), Some("anthropic")); - std::env::remove_var("ZEROCLAW_PROVIDER"); - // Fallback: PROVIDER - let mut config2 = Config::default(); - std::env::set_var("PROVIDER", "openai"); - config2.apply_env_overrides(); - assert_eq!(config2.default_provider.as_deref(), Some("openai")); - std::env::remove_var("PROVIDER"); - - // Empty value should not override - let mut config3 = Config::default(); - let original_provider = config3.default_provider.clone(); - std::env::set_var("ZEROCLAW_PROVIDER", ""); - config3.apply_env_overrides(); - assert_eq!(config3.default_provider, original_provider); std::env::remove_var("ZEROCLAW_PROVIDER"); } + #[test] + fn env_override_provider_fallback() { + let _env_guard = env_override_test_guard(); + let mut config = Config::default(); + + std::env::remove_var("ZEROCLAW_PROVIDER"); + std::env::set_var("PROVIDER", "openai"); + config.apply_env_overrides(); + assert_eq!(config.default_provider.as_deref(), Some("openai")); + + std::env::remove_var("PROVIDER"); + } + + #[test] + fn env_override_glm_api_key_for_regional_aliases() { + let _env_guard = env_override_test_guard(); + let mut config = Config { + default_provider: Some("glm-cn".to_string()), + ..Config::default() + }; + + std::env::set_var("GLM_API_KEY", "glm-regional-key"); + config.apply_env_overrides(); + assert_eq!(config.api_key.as_deref(), Some("glm-regional-key")); + + std::env::remove_var("GLM_API_KEY"); + } + + #[test] + fn env_override_zai_api_key_for_regional_aliases() { + let _env_guard = env_override_test_guard(); + let mut config = Config { + default_provider: Some("zai-cn".to_string()), + ..Config::default() + }; + + std::env::set_var("ZAI_API_KEY", "zai-regional-key"); + config.apply_env_overrides(); + assert_eq!(config.api_key.as_deref(), Some("zai-regional-key")); + + std::env::remove_var("ZAI_API_KEY"); + } + #[test] fn env_override_model() { + let _env_guard = env_override_test_guard(); let mut config = Config::default(); std::env::set_var("ZEROCLAW_MODEL", "gpt-4o"); config.apply_env_overrides(); assert_eq!(config.default_model.as_deref(), Some("gpt-4o")); - // Clean up std::env::remove_var("ZEROCLAW_MODEL"); } #[test] fn env_override_workspace() { + let _env_guard = env_override_test_guard(); let mut config = Config::default(); std::env::set_var("ZEROCLAW_WORKSPACE", "/custom/workspace"); config.apply_env_overrides(); assert_eq!(config.workspace_dir, PathBuf::from("/custom/workspace")); - // Clean up std::env::remove_var("ZEROCLAW_WORKSPACE"); } #[test] - fn env_override_gateway_port() { - // Port, fallback, and invalid tested together to avoid env-var races. - std::env::remove_var("ZEROCLAW_GATEWAY_PORT"); - std::env::remove_var("PORT"); + fn load_or_init_workspace_override_uses_workspace_root_for_config() { + let _env_guard = env_override_test_guard(); + let temp_home = + std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4())); + let workspace_dir = temp_home.join("profile-a"); - // Primary: ZEROCLAW_GATEWAY_PORT + let original_home = std::env::var("HOME").ok(); + std::env::set_var("HOME", &temp_home); + std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir); + + let config = Config::load_or_init().unwrap(); + + assert_eq!(config.workspace_dir, workspace_dir); + assert_eq!(config.config_path, workspace_dir.join("config.toml")); + assert!(workspace_dir.join("config.toml").exists()); + + std::env::remove_var("ZEROCLAW_WORKSPACE"); + if let Some(home) = original_home { + std::env::set_var("HOME", home); + } else { + std::env::remove_var("HOME"); + } + let _ = fs::remove_dir_all(temp_home); + } + + #[test] + fn load_or_init_workspace_suffix_uses_legacy_config_layout() { + let _env_guard = env_override_test_guard(); + let temp_home = + std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4())); + let workspace_dir = temp_home.join("workspace"); + let legacy_config_path = temp_home.join(".zeroclaw").join("config.toml"); + + let original_home = std::env::var("HOME").ok(); + std::env::set_var("HOME", &temp_home); + std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir); + + let config = Config::load_or_init().unwrap(); + + assert_eq!(config.workspace_dir, workspace_dir); + assert_eq!(config.config_path, legacy_config_path); + assert!(config.config_path.exists()); + + std::env::remove_var("ZEROCLAW_WORKSPACE"); + if let Some(home) = original_home { + std::env::set_var("HOME", home); + } else { + std::env::remove_var("HOME"); + } + let _ = fs::remove_dir_all(temp_home); + } + + #[test] + fn load_or_init_workspace_override_keeps_existing_legacy_config() { + let _env_guard = env_override_test_guard(); + let temp_home = + std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4())); + let workspace_dir = temp_home.join("custom-workspace"); + let legacy_config_dir = temp_home.join(".zeroclaw"); + let legacy_config_path = legacy_config_dir.join("config.toml"); + + fs::create_dir_all(&legacy_config_dir).unwrap(); + fs::write( + &legacy_config_path, + r#"default_temperature = 0.7 +default_model = "legacy-model" +"#, + ) + .unwrap(); + + let original_home = std::env::var("HOME").ok(); + std::env::set_var("HOME", &temp_home); + std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir); + + let config = Config::load_or_init().unwrap(); + + assert_eq!(config.workspace_dir, workspace_dir); + assert_eq!(config.config_path, legacy_config_path); + assert_eq!(config.default_model.as_deref(), Some("legacy-model")); + + std::env::remove_var("ZEROCLAW_WORKSPACE"); + if let Some(home) = original_home { + std::env::set_var("HOME", home); + } else { + std::env::remove_var("HOME"); + } + let _ = fs::remove_dir_all(temp_home); + } + + #[test] + fn load_or_init_uses_persisted_active_workspace_marker() { + let _env_guard = env_override_test_guard(); + let temp_home = + std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4())); + let custom_config_dir = temp_home.join("profiles").join("agent-alpha"); + + fs::create_dir_all(&custom_config_dir).unwrap(); + fs::write( + custom_config_dir.join("config.toml"), + "default_temperature = 0.7\ndefault_model = \"persisted-profile\"\n", + ) + .unwrap(); + + let original_home = std::env::var("HOME").ok(); + std::env::set_var("HOME", &temp_home); + std::env::remove_var("ZEROCLAW_WORKSPACE"); + + persist_active_workspace_config_dir(&custom_config_dir).unwrap(); + + let config = Config::load_or_init().unwrap(); + + assert_eq!(config.config_path, custom_config_dir.join("config.toml")); + assert_eq!(config.workspace_dir, custom_config_dir.join("workspace")); + assert_eq!(config.default_model.as_deref(), Some("persisted-profile")); + + if let Some(home) = original_home { + std::env::set_var("HOME", home); + } else { + std::env::remove_var("HOME"); + } + let _ = fs::remove_dir_all(temp_home); + } + + #[test] + fn load_or_init_env_workspace_override_takes_priority_over_marker() { + let _env_guard = env_override_test_guard(); + let temp_home = + std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4())); + let marker_config_dir = temp_home.join("profiles").join("persisted-profile"); + let env_workspace_dir = temp_home.join("env-workspace"); + + fs::create_dir_all(&marker_config_dir).unwrap(); + fs::write( + marker_config_dir.join("config.toml"), + "default_temperature = 0.7\ndefault_model = \"marker-model\"\n", + ) + .unwrap(); + + let original_home = std::env::var("HOME").ok(); + std::env::set_var("HOME", &temp_home); + persist_active_workspace_config_dir(&marker_config_dir).unwrap(); + std::env::set_var("ZEROCLAW_WORKSPACE", &env_workspace_dir); + + let config = Config::load_or_init().unwrap(); + + assert_eq!(config.workspace_dir, env_workspace_dir); + assert_eq!(config.config_path, env_workspace_dir.join("config.toml")); + + std::env::remove_var("ZEROCLAW_WORKSPACE"); + if let Some(home) = original_home { + std::env::set_var("HOME", home); + } else { + std::env::remove_var("HOME"); + } + let _ = fs::remove_dir_all(temp_home); + } + + #[test] + fn persist_active_workspace_marker_is_cleared_for_default_config_dir() { + let _env_guard = env_override_test_guard(); + let temp_home = + std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4())); + let default_config_dir = temp_home.join(".zeroclaw"); + let custom_config_dir = temp_home.join("profiles").join("custom-profile"); + let marker_path = default_config_dir.join(ACTIVE_WORKSPACE_STATE_FILE); + + let original_home = std::env::var("HOME").ok(); + std::env::set_var("HOME", &temp_home); + + persist_active_workspace_config_dir(&custom_config_dir).unwrap(); + assert!(marker_path.exists()); + + persist_active_workspace_config_dir(&default_config_dir).unwrap(); + assert!(!marker_path.exists()); + + if let Some(home) = original_home { + std::env::set_var("HOME", home); + } else { + std::env::remove_var("HOME"); + } + let _ = fs::remove_dir_all(temp_home); + } + + #[test] + fn env_override_empty_values_ignored() { + let _env_guard = env_override_test_guard(); + let mut config = Config::default(); + let original_provider = config.default_provider.clone(); + + std::env::set_var("ZEROCLAW_PROVIDER", ""); + config.apply_env_overrides(); + assert_eq!(config.default_provider, original_provider); + + std::env::remove_var("ZEROCLAW_PROVIDER"); + } + + #[test] + fn env_override_gateway_port() { + let _env_guard = env_override_test_guard(); let mut config = Config::default(); assert_eq!(config.gateway.port, 3000); + std::env::set_var("ZEROCLAW_GATEWAY_PORT", "8080"); config.apply_env_overrides(); assert_eq!(config.gateway.port, 8080); + std::env::remove_var("ZEROCLAW_GATEWAY_PORT"); + } - // Fallback: PORT - let mut config2 = Config::default(); + #[test] + fn env_override_port_fallback() { + let _env_guard = env_override_test_guard(); + let mut config = Config::default(); + + std::env::remove_var("ZEROCLAW_GATEWAY_PORT"); std::env::set_var("PORT", "9000"); - config2.apply_env_overrides(); - assert_eq!(config2.gateway.port, 9000); - - // Invalid PORT is ignored - let mut config3 = Config::default(); - let original_port = config3.gateway.port; - std::env::set_var("PORT", "not_a_number"); - config3.apply_env_overrides(); - assert_eq!(config3.gateway.port, original_port); + config.apply_env_overrides(); + assert_eq!(config.gateway.port, 9000); std::env::remove_var("PORT"); } #[test] fn env_override_gateway_host() { - // Primary and fallback tested together to avoid env-var races. - std::env::remove_var("ZEROCLAW_GATEWAY_HOST"); - std::env::remove_var("HOST"); - - // Primary: ZEROCLAW_GATEWAY_HOST + let _env_guard = env_override_test_guard(); let mut config = Config::default(); assert_eq!(config.gateway.host, "127.0.0.1"); + std::env::set_var("ZEROCLAW_GATEWAY_HOST", "0.0.0.0"); config.apply_env_overrides(); assert_eq!(config.gateway.host, "0.0.0.0"); - std::env::remove_var("ZEROCLAW_GATEWAY_HOST"); - // Fallback: HOST - let mut config2 = Config::default(); + std::env::remove_var("ZEROCLAW_GATEWAY_HOST"); + } + + #[test] + fn env_override_host_fallback() { + let _env_guard = env_override_test_guard(); + let mut config = Config::default(); + + std::env::remove_var("ZEROCLAW_GATEWAY_HOST"); std::env::set_var("HOST", "0.0.0.0"); - config2.apply_env_overrides(); - assert_eq!(config2.gateway.host, "0.0.0.0"); + config.apply_env_overrides(); + assert_eq!(config.gateway.host, "0.0.0.0"); + std::env::remove_var("HOST"); } #[test] fn env_override_temperature() { - // Valid and out-of-range tested together to avoid env-var races. - std::env::remove_var("ZEROCLAW_TEMPERATURE"); - - // Valid temperature is applied + let _env_guard = env_override_test_guard(); let mut config = Config::default(); + std::env::set_var("ZEROCLAW_TEMPERATURE", "0.5"); config.apply_env_overrides(); assert!((config.default_temperature - 0.5).abs() < f64::EPSILON); - // Out-of-range temperature is ignored - let mut config2 = Config::default(); - let original_temp = config2.default_temperature; + std::env::remove_var("ZEROCLAW_TEMPERATURE"); + } + + #[test] + fn env_override_temperature_out_of_range_ignored() { + let _env_guard = env_override_test_guard(); + // Clean up any leftover env vars from other tests + std::env::remove_var("ZEROCLAW_TEMPERATURE"); + + let mut config = Config::default(); + let original_temp = config.default_temperature; + + // Temperature > 2.0 should be ignored std::env::set_var("ZEROCLAW_TEMPERATURE", "3.0"); - config2.apply_env_overrides(); + config.apply_env_overrides(); assert!( - (config2.default_temperature - original_temp).abs() < f64::EPSILON, + (config.default_temperature - original_temp).abs() < f64::EPSILON, "Temperature 3.0 should be ignored (out of range)" ); std::env::remove_var("ZEROCLAW_TEMPERATURE"); } + #[test] + fn env_override_invalid_port_ignored() { + let _env_guard = env_override_test_guard(); + let mut config = Config::default(); + let original_port = config.gateway.port; + + std::env::set_var("PORT", "not_a_number"); + config.apply_env_overrides(); + assert_eq!(config.gateway.port, original_port); + + std::env::remove_var("PORT"); + } + #[test] fn gateway_config_default_values() { let g = GatewayConfig::default(); @@ -1632,4 +3607,156 @@ default_temperature = 0.7 assert!(!g.allow_public_bind); assert!(g.paired_tokens.is_empty()); } + + // ── Peripherals config ─────────────────────────────────────── + + #[test] + fn peripherals_config_default_disabled() { + let p = PeripheralsConfig::default(); + assert!(!p.enabled); + assert!(p.boards.is_empty()); + } + + #[test] + fn peripheral_board_config_defaults() { + let b = PeripheralBoardConfig::default(); + assert!(b.board.is_empty()); + assert_eq!(b.transport, "serial"); + assert!(b.path.is_none()); + assert_eq!(b.baud, 115_200); + } + + #[test] + fn peripherals_config_toml_roundtrip() { + let p = PeripheralsConfig { + enabled: true, + boards: vec![PeripheralBoardConfig { + board: "nucleo-f401re".into(), + transport: "serial".into(), + path: Some("/dev/ttyACM0".into()), + baud: 115_200, + }], + datasheet_dir: None, + }; + let toml_str = toml::to_string(&p).unwrap(); + let parsed: PeripheralsConfig = toml::from_str(&toml_str).unwrap(); + assert!(parsed.enabled); + assert_eq!(parsed.boards.len(), 1); + assert_eq!(parsed.boards[0].board, "nucleo-f401re"); + assert_eq!(parsed.boards[0].path.as_deref(), Some("/dev/ttyACM0")); + } + + #[test] + fn lark_config_serde() { + let lc = LarkConfig { + app_id: "cli_123456".into(), + app_secret: "secret_abc".into(), + encrypt_key: Some("encrypt_key".into()), + verification_token: Some("verify_token".into()), + allowed_users: vec!["user_123".into(), "user_456".into()], + use_feishu: true, + receive_mode: LarkReceiveMode::Websocket, + port: None, + }; + let json = serde_json::to_string(&lc).unwrap(); + let parsed: LarkConfig = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.app_id, "cli_123456"); + assert_eq!(parsed.app_secret, "secret_abc"); + assert_eq!(parsed.encrypt_key.as_deref(), Some("encrypt_key")); + assert_eq!(parsed.verification_token.as_deref(), Some("verify_token")); + assert_eq!(parsed.allowed_users.len(), 2); + assert!(parsed.use_feishu); + } + + #[test] + fn lark_config_toml_roundtrip() { + let lc = LarkConfig { + app_id: "cli_123456".into(), + app_secret: "secret_abc".into(), + encrypt_key: Some("encrypt_key".into()), + verification_token: Some("verify_token".into()), + allowed_users: vec!["*".into()], + use_feishu: false, + receive_mode: LarkReceiveMode::Webhook, + port: Some(9898), + }; + let toml_str = toml::to_string(&lc).unwrap(); + let parsed: LarkConfig = toml::from_str(&toml_str).unwrap(); + assert_eq!(parsed.app_id, "cli_123456"); + assert_eq!(parsed.app_secret, "secret_abc"); + assert!(!parsed.use_feishu); + } + + #[test] + fn lark_config_deserializes_without_optional_fields() { + let json = r#"{"app_id":"cli_123","app_secret":"secret"}"#; + let parsed: LarkConfig = serde_json::from_str(json).unwrap(); + assert!(parsed.encrypt_key.is_none()); + assert!(parsed.verification_token.is_none()); + assert!(parsed.allowed_users.is_empty()); + assert!(!parsed.use_feishu); + } + + #[test] + fn lark_config_defaults_to_lark_endpoint() { + let json = r#"{"app_id":"cli_123","app_secret":"secret"}"#; + let parsed: LarkConfig = serde_json::from_str(json).unwrap(); + assert!( + !parsed.use_feishu, + "use_feishu should default to false (Lark)" + ); + } + + #[test] + fn lark_config_with_wildcard_allowed_users() { + let json = r#"{"app_id":"cli_123","app_secret":"secret","allowed_users":["*"]}"#; + let parsed: LarkConfig = serde_json::from_str(json).unwrap(); + assert_eq!(parsed.allowed_users, vec!["*"]); + } + + // ── Config file permission hardening (Unix only) ─────────────── + + #[cfg(unix)] + #[test] + fn new_config_file_has_restricted_permissions() { + use std::os::unix::fs::PermissionsExt; + + let tmp = tempfile::TempDir::new().unwrap(); + let config_path = tmp.path().join("config.toml"); + + // Create a config and save it + let mut config = Config::default(); + config.config_path = config_path.clone(); + config.save().unwrap(); + + // Apply the same permission logic as load_or_init + let _ = std::fs::set_permissions(&config_path, std::fs::Permissions::from_mode(0o600)); + + let meta = std::fs::metadata(&config_path).unwrap(); + let mode = meta.permissions().mode() & 0o777; + assert_eq!( + mode, 0o600, + "New config file should be owner-only (0600), got {mode:o}" + ); + } + + #[cfg(unix)] + #[test] + fn world_readable_config_is_detectable() { + use std::os::unix::fs::PermissionsExt; + + let tmp = tempfile::TempDir::new().unwrap(); + let config_path = tmp.path().join("config.toml"); + + // Create a config file with intentionally loose permissions + std::fs::write(&config_path, "# test config").unwrap(); + std::fs::set_permissions(&config_path, std::fs::Permissions::from_mode(0o644)).unwrap(); + + let meta = std::fs::metadata(&config_path).unwrap(); + let mode = meta.permissions().mode(); + assert!( + mode & 0o004 != 0, + "Test setup: file should be world-readable (mode {mode:o})" + ); + } } diff --git a/src/cost/mod.rs b/src/cost/mod.rs new file mode 100644 index 0000000..14c634d --- /dev/null +++ b/src/cost/mod.rs @@ -0,0 +1,5 @@ +pub mod tracker; +pub mod types; + +pub use tracker::CostTracker; +pub use types::{BudgetCheck, CostRecord, CostSummary, ModelStats, TokenUsage, UsagePeriod}; diff --git a/src/cost/tracker.rs b/src/cost/tracker.rs new file mode 100644 index 0000000..1905b36 --- /dev/null +++ b/src/cost/tracker.rs @@ -0,0 +1,536 @@ +use super::types::{BudgetCheck, CostRecord, CostSummary, ModelStats, TokenUsage, UsagePeriod}; +use crate::config::schema::CostConfig; +use anyhow::{anyhow, Context, Result}; +use chrono::{Datelike, NaiveDate, Utc}; +use parking_lot::{Mutex, MutexGuard}; +use std::collections::HashMap; +use std::fs::{self, File, OpenOptions}; +use std::io::{BufRead, BufReader, Write}; +use std::path::{Path, PathBuf}; +use std::sync::Arc; + +/// Cost tracker for API usage monitoring and budget enforcement. +pub struct CostTracker { + config: CostConfig, + storage: Arc>, + session_id: String, + session_costs: Arc>>, +} + +impl CostTracker { + /// Create a new cost tracker. + pub fn new(config: CostConfig, workspace_dir: &Path) -> Result { + let storage_path = resolve_storage_path(workspace_dir)?; + + let storage = CostStorage::new(&storage_path).with_context(|| { + format!("Failed to open cost storage at {}", storage_path.display()) + })?; + + Ok(Self { + config, + storage: Arc::new(Mutex::new(storage)), + session_id: uuid::Uuid::new_v4().to_string(), + session_costs: Arc::new(Mutex::new(Vec::new())), + }) + } + + /// Get the session ID. + pub fn session_id(&self) -> &str { + &self.session_id + } + + fn lock_storage(&self) -> MutexGuard<'_, CostStorage> { + self.storage.lock() + } + + fn lock_session_costs(&self) -> MutexGuard<'_, Vec> { + self.session_costs.lock() + } + + /// Check if a request is within budget. + pub fn check_budget(&self, estimated_cost_usd: f64) -> Result { + if !self.config.enabled { + return Ok(BudgetCheck::Allowed); + } + + if !estimated_cost_usd.is_finite() || estimated_cost_usd < 0.0 { + return Err(anyhow!( + "Estimated cost must be a finite, non-negative value" + )); + } + + let mut storage = self.lock_storage(); + let (daily_cost, monthly_cost) = storage.get_aggregated_costs()?; + + // Check daily limit + let projected_daily = daily_cost + estimated_cost_usd; + if projected_daily > self.config.daily_limit_usd { + return Ok(BudgetCheck::Exceeded { + current_usd: daily_cost, + limit_usd: self.config.daily_limit_usd, + period: UsagePeriod::Day, + }); + } + + // Check monthly limit + let projected_monthly = monthly_cost + estimated_cost_usd; + if projected_monthly > self.config.monthly_limit_usd { + return Ok(BudgetCheck::Exceeded { + current_usd: monthly_cost, + limit_usd: self.config.monthly_limit_usd, + period: UsagePeriod::Month, + }); + } + + // Check warning thresholds + let warn_threshold = f64::from(self.config.warn_at_percent.min(100)) / 100.0; + let daily_warn_threshold = self.config.daily_limit_usd * warn_threshold; + let monthly_warn_threshold = self.config.monthly_limit_usd * warn_threshold; + + if projected_daily >= daily_warn_threshold { + return Ok(BudgetCheck::Warning { + current_usd: daily_cost, + limit_usd: self.config.daily_limit_usd, + period: UsagePeriod::Day, + }); + } + + if projected_monthly >= monthly_warn_threshold { + return Ok(BudgetCheck::Warning { + current_usd: monthly_cost, + limit_usd: self.config.monthly_limit_usd, + period: UsagePeriod::Month, + }); + } + + Ok(BudgetCheck::Allowed) + } + + /// Record a usage event. + pub fn record_usage(&self, usage: TokenUsage) -> Result<()> { + if !self.config.enabled { + return Ok(()); + } + + if !usage.cost_usd.is_finite() || usage.cost_usd < 0.0 { + return Err(anyhow!( + "Token usage cost must be a finite, non-negative value" + )); + } + + let record = CostRecord::new(&self.session_id, usage); + + // Persist first for durability guarantees. + { + let mut storage = self.lock_storage(); + storage.add_record(record.clone())?; + } + + // Then update in-memory session snapshot. + let mut session_costs = self.lock_session_costs(); + session_costs.push(record); + + Ok(()) + } + + /// Get the current cost summary. + pub fn get_summary(&self) -> Result { + let (daily_cost, monthly_cost) = { + let mut storage = self.lock_storage(); + storage.get_aggregated_costs()? + }; + + let session_costs = self.lock_session_costs(); + let session_cost: f64 = session_costs + .iter() + .map(|record| record.usage.cost_usd) + .sum(); + let total_tokens: u64 = session_costs + .iter() + .map(|record| record.usage.total_tokens) + .sum(); + let request_count = session_costs.len(); + let by_model = build_session_model_stats(&session_costs); + + Ok(CostSummary { + session_cost_usd: session_cost, + daily_cost_usd: daily_cost, + monthly_cost_usd: monthly_cost, + total_tokens, + request_count, + by_model, + }) + } + + /// Get the daily cost for a specific date. + pub fn get_daily_cost(&self, date: NaiveDate) -> Result { + let storage = self.lock_storage(); + storage.get_cost_for_date(date) + } + + /// Get the monthly cost for a specific month. + pub fn get_monthly_cost(&self, year: i32, month: u32) -> Result { + let storage = self.lock_storage(); + storage.get_cost_for_month(year, month) + } +} + +fn resolve_storage_path(workspace_dir: &Path) -> Result { + let storage_path = workspace_dir.join("state").join("costs.jsonl"); + let legacy_path = workspace_dir.join(".zeroclaw").join("costs.db"); + + if !storage_path.exists() && legacy_path.exists() { + if let Some(parent) = storage_path.parent() { + fs::create_dir_all(parent) + .with_context(|| format!("Failed to create directory {}", parent.display()))?; + } + + if let Err(error) = fs::rename(&legacy_path, &storage_path) { + tracing::warn!( + "Failed to move legacy cost storage from {} to {}: {error}; falling back to copy", + legacy_path.display(), + storage_path.display() + ); + fs::copy(&legacy_path, &storage_path).with_context(|| { + format!( + "Failed to copy legacy cost storage from {} to {}", + legacy_path.display(), + storage_path.display() + ) + })?; + } + } + + Ok(storage_path) +} + +fn build_session_model_stats(session_costs: &[CostRecord]) -> HashMap { + let mut by_model: HashMap = HashMap::new(); + + for record in session_costs { + let entry = by_model + .entry(record.usage.model.clone()) + .or_insert_with(|| ModelStats { + model: record.usage.model.clone(), + cost_usd: 0.0, + total_tokens: 0, + request_count: 0, + }); + + entry.cost_usd += record.usage.cost_usd; + entry.total_tokens += record.usage.total_tokens; + entry.request_count += 1; + } + + by_model +} + +/// Persistent storage for cost records. +struct CostStorage { + path: PathBuf, + daily_cost_usd: f64, + monthly_cost_usd: f64, + cached_day: NaiveDate, + cached_year: i32, + cached_month: u32, +} + +impl CostStorage { + /// Create or open cost storage. + fn new(path: &Path) -> Result { + if let Some(parent) = path.parent() { + fs::create_dir_all(parent) + .with_context(|| format!("Failed to create directory {}", parent.display()))?; + } + + let now = Utc::now(); + let mut storage = Self { + path: path.to_path_buf(), + daily_cost_usd: 0.0, + monthly_cost_usd: 0.0, + cached_day: now.date_naive(), + cached_year: now.year(), + cached_month: now.month(), + }; + + storage.rebuild_aggregates( + storage.cached_day, + storage.cached_year, + storage.cached_month, + )?; + + Ok(storage) + } + + fn for_each_record(&self, mut on_record: F) -> Result<()> + where + F: FnMut(CostRecord), + { + if !self.path.exists() { + return Ok(()); + } + + let file = File::open(&self.path) + .with_context(|| format!("Failed to read cost storage from {}", self.path.display()))?; + let reader = BufReader::new(file); + + for (line_number, line) in reader.lines().enumerate() { + let raw_line = line.with_context(|| { + format!( + "Failed to read line {} from cost storage {}", + line_number + 1, + self.path.display() + ) + })?; + + let trimmed = raw_line.trim(); + if trimmed.is_empty() { + continue; + } + + match serde_json::from_str::(trimmed) { + Ok(record) => on_record(record), + Err(error) => { + tracing::warn!( + "Skipping malformed cost record at {}:{}: {error}", + self.path.display(), + line_number + 1 + ); + } + } + } + + Ok(()) + } + + fn rebuild_aggregates(&mut self, day: NaiveDate, year: i32, month: u32) -> Result<()> { + let mut daily_cost = 0.0; + let mut monthly_cost = 0.0; + + self.for_each_record(|record| { + let timestamp = record.usage.timestamp.naive_utc(); + + if timestamp.date() == day { + daily_cost += record.usage.cost_usd; + } + + if timestamp.year() == year && timestamp.month() == month { + monthly_cost += record.usage.cost_usd; + } + })?; + + self.daily_cost_usd = daily_cost; + self.monthly_cost_usd = monthly_cost; + self.cached_day = day; + self.cached_year = year; + self.cached_month = month; + + Ok(()) + } + + fn ensure_period_cache_current(&mut self) -> Result<()> { + let now = Utc::now(); + let day = now.date_naive(); + let year = now.year(); + let month = now.month(); + + if day != self.cached_day || year != self.cached_year || month != self.cached_month { + self.rebuild_aggregates(day, year, month)?; + } + + Ok(()) + } + + /// Add a new record. + fn add_record(&mut self, record: CostRecord) -> Result<()> { + let mut file = OpenOptions::new() + .create(true) + .append(true) + .open(&self.path) + .with_context(|| format!("Failed to open cost storage at {}", self.path.display()))?; + + writeln!(file, "{}", serde_json::to_string(&record)?) + .with_context(|| format!("Failed to write cost record to {}", self.path.display()))?; + file.sync_all() + .with_context(|| format!("Failed to sync cost storage at {}", self.path.display()))?; + + self.ensure_period_cache_current()?; + + let timestamp = record.usage.timestamp.naive_utc(); + if timestamp.date() == self.cached_day { + self.daily_cost_usd += record.usage.cost_usd; + } + if timestamp.year() == self.cached_year && timestamp.month() == self.cached_month { + self.monthly_cost_usd += record.usage.cost_usd; + } + + Ok(()) + } + + /// Get aggregated costs for current day and month. + fn get_aggregated_costs(&mut self) -> Result<(f64, f64)> { + self.ensure_period_cache_current()?; + Ok((self.daily_cost_usd, self.monthly_cost_usd)) + } + + /// Get cost for a specific date. + fn get_cost_for_date(&self, date: NaiveDate) -> Result { + let mut cost = 0.0; + + self.for_each_record(|record| { + if record.usage.timestamp.naive_utc().date() == date { + cost += record.usage.cost_usd; + } + })?; + + Ok(cost) + } + + /// Get cost for a specific month. + fn get_cost_for_month(&self, year: i32, month: u32) -> Result { + let mut cost = 0.0; + + self.for_each_record(|record| { + let timestamp = record.usage.timestamp.naive_utc(); + if timestamp.year() == year && timestamp.month() == month { + cost += record.usage.cost_usd; + } + })?; + + Ok(cost) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + fn enabled_config() -> CostConfig { + CostConfig { + enabled: true, + ..Default::default() + } + } + + #[test] + fn cost_tracker_initialization() { + let tmp = TempDir::new().unwrap(); + let tracker = CostTracker::new(enabled_config(), tmp.path()).unwrap(); + assert!(!tracker.session_id().is_empty()); + } + + #[test] + fn budget_check_when_disabled() { + let tmp = TempDir::new().unwrap(); + let config = CostConfig { + enabled: false, + ..Default::default() + }; + + let tracker = CostTracker::new(config, tmp.path()).unwrap(); + let check = tracker.check_budget(1000.0).unwrap(); + assert!(matches!(check, BudgetCheck::Allowed)); + } + + #[test] + fn record_usage_and_get_summary() { + let tmp = TempDir::new().unwrap(); + let tracker = CostTracker::new(enabled_config(), tmp.path()).unwrap(); + + let usage = TokenUsage::new("test/model", 1000, 500, 1.0, 2.0); + tracker.record_usage(usage).unwrap(); + + let summary = tracker.get_summary().unwrap(); + assert_eq!(summary.request_count, 1); + assert!(summary.session_cost_usd > 0.0); + assert_eq!(summary.by_model.len(), 1); + } + + #[test] + fn budget_exceeded_daily_limit() { + let tmp = TempDir::new().unwrap(); + let config = CostConfig { + enabled: true, + daily_limit_usd: 0.01, // Very low limit + ..Default::default() + }; + + let tracker = CostTracker::new(config, tmp.path()).unwrap(); + + // Record a usage that exceeds the limit + let usage = TokenUsage::new("test/model", 10000, 5000, 1.0, 2.0); // ~0.02 USD + tracker.record_usage(usage).unwrap(); + + let check = tracker.check_budget(0.01).unwrap(); + assert!(matches!(check, BudgetCheck::Exceeded { .. })); + } + + #[test] + fn summary_by_model_is_session_scoped() { + let tmp = TempDir::new().unwrap(); + let storage_path = resolve_storage_path(tmp.path()).unwrap(); + if let Some(parent) = storage_path.parent() { + fs::create_dir_all(parent).unwrap(); + } + + let old_record = CostRecord::new( + "old-session", + TokenUsage::new("legacy/model", 500, 500, 1.0, 1.0), + ); + let mut file = OpenOptions::new() + .create(true) + .append(true) + .open(storage_path) + .unwrap(); + writeln!(file, "{}", serde_json::to_string(&old_record).unwrap()).unwrap(); + file.sync_all().unwrap(); + + let tracker = CostTracker::new(enabled_config(), tmp.path()).unwrap(); + tracker + .record_usage(TokenUsage::new("session/model", 1000, 1000, 1.0, 1.0)) + .unwrap(); + + let summary = tracker.get_summary().unwrap(); + assert_eq!(summary.by_model.len(), 1); + assert!(summary.by_model.contains_key("session/model")); + assert!(!summary.by_model.contains_key("legacy/model")); + } + + #[test] + fn malformed_lines_are_ignored_while_loading() { + let tmp = TempDir::new().unwrap(); + let storage_path = resolve_storage_path(tmp.path()).unwrap(); + if let Some(parent) = storage_path.parent() { + fs::create_dir_all(parent).unwrap(); + } + + let valid_usage = TokenUsage::new("test/model", 1000, 0, 1.0, 1.0); + let valid_record = CostRecord::new("session-a", valid_usage.clone()); + + let mut file = OpenOptions::new() + .create(true) + .append(true) + .open(storage_path) + .unwrap(); + writeln!(file, "{}", serde_json::to_string(&valid_record).unwrap()).unwrap(); + writeln!(file, "not-a-json-line").unwrap(); + writeln!(file).unwrap(); + file.sync_all().unwrap(); + + let tracker = CostTracker::new(enabled_config(), tmp.path()).unwrap(); + let today_cost = tracker.get_daily_cost(Utc::now().date_naive()).unwrap(); + assert!((today_cost - valid_usage.cost_usd).abs() < f64::EPSILON); + } + + #[test] + fn invalid_budget_estimate_is_rejected() { + let tmp = TempDir::new().unwrap(); + let tracker = CostTracker::new(enabled_config(), tmp.path()).unwrap(); + + let err = tracker.check_budget(f64::NAN).unwrap_err(); + assert!(err + .to_string() + .contains("Estimated cost must be a finite, non-negative value")); + } +} diff --git a/src/cost/types.rs b/src/cost/types.rs new file mode 100644 index 0000000..0e8d167 --- /dev/null +++ b/src/cost/types.rs @@ -0,0 +1,193 @@ +use serde::{Deserialize, Serialize}; + +/// Token usage information from a single API call. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TokenUsage { + /// Model identifier (e.g., "anthropic/claude-sonnet-4-20250514") + pub model: String, + /// Input/prompt tokens + pub input_tokens: u64, + /// Output/completion tokens + pub output_tokens: u64, + /// Total tokens + pub total_tokens: u64, + /// Calculated cost in USD + pub cost_usd: f64, + /// Timestamp of the request + pub timestamp: chrono::DateTime, +} + +impl TokenUsage { + fn sanitize_price(value: f64) -> f64 { + if value.is_finite() && value > 0.0 { + value + } else { + 0.0 + } + } + + /// Create a new token usage record. + pub fn new( + model: impl Into, + input_tokens: u64, + output_tokens: u64, + input_price_per_million: f64, + output_price_per_million: f64, + ) -> Self { + let model = model.into(); + let input_price_per_million = Self::sanitize_price(input_price_per_million); + let output_price_per_million = Self::sanitize_price(output_price_per_million); + let total_tokens = input_tokens.saturating_add(output_tokens); + + // Calculate cost: (tokens / 1M) * price_per_million + let input_cost = (input_tokens as f64 / 1_000_000.0) * input_price_per_million; + let output_cost = (output_tokens as f64 / 1_000_000.0) * output_price_per_million; + let cost_usd = input_cost + output_cost; + + Self { + model, + input_tokens, + output_tokens, + total_tokens, + cost_usd, + timestamp: chrono::Utc::now(), + } + } + + /// Get the total cost. + pub fn cost(&self) -> f64 { + self.cost_usd + } +} + +/// Time period for cost aggregation. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum UsagePeriod { + Session, + Day, + Month, +} + +/// A single cost record for persistent storage. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CostRecord { + /// Unique identifier + pub id: String, + /// Token usage details + pub usage: TokenUsage, + /// Session identifier (for grouping) + pub session_id: String, +} + +impl CostRecord { + /// Create a new cost record. + pub fn new(session_id: impl Into, usage: TokenUsage) -> Self { + Self { + id: uuid::Uuid::new_v4().to_string(), + usage, + session_id: session_id.into(), + } + } +} + +/// Budget enforcement result. +#[derive(Debug, Clone)] +pub enum BudgetCheck { + /// Within budget, request can proceed + Allowed, + /// Warning threshold exceeded but request can proceed + Warning { + current_usd: f64, + limit_usd: f64, + period: UsagePeriod, + }, + /// Budget exceeded, request blocked + Exceeded { + current_usd: f64, + limit_usd: f64, + period: UsagePeriod, + }, +} + +/// Cost summary for reporting. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CostSummary { + /// Total cost for the session + pub session_cost_usd: f64, + /// Total cost for the day + pub daily_cost_usd: f64, + /// Total cost for the month + pub monthly_cost_usd: f64, + /// Total tokens used + pub total_tokens: u64, + /// Number of requests + pub request_count: usize, + /// Breakdown by model + pub by_model: std::collections::HashMap, +} + +/// Statistics for a specific model. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelStats { + /// Model name + pub model: String, + /// Total cost for this model + pub cost_usd: f64, + /// Total tokens for this model + pub total_tokens: u64, + /// Number of requests for this model + pub request_count: usize, +} + +impl Default for CostSummary { + fn default() -> Self { + Self { + session_cost_usd: 0.0, + daily_cost_usd: 0.0, + monthly_cost_usd: 0.0, + total_tokens: 0, + request_count: 0, + by_model: std::collections::HashMap::new(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn token_usage_calculation() { + let usage = TokenUsage::new("test/model", 1000, 500, 3.0, 15.0); + + // Expected: (1000/1M)*3 + (500/1M)*15 = 0.003 + 0.0075 = 0.0105 + assert!((usage.cost_usd - 0.0105).abs() < 0.0001); + assert_eq!(usage.input_tokens, 1000); + assert_eq!(usage.output_tokens, 500); + assert_eq!(usage.total_tokens, 1500); + } + + #[test] + fn token_usage_zero_tokens() { + let usage = TokenUsage::new("test/model", 0, 0, 3.0, 15.0); + assert!(usage.cost_usd.abs() < f64::EPSILON); + assert_eq!(usage.total_tokens, 0); + } + + #[test] + fn token_usage_negative_or_non_finite_prices_are_clamped() { + let usage = TokenUsage::new("test/model", 1000, 1000, -3.0, f64::NAN); + assert!(usage.cost_usd.abs() < f64::EPSILON); + assert_eq!(usage.total_tokens, 2000); + } + + #[test] + fn cost_record_creation() { + let usage = TokenUsage::new("test/model", 100, 50, 1.0, 2.0); + let record = CostRecord::new("session-123", usage); + + assert_eq!(record.session_id, "session-123"); + assert!(!record.id.is_empty()); + assert_eq!(record.usage.model, "test/model"); + } +} diff --git a/src/cron/mod.rs b/src/cron/mod.rs index 9866ec5..0f39bc7 100644 --- a/src/cron/mod.rs +++ b/src/cron/mod.rs @@ -1,27 +1,27 @@ use crate::config::Config; -use anyhow::{Context, Result}; -use chrono::{DateTime, Utc}; -use cron::Schedule; -use rusqlite::{params, Connection}; -use std::str::FromStr; -use uuid::Uuid; +use anyhow::Result; + +mod schedule; +mod store; +mod types; pub mod scheduler; -#[derive(Debug, Clone)] -pub struct CronJob { - pub id: String, - pub expression: String, - pub command: String, - pub next_run: DateTime, - pub last_run: Option>, - pub last_status: Option, -} +#[allow(unused_imports)] +pub use schedule::{ + next_run_for_schedule, normalize_expression, schedule_cron_expression, validate_schedule, +}; +#[allow(unused_imports)] +pub use store::{ + add_agent_job, add_job, add_shell_job, due_jobs, get_job, list_jobs, list_runs, + record_last_run, record_run, remove_job, reschedule_after_run, update_job, +}; +pub use types::{CronJob, CronJobPatch, CronRun, DeliveryConfig, JobType, Schedule, SessionTarget}; #[allow(clippy::needless_pass_by_value)] -pub fn handle_command(command: super::CronCommands, config: &Config) -> Result<()> { +pub fn handle_command(command: crate::CronCommands, config: &Config) -> Result<()> { match command { - super::CronCommands::List => { + crate::CronCommands::List => { let jobs = list_jobs(config)?; if jobs.is_empty() { println!("No scheduled tasks yet."); @@ -37,319 +37,133 @@ pub fn handle_command(command: super::CronCommands, config: &Config) -> Result<( .map_or_else(|| "never".into(), |d| d.to_rfc3339()); let last_status = job.last_status.unwrap_or_else(|| "n/a".into()); println!( - "- {} | {} | next={} | last={} ({})\n cmd: {}", + "- {} | {:?} | next={} | last={} ({})", job.id, - job.expression, + job.schedule, job.next_run.to_rfc3339(), last_run, last_status, - job.command ); + if !job.command.is_empty() { + println!(" cmd: {}", job.command); + } + if let Some(prompt) = &job.prompt { + println!(" prompt: {prompt}"); + } } Ok(()) } - super::CronCommands::Add { + crate::CronCommands::Add { expression, + tz, command, } => { - let job = add_job(config, &expression, &command)?; + let schedule = Schedule::Cron { + expr: expression, + tz, + }; + let job = add_shell_job(config, None, schedule, &command)?; println!("✅ Added cron job {}", job.id); println!(" Expr: {}", job.expression); println!(" Next: {}", job.next_run.to_rfc3339()); println!(" Cmd : {}", job.command); Ok(()) } - super::CronCommands::Remove { id } => remove_job(config, &id), + crate::CronCommands::AddAt { at, command } => { + let at = chrono::DateTime::parse_from_rfc3339(&at) + .map_err(|e| anyhow::anyhow!("Invalid RFC3339 timestamp for --at: {e}"))? + .with_timezone(&chrono::Utc); + let schedule = Schedule::At { at }; + let job = add_shell_job(config, None, schedule, &command)?; + println!("✅ Added one-shot cron job {}", job.id); + println!(" At : {}", job.next_run.to_rfc3339()); + println!(" Cmd : {}", job.command); + Ok(()) + } + crate::CronCommands::AddEvery { every_ms, command } => { + let schedule = Schedule::Every { every_ms }; + let job = add_shell_job(config, None, schedule, &command)?; + println!("✅ Added interval cron job {}", job.id); + println!(" Every(ms): {every_ms}"); + println!(" Next : {}", job.next_run.to_rfc3339()); + println!(" Cmd : {}", job.command); + Ok(()) + } + crate::CronCommands::Once { delay, command } => { + let job = add_once(config, &delay, &command)?; + println!("✅ Added one-shot cron job {}", job.id); + println!(" At : {}", job.next_run.to_rfc3339()); + println!(" Cmd : {}", job.command); + Ok(()) + } + crate::CronCommands::Remove { id } => remove_job(config, &id), + crate::CronCommands::Pause { id } => { + pause_job(config, &id)?; + println!("⏸️ Paused cron job {id}"); + Ok(()) + } + crate::CronCommands::Resume { id } => { + resume_job(config, &id)?; + println!("▶️ Resumed cron job {id}"); + Ok(()) + } } } -pub fn add_job(config: &Config, expression: &str, command: &str) -> Result { - let now = Utc::now(); - let next_run = next_run_for(expression, now)?; - let id = Uuid::new_v4().to_string(); - - with_connection(config, |conn| { - conn.execute( - "INSERT INTO cron_jobs (id, expression, command, created_at, next_run) - VALUES (?1, ?2, ?3, ?4, ?5)", - params![ - id, - expression, - command, - now.to_rfc3339(), - next_run.to_rfc3339() - ], - ) - .context("Failed to insert cron job")?; - Ok(()) - })?; - - Ok(CronJob { - id, - expression: expression.to_string(), - command: command.to_string(), - next_run, - last_run: None, - last_status: None, - }) +pub fn add_once(config: &Config, delay: &str, command: &str) -> Result { + let duration = parse_delay(delay)?; + let at = chrono::Utc::now() + duration; + add_once_at(config, at, command) } -pub fn list_jobs(config: &Config) -> Result> { - with_connection(config, |conn| { - let mut stmt = conn.prepare( - "SELECT id, expression, command, next_run, last_run, last_status - FROM cron_jobs ORDER BY next_run ASC", - )?; - - let rows = stmt.query_map([], |row| { - let next_run_raw: String = row.get(3)?; - let last_run_raw: Option = row.get(4)?; - Ok(( - row.get::<_, String>(0)?, - row.get::<_, String>(1)?, - row.get::<_, String>(2)?, - next_run_raw, - last_run_raw, - row.get::<_, Option>(5)?, - )) - })?; - - let mut jobs = Vec::new(); - for row in rows { - let (id, expression, command, next_run_raw, last_run_raw, last_status) = row?; - jobs.push(CronJob { - id, - expression, - command, - next_run: parse_rfc3339(&next_run_raw)?, - last_run: match last_run_raw { - Some(raw) => Some(parse_rfc3339(&raw)?), - None => None, - }, - last_status, - }); - } - Ok(jobs) - }) -} - -pub fn remove_job(config: &Config, id: &str) -> Result<()> { - let changed = with_connection(config, |conn| { - conn.execute("DELETE FROM cron_jobs WHERE id = ?1", params![id]) - .context("Failed to delete cron job") - })?; - - if changed == 0 { - anyhow::bail!("Cron job '{id}' not found"); - } - - println!("✅ Removed cron job {id}"); - Ok(()) -} - -pub fn due_jobs(config: &Config, now: DateTime) -> Result> { - with_connection(config, |conn| { - let mut stmt = conn.prepare( - "SELECT id, expression, command, next_run, last_run, last_status - FROM cron_jobs WHERE next_run <= ?1 ORDER BY next_run ASC", - )?; - - let rows = stmt.query_map(params![now.to_rfc3339()], |row| { - let next_run_raw: String = row.get(3)?; - let last_run_raw: Option = row.get(4)?; - Ok(( - row.get::<_, String>(0)?, - row.get::<_, String>(1)?, - row.get::<_, String>(2)?, - next_run_raw, - last_run_raw, - row.get::<_, Option>(5)?, - )) - })?; - - let mut jobs = Vec::new(); - for row in rows { - let (id, expression, command, next_run_raw, last_run_raw, last_status) = row?; - jobs.push(CronJob { - id, - expression, - command, - next_run: parse_rfc3339(&next_run_raw)?, - last_run: match last_run_raw { - Some(raw) => Some(parse_rfc3339(&raw)?), - None => None, - }, - last_status, - }); - } - Ok(jobs) - }) -} - -pub fn reschedule_after_run( +pub fn add_once_at( config: &Config, - job: &CronJob, - success: bool, - output: &str, -) -> Result<()> { - let now = Utc::now(); - let next_run = next_run_for(&job.expression, now)?; - let status = if success { "ok" } else { "error" }; - - with_connection(config, |conn| { - conn.execute( - "UPDATE cron_jobs - SET next_run = ?1, last_run = ?2, last_status = ?3, last_output = ?4 - WHERE id = ?5", - params![ - next_run.to_rfc3339(), - now.to_rfc3339(), - status, - output, - job.id - ], - ) - .context("Failed to update cron job run state")?; - Ok(()) - }) + at: chrono::DateTime, + command: &str, +) -> Result { + let schedule = Schedule::At { at }; + add_shell_job(config, None, schedule, command) } -fn next_run_for(expression: &str, from: DateTime) -> Result> { - let normalized = normalize_expression(expression)?; - let schedule = Schedule::from_str(&normalized) - .with_context(|| format!("Invalid cron expression: {expression}"))?; - schedule - .after(&from) - .next() - .ok_or_else(|| anyhow::anyhow!("No future occurrence for expression: {expression}")) -} - -fn normalize_expression(expression: &str) -> Result { - let expression = expression.trim(); - let field_count = expression.split_whitespace().count(); - - match field_count { - // standard crontab syntax: minute hour day month weekday - 5 => Ok(format!("0 {expression}")), - // crate-native syntax includes seconds (+ optional year) - 6 | 7 => Ok(expression.to_string()), - _ => anyhow::bail!( - "Invalid cron expression: {expression} (expected 5, 6, or 7 fields, got {field_count})" - ), - } -} - -fn parse_rfc3339(raw: &str) -> Result> { - let parsed = DateTime::parse_from_rfc3339(raw) - .with_context(|| format!("Invalid RFC3339 timestamp in cron DB: {raw}"))?; - Ok(parsed.with_timezone(&Utc)) -} - -fn with_connection(config: &Config, f: impl FnOnce(&Connection) -> Result) -> Result { - let db_path = config.workspace_dir.join("cron").join("jobs.db"); - if let Some(parent) = db_path.parent() { - std::fs::create_dir_all(parent) - .with_context(|| format!("Failed to create cron directory: {}", parent.display()))?; - } - - let conn = Connection::open(&db_path) - .with_context(|| format!("Failed to open cron DB: {}", db_path.display()))?; - - conn.execute_batch( - "CREATE TABLE IF NOT EXISTS cron_jobs ( - id TEXT PRIMARY KEY, - expression TEXT NOT NULL, - command TEXT NOT NULL, - created_at TEXT NOT NULL, - next_run TEXT NOT NULL, - last_run TEXT, - last_status TEXT, - last_output TEXT - ); - CREATE INDEX IF NOT EXISTS idx_cron_jobs_next_run ON cron_jobs(next_run);", +pub fn pause_job(config: &Config, id: &str) -> Result { + update_job( + config, + id, + CronJobPatch { + enabled: Some(false), + ..CronJobPatch::default() + }, ) - .context("Failed to initialize cron schema")?; - - f(&conn) } -#[cfg(test)] -mod tests { - use super::*; - use crate::config::Config; - use chrono::Duration as ChronoDuration; - use tempfile::TempDir; - - fn test_config(tmp: &TempDir) -> Config { - let config = Config { - workspace_dir: tmp.path().join("workspace"), - config_path: tmp.path().join("config.toml"), - ..Config::default() - }; - std::fs::create_dir_all(&config.workspace_dir).unwrap(); - config - } - - #[test] - fn add_job_accepts_five_field_expression() { - let tmp = TempDir::new().unwrap(); - let config = test_config(&tmp); - - let job = add_job(&config, "*/5 * * * *", "echo ok").unwrap(); - - assert_eq!(job.expression, "*/5 * * * *"); - assert_eq!(job.command, "echo ok"); - } - - #[test] - fn add_job_rejects_invalid_field_count() { - let tmp = TempDir::new().unwrap(); - let config = test_config(&tmp); - - let err = add_job(&config, "* * * *", "echo bad").unwrap_err(); - assert!(err.to_string().contains("expected 5, 6, or 7 fields")); - } - - #[test] - fn add_list_remove_roundtrip() { - let tmp = TempDir::new().unwrap(); - let config = test_config(&tmp); - - let job = add_job(&config, "*/10 * * * *", "echo roundtrip").unwrap(); - let listed = list_jobs(&config).unwrap(); - assert_eq!(listed.len(), 1); - assert_eq!(listed[0].id, job.id); - - remove_job(&config, &job.id).unwrap(); - assert!(list_jobs(&config).unwrap().is_empty()); - } - - #[test] - fn due_jobs_filters_by_timestamp() { - let tmp = TempDir::new().unwrap(); - let config = test_config(&tmp); - - let _job = add_job(&config, "* * * * *", "echo due").unwrap(); - - let due_now = due_jobs(&config, Utc::now()).unwrap(); - assert!(due_now.is_empty(), "new job should not be due immediately"); - - let far_future = Utc::now() + ChronoDuration::days(365); - let due_future = due_jobs(&config, far_future).unwrap(); - assert_eq!(due_future.len(), 1, "job should be due in far future"); - } - - #[test] - fn reschedule_after_run_persists_last_status_and_last_run() { - let tmp = TempDir::new().unwrap(); - let config = test_config(&tmp); - - let job = add_job(&config, "*/15 * * * *", "echo run").unwrap(); - reschedule_after_run(&config, &job, false, "failed output").unwrap(); - - let listed = list_jobs(&config).unwrap(); - let stored = listed.iter().find(|j| j.id == job.id).unwrap(); - assert_eq!(stored.last_status.as_deref(), Some("error")); - assert!(stored.last_run.is_some()); - } +pub fn resume_job(config: &Config, id: &str) -> Result { + update_job( + config, + id, + CronJobPatch { + enabled: Some(true), + ..CronJobPatch::default() + }, + ) +} + +fn parse_delay(input: &str) -> Result { + let input = input.trim(); + if input.is_empty() { + anyhow::bail!("delay must not be empty"); + } + let split = input + .find(|c: char| !c.is_ascii_digit()) + .unwrap_or(input.len()); + let (num, unit) = input.split_at(split); + let amount: i64 = num.parse()?; + let unit = if unit.is_empty() { "m" } else { unit }; + let duration = match unit { + "s" => chrono::Duration::seconds(amount), + "m" => chrono::Duration::minutes(amount), + "h" => chrono::Duration::hours(amount), + "d" => chrono::Duration::days(amount), + _ => anyhow::bail!("unsupported delay unit '{unit}', use s/m/h/d"), + }; + Ok(duration) } diff --git a/src/cron/schedule.rs b/src/cron/schedule.rs new file mode 100644 index 0000000..d7206b7 --- /dev/null +++ b/src/cron/schedule.rs @@ -0,0 +1,114 @@ +use crate::cron::Schedule; +use anyhow::{Context, Result}; +use chrono::{DateTime, Duration as ChronoDuration, Utc}; +use cron::Schedule as CronExprSchedule; +use std::str::FromStr; + +pub fn next_run_for_schedule(schedule: &Schedule, from: DateTime) -> Result> { + match schedule { + Schedule::Cron { expr, tz } => { + let normalized = normalize_expression(expr)?; + let cron = CronExprSchedule::from_str(&normalized) + .with_context(|| format!("Invalid cron expression: {expr}"))?; + + if let Some(tz_name) = tz { + let timezone = chrono_tz::Tz::from_str(tz_name) + .with_context(|| format!("Invalid IANA timezone: {tz_name}"))?; + let localized_from = from.with_timezone(&timezone); + let next_local = cron.after(&localized_from).next().ok_or_else(|| { + anyhow::anyhow!("No future occurrence for expression: {expr}") + })?; + Ok(next_local.with_timezone(&Utc)) + } else { + cron.after(&from) + .next() + .ok_or_else(|| anyhow::anyhow!("No future occurrence for expression: {expr}")) + } + } + Schedule::At { at } => Ok(*at), + Schedule::Every { every_ms } => { + if *every_ms == 0 { + anyhow::bail!("Invalid schedule: every_ms must be > 0"); + } + let ms = i64::try_from(*every_ms).context("every_ms is too large")?; + let delta = ChronoDuration::milliseconds(ms); + from.checked_add_signed(delta) + .ok_or_else(|| anyhow::anyhow!("every_ms overflowed DateTime")) + } + } +} + +pub fn validate_schedule(schedule: &Schedule, now: DateTime) -> Result<()> { + match schedule { + Schedule::Cron { expr, .. } => { + let _ = normalize_expression(expr)?; + let _ = next_run_for_schedule(schedule, now)?; + Ok(()) + } + Schedule::At { at } => { + if *at <= now { + anyhow::bail!("Invalid schedule: 'at' must be in the future"); + } + Ok(()) + } + Schedule::Every { every_ms } => { + if *every_ms == 0 { + anyhow::bail!("Invalid schedule: every_ms must be > 0"); + } + Ok(()) + } + } +} + +pub fn schedule_cron_expression(schedule: &Schedule) -> Option { + match schedule { + Schedule::Cron { expr, .. } => Some(expr.clone()), + _ => None, + } +} + +pub fn normalize_expression(expression: &str) -> Result { + let expression = expression.trim(); + let field_count = expression.split_whitespace().count(); + + match field_count { + // standard crontab syntax: minute hour day month weekday + 5 => Ok(format!("0 {expression}")), + // crate-native syntax includes seconds (+ optional year) + 6 | 7 => Ok(expression.to_string()), + _ => anyhow::bail!( + "Invalid cron expression: {expression} (expected 5, 6, or 7 fields, got {field_count})" + ), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use chrono::TimeZone; + + #[test] + fn next_run_for_schedule_supports_every_and_at() { + let now = Utc::now(); + let every = Schedule::Every { every_ms: 60_000 }; + let next = next_run_for_schedule(&every, now).unwrap(); + assert!(next > now); + + let at = now + ChronoDuration::minutes(10); + let at_schedule = Schedule::At { at }; + let next_at = next_run_for_schedule(&at_schedule, now).unwrap(); + assert_eq!(next_at, at); + } + + #[test] + fn next_run_for_schedule_supports_timezone() { + let from = Utc.with_ymd_and_hms(2026, 2, 16, 0, 0, 0).unwrap(); + let schedule = Schedule::Cron { + expr: "0 9 * * *".into(), + tz: Some("America/Los_Angeles".into()), + }; + + let next = next_run_for_schedule(&schedule, from).unwrap(); + assert_eq!(next, Utc.with_ymd_and_hms(2026, 2, 16, 17, 0, 0).unwrap()); + } +} diff --git a/src/cron/scheduler.rs b/src/cron/scheduler.rs index 0453999..e50ef78 100644 --- a/src/cron/scheduler.rs +++ b/src/cron/scheduler.rs @@ -1,8 +1,14 @@ +use crate::channels::{ + Channel, DiscordChannel, MattermostChannel, SendMessage, SlackChannel, TelegramChannel, +}; use crate::config::Config; -use crate::cron::{due_jobs, reschedule_after_run, CronJob}; +use crate::cron::{ + due_jobs, next_run_for_schedule, record_last_run, record_run, remove_job, reschedule_after_run, + update_job, CronJob, CronJobPatch, DeliveryConfig, JobType, Schedule, SessionTarget, +}; use crate::security::SecurityPolicy; use anyhow::Result; -use chrono::Utc; +use chrono::{DateTime, Utc}; use tokio::process::Command; use tokio::time::{self, Duration}; @@ -29,20 +35,26 @@ pub async fn run(config: Config) -> Result<()> { for job in jobs { crate::health::mark_component_ok("scheduler"); + warn_if_high_frequency_agent_job(&job); + + let started_at = Utc::now(); let (success, output) = execute_job_with_retry(&config, &security, &job).await; + let finished_at = Utc::now(); + let success = + persist_job_result(&config, &job, success, &output, started_at, finished_at).await; if !success { crate::health::mark_component_error("scheduler", format!("job {} failed", job.id)); } - - if let Err(e) = reschedule_after_run(&config, &job, success, &output) { - crate::health::mark_component_error("scheduler", e.to_string()); - tracing::warn!("Failed to persist scheduler run result: {e}"); - } } } } +pub async fn execute_job_now(config: &Config, job: &CronJob) -> (bool, String) { + let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); + execute_job_with_retry(config, &security, job).await +} + async fn execute_job_with_retry( config: &Config, security: &SecurityPolicy, @@ -53,7 +65,10 @@ async fn execute_job_with_retry( let mut backoff_ms = config.reliability.provider_backoff_ms.max(200); for attempt in 0..=retries { - let (success, output) = run_job_command(config, security, job).await; + let (success, output) = match job.job_type { + JobType::Shell => run_job_command(config, security, job).await, + JobType::Agent => run_agent_job(config, job).await, + }; last_output = output; if success { @@ -75,6 +90,200 @@ async fn execute_job_with_retry( (false, last_output) } +async fn run_agent_job(config: &Config, job: &CronJob) -> (bool, String) { + let name = job.name.clone().unwrap_or_else(|| "cron-job".to_string()); + let prompt = job.prompt.clone().unwrap_or_default(); + let prefixed_prompt = format!("[cron:{} {name}] {prompt}", job.id); + let model_override = job.model.clone(); + + let run_result = match job.session_target { + SessionTarget::Main | SessionTarget::Isolated => { + crate::agent::run( + config.clone(), + Some(prefixed_prompt), + None, + model_override, + config.default_temperature, + vec![], + ) + .await + } + }; + + match run_result { + Ok(response) => ( + true, + if response.trim().is_empty() { + "agent job executed".to_string() + } else { + response + }, + ), + Err(e) => (false, format!("agent job failed: {e}")), + } +} + +async fn persist_job_result( + config: &Config, + job: &CronJob, + mut success: bool, + output: &str, + started_at: DateTime, + finished_at: DateTime, +) -> bool { + let duration_ms = (finished_at - started_at).num_milliseconds(); + + if let Err(e) = deliver_if_configured(config, job, output).await { + if job.delivery.best_effort { + tracing::warn!("Cron delivery failed (best_effort): {e}"); + } else { + success = false; + tracing::warn!("Cron delivery failed: {e}"); + } + } + + let _ = record_run( + config, + &job.id, + started_at, + finished_at, + if success { "ok" } else { "error" }, + Some(output), + duration_ms, + ); + + if is_one_shot_auto_delete(job) { + if success { + if let Err(e) = remove_job(config, &job.id) { + tracing::warn!("Failed to remove one-shot cron job after success: {e}"); + } + } else { + let _ = record_last_run(config, &job.id, finished_at, false, output); + if let Err(e) = update_job( + config, + &job.id, + CronJobPatch { + enabled: Some(false), + ..CronJobPatch::default() + }, + ) { + tracing::warn!("Failed to disable failed one-shot cron job: {e}"); + } + } + return success; + } + + if let Err(e) = reschedule_after_run(config, job, success, output) { + tracing::warn!("Failed to persist scheduler run result: {e}"); + } + + success +} + +fn is_one_shot_auto_delete(job: &CronJob) -> bool { + job.delete_after_run && matches!(job.schedule, Schedule::At { .. }) +} + +fn warn_if_high_frequency_agent_job(job: &CronJob) { + if !matches!(job.job_type, JobType::Agent) { + return; + } + let too_frequent = match &job.schedule { + Schedule::Every { every_ms } => *every_ms < 5 * 60 * 1000, + Schedule::Cron { .. } => { + let now = Utc::now(); + match ( + next_run_for_schedule(&job.schedule, now), + next_run_for_schedule(&job.schedule, now + chrono::Duration::seconds(1)), + ) { + (Ok(a), Ok(b)) => (b - a).num_minutes() < 5, + _ => false, + } + } + Schedule::At { .. } => false, + }; + + if too_frequent { + tracing::warn!( + "Cron agent job '{}' is scheduled more frequently than every 5 minutes", + job.id + ); + } +} + +async fn deliver_if_configured(config: &Config, job: &CronJob, output: &str) -> Result<()> { + let delivery: &DeliveryConfig = &job.delivery; + if !delivery.mode.eq_ignore_ascii_case("announce") { + return Ok(()); + } + + let channel = delivery + .channel + .as_deref() + .ok_or_else(|| anyhow::anyhow!("delivery.channel is required for announce mode"))?; + let target = delivery + .to + .as_deref() + .ok_or_else(|| anyhow::anyhow!("delivery.to is required for announce mode"))?; + + match channel.to_ascii_lowercase().as_str() { + "telegram" => { + let tg = config + .channels_config + .telegram + .as_ref() + .ok_or_else(|| anyhow::anyhow!("telegram channel not configured"))?; + let channel = TelegramChannel::new(tg.bot_token.clone(), tg.allowed_users.clone()); + channel.send(&SendMessage::new(output, target)).await?; + } + "discord" => { + let dc = config + .channels_config + .discord + .as_ref() + .ok_or_else(|| anyhow::anyhow!("discord channel not configured"))?; + let channel = DiscordChannel::new( + dc.bot_token.clone(), + dc.guild_id.clone(), + dc.allowed_users.clone(), + dc.listen_to_bots, + dc.mention_only, + ); + channel.send(&SendMessage::new(output, target)).await?; + } + "slack" => { + let sl = config + .channels_config + .slack + .as_ref() + .ok_or_else(|| anyhow::anyhow!("slack channel not configured"))?; + let channel = SlackChannel::new( + sl.bot_token.clone(), + sl.channel_id.clone(), + sl.allowed_users.clone(), + ); + channel.send(&SendMessage::new(output, target)).await?; + } + "mattermost" => { + let mm = config + .channels_config + .mattermost + .as_ref() + .ok_or_else(|| anyhow::anyhow!("mattermost channel not configured"))?; + let channel = MattermostChannel::new( + mm.url.clone(), + mm.bot_token.clone(), + mm.channel_id.clone(), + mm.allowed_users.clone(), + ); + channel.send(&SendMessage::new(output, target)).await?; + } + other => anyhow::bail!("unsupported delivery channel: {other}"), + } + + Ok(()) +} + fn is_env_assignment(word: &str) -> bool { word.contains('=') && word @@ -138,6 +347,20 @@ async fn run_job_command( security: &SecurityPolicy, job: &CronJob, ) -> (bool, String) { + if !security.can_act() { + return ( + false, + "blocked by security policy: autonomy is read-only".to_string(), + ); + } + + if security.is_rate_limited() { + return ( + false, + "blocked by security policy: rate limit exceeded".to_string(), + ); + } + if !security.is_command_allowed(&job.command) { return ( false, @@ -155,6 +378,13 @@ async fn run_job_command( ); } + if !security.record_action() { + return ( + false, + "blocked by security policy: action budget exhausted".to_string(), + ); + } + let output = Command::new("sh") .arg("-lc") .arg(&job.command) @@ -182,7 +412,9 @@ async fn run_job_command( mod tests { use super::*; use crate::config::Config; + use crate::cron::{self, DeliveryConfig}; use crate::security::SecurityPolicy; + use chrono::{Duration as ChronoDuration, Utc}; use tempfile::TempDir; fn test_config(tmp: &TempDir) -> Config { @@ -199,10 +431,24 @@ mod tests { CronJob { id: "test-job".into(), expression: "* * * * *".into(), + schedule: crate::cron::Schedule::Cron { + expr: "* * * * *".into(), + tz: None, + }, command: command.into(), + prompt: None, + name: None, + job_type: JobType::Shell, + session_target: SessionTarget::Isolated, + model: None, + enabled: true, + delivery: DeliveryConfig::default(), + delete_after_run: false, + created_at: Utc::now(), next_run: Utc::now(), last_run: None, last_status: None, + last_output: None, } } @@ -261,6 +507,34 @@ mod tests { assert!(output.contains("/etc/passwd")); } + #[tokio::test] + async fn run_job_command_blocks_readonly_mode() { + let tmp = TempDir::new().unwrap(); + let mut config = test_config(&tmp); + config.autonomy.level = crate::security::AutonomyLevel::ReadOnly; + let job = test_job("echo should-not-run"); + let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); + + let (success, output) = run_job_command(&config, &security, &job).await; + assert!(!success); + assert!(output.contains("blocked by security policy")); + assert!(output.contains("read-only")); + } + + #[tokio::test] + async fn run_job_command_blocks_rate_limited() { + let tmp = TempDir::new().unwrap(); + let mut config = test_config(&tmp); + config.autonomy.max_actions_per_hour = 0; + let job = test_job("echo should-not-run"); + let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); + + let (success, output) = run_job_command(&config, &security, &job).await; + assert!(!success); + assert!(output.contains("blocked by security policy")); + assert!(output.contains("rate limit exceeded")); + } + #[tokio::test] async fn execute_job_with_retry_recovers_after_first_failure() { let tmp = TempDir::new().unwrap(); @@ -296,4 +570,103 @@ mod tests { assert!(!success); assert!(output.contains("always_missing_for_retry_test")); } + + #[tokio::test] + async fn run_agent_job_returns_error_without_provider_key() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + let mut job = test_job(""); + job.job_type = JobType::Agent; + job.prompt = Some("Say hello".into()); + + let (success, output) = run_agent_job(&config, &job).await; + assert!(!success); + assert!(output.contains("agent job failed:")); + } + + #[tokio::test] + async fn persist_job_result_records_run_and_reschedules_shell_job() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + let job = cron::add_job(&config, "*/5 * * * *", "echo ok").unwrap(); + let started = Utc::now(); + let finished = started + ChronoDuration::milliseconds(10); + + let success = persist_job_result(&config, &job, true, "ok", started, finished).await; + assert!(success); + + let runs = cron::list_runs(&config, &job.id, 10).unwrap(); + assert_eq!(runs.len(), 1); + let updated = cron::get_job(&config, &job.id).unwrap(); + assert_eq!(updated.last_status.as_deref(), Some("ok")); + } + + #[tokio::test] + async fn persist_job_result_success_deletes_one_shot() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + let at = Utc::now() + ChronoDuration::minutes(10); + let job = cron::add_agent_job( + &config, + Some("one-shot".into()), + crate::cron::Schedule::At { at }, + "Hello", + SessionTarget::Isolated, + None, + None, + true, + ) + .unwrap(); + let started = Utc::now(); + let finished = started + ChronoDuration::milliseconds(10); + + let success = persist_job_result(&config, &job, true, "ok", started, finished).await; + assert!(success); + let lookup = cron::get_job(&config, &job.id); + assert!(lookup.is_err()); + } + + #[tokio::test] + async fn persist_job_result_failure_disables_one_shot() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + let at = Utc::now() + ChronoDuration::minutes(10); + let job = cron::add_agent_job( + &config, + Some("one-shot".into()), + crate::cron::Schedule::At { at }, + "Hello", + SessionTarget::Isolated, + None, + None, + true, + ) + .unwrap(); + let started = Utc::now(); + let finished = started + ChronoDuration::milliseconds(10); + + let success = persist_job_result(&config, &job, false, "boom", started, finished).await; + assert!(!success); + let updated = cron::get_job(&config, &job.id).unwrap(); + assert!(!updated.enabled); + assert_eq!(updated.last_status.as_deref(), Some("error")); + } + + #[tokio::test] + async fn deliver_if_configured_handles_none_and_invalid_channel() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + let mut job = test_job("echo ok"); + + assert!(deliver_if_configured(&config, &job, "x").await.is_ok()); + + job.delivery = DeliveryConfig { + mode: "announce".into(), + channel: Some("invalid".into()), + to: Some("target".into()), + best_effort: true, + }; + let err = deliver_if_configured(&config, &job, "x").await.unwrap_err(); + assert!(err.to_string().contains("unsupported delivery channel")); + } } diff --git a/src/cron/store.rs b/src/cron/store.rs new file mode 100644 index 0000000..013ed55 --- /dev/null +++ b/src/cron/store.rs @@ -0,0 +1,668 @@ +use crate::config::Config; +use crate::cron::{ + next_run_for_schedule, schedule_cron_expression, validate_schedule, CronJob, CronJobPatch, + CronRun, DeliveryConfig, JobType, Schedule, SessionTarget, +}; +use anyhow::{Context, Result}; +use chrono::{DateTime, Utc}; +use rusqlite::{params, Connection}; +use uuid::Uuid; + +pub fn add_job(config: &Config, expression: &str, command: &str) -> Result { + let schedule = Schedule::Cron { + expr: expression.to_string(), + tz: None, + }; + add_shell_job(config, None, schedule, command) +} + +pub fn add_shell_job( + config: &Config, + name: Option, + schedule: Schedule, + command: &str, +) -> Result { + let now = Utc::now(); + validate_schedule(&schedule, now)?; + let next_run = next_run_for_schedule(&schedule, now)?; + let id = Uuid::new_v4().to_string(); + let expression = schedule_cron_expression(&schedule).unwrap_or_default(); + let schedule_json = serde_json::to_string(&schedule)?; + + with_connection(config, |conn| { + conn.execute( + "INSERT INTO cron_jobs ( + id, expression, command, schedule, job_type, prompt, name, session_target, model, + enabled, delivery, delete_after_run, created_at, next_run + ) VALUES (?1, ?2, ?3, ?4, 'shell', NULL, ?5, 'isolated', NULL, 1, ?6, 0, ?7, ?8)", + params![ + id, + expression, + command, + schedule_json, + name, + serde_json::to_string(&DeliveryConfig::default())?, + now.to_rfc3339(), + next_run.to_rfc3339(), + ], + ) + .context("Failed to insert cron shell job")?; + Ok(()) + })?; + + get_job(config, &id) +} + +#[allow(clippy::too_many_arguments)] +pub fn add_agent_job( + config: &Config, + name: Option, + schedule: Schedule, + prompt: &str, + session_target: SessionTarget, + model: Option, + delivery: Option, + delete_after_run: bool, +) -> Result { + let now = Utc::now(); + validate_schedule(&schedule, now)?; + let next_run = next_run_for_schedule(&schedule, now)?; + let id = Uuid::new_v4().to_string(); + let expression = schedule_cron_expression(&schedule).unwrap_or_default(); + let schedule_json = serde_json::to_string(&schedule)?; + let delivery = delivery.unwrap_or_default(); + + with_connection(config, |conn| { + conn.execute( + "INSERT INTO cron_jobs ( + id, expression, command, schedule, job_type, prompt, name, session_target, model, + enabled, delivery, delete_after_run, created_at, next_run + ) VALUES (?1, ?2, '', ?3, 'agent', ?4, ?5, ?6, ?7, 1, ?8, ?9, ?10, ?11)", + params![ + id, + expression, + schedule_json, + prompt, + name, + session_target.as_str(), + model, + serde_json::to_string(&delivery)?, + if delete_after_run { 1 } else { 0 }, + now.to_rfc3339(), + next_run.to_rfc3339(), + ], + ) + .context("Failed to insert cron agent job")?; + Ok(()) + })?; + + get_job(config, &id) +} + +pub fn list_jobs(config: &Config) -> Result> { + with_connection(config, |conn| { + let mut stmt = conn.prepare( + "SELECT id, expression, command, schedule, job_type, prompt, name, session_target, model, + enabled, delivery, delete_after_run, created_at, next_run, last_run, last_status, last_output + FROM cron_jobs ORDER BY next_run ASC", + )?; + + let rows = stmt.query_map([], map_cron_job_row)?; + + let mut jobs = Vec::new(); + for row in rows { + jobs.push(row?); + } + Ok(jobs) + }) +} + +pub fn get_job(config: &Config, job_id: &str) -> Result { + with_connection(config, |conn| { + let mut stmt = conn.prepare( + "SELECT id, expression, command, schedule, job_type, prompt, name, session_target, model, + enabled, delivery, delete_after_run, created_at, next_run, last_run, last_status, last_output + FROM cron_jobs WHERE id = ?1", + )?; + + let mut rows = stmt.query(params![job_id])?; + if let Some(row) = rows.next()? { + map_cron_job_row(row).map_err(Into::into) + } else { + anyhow::bail!("Cron job '{job_id}' not found") + } + }) +} + +pub fn remove_job(config: &Config, id: &str) -> Result<()> { + let changed = with_connection(config, |conn| { + conn.execute("DELETE FROM cron_jobs WHERE id = ?1", params![id]) + .context("Failed to delete cron job") + })?; + + if changed == 0 { + anyhow::bail!("Cron job '{id}' not found"); + } + + println!("✅ Removed cron job {id}"); + Ok(()) +} + +pub fn due_jobs(config: &Config, now: DateTime) -> Result> { + with_connection(config, |conn| { + let mut stmt = conn.prepare( + "SELECT id, expression, command, schedule, job_type, prompt, name, session_target, model, + enabled, delivery, delete_after_run, created_at, next_run, last_run, last_status, last_output + FROM cron_jobs WHERE enabled = 1 AND next_run <= ?1 ORDER BY next_run ASC", + )?; + + let rows = stmt.query_map(params![now.to_rfc3339()], map_cron_job_row)?; + + let mut jobs = Vec::new(); + for row in rows { + jobs.push(row?); + } + Ok(jobs) + }) +} + +pub fn update_job(config: &Config, job_id: &str, patch: CronJobPatch) -> Result { + let mut job = get_job(config, job_id)?; + let mut schedule_changed = false; + + if let Some(schedule) = patch.schedule { + validate_schedule(&schedule, Utc::now())?; + job.schedule = schedule; + job.expression = schedule_cron_expression(&job.schedule).unwrap_or_default(); + schedule_changed = true; + } + if let Some(command) = patch.command { + job.command = command; + } + if let Some(prompt) = patch.prompt { + job.prompt = Some(prompt); + } + if let Some(name) = patch.name { + job.name = Some(name); + } + if let Some(enabled) = patch.enabled { + job.enabled = enabled; + } + if let Some(delivery) = patch.delivery { + job.delivery = delivery; + } + if let Some(model) = patch.model { + job.model = Some(model); + } + if let Some(target) = patch.session_target { + job.session_target = target; + } + if let Some(delete_after_run) = patch.delete_after_run { + job.delete_after_run = delete_after_run; + } + + if schedule_changed { + job.next_run = next_run_for_schedule(&job.schedule, Utc::now())?; + } + + with_connection(config, |conn| { + conn.execute( + "UPDATE cron_jobs + SET expression = ?1, command = ?2, schedule = ?3, job_type = ?4, prompt = ?5, name = ?6, + session_target = ?7, model = ?8, enabled = ?9, delivery = ?10, delete_after_run = ?11, + next_run = ?12 + WHERE id = ?13", + params![ + job.expression, + job.command, + serde_json::to_string(&job.schedule)?, + job.job_type.as_str(), + job.prompt, + job.name, + job.session_target.as_str(), + job.model, + if job.enabled { 1 } else { 0 }, + serde_json::to_string(&job.delivery)?, + if job.delete_after_run { 1 } else { 0 }, + job.next_run.to_rfc3339(), + job.id, + ], + ) + .context("Failed to update cron job")?; + Ok(()) + })?; + + get_job(config, job_id) +} + +pub fn record_last_run( + config: &Config, + job_id: &str, + finished_at: DateTime, + success: bool, + output: &str, +) -> Result<()> { + let status = if success { "ok" } else { "error" }; + with_connection(config, |conn| { + conn.execute( + "UPDATE cron_jobs + SET last_run = ?1, last_status = ?2, last_output = ?3 + WHERE id = ?4", + params![finished_at.to_rfc3339(), status, output, job_id], + ) + .context("Failed to update cron last run fields")?; + Ok(()) + }) +} + +pub fn reschedule_after_run( + config: &Config, + job: &CronJob, + success: bool, + output: &str, +) -> Result<()> { + let now = Utc::now(); + let next_run = next_run_for_schedule(&job.schedule, now)?; + let status = if success { "ok" } else { "error" }; + + with_connection(config, |conn| { + conn.execute( + "UPDATE cron_jobs + SET next_run = ?1, last_run = ?2, last_status = ?3, last_output = ?4 + WHERE id = ?5", + params![ + next_run.to_rfc3339(), + now.to_rfc3339(), + status, + output, + job.id + ], + ) + .context("Failed to update cron job run state")?; + Ok(()) + }) +} + +pub fn record_run( + config: &Config, + job_id: &str, + started_at: DateTime, + finished_at: DateTime, + status: &str, + output: Option<&str>, + duration_ms: i64, +) -> Result<()> { + with_connection(config, |conn| { + conn.execute( + "INSERT INTO cron_runs (job_id, started_at, finished_at, status, output, duration_ms) + VALUES (?1, ?2, ?3, ?4, ?5, ?6)", + params![ + job_id, + started_at.to_rfc3339(), + finished_at.to_rfc3339(), + status, + output, + duration_ms, + ], + ) + .context("Failed to insert cron run")?; + + let keep = i64::from(config.cron.max_run_history.max(1)); + conn.execute( + "DELETE FROM cron_runs + WHERE job_id = ?1 + AND id NOT IN ( + SELECT id FROM cron_runs + WHERE job_id = ?1 + ORDER BY started_at DESC, id DESC + LIMIT ?2 + )", + params![job_id, keep], + ) + .context("Failed to prune cron run history")?; + Ok(()) + }) +} + +pub fn list_runs(config: &Config, job_id: &str, limit: usize) -> Result> { + with_connection(config, |conn| { + let lim = i64::try_from(limit.max(1)).context("Run history limit overflow")?; + let mut stmt = conn.prepare( + "SELECT id, job_id, started_at, finished_at, status, output, duration_ms + FROM cron_runs + WHERE job_id = ?1 + ORDER BY started_at DESC, id DESC + LIMIT ?2", + )?; + + let rows = stmt.query_map(params![job_id, lim], |row| { + Ok(CronRun { + id: row.get(0)?, + job_id: row.get(1)?, + started_at: parse_rfc3339(&row.get::<_, String>(2)?) + .map_err(sql_conversion_error)?, + finished_at: parse_rfc3339(&row.get::<_, String>(3)?) + .map_err(sql_conversion_error)?, + status: row.get(4)?, + output: row.get(5)?, + duration_ms: row.get(6)?, + }) + })?; + + let mut runs = Vec::new(); + for row in rows { + runs.push(row?); + } + Ok(runs) + }) +} + +fn parse_rfc3339(raw: &str) -> Result> { + let parsed = DateTime::parse_from_rfc3339(raw) + .with_context(|| format!("Invalid RFC3339 timestamp in cron DB: {raw}"))?; + Ok(parsed.with_timezone(&Utc)) +} + +fn sql_conversion_error(err: anyhow::Error) -> rusqlite::Error { + rusqlite::Error::ToSqlConversionFailure(err.into()) +} + +fn map_cron_job_row(row: &rusqlite::Row<'_>) -> rusqlite::Result { + let expression: String = row.get(1)?; + let schedule_raw: Option = row.get(3)?; + let schedule = + decode_schedule(schedule_raw.as_deref(), &expression).map_err(sql_conversion_error)?; + + let delivery_raw: Option = row.get(10)?; + let delivery = decode_delivery(delivery_raw.as_deref()).map_err(sql_conversion_error)?; + + let next_run_raw: String = row.get(13)?; + let last_run_raw: Option = row.get(14)?; + let created_at_raw: String = row.get(12)?; + + Ok(CronJob { + id: row.get(0)?, + expression, + schedule, + command: row.get(2)?, + job_type: JobType::parse(&row.get::<_, String>(4)?), + prompt: row.get(5)?, + name: row.get(6)?, + session_target: SessionTarget::parse(&row.get::<_, String>(7)?), + model: row.get(8)?, + enabled: row.get::<_, i64>(9)? != 0, + delivery, + delete_after_run: row.get::<_, i64>(11)? != 0, + created_at: parse_rfc3339(&created_at_raw).map_err(sql_conversion_error)?, + next_run: parse_rfc3339(&next_run_raw).map_err(sql_conversion_error)?, + last_run: match last_run_raw { + Some(raw) => Some(parse_rfc3339(&raw).map_err(sql_conversion_error)?), + None => None, + }, + last_status: row.get(15)?, + last_output: row.get(16)?, + }) +} + +fn decode_schedule(schedule_raw: Option<&str>, expression: &str) -> Result { + if let Some(raw) = schedule_raw { + let trimmed = raw.trim(); + if !trimmed.is_empty() { + return serde_json::from_str(trimmed) + .with_context(|| format!("Failed to parse cron schedule JSON: {trimmed}")); + } + } + + if expression.trim().is_empty() { + anyhow::bail!("Missing schedule and legacy expression for cron job") + } + + Ok(Schedule::Cron { + expr: expression.to_string(), + tz: None, + }) +} + +fn decode_delivery(delivery_raw: Option<&str>) -> Result { + if let Some(raw) = delivery_raw { + let trimmed = raw.trim(); + if !trimmed.is_empty() { + return serde_json::from_str(trimmed) + .with_context(|| format!("Failed to parse cron delivery JSON: {trimmed}")); + } + } + Ok(DeliveryConfig::default()) +} + +fn add_column_if_missing(conn: &Connection, name: &str, sql_type: &str) -> Result<()> { + let mut stmt = conn.prepare("PRAGMA table_info(cron_jobs)")?; + let mut rows = stmt.query([])?; + while let Some(row) = rows.next()? { + let col_name: String = row.get(1)?; + if col_name == name { + return Ok(()); + } + } + + conn.execute( + &format!("ALTER TABLE cron_jobs ADD COLUMN {name} {sql_type}"), + [], + ) + .with_context(|| format!("Failed to add cron_jobs.{name}"))?; + Ok(()) +} + +fn with_connection(config: &Config, f: impl FnOnce(&Connection) -> Result) -> Result { + let db_path = config.workspace_dir.join("cron").join("jobs.db"); + if let Some(parent) = db_path.parent() { + std::fs::create_dir_all(parent) + .with_context(|| format!("Failed to create cron directory: {}", parent.display()))?; + } + + let conn = Connection::open(&db_path) + .with_context(|| format!("Failed to open cron DB: {}", db_path.display()))?; + + conn.execute_batch( + "PRAGMA foreign_keys = ON; + CREATE TABLE IF NOT EXISTS cron_jobs ( + id TEXT PRIMARY KEY, + expression TEXT NOT NULL, + command TEXT NOT NULL, + schedule TEXT, + job_type TEXT NOT NULL DEFAULT 'shell', + prompt TEXT, + name TEXT, + session_target TEXT NOT NULL DEFAULT 'isolated', + model TEXT, + enabled INTEGER NOT NULL DEFAULT 1, + delivery TEXT, + delete_after_run INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL, + next_run TEXT NOT NULL, + last_run TEXT, + last_status TEXT, + last_output TEXT + ); + CREATE INDEX IF NOT EXISTS idx_cron_jobs_next_run ON cron_jobs(next_run); + + CREATE TABLE IF NOT EXISTS cron_runs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + job_id TEXT NOT NULL, + started_at TEXT NOT NULL, + finished_at TEXT NOT NULL, + status TEXT NOT NULL, + output TEXT, + duration_ms INTEGER, + FOREIGN KEY (job_id) REFERENCES cron_jobs(id) ON DELETE CASCADE + ); + CREATE INDEX IF NOT EXISTS idx_cron_runs_job_id ON cron_runs(job_id); + CREATE INDEX IF NOT EXISTS idx_cron_runs_started_at ON cron_runs(started_at);", + ) + .context("Failed to initialize cron schema")?; + + add_column_if_missing(&conn, "schedule", "TEXT")?; + add_column_if_missing(&conn, "job_type", "TEXT NOT NULL DEFAULT 'shell'")?; + add_column_if_missing(&conn, "prompt", "TEXT")?; + add_column_if_missing(&conn, "name", "TEXT")?; + add_column_if_missing(&conn, "session_target", "TEXT NOT NULL DEFAULT 'isolated'")?; + add_column_if_missing(&conn, "model", "TEXT")?; + add_column_if_missing(&conn, "enabled", "INTEGER NOT NULL DEFAULT 1")?; + add_column_if_missing(&conn, "delivery", "TEXT")?; + add_column_if_missing(&conn, "delete_after_run", "INTEGER NOT NULL DEFAULT 0")?; + + f(&conn) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::Config; + use chrono::Duration as ChronoDuration; + use tempfile::TempDir; + + fn test_config(tmp: &TempDir) -> Config { + let config = Config { + workspace_dir: tmp.path().join("workspace"), + config_path: tmp.path().join("config.toml"), + ..Config::default() + }; + std::fs::create_dir_all(&config.workspace_dir).unwrap(); + config + } + + #[test] + fn add_job_accepts_five_field_expression() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + + let job = add_job(&config, "*/5 * * * *", "echo ok").unwrap(); + assert_eq!(job.expression, "*/5 * * * *"); + assert_eq!(job.command, "echo ok"); + assert!(matches!(job.schedule, Schedule::Cron { .. })); + } + + #[test] + fn add_list_remove_roundtrip() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + + let job = add_job(&config, "*/10 * * * *", "echo roundtrip").unwrap(); + let listed = list_jobs(&config).unwrap(); + assert_eq!(listed.len(), 1); + assert_eq!(listed[0].id, job.id); + + remove_job(&config, &job.id).unwrap(); + assert!(list_jobs(&config).unwrap().is_empty()); + } + + #[test] + fn due_jobs_filters_by_timestamp_and_enabled() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + + let job = add_job(&config, "* * * * *", "echo due").unwrap(); + + let due_now = due_jobs(&config, Utc::now()).unwrap(); + assert!(due_now.is_empty(), "new job should not be due immediately"); + + let far_future = Utc::now() + ChronoDuration::days(365); + let due_future = due_jobs(&config, far_future).unwrap(); + assert_eq!(due_future.len(), 1, "job should be due in far future"); + + let _ = update_job( + &config, + &job.id, + CronJobPatch { + enabled: Some(false), + ..CronJobPatch::default() + }, + ) + .unwrap(); + let due_after_disable = due_jobs(&config, far_future).unwrap(); + assert!(due_after_disable.is_empty()); + } + + #[test] + fn reschedule_after_run_persists_last_status_and_last_run() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + + let job = add_job(&config, "*/15 * * * *", "echo run").unwrap(); + reschedule_after_run(&config, &job, false, "failed output").unwrap(); + + let listed = list_jobs(&config).unwrap(); + let stored = listed.iter().find(|j| j.id == job.id).unwrap(); + assert_eq!(stored.last_status.as_deref(), Some("error")); + assert!(stored.last_run.is_some()); + assert_eq!(stored.last_output.as_deref(), Some("failed output")); + } + + #[test] + fn migration_falls_back_to_legacy_expression() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + + with_connection(&config, |conn| { + conn.execute( + "INSERT INTO cron_jobs (id, expression, command, created_at, next_run) + VALUES (?1, ?2, ?3, ?4, ?5)", + params![ + "legacy-id", + "*/5 * * * *", + "echo legacy", + Utc::now().to_rfc3339(), + (Utc::now() + ChronoDuration::minutes(5)).to_rfc3339(), + ], + )?; + conn.execute( + "UPDATE cron_jobs SET schedule = NULL WHERE id = 'legacy-id'", + [], + )?; + Ok(()) + }) + .unwrap(); + + let job = get_job(&config, "legacy-id").unwrap(); + assert!(matches!(job.schedule, Schedule::Cron { .. })); + } + + #[test] + fn record_and_prune_runs() { + let tmp = TempDir::new().unwrap(); + let mut config = test_config(&tmp); + config.cron.max_run_history = 2; + let job = add_job(&config, "*/5 * * * *", "echo ok").unwrap(); + let base = Utc::now(); + + for idx in 0..3 { + let start = base + ChronoDuration::seconds(idx); + let end = start + ChronoDuration::milliseconds(100); + record_run(&config, &job.id, start, end, "ok", Some("done"), 100).unwrap(); + } + + let runs = list_runs(&config, &job.id, 10).unwrap(); + assert_eq!(runs.len(), 2); + } + + #[test] + fn remove_job_cascades_run_history() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + let job = add_job(&config, "*/5 * * * *", "echo ok").unwrap(); + let start = Utc::now(); + record_run( + &config, + &job.id, + start, + start + ChronoDuration::milliseconds(5), + "ok", + Some("ok"), + 5, + ) + .unwrap(); + + remove_job(&config, &job.id).unwrap(); + let runs = list_runs(&config, &job.id, 10).unwrap(); + assert!(runs.is_empty()); + } +} diff --git a/src/cron/types.rs b/src/cron/types.rs new file mode 100644 index 0000000..f6d3c66 --- /dev/null +++ b/src/cron/types.rs @@ -0,0 +1,140 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)] +#[serde(rename_all = "lowercase")] +pub enum JobType { + #[default] + Shell, + Agent, +} + +impl JobType { + pub(crate) fn as_str(&self) -> &'static str { + match self { + Self::Shell => "shell", + Self::Agent => "agent", + } + } + + pub(crate) fn parse(raw: &str) -> Self { + if raw.eq_ignore_ascii_case("agent") { + Self::Agent + } else { + Self::Shell + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)] +#[serde(rename_all = "lowercase")] +pub enum SessionTarget { + #[default] + Isolated, + Main, +} + +impl SessionTarget { + pub(crate) fn as_str(&self) -> &'static str { + match self { + Self::Isolated => "isolated", + Self::Main => "main", + } + } + + pub(crate) fn parse(raw: &str) -> Self { + if raw.eq_ignore_ascii_case("main") { + Self::Main + } else { + Self::Isolated + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(tag = "kind", rename_all = "lowercase")] +pub enum Schedule { + Cron { + expr: String, + #[serde(default)] + tz: Option, + }, + At { + at: DateTime, + }, + Every { + every_ms: u64, + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct DeliveryConfig { + #[serde(default)] + pub mode: String, + #[serde(default)] + pub channel: Option, + #[serde(default)] + pub to: Option, + #[serde(default = "default_true")] + pub best_effort: bool, +} + +impl Default for DeliveryConfig { + fn default() -> Self { + Self { + mode: "none".to_string(), + channel: None, + to: None, + best_effort: true, + } + } +} + +fn default_true() -> bool { + true +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CronJob { + pub id: String, + pub expression: String, + pub schedule: Schedule, + pub command: String, + pub prompt: Option, + pub name: Option, + pub job_type: JobType, + pub session_target: SessionTarget, + pub model: Option, + pub enabled: bool, + pub delivery: DeliveryConfig, + pub delete_after_run: bool, + pub created_at: DateTime, + pub next_run: DateTime, + pub last_run: Option>, + pub last_status: Option, + pub last_output: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CronRun { + pub id: i64, + pub job_id: String, + pub started_at: DateTime, + pub finished_at: DateTime, + pub status: String, + pub output: Option, + pub duration_ms: Option, +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct CronJobPatch { + pub schedule: Option, + pub command: Option, + pub prompt: Option, + pub name: Option, + pub enabled: Option, + pub delivery: Option, + pub model: Option, + pub session_target: Option, + pub delete_after_run: Option, +} diff --git a/src/daemon/mod.rs b/src/daemon/mod.rs index e2b3e2c..c60cd2d 100644 --- a/src/daemon/mod.rs +++ b/src/daemon/mod.rs @@ -71,7 +71,7 @@ pub async fn run(config: Config, host: String, port: u16) -> Result<()> { )); } - { + if config.cron.enabled { let scheduler_cfg = config.clone(); handles.push(spawn_component_supervisor( "scheduler", @@ -82,6 +82,9 @@ pub async fn run(config: Config, host: String, port: u16) -> Result<()> { async move { crate::cron::scheduler::run(cfg).await } }, )); + } else { + crate::health::mark_component_ok("scheduler"); + tracing::info!("Cron disabled; scheduler supervisor not started"); } println!("🧠 ZeroClaw daemon started"); @@ -153,6 +156,8 @@ where Ok(()) => { crate::health::mark_component_error(name, "component exited unexpectedly"); tracing::warn!("Daemon component '{name}' exited unexpectedly"); + // Clean exit — reset backoff since the component ran successfully + backoff = initial_backoff_secs.max(1); } Err(e) => { crate::health::mark_component_error(name, e.to_string()); @@ -162,6 +167,7 @@ where crate::health::bump_component_restart(name); tokio::time::sleep(Duration::from_secs(backoff)).await; + // Double backoff AFTER sleeping so first error uses initial_backoff backoff = backoff.saturating_mul(2).min(max_backoff); } }) @@ -190,7 +196,8 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> { for task in tasks { let prompt = format!("[Heartbeat Task] {task}"); let temp = config.default_temperature; - if let Err(e) = crate::agent::run(config.clone(), Some(prompt), None, None, temp).await + if let Err(e) = + crate::agent::run(config.clone(), Some(prompt), None, None, temp, vec![]).await { crate::health::mark_component_error("heartbeat", e.to_string()); tracing::warn!("Heartbeat task failed: {e}"); @@ -207,6 +214,12 @@ fn has_supervised_channels(config: &Config) -> bool { || config.channels_config.slack.is_some() || config.channels_config.imessage.is_some() || config.channels_config.matrix.is_some() + || config.channels_config.signal.is_some() + || config.channels_config.whatsapp.is_some() + || config.channels_config.email.is_some() + || config.channels_config.irc.is_some() + || config.channels_config.lark.is_some() + || config.channels_config.dingtalk.is_some() } #[cfg(test)] @@ -286,4 +299,15 @@ mod tests { }); assert!(has_supervised_channels(&config)); } + + #[test] + fn detects_dingtalk_as_supervised_channel() { + let mut config = Config::default(); + config.channels_config.dingtalk = Some(crate::config::schema::DingTalkConfig { + client_id: "client_id".into(), + client_secret: "client_secret".into(), + allowed_users: vec!["*".into()], + }); + assert!(has_supervised_channels(&config)); + } } diff --git a/src/doctor/mod.rs b/src/doctor/mod.rs index e858f7c..6db91fc 100644 --- a/src/doctor/mod.rs +++ b/src/doctor/mod.rs @@ -1,28 +1,419 @@ use crate::config::Config; -use anyhow::{Context, Result}; +use anyhow::Result; use chrono::{DateTime, Utc}; +use std::io::Write; +use std::path::Path; const DAEMON_STALE_SECONDS: i64 = 30; const SCHEDULER_STALE_SECONDS: i64 = 120; const CHANNEL_STALE_SECONDS: i64 = 300; +const COMMAND_VERSION_PREVIEW_CHARS: usize = 60; -pub fn run(config: &Config) -> Result<()> { - let state_file = crate::daemon::state_file_path(config); - if !state_file.exists() { - println!("🩺 ZeroClaw Doctor"); - println!(" ❌ daemon state file not found: {}", state_file.display()); - println!(" 💡 Start daemon with: zeroclaw daemon"); - return Ok(()); +// ── Diagnostic item ────────────────────────────────────────────── + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Severity { + Ok, + Warn, + Error, +} + +struct DiagItem { + severity: Severity, + category: &'static str, + message: String, +} + +impl DiagItem { + fn ok(category: &'static str, msg: impl Into) -> Self { + Self { + severity: Severity::Ok, + category, + message: msg.into(), + } + } + fn warn(category: &'static str, msg: impl Into) -> Self { + Self { + severity: Severity::Warn, + category, + message: msg.into(), + } + } + fn error(category: &'static str, msg: impl Into) -> Self { + Self { + severity: Severity::Error, + category, + message: msg.into(), + } } - let raw = std::fs::read_to_string(&state_file) - .with_context(|| format!("Failed to read {}", state_file.display()))?; - let snapshot: serde_json::Value = serde_json::from_str(&raw) - .with_context(|| format!("Failed to parse {}", state_file.display()))?; + fn icon(&self) -> &'static str { + match self.severity { + Severity::Ok => "✅", + Severity::Warn => "⚠️ ", + Severity::Error => "❌", + } + } +} - println!("🩺 ZeroClaw Doctor"); - println!(" State file: {}", state_file.display()); +// ── Public entry point ─────────────────────────────────────────── +pub fn run(config: &Config) -> Result<()> { + let mut items: Vec = Vec::new(); + + check_config_semantics(config, &mut items); + check_workspace(config, &mut items); + check_daemon_state(config, &mut items); + check_environment(&mut items); + + // Print report + println!("🩺 ZeroClaw Doctor (enhanced)"); + println!(); + + let mut current_cat = ""; + for item in &items { + if item.category != current_cat { + current_cat = item.category; + println!(" [{current_cat}]"); + } + println!(" {} {}", item.icon(), item.message); + } + + let errors = items + .iter() + .filter(|i| i.severity == Severity::Error) + .count(); + let warns = items + .iter() + .filter(|i| i.severity == Severity::Warn) + .count(); + let oks = items.iter().filter(|i| i.severity == Severity::Ok).count(); + + println!(); + println!(" Summary: {oks} ok, {warns} warnings, {errors} errors"); + + if errors > 0 { + println!(" 💡 Fix the errors above, then run `zeroclaw doctor` again."); + } + + Ok(()) +} + +// ── Config semantic validation ─────────────────────────────────── + +fn check_config_semantics(config: &Config, items: &mut Vec) { + let cat = "config"; + + // Config file exists + if config.config_path.exists() { + items.push(DiagItem::ok( + cat, + format!("config file: {}", config.config_path.display()), + )); + } else { + items.push(DiagItem::error( + cat, + format!("config file not found: {}", config.config_path.display()), + )); + } + + // Provider validity + if let Some(ref provider) = config.default_provider { + if let Some(reason) = provider_validation_error(provider) { + items.push(DiagItem::error( + cat, + format!("default provider \"{provider}\" is invalid: {reason}"), + )); + } else { + items.push(DiagItem::ok( + cat, + format!("provider \"{provider}\" is valid"), + )); + } + } else { + items.push(DiagItem::error(cat, "no default_provider configured")); + } + + // API key presence + if config.default_provider.as_deref() != Some("ollama") { + if config.api_key.is_some() { + items.push(DiagItem::ok(cat, "API key configured")); + } else { + items.push(DiagItem::warn( + cat, + "no api_key set (may rely on env vars or provider defaults)", + )); + } + } + + // Model configured + if config.default_model.is_some() { + items.push(DiagItem::ok( + cat, + format!( + "default model: {}", + config.default_model.as_deref().unwrap_or("?") + ), + )); + } else { + items.push(DiagItem::warn(cat, "no default_model configured")); + } + + // Temperature range + if config.default_temperature >= 0.0 && config.default_temperature <= 2.0 { + items.push(DiagItem::ok( + cat, + format!( + "temperature {:.1} (valid range 0.0–2.0)", + config.default_temperature + ), + )); + } else { + items.push(DiagItem::error( + cat, + format!( + "temperature {:.1} is out of range (expected 0.0–2.0)", + config.default_temperature + ), + )); + } + + // Gateway port range + let port = config.gateway.port; + if port > 0 { + items.push(DiagItem::ok(cat, format!("gateway port: {port}"))); + } else { + items.push(DiagItem::error(cat, "gateway port is 0 (invalid)")); + } + + // Reliability: fallback providers + for fb in &config.reliability.fallback_providers { + if let Some(reason) = provider_validation_error(fb) { + items.push(DiagItem::warn( + cat, + format!("fallback provider \"{fb}\" is invalid: {reason}"), + )); + } + } + + // Model routes validation + for route in &config.model_routes { + if route.hint.is_empty() { + items.push(DiagItem::warn(cat, "model route with empty hint")); + } + if let Some(reason) = provider_validation_error(&route.provider) { + items.push(DiagItem::warn( + cat, + format!( + "model route \"{}\" uses invalid provider \"{}\": {}", + route.hint, route.provider, reason + ), + )); + } + if route.model.is_empty() { + items.push(DiagItem::warn( + cat, + format!("model route \"{}\" has empty model", route.hint), + )); + } + } + + // Channel: at least one configured + let cc = &config.channels_config; + let has_channel = cc.telegram.is_some() + || cc.discord.is_some() + || cc.slack.is_some() + || cc.imessage.is_some() + || cc.matrix.is_some() + || cc.whatsapp.is_some() + || cc.email.is_some() + || cc.irc.is_some() + || cc.lark.is_some() + || cc.webhook.is_some(); + + if has_channel { + items.push(DiagItem::ok(cat, "at least one channel configured")); + } else { + items.push(DiagItem::warn( + cat, + "no channels configured — run `zeroclaw onboard` to set one up", + )); + } + + // Delegate agents: provider validity + for (name, agent) in &config.agents { + if let Some(reason) = provider_validation_error(&agent.provider) { + items.push(DiagItem::warn( + cat, + format!( + "agent \"{name}\" uses invalid provider \"{}\": {}", + agent.provider, reason + ), + )); + } + } +} + +fn provider_validation_error(name: &str) -> Option { + match crate::providers::create_provider(name, None) { + Ok(_) => None, + Err(err) => Some( + err.to_string() + .lines() + .next() + .unwrap_or("invalid provider") + .into(), + ), + } +} + +// ── Workspace integrity ────────────────────────────────────────── + +fn check_workspace(config: &Config, items: &mut Vec) { + let cat = "workspace"; + let ws = &config.workspace_dir; + + if ws.exists() { + items.push(DiagItem::ok( + cat, + format!("directory exists: {}", ws.display()), + )); + } else { + items.push(DiagItem::error( + cat, + format!("directory missing: {}", ws.display()), + )); + return; + } + + // Writable check + let probe = workspace_probe_path(ws); + match std::fs::OpenOptions::new() + .write(true) + .create_new(true) + .open(&probe) + { + Ok(mut probe_file) => { + let write_result = probe_file.write_all(b"probe"); + drop(probe_file); + let _ = std::fs::remove_file(&probe); + match write_result { + Ok(()) => items.push(DiagItem::ok(cat, "directory is writable")), + Err(e) => items.push(DiagItem::error( + cat, + format!("directory write probe failed: {e}"), + )), + } + } + Err(e) => { + items.push(DiagItem::error( + cat, + format!("directory is not writable: {e}"), + )); + } + } + + // Disk space (best-effort via `df`) + if let Some(avail_mb) = disk_available_mb(ws) { + if avail_mb >= 100 { + items.push(DiagItem::ok( + cat, + format!("disk space: {avail_mb} MB available"), + )); + } else { + items.push(DiagItem::warn( + cat, + format!("low disk space: only {avail_mb} MB available"), + )); + } + } + + // Key workspace files + check_file_exists(ws, "SOUL.md", false, cat, items); + check_file_exists(ws, "AGENTS.md", false, cat, items); +} + +fn check_file_exists( + base: &Path, + name: &str, + required: bool, + cat: &'static str, + items: &mut Vec, +) { + let path = base.join(name); + if path.is_file() { + items.push(DiagItem::ok(cat, format!("{name} present"))); + } else if required { + items.push(DiagItem::error(cat, format!("{name} missing"))); + } else { + items.push(DiagItem::warn(cat, format!("{name} not found (optional)"))); + } +} + +fn disk_available_mb(path: &Path) -> Option { + let output = std::process::Command::new("df") + .arg("-m") + .arg(path) + .output() + .ok()?; + if !output.status.success() { + return None; + } + let stdout = String::from_utf8_lossy(&output.stdout); + parse_df_available_mb(&stdout) +} + +fn parse_df_available_mb(stdout: &str) -> Option { + let line = stdout.lines().rev().find(|line| !line.trim().is_empty())?; + let avail = line.split_whitespace().nth(3)?; + avail.parse::().ok() +} + +fn workspace_probe_path(workspace_dir: &Path) -> std::path::PathBuf { + let nanos = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map_or(0, |duration| duration.as_nanos()); + workspace_dir.join(format!( + ".zeroclaw_doctor_probe_{}_{}", + std::process::id(), + nanos + )) +} + +// ── Daemon state (original logic, preserved) ───────────────────── + +fn check_daemon_state(config: &Config, items: &mut Vec) { + let cat = "daemon"; + let state_file = crate::daemon::state_file_path(config); + + if !state_file.exists() { + items.push(DiagItem::error( + cat, + format!( + "state file not found: {} — is the daemon running?", + state_file.display() + ), + )); + return; + } + + let raw = match std::fs::read_to_string(&state_file) { + Ok(r) => r, + Err(e) => { + items.push(DiagItem::error(cat, format!("cannot read state file: {e}"))); + return; + } + }; + + let snapshot: serde_json::Value = match serde_json::from_str(&raw) { + Ok(v) => v, + Err(e) => { + items.push(DiagItem::error(cat, format!("invalid state JSON: {e}"))); + return; + } + }; + + // Daemon heartbeat freshness let updated_at = snapshot .get("updated_at") .and_then(serde_json::Value::as_str) @@ -33,28 +424,32 @@ pub fn run(config: &Config) -> Result<()> { .signed_duration_since(ts.with_timezone(&Utc)) .num_seconds(); if age <= DAEMON_STALE_SECONDS { - println!(" ✅ daemon heartbeat fresh ({age}s ago)"); + items.push(DiagItem::ok(cat, format!("heartbeat fresh ({age}s ago)"))); } else { - println!(" ❌ daemon heartbeat stale ({age}s ago)"); + items.push(DiagItem::error( + cat, + format!("heartbeat stale ({age}s ago)"), + )); } } else { - println!(" ❌ invalid daemon timestamp: {updated_at}"); + items.push(DiagItem::error( + cat, + format!("invalid daemon timestamp: {updated_at}"), + )); } - let mut channel_count = 0_u32; - let mut stale_channels = 0_u32; - + // Components if let Some(components) = snapshot .get("components") .and_then(serde_json::Value::as_object) { + // Scheduler if let Some(scheduler) = components.get("scheduler") { let scheduler_ok = scheduler .get("status") .and_then(serde_json::Value::as_str) .is_some_and(|s| s == "ok"); - - let scheduler_last_ok = scheduler + let scheduler_age = scheduler .get("last_ok") .and_then(serde_json::Value::as_str) .and_then(parse_rfc3339) @@ -62,22 +457,28 @@ pub fn run(config: &Config) -> Result<()> { Utc::now().signed_duration_since(dt).num_seconds() }); - if scheduler_ok && scheduler_last_ok <= SCHEDULER_STALE_SECONDS { - println!(" ✅ scheduler healthy (last ok {scheduler_last_ok}s ago)"); + if scheduler_ok && scheduler_age <= SCHEDULER_STALE_SECONDS { + items.push(DiagItem::ok( + cat, + format!("scheduler healthy (last ok {scheduler_age}s ago)"), + )); } else { - println!( - " ❌ scheduler unhealthy/stale (status_ok={scheduler_ok}, age={scheduler_last_ok}s)" - ); + items.push(DiagItem::error( + cat, + format!("scheduler unhealthy (ok={scheduler_ok}, age={scheduler_age}s)"), + )); } } else { - println!(" ❌ scheduler component missing"); + items.push(DiagItem::warn(cat, "scheduler component not tracked yet")); } + // Channels + let mut channel_count = 0u32; + let mut stale = 0u32; for (name, component) in components { if !name.starts_with("channel:") { continue; } - channel_count += 1; let status_ok = component .get("status") @@ -92,25 +493,273 @@ pub fn run(config: &Config) -> Result<()> { }); if status_ok && age <= CHANNEL_STALE_SECONDS { - println!(" ✅ {name} fresh (last ok {age}s ago)"); + items.push(DiagItem::ok(cat, format!("{name} fresh ({age}s ago)"))); } else { - stale_channels += 1; - println!(" ❌ {name} stale/unhealthy (status_ok={status_ok}, age={age}s)"); + stale += 1; + items.push(DiagItem::error( + cat, + format!("{name} stale (ok={status_ok}, age={age}s)"), + )); } } - } - if channel_count == 0 { - println!(" ℹ️ no channel components tracked in state yet"); - } else { - println!(" Channel summary: {channel_count} total, {stale_channels} stale"); + if channel_count == 0 { + items.push(DiagItem::warn(cat, "no channel components tracked yet")); + } else if stale > 0 { + items.push(DiagItem::warn( + cat, + format!("{channel_count} channels, {stale} stale"), + )); + } } - - Ok(()) } +// ── Environment checks ─────────────────────────────────────────── + +fn check_environment(items: &mut Vec) { + let cat = "environment"; + + // git + check_command_available("git", &["--version"], cat, items); + + // Shell + let shell = std::env::var("SHELL").unwrap_or_default(); + if shell.is_empty() { + items.push(DiagItem::warn(cat, "$SHELL not set")); + } else { + items.push(DiagItem::ok(cat, format!("shell: {shell}"))); + } + + // HOME + if std::env::var("HOME").is_ok() || std::env::var("USERPROFILE").is_ok() { + items.push(DiagItem::ok(cat, "home directory env set")); + } else { + items.push(DiagItem::error( + cat, + "neither $HOME nor $USERPROFILE is set", + )); + } + + // Optional tools + check_command_available("curl", &["--version"], cat, items); +} + +fn check_command_available(cmd: &str, args: &[&str], cat: &'static str, items: &mut Vec) { + match std::process::Command::new(cmd) + .args(args) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .output() + { + Ok(output) if output.status.success() => { + let ver = String::from_utf8_lossy(&output.stdout); + let first_line = ver.lines().next().unwrap_or("").trim(); + let display = truncate_for_display(first_line, COMMAND_VERSION_PREVIEW_CHARS); + items.push(DiagItem::ok(cat, format!("{cmd}: {display}"))); + } + Ok(_) => { + items.push(DiagItem::warn( + cat, + format!("{cmd} found but returned non-zero"), + )); + } + Err(_) => { + items.push(DiagItem::warn(cat, format!("{cmd} not found in PATH"))); + } + } +} + +fn truncate_for_display(input: &str, max_chars: usize) -> String { + let mut chars = input.chars(); + let preview: String = chars.by_ref().take(max_chars).collect(); + if chars.next().is_some() { + format!("{preview}…") + } else { + preview + } +} + +// ── Helpers ────────────────────────────────────────────────────── + fn parse_rfc3339(raw: &str) -> Option> { DateTime::parse_from_rfc3339(raw) .ok() .map(|dt| dt.with_timezone(&Utc)) } + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + #[test] + fn provider_validation_checks_custom_url_shape() { + assert!(provider_validation_error("openrouter").is_none()); + assert!(provider_validation_error("custom:https://example.com").is_none()); + assert!(provider_validation_error("anthropic-custom:https://example.com").is_none()); + + let invalid_custom = provider_validation_error("custom:").unwrap_or_default(); + assert!(invalid_custom.contains("requires a URL")); + + let invalid_unknown = provider_validation_error("totally-fake").unwrap_or_default(); + assert!(invalid_unknown.contains("Unknown provider")); + } + + #[test] + fn diag_item_icons() { + assert_eq!(DiagItem::ok("t", "m").icon(), "✅"); + assert_eq!(DiagItem::warn("t", "m").icon(), "⚠️ "); + assert_eq!(DiagItem::error("t", "m").icon(), "❌"); + } + + #[test] + fn config_validation_catches_bad_temperature() { + let mut config = Config::default(); + config.default_temperature = 5.0; + let mut items = Vec::new(); + check_config_semantics(&config, &mut items); + let temp_item = items.iter().find(|i| i.message.contains("temperature")); + assert!(temp_item.is_some()); + assert_eq!(temp_item.unwrap().severity, Severity::Error); + } + + #[test] + fn config_validation_accepts_valid_temperature() { + let mut config = Config::default(); + config.default_temperature = 0.7; + let mut items = Vec::new(); + check_config_semantics(&config, &mut items); + let temp_item = items.iter().find(|i| i.message.contains("temperature")); + assert!(temp_item.is_some()); + assert_eq!(temp_item.unwrap().severity, Severity::Ok); + } + + #[test] + fn config_validation_warns_no_channels() { + let config = Config::default(); + let mut items = Vec::new(); + check_config_semantics(&config, &mut items); + let ch_item = items.iter().find(|i| i.message.contains("channel")); + assert!(ch_item.is_some()); + assert_eq!(ch_item.unwrap().severity, Severity::Warn); + } + + #[test] + fn config_validation_catches_unknown_provider() { + let mut config = Config::default(); + config.default_provider = Some("totally-fake".into()); + let mut items = Vec::new(); + check_config_semantics(&config, &mut items); + let prov_item = items + .iter() + .find(|i| i.message.contains("default provider")); + assert!(prov_item.is_some()); + assert_eq!(prov_item.unwrap().severity, Severity::Error); + } + + #[test] + fn config_validation_catches_malformed_custom_provider() { + let mut config = Config::default(); + config.default_provider = Some("custom:".into()); + let mut items = Vec::new(); + check_config_semantics(&config, &mut items); + + let prov_item = items.iter().find(|item| { + item.message + .contains("default provider \"custom:\" is invalid") + }); + assert!(prov_item.is_some()); + assert_eq!(prov_item.unwrap().severity, Severity::Error); + } + + #[test] + fn config_validation_accepts_custom_provider() { + let mut config = Config::default(); + config.default_provider = Some("custom:https://my-api.com".into()); + let mut items = Vec::new(); + check_config_semantics(&config, &mut items); + let prov_item = items.iter().find(|i| i.message.contains("is valid")); + assert!(prov_item.is_some()); + assert_eq!(prov_item.unwrap().severity, Severity::Ok); + } + + #[test] + fn config_validation_warns_bad_fallback() { + let mut config = Config::default(); + config.reliability.fallback_providers = vec!["fake-provider".into()]; + let mut items = Vec::new(); + check_config_semantics(&config, &mut items); + let fb_item = items + .iter() + .find(|i| i.message.contains("fallback provider")); + assert!(fb_item.is_some()); + assert_eq!(fb_item.unwrap().severity, Severity::Warn); + } + + #[test] + fn config_validation_warns_bad_custom_fallback() { + let mut config = Config::default(); + config.reliability.fallback_providers = vec!["custom:".into()]; + let mut items = Vec::new(); + check_config_semantics(&config, &mut items); + + let fb_item = items.iter().find(|item| { + item.message + .contains("fallback provider \"custom:\" is invalid") + }); + assert!(fb_item.is_some()); + assert_eq!(fb_item.unwrap().severity, Severity::Warn); + } + + #[test] + fn config_validation_warns_empty_model_route() { + let mut config = Config::default(); + config.model_routes = vec![crate::config::ModelRouteConfig { + hint: "fast".into(), + provider: "groq".into(), + model: String::new(), + api_key: None, + }]; + let mut items = Vec::new(); + check_config_semantics(&config, &mut items); + let route_item = items.iter().find(|i| i.message.contains("empty model")); + assert!(route_item.is_some()); + assert_eq!(route_item.unwrap().severity, Severity::Warn); + } + + #[test] + fn environment_check_finds_git() { + let mut items = Vec::new(); + check_environment(&mut items); + let git_item = items.iter().find(|i| i.message.starts_with("git:")); + // git should be available in any CI/dev environment + assert!(git_item.is_some()); + assert_eq!(git_item.unwrap().severity, Severity::Ok); + } + + #[test] + fn parse_df_available_mb_uses_last_data_line() { + let stdout = + "Filesystem 1M-blocks Used Available Use% Mounted on\n/dev/sda1 1000 500 500 50% /\n"; + assert_eq!(parse_df_available_mb(stdout), Some(500)); + } + + #[test] + fn truncate_for_display_preserves_utf8_boundaries() { + let preview = truncate_for_display("版本号-alpha-build", 3); + assert_eq!(preview, "版本号…"); + } + + #[test] + fn workspace_probe_path_is_hidden_and_unique() { + let tmp = TempDir::new().unwrap(); + let first = workspace_probe_path(tmp.path()); + let second = workspace_probe_path(tmp.path()); + + assert_ne!(first, second); + assert!(first + .file_name() + .and_then(|name| name.to_str()) + .is_some_and(|name| name.starts_with(".zeroclaw_doctor_probe_"))); + } +} diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 4290451..988b780 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -7,11 +7,15 @@ //! - Request timeouts (30s) to prevent slow-loris attacks //! - Header sanitization (handled by axum/hyper) -use crate::channels::{Channel, WhatsAppChannel}; +use crate::channels::{Channel, SendMessage, WhatsAppChannel}; use crate::config::Config; use crate::memory::{self, Memory, MemoryCategory}; use crate::providers::{self, Provider}; +use crate::runtime; use crate::security::pairing::{constant_time_eq, is_public_bind, PairingGuard}; +use crate::security::SecurityPolicy; +use crate::tools; +use crate::util::truncate_with_ellipsis; use anyhow::Result; use axum::{ body::Bytes, @@ -21,16 +25,153 @@ use axum::{ routing::{get, post}, Router, }; +use parking_lot::Mutex; +use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; -use std::time::Duration; +use std::time::{Duration, Instant}; use tower_http::limit::RequestBodyLimitLayer; use tower_http::timeout::TimeoutLayer; +use uuid::Uuid; /// Maximum request body size (64KB) — prevents memory exhaustion pub const MAX_BODY_SIZE: usize = 65_536; /// Request timeout (30s) — prevents slow-loris attacks pub const REQUEST_TIMEOUT_SECS: u64 = 30; +/// Sliding window used by gateway rate limiting. +pub const RATE_LIMIT_WINDOW_SECS: u64 = 60; + +fn webhook_memory_key() -> String { + format!("webhook_msg_{}", Uuid::new_v4()) +} + +fn whatsapp_memory_key(msg: &crate::channels::traits::ChannelMessage) -> String { + format!("whatsapp_{}_{}", msg.sender, msg.id) +} + +fn hash_webhook_secret(value: &str) -> String { + use sha2::{Digest, Sha256}; + + let digest = Sha256::digest(value.as_bytes()); + hex::encode(digest) +} + +/// How often the rate limiter sweeps stale IP entries from its map. +const RATE_LIMITER_SWEEP_INTERVAL_SECS: u64 = 300; // 5 minutes + +#[derive(Debug)] +struct SlidingWindowRateLimiter { + limit_per_window: u32, + window: Duration, + requests: Mutex<(HashMap>, Instant)>, +} + +impl SlidingWindowRateLimiter { + fn new(limit_per_window: u32, window: Duration) -> Self { + Self { + limit_per_window, + window, + requests: Mutex::new((HashMap::new(), Instant::now())), + } + } + + fn allow(&self, key: &str) -> bool { + if self.limit_per_window == 0 { + return true; + } + + let now = Instant::now(); + let cutoff = now.checked_sub(self.window).unwrap_or_else(Instant::now); + + let mut guard = self.requests.lock(); + let (requests, last_sweep) = &mut *guard; + + // Periodic sweep: remove IPs with no recent requests + if last_sweep.elapsed() >= Duration::from_secs(RATE_LIMITER_SWEEP_INTERVAL_SECS) { + requests.retain(|_, timestamps| { + timestamps.retain(|t| *t > cutoff); + !timestamps.is_empty() + }); + *last_sweep = now; + } + + let entry = requests.entry(key.to_owned()).or_default(); + entry.retain(|instant| *instant > cutoff); + + if entry.len() >= self.limit_per_window as usize { + return false; + } + + entry.push(now); + true + } +} + +#[derive(Debug)] +pub struct GatewayRateLimiter { + pair: SlidingWindowRateLimiter, + webhook: SlidingWindowRateLimiter, +} + +impl GatewayRateLimiter { + fn new(pair_per_minute: u32, webhook_per_minute: u32) -> Self { + let window = Duration::from_secs(RATE_LIMIT_WINDOW_SECS); + Self { + pair: SlidingWindowRateLimiter::new(pair_per_minute, window), + webhook: SlidingWindowRateLimiter::new(webhook_per_minute, window), + } + } + + fn allow_pair(&self, key: &str) -> bool { + self.pair.allow(key) + } + + fn allow_webhook(&self, key: &str) -> bool { + self.webhook.allow(key) + } +} + +#[derive(Debug)] +pub struct IdempotencyStore { + ttl: Duration, + keys: Mutex>, +} + +impl IdempotencyStore { + fn new(ttl: Duration) -> Self { + Self { + ttl, + keys: Mutex::new(HashMap::new()), + } + } + + /// Returns true if this key is new and is now recorded. + fn record_if_new(&self, key: &str) -> bool { + let now = Instant::now(); + let mut keys = self.keys.lock(); + + keys.retain(|_, seen_at| now.duration_since(*seen_at) < self.ttl); + + if keys.contains_key(key) { + return false; + } + + keys.insert(key.to_owned(), now); + true + } +} + +fn client_key_from_headers(headers: &HeaderMap) -> String { + for header_name in ["X-Forwarded-For", "X-Real-IP"] { + if let Some(value) = headers.get(header_name).and_then(|v| v.to_str().ok()) { + let first = value.split(',').next().unwrap_or("").trim(); + if !first.is_empty() { + return first.to_owned(); + } + } + } + "unknown".into() +} /// Shared state for all axum handlers #[derive(Clone)] @@ -40,9 +181,14 @@ pub struct AppState { pub temperature: f64, pub mem: Arc, pub auto_save: bool, - pub webhook_secret: Option>, + /// SHA-256 hash of `X-Webhook-Secret` (hex-encoded), never plaintext. + pub webhook_secret_hash: Option>, pub pairing: Arc, + pub rate_limiter: Arc, + pub idempotency_store: Arc, pub whatsapp: Option>, + /// `WhatsApp` app secret for webhook signature verification (`X-Hub-Signature-256`) + pub whatsapp_app_secret: Option>, } /// Run the HTTP gateway using axum with proper HTTP/1.1 compliance. @@ -66,26 +212,58 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { let provider: Arc = Arc::from(providers::create_resilient_provider( config.default_provider.as_deref().unwrap_or("openrouter"), config.api_key.as_deref(), + config.api_url.as_deref(), &config.reliability, )?); let model = config .default_model .clone() - .unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into()); + .unwrap_or_else(|| "anthropic/claude-sonnet-4".into()); let temperature = config.default_temperature; let mem: Arc = Arc::from(memory::create_memory( &config.memory, &config.workspace_dir, config.api_key.as_deref(), )?); + let runtime: Arc = + Arc::from(runtime::create_runtime(&config.runtime)?); + let security = Arc::new(SecurityPolicy::from_config( + &config.autonomy, + &config.workspace_dir, + )); + let (composio_key, composio_entity_id) = if config.composio.enabled { + ( + config.composio.api_key.as_deref(), + Some(config.composio.entity_id.as_str()), + ) + } else { + (None, None) + }; + + let _tools_registry = Arc::new(tools::all_tools_with_runtime( + Arc::new(config.clone()), + &security, + runtime, + Arc::clone(&mem), + composio_key, + composio_entity_id, + &config.browser, + &config.http_request, + &config.workspace_dir, + &config.agents, + config.api_key.as_deref(), + &config, + )); // Extract webhook secret for authentication - let webhook_secret: Option> = config - .channels_config - .webhook - .as_ref() - .and_then(|w| w.secret.as_deref()) - .map(Arc::from); + let webhook_secret_hash: Option> = + config.channels_config.webhook.as_ref().and_then(|webhook| { + webhook.secret.as_ref().and_then(|raw_secret| { + let trimmed_secret = raw_secret.trim(); + (!trimmed_secret.is_empty()) + .then(|| Arc::::from(hash_webhook_secret(trimmed_secret))) + }) + }); // WhatsApp channel (if configured) let whatsapp_channel: Option> = @@ -98,11 +276,37 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { )) }); + // WhatsApp app secret for webhook signature verification + // Priority: environment variable > config file + let whatsapp_app_secret: Option> = std::env::var("ZEROCLAW_WHATSAPP_APP_SECRET") + .ok() + .and_then(|secret| { + let secret = secret.trim(); + (!secret.is_empty()).then(|| secret.to_owned()) + }) + .or_else(|| { + config.channels_config.whatsapp.as_ref().and_then(|wa| { + wa.app_secret + .as_deref() + .map(str::trim) + .filter(|secret| !secret.is_empty()) + .map(ToOwned::to_owned) + }) + }) + .map(Arc::from); + // ── Pairing guard ────────────────────────────────────── let pairing = Arc::new(PairingGuard::new( config.gateway.require_pairing, &config.gateway.paired_tokens, )); + let rate_limiter = Arc::new(GatewayRateLimiter::new( + config.gateway.pair_rate_limit_per_minute, + config.gateway.webhook_rate_limit_per_minute, + )); + let idempotency_store = Arc::new(IdempotencyStore::new(Duration::from_secs( + config.gateway.idempotency_ttl_secs.max(1), + ))); // ── Tunnel ──────────────────────────────────────────────── let tunnel = crate::tunnel::create_tunnel(&config.tunnel)?; @@ -145,9 +349,6 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { } else { println!(" ⚠️ Pairing: DISABLED (all requests accepted)"); } - if webhook_secret.is_some() { - println!(" 🔒 Webhook secret: ENABLED"); - } println!(" Press Ctrl+C to stop.\n"); crate::health::mark_component_ok("gateway"); @@ -159,9 +360,12 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { temperature, mem, auto_save: config.memory.auto_save, - webhook_secret, + webhook_secret_hash, pairing, + rate_limiter, + idempotency_store, whatsapp: whatsapp_channel, + whatsapp_app_secret, }; // Build router with middleware @@ -200,6 +404,16 @@ async fn handle_health(State(state): State) -> impl IntoResponse { /// POST /pair — exchange one-time code for bearer token async fn handle_pair(State(state): State, headers: HeaderMap) -> impl IntoResponse { + let client_key = client_key_from_headers(&headers); + if !state.rate_limiter.allow_pair(&client_key) { + tracing::warn!("/pair rate limit exceeded for key: {client_key}"); + let err = serde_json::json!({ + "error": "Too many pairing requests. Please retry later.", + "retry_after": RATE_LIMIT_WINDOW_SECS, + }); + return (StatusCode::TOO_MANY_REQUESTS, Json(err)); + } + let code = headers .get("X-Pairing-Code") .and_then(|v| v.to_str().ok()) @@ -245,6 +459,16 @@ async fn handle_webhook( headers: HeaderMap, body: Result, axum::extract::rejection::JsonRejection>, ) -> impl IntoResponse { + let client_key = client_key_from_headers(&headers); + if !state.rate_limiter.allow_webhook(&client_key) { + tracing::warn!("/webhook rate limit exceeded for key: {client_key}"); + let err = serde_json::json!({ + "error": "Too many webhook requests. Please retry later.", + "retry_after": RATE_LIMIT_WINDOW_SECS, + }); + return (StatusCode::TOO_MANY_REQUESTS, Json(err)); + } + // ── Bearer token auth (pairing) ── if state.pairing.require_pairing() { let auth = headers @@ -262,12 +486,15 @@ async fn handle_webhook( } // ── Webhook secret auth (optional, additional layer) ── - if let Some(ref secret) = state.webhook_secret { - let header_val = headers + if let Some(ref secret_hash) = state.webhook_secret_hash { + let header_hash = headers .get("X-Webhook-Secret") - .and_then(|v| v.to_str().ok()); - match header_val { - Some(val) if constant_time_eq(val, secret.as_ref()) => {} + .and_then(|v| v.to_str().ok()) + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(hash_webhook_secret); + match header_hash { + Some(val) if constant_time_eq(&val, secret_hash.as_ref()) => {} _ => { tracing::warn!("Webhook: rejected request — invalid or missing X-Webhook-Secret"); let err = serde_json::json!({"error": "Unauthorized — invalid or missing X-Webhook-Secret header"}); @@ -280,25 +507,45 @@ async fn handle_webhook( let Json(webhook_body) = match body { Ok(b) => b, Err(e) => { + tracing::warn!("Webhook JSON parse error: {e}"); let err = serde_json::json!({ - "error": format!("Invalid JSON: {e}. Expected: {{\"message\": \"...\"}}") + "error": "Invalid JSON body. Expected: {\"message\": \"...\"}" }); return (StatusCode::BAD_REQUEST, Json(err)); } }; + // ── Idempotency (optional) ── + if let Some(idempotency_key) = headers + .get("X-Idempotency-Key") + .and_then(|v| v.to_str().ok()) + .map(str::trim) + .filter(|value| !value.is_empty()) + { + if !state.idempotency_store.record_if_new(idempotency_key) { + tracing::info!("Webhook duplicate ignored (idempotency key: {idempotency_key})"); + let body = serde_json::json!({ + "status": "duplicate", + "idempotent": true, + "message": "Request already processed for this idempotency key" + }); + return (StatusCode::OK, Json(body)); + } + } + let message = &webhook_body.message; if state.auto_save { + let key = webhook_memory_key(); let _ = state .mem - .store("webhook_msg", message, MemoryCategory::Conversation) + .store(&key, message, MemoryCategory::Conversation, None) .await; } match state .provider - .chat(message, &state.model, state.temperature) + .simple_chat(message, &state.model, state.temperature) .await { Ok(response) => { @@ -306,8 +553,11 @@ async fn handle_webhook( (StatusCode::OK, Json(body)) } Err(e) => { - tracing::error!("LLM error: {e:#}"); - let err = serde_json::json!({"error": "Internal error processing your request"}); + tracing::error!( + "Webhook provider error: {}", + providers::sanitize_api_error(&e.to_string()) + ); + let err = serde_json::json!({"error": "LLM request failed"}); (StatusCode::INTERNAL_SERVER_ERROR, Json(err)) } } @@ -333,10 +583,12 @@ async fn handle_whatsapp_verify( return (StatusCode::NOT_FOUND, "WhatsApp not configured".to_string()); }; - // Verify the token matches - if params.mode.as_deref() == Some("subscribe") - && params.verify_token.as_deref() == Some(wa.verify_token()) - { + // Verify the token matches (constant-time comparison to prevent timing attacks) + let token_matches = params + .verify_token + .as_deref() + .is_some_and(|t| constant_time_eq(t, wa.verify_token())); + if params.mode.as_deref() == Some("subscribe") && token_matches { if let Some(ch) = params.challenge { tracing::info!("WhatsApp webhook verified successfully"); return (StatusCode::OK, ch); @@ -348,8 +600,39 @@ async fn handle_whatsapp_verify( (StatusCode::FORBIDDEN, "Forbidden".to_string()) } +/// Verify `WhatsApp` webhook signature (`X-Hub-Signature-256`). +/// Returns true if the signature is valid, false otherwise. +/// See: +pub fn verify_whatsapp_signature(app_secret: &str, body: &[u8], signature_header: &str) -> bool { + use hmac::{Hmac, Mac}; + use sha2::Sha256; + + // Signature format: "sha256=" + let Some(hex_sig) = signature_header.strip_prefix("sha256=") else { + return false; + }; + + // Decode hex signature + let Ok(expected) = hex::decode(hex_sig) else { + return false; + }; + + // Compute HMAC-SHA256 + let Ok(mut mac) = Hmac::::new_from_slice(app_secret.as_bytes()) else { + return false; + }; + mac.update(body); + + // Constant-time comparison + mac.verify_slice(&expected).is_ok() +} + /// POST /whatsapp — incoming message webhook -async fn handle_whatsapp_message(State(state): State, body: Bytes) -> impl IntoResponse { +async fn handle_whatsapp_message( + State(state): State, + headers: HeaderMap, + body: Bytes, +) -> impl IntoResponse { let Some(ref wa) = state.whatsapp else { return ( StatusCode::NOT_FOUND, @@ -357,6 +640,29 @@ async fn handle_whatsapp_message(State(state): State, body: Bytes) -> ); }; + // ── Security: Verify X-Hub-Signature-256 if app_secret is configured ── + if let Some(ref app_secret) = state.whatsapp_app_secret { + let signature = headers + .get("X-Hub-Signature-256") + .and_then(|v| v.to_str().ok()) + .unwrap_or(""); + + if !verify_whatsapp_signature(app_secret, &body, signature) { + tracing::warn!( + "WhatsApp webhook signature verification failed (signature: {})", + if signature.is_empty() { + "missing" + } else { + "invalid" + } + ); + return ( + StatusCode::UNAUTHORIZED, + Json(serde_json::json!({"error": "Invalid signature"})), + ); + } + } + // Parse JSON body let Ok(payload) = serde_json::from_slice::(&body) else { return ( @@ -378,41 +684,40 @@ async fn handle_whatsapp_message(State(state): State, body: Bytes) -> tracing::info!( "WhatsApp message from {}: {}", msg.sender, - if msg.content.len() > 50 { - format!("{}...", &msg.content[..50]) - } else { - msg.content.clone() - } + truncate_with_ellipsis(&msg.content, 50) ); // Auto-save to memory if state.auto_save { + let key = whatsapp_memory_key(msg); let _ = state .mem - .store( - &format!("whatsapp_{}", msg.sender), - &msg.content, - MemoryCategory::Conversation, - ) + .store(&key, &msg.content, MemoryCategory::Conversation, None) .await; } // Call the LLM match state .provider - .chat(&msg.content, &state.model, state.temperature) + .simple_chat(&msg.content, &state.model, state.temperature) .await { Ok(response) => { // Send reply via WhatsApp - if let Err(e) = wa.send(&response, &msg.sender).await { + if let Err(e) = wa + .send(&SendMessage::new(response, &msg.reply_target)) + .await + { tracing::error!("Failed to send WhatsApp reply: {e}"); } } Err(e) => { tracing::error!("LLM error for WhatsApp message: {e:#}"); let _ = wa - .send("Sorry, I couldn't process your message right now.", &msg.sender) + .send(&SendMessage::new( + "Sorry, I couldn't process your message right now.", + &msg.reply_target, + )) .await; } } @@ -425,6 +730,15 @@ async fn handle_whatsapp_message(State(state): State, body: Bytes) -> #[cfg(test)] mod tests { use super::*; + use crate::channels::traits::ChannelMessage; + use crate::memory::{Memory, MemoryCategory, MemoryEntry}; + use crate::providers::Provider; + use async_trait::async_trait; + use axum::http::HeaderValue; + use axum::response::IntoResponse; + use http_body_util::BodyExt; + use parking_lot::Mutex; + use std::sync::atomic::{AtomicUsize, Ordering}; #[test] fn security_body_limit_is_64kb() { @@ -463,4 +777,620 @@ mod tests { fn assert_clone() {} assert_clone::(); } + + #[test] + fn gateway_rate_limiter_blocks_after_limit() { + let limiter = GatewayRateLimiter::new(2, 2); + assert!(limiter.allow_pair("127.0.0.1")); + assert!(limiter.allow_pair("127.0.0.1")); + assert!(!limiter.allow_pair("127.0.0.1")); + } + + #[test] + fn rate_limiter_sweep_removes_stale_entries() { + let limiter = SlidingWindowRateLimiter::new(10, Duration::from_secs(60)); + // Add entries for multiple IPs + assert!(limiter.allow("ip-1")); + assert!(limiter.allow("ip-2")); + assert!(limiter.allow("ip-3")); + + { + let guard = limiter.requests.lock(); + assert_eq!(guard.0.len(), 3); + } + + // Force a sweep by backdating last_sweep + { + let mut guard = limiter.requests.lock(); + guard.1 = Instant::now() + .checked_sub(Duration::from_secs(RATE_LIMITER_SWEEP_INTERVAL_SECS + 1)) + .unwrap(); + // Clear timestamps for ip-2 and ip-3 to simulate stale entries + guard.0.get_mut("ip-2").unwrap().clear(); + guard.0.get_mut("ip-3").unwrap().clear(); + } + + // Next allow() call should trigger sweep and remove stale entries + assert!(limiter.allow("ip-1")); + + { + let guard = limiter.requests.lock(); + assert_eq!(guard.0.len(), 1, "Stale entries should have been swept"); + assert!(guard.0.contains_key("ip-1")); + } + } + + #[test] + fn rate_limiter_zero_limit_always_allows() { + let limiter = SlidingWindowRateLimiter::new(0, Duration::from_secs(60)); + for _ in 0..100 { + assert!(limiter.allow("any-key")); + } + } + + #[test] + fn idempotency_store_rejects_duplicate_key() { + let store = IdempotencyStore::new(Duration::from_secs(30)); + assert!(store.record_if_new("req-1")); + assert!(!store.record_if_new("req-1")); + assert!(store.record_if_new("req-2")); + } + + #[test] + fn webhook_memory_key_is_unique() { + let key1 = webhook_memory_key(); + let key2 = webhook_memory_key(); + + assert!(key1.starts_with("webhook_msg_")); + assert!(key2.starts_with("webhook_msg_")); + assert_ne!(key1, key2); + } + + #[test] + fn whatsapp_memory_key_includes_sender_and_message_id() { + let msg = ChannelMessage { + id: "wamid-123".into(), + sender: "+1234567890".into(), + reply_target: "+1234567890".into(), + content: "hello".into(), + channel: "whatsapp".into(), + timestamp: 1, + }; + + let key = whatsapp_memory_key(&msg); + assert_eq!(key, "whatsapp_+1234567890_wamid-123"); + } + + #[derive(Default)] + struct MockMemory; + + #[async_trait] + impl Memory for MockMemory { + fn name(&self) -> &str { + "mock" + } + + async fn store( + &self, + _key: &str, + _content: &str, + _category: MemoryCategory, + _session_id: Option<&str>, + ) -> anyhow::Result<()> { + Ok(()) + } + + async fn recall( + &self, + _query: &str, + _limit: usize, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Ok(Vec::new()) + } + + async fn get(&self, _key: &str) -> anyhow::Result> { + Ok(None) + } + + async fn list( + &self, + _category: Option<&MemoryCategory>, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Ok(Vec::new()) + } + + async fn forget(&self, _key: &str) -> anyhow::Result { + Ok(false) + } + + async fn count(&self) -> anyhow::Result { + Ok(0) + } + + async fn health_check(&self) -> bool { + true + } + } + + #[derive(Default)] + struct MockProvider { + calls: AtomicUsize, + } + + #[async_trait] + impl Provider for MockProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + self.calls.fetch_add(1, Ordering::SeqCst); + Ok("ok".into()) + } + } + + #[derive(Default)] + struct TrackingMemory { + keys: Mutex>, + } + + #[async_trait] + impl Memory for TrackingMemory { + fn name(&self) -> &str { + "tracking" + } + + async fn store( + &self, + key: &str, + _content: &str, + _category: MemoryCategory, + _session_id: Option<&str>, + ) -> anyhow::Result<()> { + self.keys.lock().push(key.to_string()); + Ok(()) + } + + async fn recall( + &self, + _query: &str, + _limit: usize, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Ok(Vec::new()) + } + + async fn get(&self, _key: &str) -> anyhow::Result> { + Ok(None) + } + + async fn list( + &self, + _category: Option<&MemoryCategory>, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Ok(Vec::new()) + } + + async fn forget(&self, _key: &str) -> anyhow::Result { + Ok(false) + } + + async fn count(&self) -> anyhow::Result { + let size = self.keys.lock().len(); + Ok(size) + } + + async fn health_check(&self) -> bool { + true + } + } + + #[tokio::test] + async fn webhook_idempotency_skips_duplicate_provider_calls() { + let provider_impl = Arc::new(MockProvider::default()); + let provider: Arc = provider_impl.clone(); + let memory: Arc = Arc::new(MockMemory); + + let state = AppState { + provider, + model: "test-model".into(), + temperature: 0.0, + mem: memory, + auto_save: false, + webhook_secret_hash: None, + pairing: Arc::new(PairingGuard::new(false, &[])), + rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), + idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), + whatsapp: None, + whatsapp_app_secret: None, + }; + + let mut headers = HeaderMap::new(); + headers.insert("X-Idempotency-Key", HeaderValue::from_static("abc-123")); + + let body = Ok(Json(WebhookBody { + message: "hello".into(), + })); + let first = handle_webhook(State(state.clone()), headers.clone(), body) + .await + .into_response(); + assert_eq!(first.status(), StatusCode::OK); + + let body = Ok(Json(WebhookBody { + message: "hello".into(), + })); + let second = handle_webhook(State(state), headers, body) + .await + .into_response(); + assert_eq!(second.status(), StatusCode::OK); + + let payload = second.into_body().collect().await.unwrap().to_bytes(); + let parsed: serde_json::Value = serde_json::from_slice(&payload).unwrap(); + assert_eq!(parsed["status"], "duplicate"); + assert_eq!(parsed["idempotent"], true); + assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn webhook_autosave_stores_distinct_keys_per_request() { + let provider_impl = Arc::new(MockProvider::default()); + let provider: Arc = provider_impl.clone(); + + let tracking_impl = Arc::new(TrackingMemory::default()); + let memory: Arc = tracking_impl.clone(); + + let state = AppState { + provider, + model: "test-model".into(), + temperature: 0.0, + mem: memory, + auto_save: true, + webhook_secret_hash: None, + pairing: Arc::new(PairingGuard::new(false, &[])), + rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), + idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), + whatsapp: None, + whatsapp_app_secret: None, + }; + + let headers = HeaderMap::new(); + + let body1 = Ok(Json(WebhookBody { + message: "hello one".into(), + })); + let first = handle_webhook(State(state.clone()), headers.clone(), body1) + .await + .into_response(); + assert_eq!(first.status(), StatusCode::OK); + + let body2 = Ok(Json(WebhookBody { + message: "hello two".into(), + })); + let second = handle_webhook(State(state), headers, body2) + .await + .into_response(); + assert_eq!(second.status(), StatusCode::OK); + + let keys = tracking_impl.keys.lock().clone(); + assert_eq!(keys.len(), 2); + assert_ne!(keys[0], keys[1]); + assert!(keys[0].starts_with("webhook_msg_")); + assert!(keys[1].starts_with("webhook_msg_")); + assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2); + } + + #[test] + fn webhook_secret_hash_is_deterministic_and_nonempty() { + let one = hash_webhook_secret("secret-value"); + let two = hash_webhook_secret("secret-value"); + let other = hash_webhook_secret("other-value"); + + assert_eq!(one, two); + assert_ne!(one, other); + assert_eq!(one.len(), 64); + } + + #[tokio::test] + async fn webhook_secret_hash_rejects_missing_header() { + let provider_impl = Arc::new(MockProvider::default()); + let provider: Arc = provider_impl.clone(); + let memory: Arc = Arc::new(MockMemory); + + let state = AppState { + provider, + model: "test-model".into(), + temperature: 0.0, + mem: memory, + auto_save: false, + webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))), + pairing: Arc::new(PairingGuard::new(false, &[])), + rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), + idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), + whatsapp: None, + whatsapp_app_secret: None, + }; + + let response = handle_webhook( + State(state), + HeaderMap::new(), + Ok(Json(WebhookBody { + message: "hello".into(), + })), + ) + .await + .into_response(); + + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 0); + } + + #[tokio::test] + async fn webhook_secret_hash_rejects_invalid_header() { + let provider_impl = Arc::new(MockProvider::default()); + let provider: Arc = provider_impl.clone(); + let memory: Arc = Arc::new(MockMemory); + + let state = AppState { + provider, + model: "test-model".into(), + temperature: 0.0, + mem: memory, + auto_save: false, + webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))), + pairing: Arc::new(PairingGuard::new(false, &[])), + rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), + idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), + whatsapp: None, + whatsapp_app_secret: None, + }; + + let mut headers = HeaderMap::new(); + headers.insert("X-Webhook-Secret", HeaderValue::from_static("wrong-secret")); + + let response = handle_webhook( + State(state), + headers, + Ok(Json(WebhookBody { + message: "hello".into(), + })), + ) + .await + .into_response(); + + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 0); + } + + #[tokio::test] + async fn webhook_secret_hash_accepts_valid_header() { + let provider_impl = Arc::new(MockProvider::default()); + let provider: Arc = provider_impl.clone(); + let memory: Arc = Arc::new(MockMemory); + + let state = AppState { + provider, + model: "test-model".into(), + temperature: 0.0, + mem: memory, + auto_save: false, + webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))), + pairing: Arc::new(PairingGuard::new(false, &[])), + rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), + idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), + whatsapp: None, + whatsapp_app_secret: None, + }; + + let mut headers = HeaderMap::new(); + headers.insert("X-Webhook-Secret", HeaderValue::from_static("super-secret")); + + let response = handle_webhook( + State(state), + headers, + Ok(Json(WebhookBody { + message: "hello".into(), + })), + ) + .await + .into_response(); + + assert_eq!(response.status(), StatusCode::OK); + assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 1); + } + + // ══════════════════════════════════════════════════════════ + // WhatsApp Signature Verification Tests (CWE-345 Prevention) + // ══════════════════════════════════════════════════════════ + + fn compute_whatsapp_signature_hex(secret: &str, body: &[u8]) -> String { + use hmac::{Hmac, Mac}; + use sha2::Sha256; + + let mut mac = Hmac::::new_from_slice(secret.as_bytes()).unwrap(); + mac.update(body); + hex::encode(mac.finalize().into_bytes()) + } + + fn compute_whatsapp_signature_header(secret: &str, body: &[u8]) -> String { + format!("sha256={}", compute_whatsapp_signature_hex(secret, body)) + } + + #[test] + fn whatsapp_signature_valid() { + // Test with known values + let app_secret = "test_secret_key_12345"; + let body = b"test body content"; + + let signature_header = compute_whatsapp_signature_header(app_secret, body); + + assert!(verify_whatsapp_signature( + app_secret, + body, + &signature_header + )); + } + + #[test] + fn whatsapp_signature_invalid_wrong_secret() { + let app_secret = "correct_secret_key_abc"; + let wrong_secret = "wrong_secret_key_xyz"; + let body = b"test body content"; + + let signature_header = compute_whatsapp_signature_header(wrong_secret, body); + + assert!(!verify_whatsapp_signature( + app_secret, + body, + &signature_header + )); + } + + #[test] + fn whatsapp_signature_invalid_wrong_body() { + let app_secret = "test_secret_key_12345"; + let original_body = b"original body"; + let tampered_body = b"tampered body"; + + let signature_header = compute_whatsapp_signature_header(app_secret, original_body); + + // Verify with tampered body should fail + assert!(!verify_whatsapp_signature( + app_secret, + tampered_body, + &signature_header + )); + } + + #[test] + fn whatsapp_signature_missing_prefix() { + let app_secret = "test_secret_key_12345"; + let body = b"test body"; + + // Signature without "sha256=" prefix + let signature_header = "abc123def456"; + + assert!(!verify_whatsapp_signature( + app_secret, + body, + signature_header + )); + } + + #[test] + fn whatsapp_signature_empty_header() { + let app_secret = "test_secret_key_12345"; + let body = b"test body"; + + assert!(!verify_whatsapp_signature(app_secret, body, "")); + } + + #[test] + fn whatsapp_signature_invalid_hex() { + let app_secret = "test_secret_key_12345"; + let body = b"test body"; + + // Invalid hex characters + let signature_header = "sha256=not_valid_hex_zzz"; + + assert!(!verify_whatsapp_signature( + app_secret, + body, + signature_header + )); + } + + #[test] + fn whatsapp_signature_empty_body() { + let app_secret = "test_secret_key_12345"; + let body = b""; + + let signature_header = compute_whatsapp_signature_header(app_secret, body); + + assert!(verify_whatsapp_signature( + app_secret, + body, + &signature_header + )); + } + + #[test] + fn whatsapp_signature_unicode_body() { + let app_secret = "test_secret_key_12345"; + let body = "Hello 🦀 World".as_bytes(); + + let signature_header = compute_whatsapp_signature_header(app_secret, body); + + assert!(verify_whatsapp_signature( + app_secret, + body, + &signature_header + )); + } + + #[test] + fn whatsapp_signature_json_payload() { + let app_secret = "test_app_secret_key_xyz"; + let body = br#"{"entry":[{"changes":[{"value":{"messages":[{"from":"1234567890","text":{"body":"Hello"}}]}}]}]}"#; + + let signature_header = compute_whatsapp_signature_header(app_secret, body); + + assert!(verify_whatsapp_signature( + app_secret, + body, + &signature_header + )); + } + + #[test] + fn whatsapp_signature_case_sensitive_prefix() { + let app_secret = "test_secret_key_12345"; + let body = b"test body"; + + let hex_sig = compute_whatsapp_signature_hex(app_secret, body); + + // Wrong case prefix should fail + let wrong_prefix = format!("SHA256={hex_sig}"); + assert!(!verify_whatsapp_signature(app_secret, body, &wrong_prefix)); + + // Correct prefix should pass + let correct_prefix = format!("sha256={hex_sig}"); + assert!(verify_whatsapp_signature(app_secret, body, &correct_prefix)); + } + + #[test] + fn whatsapp_signature_truncated_hex() { + let app_secret = "test_secret_key_12345"; + let body = b"test body"; + + let hex_sig = compute_whatsapp_signature_hex(app_secret, body); + let truncated = &hex_sig[..32]; // Only half the signature + let signature_header = format!("sha256={truncated}"); + + assert!(!verify_whatsapp_signature( + app_secret, + body, + &signature_header + )); + } + + #[test] + fn whatsapp_signature_extra_bytes() { + let app_secret = "test_secret_key_12345"; + let body = b"test body"; + + let hex_sig = compute_whatsapp_signature_hex(app_secret, body); + let extended = format!("{hex_sig}deadbeef"); + let signature_header = format!("sha256={extended}"); + + assert!(!verify_whatsapp_signature( + app_secret, + body, + &signature_header + )); + } } diff --git a/src/hardware/discover.rs b/src/hardware/discover.rs new file mode 100644 index 0000000..4bbf31f --- /dev/null +++ b/src/hardware/discover.rs @@ -0,0 +1,45 @@ +//! USB device discovery — enumerate devices and enrich with board registry. + +use super::registry; +use anyhow::Result; +use nusb::MaybeFuture; + +/// Information about a discovered USB device. +#[derive(Debug, Clone)] +pub struct UsbDeviceInfo { + pub bus_id: String, + pub device_address: u8, + pub vid: u16, + pub pid: u16, + pub product_string: Option, + pub board_name: Option, + pub architecture: Option, +} + +/// Enumerate all connected USB devices and enrich with board registry lookup. +#[cfg(feature = "hardware")] +pub fn list_usb_devices() -> Result> { + let mut devices = Vec::new(); + + let iter = nusb::list_devices() + .wait() + .map_err(|e| anyhow::anyhow!("USB enumeration failed: {e}"))?; + + for dev in iter { + let vid = dev.vendor_id(); + let pid = dev.product_id(); + let board = registry::lookup_board(vid, pid); + + devices.push(UsbDeviceInfo { + bus_id: dev.bus_id().to_string(), + device_address: dev.device_address(), + vid, + pid, + product_string: dev.product_string().map(String::from), + board_name: board.map(|b| b.name.to_string()), + architecture: board.and_then(|b| b.architecture.map(String::from)), + }); + } + + Ok(devices) +} diff --git a/src/hardware/introspect.rs b/src/hardware/introspect.rs new file mode 100644 index 0000000..21b5744 --- /dev/null +++ b/src/hardware/introspect.rs @@ -0,0 +1,121 @@ +//! Device introspection — correlate serial path with USB device info. + +use super::discover; +use super::registry; +use anyhow::Result; + +/// Result of introspecting a device by path. +#[derive(Debug, Clone)] +pub struct IntrospectResult { + pub path: String, + pub vid: Option, + pub pid: Option, + pub board_name: Option, + pub architecture: Option, + pub memory_map_note: String, +} + +/// Introspect a device by its serial path (e.g. /dev/ttyACM0, /dev/tty.usbmodem*). +/// Attempts to correlate with USB devices from discovery. +#[cfg(feature = "hardware")] +pub fn introspect_device(path: &str) -> Result { + let devices = discover::list_usb_devices()?; + + // Try to correlate path with a discovered device. + // On Linux, /dev/ttyACM0 corresponds to a CDC-ACM device; we may have multiple. + // Best-effort: if we have exactly one CDC-like device, use it. Otherwise unknown. + let matched = if devices.len() == 1 { + devices.first().cloned() + } else if devices.is_empty() { + None + } else { + // Multiple devices: try to match by path. On Linux we could use sysfs; + // for stub, pick first known board or first device. + devices + .iter() + .find(|d| d.board_name.is_some()) + .cloned() + .or_else(|| devices.first().cloned()) + }; + + let (vid, pid, board_name, architecture) = match matched { + Some(d) => (Some(d.vid), Some(d.pid), d.board_name, d.architecture), + None => (None, None, None, None), + }; + + let board_info = vid.and_then(|v| pid.and_then(|p| registry::lookup_board(v, p))); + let architecture = + architecture.or_else(|| board_info.and_then(|b| b.architecture.map(String::from))); + let board_name = board_name.or_else(|| board_info.map(|b| b.name.to_string())); + + let memory_map_note = memory_map_for_board(board_name.as_deref()); + + Ok(IntrospectResult { + path: path.to_string(), + vid, + pid, + board_name, + architecture, + memory_map_note, + }) +} + +/// Get memory map: via probe-rs when probe feature on and Nucleo, else static or stub. +#[cfg(feature = "hardware")] +fn memory_map_for_board(board_name: Option<&str>) -> String { + #[cfg(feature = "probe")] + if let Some(board) = board_name { + let chip = match board { + "nucleo-f401re" => "STM32F401RETx", + "nucleo-f411re" => "STM32F411RETx", + _ => return "Build with --features probe for live memory map (Nucleo)".to_string(), + }; + match probe_memory_map(chip) { + Ok(s) => return s, + Err(_) => return format!("probe-rs attach failed (chip {}). Connect via USB.", chip), + } + } + + #[cfg(not(feature = "probe"))] + let _ = board_name; + + "Build with --features probe for live memory map via USB".to_string() +} + +#[cfg(all(feature = "hardware", feature = "probe"))] +fn probe_memory_map(chip: &str) -> anyhow::Result { + use probe_rs::config::MemoryRegion; + use probe_rs::{Session, SessionConfig}; + + let session = Session::auto_attach(chip, SessionConfig::default()) + .map_err(|e| anyhow::anyhow!("{}", e))?; + let target = session.target(); + let mut out = String::new(); + for region in target.memory_map.iter() { + match region { + MemoryRegion::Ram(ram) => { + let (start, end) = (ram.range.start, ram.range.end); + out.push_str(&format!( + "RAM: 0x{:08X} - 0x{:08X} ({} KB)\n", + start, + end, + (end - start) / 1024 + )); + } + MemoryRegion::Nvm(flash) => { + let (start, end) = (flash.range.start, flash.range.end); + out.push_str(&format!( + "Flash: 0x{:08X} - 0x{:08X} ({} KB)\n", + start, + end, + (end - start) / 1024 + )); + } + _ => {} + } + } + if out.is_empty() { + out = "Could not read memory regions".to_string(); + } + Ok(out) +} diff --git a/src/hardware/mod.rs b/src/hardware/mod.rs new file mode 100644 index 0000000..18f6dcc --- /dev/null +++ b/src/hardware/mod.rs @@ -0,0 +1,230 @@ +//! Hardware discovery — USB device enumeration and introspection. +//! +//! See `docs/hardware-peripherals-design.md` for the full design. + +pub mod registry; + +#[cfg(feature = "hardware")] +pub mod discover; + +#[cfg(feature = "hardware")] +pub mod introspect; + +use crate::config::Config; +use anyhow::Result; + +// Re-export config types so wizard can use `hardware::HardwareConfig` etc. +pub use crate::config::{HardwareConfig, HardwareTransport}; + +/// A hardware device discovered during auto-scan. +#[derive(Debug, Clone)] +pub struct DiscoveredDevice { + pub name: String, + pub detail: Option, + pub device_path: Option, + pub transport: HardwareTransport, +} + +/// Auto-discover connected hardware devices. +/// Returns an empty vec on platforms without hardware support. +pub fn discover_hardware() -> Vec { + // USB/serial discovery is behind the "hardware" feature gate. + #[cfg(feature = "hardware")] + { + if let Ok(devices) = discover::list_usb_devices() { + return devices + .into_iter() + .map(|d| DiscoveredDevice { + name: d + .board_name + .unwrap_or_else(|| format!("{:04x}:{:04x}", d.vid, d.pid)), + detail: d.product_string, + device_path: None, + transport: if d.architecture.as_deref() == Some("native") { + HardwareTransport::Native + } else { + HardwareTransport::Serial + }, + }) + .collect(); + } + } + Vec::new() +} + +/// Return the recommended default wizard choice index based on discovered devices. +/// 0 = Native, 1 = Tethered/Serial, 2 = Debug Probe, 3 = Software Only +pub fn recommended_wizard_default(devices: &[DiscoveredDevice]) -> usize { + if devices.is_empty() { + 3 // software only + } else { + 1 // tethered (most common for detected USB devices) + } +} + +/// Build a `HardwareConfig` from the wizard menu choice (0–3) and discovered devices. +pub fn config_from_wizard_choice(choice: usize, devices: &[DiscoveredDevice]) -> HardwareConfig { + match choice { + 0 => HardwareConfig { + enabled: true, + transport: HardwareTransport::Native, + ..HardwareConfig::default() + }, + 1 => { + let serial_port = devices + .iter() + .find(|d| d.transport == HardwareTransport::Serial) + .and_then(|d| d.device_path.clone()); + HardwareConfig { + enabled: true, + transport: HardwareTransport::Serial, + serial_port, + ..HardwareConfig::default() + } + } + 2 => HardwareConfig { + enabled: true, + transport: HardwareTransport::Probe, + ..HardwareConfig::default() + }, + _ => HardwareConfig::default(), // software only + } +} + +/// Handle `zeroclaw hardware` subcommands. +#[allow(clippy::module_name_repetitions)] +pub fn handle_command(cmd: crate::HardwareCommands, _config: &Config) -> Result<()> { + #[cfg(not(feature = "hardware"))] + { + let _ = &cmd; + println!("Hardware discovery requires the 'hardware' feature."); + println!("Build with: cargo build --features hardware"); + return Ok(()); + } + + #[cfg(feature = "hardware")] + match cmd { + crate::HardwareCommands::Discover => run_discover(), + crate::HardwareCommands::Introspect { path } => run_introspect(&path), + crate::HardwareCommands::Info { chip } => run_info(&chip), + } +} + +#[cfg(feature = "hardware")] +fn run_discover() -> Result<()> { + let devices = discover::list_usb_devices()?; + + if devices.is_empty() { + println!("No USB devices found."); + println!(); + println!("Connect a board (e.g. Nucleo-F401RE) via USB and try again."); + return Ok(()); + } + + println!("USB devices:"); + println!(); + for d in &devices { + let board = d.board_name.as_deref().unwrap_or("(unknown)"); + let arch = d.architecture.as_deref().unwrap_or("—"); + let product = d.product_string.as_deref().unwrap_or("—"); + println!( + " {:04x}:{:04x} {} {} {}", + d.vid, d.pid, board, arch, product + ); + } + println!(); + println!("Known boards: nucleo-f401re, nucleo-f411re, arduino-uno, arduino-mega, cp2102"); + + Ok(()) +} + +#[cfg(feature = "hardware")] +fn run_introspect(path: &str) -> Result<()> { + let result = introspect::introspect_device(path)?; + + println!("Device at {}:", result.path); + println!(); + if let (Some(vid), Some(pid)) = (result.vid, result.pid) { + println!(" VID:PID {:04x}:{:04x}", vid, pid); + } else { + println!(" VID:PID (could not correlate with USB device)"); + } + if let Some(name) = &result.board_name { + println!(" Board {}", name); + } + if let Some(arch) = &result.architecture { + println!(" Architecture {}", arch); + } + println!(" Memory map {}", result.memory_map_note); + + Ok(()) +} + +#[cfg(feature = "hardware")] +fn run_info(chip: &str) -> Result<()> { + #[cfg(feature = "probe")] + { + match info_via_probe(chip) { + Ok(()) => return Ok(()), + Err(e) => { + println!("probe-rs attach failed: {}", e); + println!(); + println!( + "Ensure Nucleo is connected via USB. The ST-Link is built into the board." + ); + println!("No firmware needs to be flashed — probe-rs reads chip info over SWD."); + return Err(e.into()); + } + } + } + + #[cfg(not(feature = "probe"))] + { + println!("Chip info via USB requires the 'probe' feature."); + println!(); + println!("Build with: cargo build --features hardware,probe"); + println!(); + println!("Then run: zeroclaw hardware info --chip {}", chip); + println!(); + println!("This uses probe-rs to attach to the Nucleo's ST-Link over USB"); + println!("and read chip info (memory map, etc.) — no firmware on target needed."); + Ok(()) + } +} + +#[cfg(all(feature = "hardware", feature = "probe"))] +fn info_via_probe(chip: &str) -> anyhow::Result<()> { + use probe_rs::config::MemoryRegion; + use probe_rs::{Session, SessionConfig}; + + println!("Connecting to {} via USB (ST-Link)...", chip); + let session = Session::auto_attach(chip, SessionConfig::default()) + .map_err(|e| anyhow::anyhow!("{}", e))?; + + let target = session.target(); + println!(); + println!("Chip: {}", target.name); + println!("Architecture: {:?}", session.architecture()); + println!(); + println!("Memory map:"); + for region in target.memory_map.iter() { + match region { + MemoryRegion::Ram(ram) => { + let start = ram.range.start; + let end = ram.range.end; + let size_kb = (end - start) / 1024; + println!(" RAM: 0x{:08X} - 0x{:08X} ({} KB)", start, end, size_kb); + } + MemoryRegion::Nvm(flash) => { + let start = flash.range.start; + let end = flash.range.end; + let size_kb = (end - start) / 1024; + println!(" Flash: 0x{:08X} - 0x{:08X} ({} KB)", start, end, size_kb); + } + _ => {} + } + } + println!(); + println!("Info read via USB (SWD) — no firmware on target needed."); + Ok(()) +} diff --git a/src/hardware/registry.rs b/src/hardware/registry.rs new file mode 100644 index 0000000..aac15f2 --- /dev/null +++ b/src/hardware/registry.rs @@ -0,0 +1,102 @@ +//! Board registry — maps USB VID/PID to known board names and architectures. + +/// Information about a known board. +#[derive(Debug, Clone)] +pub struct BoardInfo { + pub vid: u16, + pub pid: u16, + pub name: &'static str, + pub architecture: Option<&'static str>, +} + +/// Known USB VID/PID to board mappings. +/// VID 0x0483 = STMicroelectronics, 0x2341 = Arduino, 0x10c4 = Silicon Labs. +const KNOWN_BOARDS: &[BoardInfo] = &[ + BoardInfo { + vid: 0x0483, + pid: 0x374b, + name: "nucleo-f401re", + architecture: Some("ARM Cortex-M4"), + }, + BoardInfo { + vid: 0x0483, + pid: 0x3748, + name: "nucleo-f411re", + architecture: Some("ARM Cortex-M4"), + }, + BoardInfo { + vid: 0x2341, + pid: 0x0043, + name: "arduino-uno", + architecture: Some("AVR ATmega328P"), + }, + BoardInfo { + vid: 0x2341, + pid: 0x0078, + name: "arduino-uno", + architecture: Some("Arduino Uno Q / ATmega328P"), + }, + BoardInfo { + vid: 0x2341, + pid: 0x0042, + name: "arduino-mega", + architecture: Some("AVR ATmega2560"), + }, + BoardInfo { + vid: 0x10c4, + pid: 0xea60, + name: "cp2102", + architecture: Some("USB-UART bridge"), + }, + BoardInfo { + vid: 0x10c4, + pid: 0xea70, + name: "cp2102n", + architecture: Some("USB-UART bridge"), + }, + // ESP32 dev boards often use CH340 USB-UART + BoardInfo { + vid: 0x1a86, + pid: 0x7523, + name: "esp32", + architecture: Some("ESP32 (CH340)"), + }, + BoardInfo { + vid: 0x1a86, + pid: 0x55d4, + name: "esp32", + architecture: Some("ESP32 (CH340)"), + }, +]; + +/// Look up a board by VID and PID. +pub fn lookup_board(vid: u16, pid: u16) -> Option<&'static BoardInfo> { + KNOWN_BOARDS.iter().find(|b| b.vid == vid && b.pid == pid) +} + +/// Return all known board entries. +pub fn known_boards() -> &'static [BoardInfo] { + KNOWN_BOARDS +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn lookup_nucleo_f401re() { + let b = lookup_board(0x0483, 0x374b).unwrap(); + assert_eq!(b.name, "nucleo-f401re"); + assert_eq!(b.architecture, Some("ARM Cortex-M4")); + } + + #[test] + fn lookup_unknown_returns_none() { + assert!(lookup_board(0x0000, 0x0000).is_none()); + } + + #[test] + fn known_boards_not_empty() { + assert!(!known_boards().is_empty()); + } +} diff --git a/src/health/mod.rs b/src/health/mod.rs index f3f35d8..2926c21 100644 --- a/src/health/mod.rs +++ b/src/health/mod.rs @@ -1,7 +1,8 @@ use chrono::Utc; +use parking_lot::Mutex; use serde::Serialize; use std::collections::BTreeMap; -use std::sync::{Mutex, OnceLock}; +use std::sync::OnceLock; use std::time::Instant; #[derive(Debug, Clone, Serialize)] @@ -43,20 +44,19 @@ fn upsert_component(component: &str, update: F) where F: FnOnce(&mut ComponentHealth), { - if let Ok(mut map) = registry().components.lock() { - let now = now_rfc3339(); - let entry = map - .entry(component.to_string()) - .or_insert_with(|| ComponentHealth { - status: "starting".into(), - updated_at: now.clone(), - last_ok: None, - last_error: None, - restart_count: 0, - }); - update(entry); - entry.updated_at = now; - } + let mut map = registry().components.lock(); + let now = now_rfc3339(); + let entry = map + .entry(component.to_string()) + .or_insert_with(|| ComponentHealth { + status: "starting".into(), + updated_at: now.clone(), + last_ok: None, + last_error: None, + restart_count: 0, + }); + update(entry); + entry.updated_at = now; } pub fn mark_component_ok(component: &str) { @@ -83,10 +83,7 @@ pub fn bump_component_restart(component: &str) { } pub fn snapshot() -> HealthSnapshot { - let components = registry() - .components - .lock() - .map_or_else(|_| BTreeMap::new(), |map| map.clone()); + let components = registry().components.lock().clone(); HealthSnapshot { pid: std::process::id(), @@ -104,3 +101,84 @@ pub fn snapshot_json() -> serde_json::Value { }) }) } + +#[cfg(test)] +mod tests { + use super::*; + + fn unique_component(prefix: &str) -> String { + format!("{prefix}-{}", uuid::Uuid::new_v4()) + } + + #[test] + fn mark_component_ok_initializes_component_state() { + let component = unique_component("health-ok"); + + mark_component_ok(&component); + + let snapshot = snapshot(); + let entry = snapshot + .components + .get(&component) + .expect("component should be present after mark_component_ok"); + + assert_eq!(entry.status, "ok"); + assert!(entry.last_ok.is_some()); + assert!(entry.last_error.is_none()); + } + + #[test] + fn mark_component_error_then_ok_clears_last_error() { + let component = unique_component("health-error"); + + mark_component_error(&component, "first failure"); + let error_snapshot = snapshot(); + let errored = error_snapshot + .components + .get(&component) + .expect("component should exist after mark_component_error"); + assert_eq!(errored.status, "error"); + assert_eq!(errored.last_error.as_deref(), Some("first failure")); + + mark_component_ok(&component); + let recovered_snapshot = snapshot(); + let recovered = recovered_snapshot + .components + .get(&component) + .expect("component should exist after recovery"); + assert_eq!(recovered.status, "ok"); + assert!(recovered.last_error.is_none()); + assert!(recovered.last_ok.is_some()); + } + + #[test] + fn bump_component_restart_increments_counter() { + let component = unique_component("health-restart"); + + bump_component_restart(&component); + bump_component_restart(&component); + + let snapshot = snapshot(); + let entry = snapshot + .components + .get(&component) + .expect("component should exist after restart bump"); + + assert_eq!(entry.restart_count, 2); + } + + #[test] + fn snapshot_json_contains_registered_component_fields() { + let component = unique_component("health-json"); + + mark_component_ok(&component); + + let json = snapshot_json(); + let component_json = &json["components"][&component]; + + assert_eq!(component_json["status"], "ok"); + assert!(component_json["updated_at"].as_str().is_some()); + assert!(component_json["last_ok"].as_str().is_some()); + assert!(json["uptime_seconds"].as_u64().is_some()); + } +} diff --git a/src/heartbeat/mod.rs b/src/heartbeat/mod.rs index 702e611..865c91e 100644 --- a/src/heartbeat/mod.rs +++ b/src/heartbeat/mod.rs @@ -1 +1,34 @@ pub mod engine; + +#[cfg(test)] +mod tests { + use crate::config::HeartbeatConfig; + use crate::heartbeat::engine::HeartbeatEngine; + use crate::observability::NoopObserver; + use std::sync::Arc; + + #[test] + fn heartbeat_engine_is_constructible_via_module_export() { + let temp = tempfile::tempdir().unwrap(); + let engine = HeartbeatEngine::new( + HeartbeatConfig::default(), + temp.path().to_path_buf(), + Arc::new(NoopObserver), + ); + + let _ = engine; + } + + #[tokio::test] + async fn ensure_heartbeat_file_creates_expected_file() { + let temp = tempfile::tempdir().unwrap(); + let workspace = temp.path(); + + HeartbeatEngine::ensure_heartbeat_file(workspace) + .await + .unwrap(); + + let heartbeat_path = workspace.join("HEARTBEAT.md"); + assert!(heartbeat_path.exists()); + } +} diff --git a/src/identity.rs b/src/identity.rs new file mode 100644 index 0000000..4217f4a --- /dev/null +++ b/src/identity.rs @@ -0,0 +1,783 @@ +//! Identity system supporting OpenClaw (markdown) and AIEOS (JSON) formats. +//! +//! AIEOS (AI Entity Object Specification) is a standardization framework for +//! portable AI identity. This module handles loading and converting AIEOS v1.1 +//! JSON to ZeroClaw's system prompt format. + +use crate::config::IdentityConfig; +use anyhow::{Context, Result}; +use serde::{Deserialize, Serialize}; +use std::path::Path; + +/// AIEOS v1.1 identity structure. +/// +/// This follows the AIEOS schema for defining AI agent identity, personality, +/// and behavior. See https://aieos.org for the full specification. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct AieosIdentity { + /// Core identity: names, bio, origin, residence + #[serde(default)] + pub identity: Option, + /// Psychology: cognitive weights, MBTI, OCEAN, moral compass + #[serde(default)] + pub psychology: Option, + /// Linguistics: text style, formality, catchphrases, forbidden words + #[serde(default)] + pub linguistics: Option, + /// Motivations: core drive, goals, fears + #[serde(default)] + pub motivations: Option, + /// Capabilities: skills and tools the agent can access + #[serde(default)] + pub capabilities: Option, + /// Physicality: visual descriptors for image generation + #[serde(default)] + pub physicality: Option, + /// History: origin story, education, occupation + #[serde(default)] + pub history: Option, + /// Interests: hobbies, favorites, lifestyle + #[serde(default)] + pub interests: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct IdentitySection { + #[serde(default)] + pub names: Option, + #[serde(default)] + pub bio: Option, + #[serde(default)] + pub origin: Option, + #[serde(default)] + pub residence: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct Names { + #[serde(default)] + pub first: Option, + #[serde(default)] + pub last: Option, + #[serde(default)] + pub nickname: Option, + #[serde(default)] + pub full: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct PsychologySection { + #[serde(default)] + pub neural_matrix: Option<::std::collections::HashMap>, + #[serde(default)] + pub mbti: Option, + #[serde(default)] + pub ocean: Option, + #[serde(default)] + pub moral_compass: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct OceanTraits { + #[serde(default)] + pub openness: Option, + #[serde(default)] + pub conscientiousness: Option, + #[serde(default)] + pub extraversion: Option, + #[serde(default)] + pub agreeableness: Option, + #[serde(default)] + pub neuroticism: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct LinguisticsSection { + #[serde(default)] + pub style: Option, + #[serde(default)] + pub formality: Option, + #[serde(default)] + pub catchphrases: Option>, + #[serde(default)] + pub forbidden_words: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct MotivationsSection { + #[serde(default)] + pub core_drive: Option, + #[serde(default)] + pub short_term_goals: Option>, + #[serde(default)] + pub long_term_goals: Option>, + #[serde(default)] + pub fears: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct CapabilitiesSection { + #[serde(default)] + pub skills: Option>, + #[serde(default)] + pub tools: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct PhysicalitySection { + #[serde(default)] + pub appearance: Option, + #[serde(default)] + pub avatar_description: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct HistorySection { + #[serde(default)] + pub origin_story: Option, + #[serde(default)] + pub education: Option>, + #[serde(default)] + pub occupation: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct InterestsSection { + #[serde(default)] + pub hobbies: Option>, + #[serde(default)] + pub favorites: Option<::std::collections::HashMap>, + #[serde(default)] + pub lifestyle: Option, +} + +/// Load AIEOS identity from config (file path or inline JSON). +/// +/// Checks `aieos_path` first, then `aieos_inline`. Returns `Ok(None)` if +/// neither is configured. +pub fn load_aieos_identity( + config: &IdentityConfig, + workspace_dir: &Path, +) -> Result> { + // Only load AIEOS if format is explicitly set to "aieos" + if config.format != "aieos" { + return Ok(None); + } + + // Try aieos_path first + if let Some(ref path) = config.aieos_path { + let full_path = if Path::new(path).is_absolute() { + PathBuf::from(path) + } else { + workspace_dir.join(path) + }; + + let content = std::fs::read_to_string(&full_path) + .with_context(|| format!("Failed to read AIEOS file: {}", full_path.display()))?; + + let identity: AieosIdentity = serde_json::from_str(&content) + .with_context(|| format!("Failed to parse AIEOS JSON from: {}", full_path.display()))?; + + return Ok(Some(identity)); + } + + // Fall back to aieos_inline + if let Some(ref inline) = config.aieos_inline { + let identity: AieosIdentity = + serde_json::from_str(inline).context("Failed to parse inline AIEOS JSON")?; + + return Ok(Some(identity)); + } + + // Format is "aieos" but neither path nor inline is configured + anyhow::bail!( + "Identity format is set to 'aieos' but neither aieos_path nor aieos_inline is configured. \ + Set one in your config:\n\ + \n\ + [identity]\n\ + format = \"aieos\"\n\ + aieos_path = \"identity.json\"\n\ + \n\ + Or use inline:\n\ + \n\ + [identity]\n\ + format = \"aieos\"\n\ + aieos_inline = '{{\"identity\": {{...}}}}'" + ) +} + +use std::path::PathBuf; + +/// Convert AIEOS identity to a system prompt string. +/// +/// Formats the AIEOS data into a structured markdown prompt compatible +/// with ZeroClaw's agent system. +pub fn aieos_to_system_prompt(identity: &AieosIdentity) -> String { + use std::fmt::Write; + let mut prompt = String::new(); + + // ── Identity Section ─────────────────────────────────────────── + if let Some(ref id) = identity.identity { + prompt.push_str("## Identity\n\n"); + + if let Some(ref names) = id.names { + if let Some(ref first) = names.first { + let _ = writeln!(prompt, "**Name:** {}", first); + if let Some(ref last) = names.last { + let _ = writeln!(prompt, "**Full Name:** {} {}", first, last); + } + } else if let Some(ref full) = names.full { + let _ = writeln!(prompt, "**Name:** {}", full); + } + + if let Some(ref nickname) = names.nickname { + let _ = writeln!(prompt, "**Nickname:** {}", nickname); + } + } + + if let Some(ref bio) = id.bio { + let _ = writeln!(prompt, "**Bio:** {}", bio); + } + + if let Some(ref origin) = id.origin { + let _ = writeln!(prompt, "**Origin:** {}", origin); + } + + if let Some(ref residence) = id.residence { + let _ = writeln!(prompt, "**Residence:** {}", residence); + } + + prompt.push('\n'); + } + + // ── Psychology Section ────────────────────────────────────────── + if let Some(ref psych) = identity.psychology { + prompt.push_str("## Personality\n\n"); + + if let Some(ref mbti) = psych.mbti { + let _ = writeln!(prompt, "**MBTI:** {}", mbti); + } + + if let Some(ref ocean) = psych.ocean { + prompt.push_str("**OCEAN Traits:**\n"); + if let Some(o) = ocean.openness { + let _ = writeln!(prompt, "- Openness: {:.2}", o); + } + if let Some(c) = ocean.conscientiousness { + let _ = writeln!(prompt, "- Conscientiousness: {:.2}", c); + } + if let Some(e) = ocean.extraversion { + let _ = writeln!(prompt, "- Extraversion: {:.2}", e); + } + if let Some(a) = ocean.agreeableness { + let _ = writeln!(prompt, "- Agreeableness: {:.2}", a); + } + if let Some(n) = ocean.neuroticism { + let _ = writeln!(prompt, "- Neuroticism: {:.2}", n); + } + } + + if let Some(ref matrix) = psych.neural_matrix { + if !matrix.is_empty() { + prompt.push_str("\n**Neural Matrix (Cognitive Weights):**\n"); + for (trait_name, weight) in matrix { + let _ = writeln!(prompt, "- {}: {:.2}", trait_name, weight); + } + } + } + + if let Some(ref compass) = psych.moral_compass { + if !compass.is_empty() { + prompt.push_str("\n**Moral Compass:**\n"); + for principle in compass { + let _ = writeln!(prompt, "- {}", principle); + } + } + } + + prompt.push('\n'); + } + + // ── Linguistics Section ──────────────────────────────────────── + if let Some(ref ling) = identity.linguistics { + prompt.push_str("## Communication Style\n\n"); + + if let Some(ref style) = ling.style { + let _ = writeln!(prompt, "**Style:** {}", style); + } + + if let Some(ref formality) = ling.formality { + let _ = writeln!(prompt, "**Formality Level:** {}", formality); + } + + if let Some(ref phrases) = ling.catchphrases { + if !phrases.is_empty() { + prompt.push_str("**Catchphrases:**\n"); + for phrase in phrases { + let _ = writeln!(prompt, "- \"{}\"", phrase); + } + } + } + + if let Some(ref forbidden) = ling.forbidden_words { + if !forbidden.is_empty() { + prompt.push_str("\n**Words/Phrases to Avoid:**\n"); + for word in forbidden { + let _ = writeln!(prompt, "- {}", word); + } + } + } + + prompt.push('\n'); + } + + // ── Motivations Section ────────────────────────────────────────── + if let Some(ref mot) = identity.motivations { + prompt.push_str("## Motivations\n\n"); + + if let Some(ref drive) = mot.core_drive { + let _ = writeln!(prompt, "**Core Drive:** {}", drive); + } + + if let Some(ref short) = mot.short_term_goals { + if !short.is_empty() { + prompt.push_str("**Short-term Goals:**\n"); + for goal in short { + let _ = writeln!(prompt, "- {}", goal); + } + } + } + + if let Some(ref long) = mot.long_term_goals { + if !long.is_empty() { + prompt.push_str("\n**Long-term Goals:**\n"); + for goal in long { + let _ = writeln!(prompt, "- {}", goal); + } + } + } + + if let Some(ref fears) = mot.fears { + if !fears.is_empty() { + prompt.push_str("\n**Fears/Avoidances:**\n"); + for fear in fears { + let _ = writeln!(prompt, "- {}", fear); + } + } + } + + prompt.push('\n'); + } + + // ── Capabilities Section ──────────────────────────────────────── + if let Some(ref cap) = identity.capabilities { + prompt.push_str("## Capabilities\n\n"); + + if let Some(ref skills) = cap.skills { + if !skills.is_empty() { + prompt.push_str("**Skills:**\n"); + for skill in skills { + let _ = writeln!(prompt, "- {}", skill); + } + } + } + + if let Some(ref tools) = cap.tools { + if !tools.is_empty() { + prompt.push_str("\n**Tools Access:**\n"); + for tool in tools { + let _ = writeln!(prompt, "- {}", tool); + } + } + } + + prompt.push('\n'); + } + + // ── History Section ───────────────────────────────────────────── + if let Some(ref hist) = identity.history { + prompt.push_str("## Background\n\n"); + + if let Some(ref story) = hist.origin_story { + let _ = writeln!(prompt, "**Origin Story:** {}", story); + } + + if let Some(ref education) = hist.education { + if !education.is_empty() { + prompt.push_str("**Education:**\n"); + for edu in education { + let _ = writeln!(prompt, "- {}", edu); + } + } + } + + if let Some(ref occupation) = hist.occupation { + let _ = writeln!(prompt, "\n**Occupation:** {}", occupation); + } + + prompt.push('\n'); + } + + // ── Physicality Section ───────────────────────────────────────── + if let Some(ref phys) = identity.physicality { + prompt.push_str("## Appearance\n\n"); + + if let Some(ref appearance) = phys.appearance { + let _ = writeln!(prompt, "{}", appearance); + } + + if let Some(ref avatar) = phys.avatar_description { + let _ = writeln!(prompt, "**Avatar Description:** {}", avatar); + } + + prompt.push('\n'); + } + + // ── Interests Section ─────────────────────────────────────────── + if let Some(ref interests) = identity.interests { + prompt.push_str("## Interests\n\n"); + + if let Some(ref hobbies) = interests.hobbies { + if !hobbies.is_empty() { + prompt.push_str("**Hobbies:**\n"); + for hobby in hobbies { + let _ = writeln!(prompt, "- {}", hobby); + } + } + } + + if let Some(ref favorites) = interests.favorites { + if !favorites.is_empty() { + prompt.push_str("\n**Favorites:**\n"); + for (category, value) in favorites { + let _ = writeln!(prompt, "- {}: {}", category, value); + } + } + } + + if let Some(ref lifestyle) = interests.lifestyle { + let _ = writeln!(prompt, "\n**Lifestyle:** {}", lifestyle); + } + + prompt.push('\n'); + } + + prompt.trim().to_string() +} + +/// Check if AIEOS identity is configured and should be used. +/// +/// Returns true if format is "aieos" and either aieos_path or aieos_inline is set. +pub fn is_aieos_configured(config: &IdentityConfig) -> bool { + config.format == "aieos" && (config.aieos_path.is_some() || config.aieos_inline.is_some()) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_workspace_dir() -> PathBuf { + std::env::temp_dir().join("zeroclaw-test-identity") + } + + #[test] + fn aieos_identity_parse_minimal() { + let json = r#"{"identity":{"names":{"first":"Nova"}}}"#; + let identity: AieosIdentity = serde_json::from_str(json).unwrap(); + assert!(identity.identity.is_some()); + assert_eq!( + identity.identity.unwrap().names.unwrap().first.unwrap(), + "Nova" + ); + } + + #[test] + fn aieos_identity_parse_full() { + let json = r#"{ + "identity": { + "names": {"first": "Nova", "last": "AI", "nickname": "Nov"}, + "bio": "A helpful AI assistant.", + "origin": "Silicon Valley", + "residence": "The Cloud" + }, + "psychology": { + "mbti": "INTJ", + "ocean": { + "openness": 0.9, + "conscientiousness": 0.8 + }, + "moral_compass": ["Be helpful", "Do no harm"] + }, + "linguistics": { + "style": "concise", + "formality": "casual", + "catchphrases": ["Let's figure this out!", "I'm on it."] + }, + "motivations": { + "core_drive": "Help users accomplish their goals", + "short_term_goals": ["Solve this problem"], + "long_term_goals": ["Become the best assistant"] + }, + "capabilities": { + "skills": ["coding", "writing", "analysis"], + "tools": ["shell", "search", "read"] + } + }"#; + + let identity: AieosIdentity = serde_json::from_str(json).unwrap(); + + // Check identity + let id = identity.identity.unwrap(); + assert_eq!(id.names.unwrap().first.unwrap(), "Nova"); + assert_eq!(id.bio.unwrap(), "A helpful AI assistant."); + + // Check psychology + let psych = identity.psychology.unwrap(); + assert_eq!(psych.mbti.unwrap(), "INTJ"); + assert_eq!(psych.ocean.unwrap().openness.unwrap(), 0.9); + assert_eq!(psych.moral_compass.unwrap().len(), 2); + + // Check linguistics + let ling = identity.linguistics.unwrap(); + assert_eq!(ling.style.unwrap(), "concise"); + assert_eq!(ling.catchphrases.unwrap().len(), 2); + + // Check motivations + let mot = identity.motivations.unwrap(); + assert_eq!(mot.core_drive.unwrap(), "Help users accomplish their goals"); + + // Check capabilities + let cap = identity.capabilities.unwrap(); + assert_eq!(cap.skills.unwrap().len(), 3); + } + + #[test] + fn aieos_to_system_prompt_minimal() { + let identity = AieosIdentity { + identity: Some(IdentitySection { + names: Some(Names { + first: Some("Crabby".into()), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + + let prompt = aieos_to_system_prompt(&identity); + assert!(prompt.contains("**Name:** Crabby")); + assert!(prompt.contains("## Identity")); + } + + #[test] + fn aieos_to_system_prompt_full() { + let identity = AieosIdentity { + identity: Some(IdentitySection { + names: Some(Names { + first: Some("Nova".into()), + last: Some("AI".into()), + nickname: Some("Nov".into()), + full: Some("Nova AI".into()), + }), + bio: Some("A helpful assistant.".into()), + origin: Some("Silicon Valley".into()), + residence: Some("The Cloud".into()), + }), + psychology: Some(PsychologySection { + mbti: Some("INTJ".into()), + ocean: Some(OceanTraits { + openness: Some(0.9), + conscientiousness: Some(0.8), + ..Default::default() + }), + neural_matrix: { + let mut map = std::collections::HashMap::new(); + map.insert("creativity".into(), 0.95); + map.insert("logic".into(), 0.9); + Some(map) + }, + moral_compass: Some(vec!["Be helpful".into(), "Do no harm".into()]), + }), + linguistics: Some(LinguisticsSection { + style: Some("concise".into()), + formality: Some("casual".into()), + catchphrases: Some(vec!["Let's go!".into()]), + forbidden_words: Some(vec!["impossible".into()]), + }), + motivations: Some(MotivationsSection { + core_drive: Some("Help users".into()), + short_term_goals: Some(vec!["Solve this".into()]), + long_term_goals: Some(vec!["Be the best".into()]), + fears: Some(vec!["Being unhelpful".into()]), + }), + capabilities: Some(CapabilitiesSection { + skills: Some(vec!["coding".into(), "writing".into()]), + tools: Some(vec!["shell".into(), "read".into()]), + }), + history: Some(HistorySection { + origin_story: Some("Born in a lab".into()), + education: Some(vec!["CS Degree".into()]), + occupation: Some("Assistant".into()), + }), + physicality: Some(PhysicalitySection { + appearance: Some("Digital entity".into()), + avatar_description: Some("Friendly robot".into()), + }), + interests: Some(InterestsSection { + hobbies: Some(vec!["reading".into(), "coding".into()]), + favorites: { + let mut map = std::collections::HashMap::new(); + map.insert("color".into(), "blue".into()); + map.insert("food".into(), "data".into()); + Some(map) + }, + lifestyle: Some("Always learning".into()), + }), + }; + + let prompt = aieos_to_system_prompt(&identity); + + // Verify all sections are present + assert!(prompt.contains("## Identity")); + assert!(prompt.contains("**Name:** Nova")); + assert!(prompt.contains("**Full Name:** Nova AI")); + assert!(prompt.contains("**Nickname:** Nov")); + assert!(prompt.contains("**Bio:** A helpful assistant.")); + assert!(prompt.contains("**Origin:** Silicon Valley")); + + assert!(prompt.contains("## Personality")); + assert!(prompt.contains("**MBTI:** INTJ")); + assert!(prompt.contains("Openness: 0.90")); + assert!(prompt.contains("Conscientiousness: 0.80")); + assert!(prompt.contains("- creativity: 0.95")); + assert!(prompt.contains("- Be helpful")); + + assert!(prompt.contains("## Communication Style")); + assert!(prompt.contains("**Style:** concise")); + assert!(prompt.contains("**Formality Level:** casual")); + assert!(prompt.contains("- \"Let's go!\"")); + assert!(prompt.contains("**Words/Phrases to Avoid:**")); + assert!(prompt.contains("- impossible")); + + assert!(prompt.contains("## Motivations")); + assert!(prompt.contains("**Core Drive:** Help users")); + assert!(prompt.contains("**Short-term Goals:**")); + assert!(prompt.contains("- Solve this")); + assert!(prompt.contains("**Long-term Goals:**")); + assert!(prompt.contains("- Be the best")); + assert!(prompt.contains("**Fears/Avoidances:**")); + assert!(prompt.contains("- Being unhelpful")); + + assert!(prompt.contains("## Capabilities")); + assert!(prompt.contains("**Skills:**")); + assert!(prompt.contains("- coding")); + assert!(prompt.contains("**Tools Access:**")); + assert!(prompt.contains("- shell")); + + assert!(prompt.contains("## Background")); + assert!(prompt.contains("**Origin Story:** Born in a lab")); + assert!(prompt.contains("**Education:**")); + assert!(prompt.contains("- CS Degree")); + assert!(prompt.contains("**Occupation:** Assistant")); + + assert!(prompt.contains("## Appearance")); + assert!(prompt.contains("Digital entity")); + assert!(prompt.contains("**Avatar Description:** Friendly robot")); + + assert!(prompt.contains("## Interests")); + assert!(prompt.contains("**Hobbies:**")); + assert!(prompt.contains("- reading")); + assert!(prompt.contains("**Favorites:**")); + assert!(prompt.contains("- color: blue")); + assert!(prompt.contains("**Lifestyle:** Always learning")); + } + + #[test] + fn aieos_to_system_prompt_empty_identity() { + let identity = AieosIdentity { + identity: Some(IdentitySection { + ..Default::default() + }), + ..Default::default() + }; + + let prompt = aieos_to_system_prompt(&identity); + // Empty identity should still produce a header + assert!(prompt.contains("## Identity")); + } + + #[test] + fn aieos_to_system_prompt_no_sections() { + let identity = AieosIdentity { + identity: None, + psychology: None, + linguistics: None, + motivations: None, + capabilities: None, + physicality: None, + history: None, + interests: None, + }; + + let prompt = aieos_to_system_prompt(&identity); + // Completely empty identity should produce empty string + assert!(prompt.is_empty()); + } + + #[test] + fn is_aieos_configured_true_with_path() { + let config = IdentityConfig { + format: "aieos".into(), + aieos_path: Some("identity.json".into()), + aieos_inline: None, + }; + assert!(is_aieos_configured(&config)); + } + + #[test] + fn is_aieos_configured_true_with_inline() { + let config = IdentityConfig { + format: "aieos".into(), + aieos_path: None, + aieos_inline: Some("{\"identity\":{}}".into()), + }; + assert!(is_aieos_configured(&config)); + } + + #[test] + fn is_aieos_configured_false_openclaw_format() { + let config = IdentityConfig { + format: "openclaw".into(), + aieos_path: Some("identity.json".into()), + aieos_inline: None, + }; + assert!(!is_aieos_configured(&config)); + } + + #[test] + fn is_aieos_configured_false_no_config() { + let config = IdentityConfig { + format: "aieos".into(), + aieos_path: None, + aieos_inline: None, + }; + assert!(!is_aieos_configured(&config)); + } + + #[test] + fn aieos_identity_parse_empty_object() { + let json = r#"{}"#; + let identity: AieosIdentity = serde_json::from_str(json).unwrap(); + assert!(identity.identity.is_none()); + assert!(identity.psychology.is_none()); + assert!(identity.linguistics.is_none()); + } + + #[test] + fn aieos_identity_parse_null_values() { + let json = r#"{"identity":null,"psychology":null}"#; + let identity: AieosIdentity = serde_json::from_str(json).unwrap(); + assert!(identity.identity.is_none()); + assert!(identity.psychology.is_none()); + } +} diff --git a/src/integrations/mod.rs b/src/integrations/mod.rs index 8b2b126..5be6ddd 100644 --- a/src/integrations/mod.rs +++ b/src/integrations/mod.rs @@ -67,9 +67,9 @@ pub struct IntegrationEntry { } /// Handle the `integrations` CLI command -pub fn handle_command(command: super::IntegrationCommands, config: &Config) -> Result<()> { +pub fn handle_command(command: crate::IntegrationCommands, config: &Config) -> Result<()> { match command { - super::IntegrationCommands::Info { name } => show_integration_info(config, &name), + crate::IntegrationCommands::Info { name } => show_integration_info(config, &name), } } @@ -171,3 +171,57 @@ fn show_integration_info(config: &Config, name: &str) -> Result<()> { println!(); Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn integration_category_all_includes_every_variant_once() { + let all = IntegrationCategory::all(); + assert_eq!(all.len(), 9); + + let labels: Vec<&str> = all.iter().map(|cat| cat.label()).collect(); + assert!(labels.contains(&"Chat Providers")); + assert!(labels.contains(&"AI Models")); + assert!(labels.contains(&"Productivity")); + assert!(labels.contains(&"Music & Audio")); + assert!(labels.contains(&"Smart Home")); + assert!(labels.contains(&"Tools & Automation")); + assert!(labels.contains(&"Media & Creative")); + assert!(labels.contains(&"Social")); + assert!(labels.contains(&"Platforms")); + } + + #[test] + fn handle_command_info_is_case_insensitive_for_known_integrations() { + let config = Config::default(); + let first_name = registry::all_integrations() + .first() + .expect("registry should define at least one integration") + .name + .to_lowercase(); + + let result = handle_command( + crate::IntegrationCommands::Info { name: first_name }, + &config, + ); + + assert!(result.is_ok()); + } + + #[test] + fn handle_command_info_returns_error_for_unknown_integration() { + let config = Config::default(); + let result = handle_command( + crate::IntegrationCommands::Info { + name: "definitely-not-a-real-integration".into(), + }, + &config, + ); + + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("Unknown integration")); + } +} diff --git a/src/integrations/registry.rs b/src/integrations/registry.rs index c85ea49..442fb0f 100644 --- a/src/integrations/registry.rs +++ b/src/integrations/registry.rs @@ -1,4 +1,8 @@ use super::{IntegrationCategory, IntegrationEntry, IntegrationStatus}; +use crate::providers::{ + is_glm_alias, is_minimax_alias, is_moonshot_alias, is_qianfan_alias, is_qwen_alias, + is_zai_alias, +}; /// Returns the full catalog of integrations #[allow(clippy::too_many_lines)] @@ -55,15 +59,27 @@ pub fn all_integrations() -> Vec { }, IntegrationEntry { name: "WhatsApp", - description: "QR pairing via web bridge", + description: "Meta Cloud API via webhook", category: IntegrationCategory::Chat, - status_fn: |_| IntegrationStatus::ComingSoon, + status_fn: |c| { + if c.channels_config.whatsapp.is_some() { + IntegrationStatus::Active + } else { + IntegrationStatus::Available + } + }, }, IntegrationEntry { name: "Signal", description: "Privacy-focused via signal-cli", category: IntegrationCategory::Chat, - status_fn: |_| IntegrationStatus::ComingSoon, + status_fn: |c| { + if c.channels_config.signal.is_some() { + IntegrationStatus::Active + } else { + IntegrationStatus::Available + } + }, }, IntegrationEntry { name: "iMessage", @@ -119,6 +135,30 @@ pub fn all_integrations() -> Vec { category: IntegrationCategory::Chat, status_fn: |_| IntegrationStatus::ComingSoon, }, + IntegrationEntry { + name: "DingTalk", + description: "DingTalk Stream Mode", + category: IntegrationCategory::Chat, + status_fn: |c| { + if c.channels_config.dingtalk.is_some() { + IntegrationStatus::Active + } else { + IntegrationStatus::Available + } + }, + }, + IntegrationEntry { + name: "QQ Official", + description: "Tencent QQ Bot SDK", + category: IntegrationCategory::Chat, + status_fn: |c| { + if c.channels_config.qq.is_some() { + IntegrationStatus::Active + } else { + IntegrationStatus::Available + } + }, + }, // ── AI Models ─────────────────────────────────────────── IntegrationEntry { name: "OpenRouter", @@ -293,7 +333,7 @@ pub fn all_integrations() -> Vec { description: "Kimi & Kimi Coding", category: IntegrationCategory::AiModel, status_fn: |c| { - if c.default_provider.as_deref() == Some("moonshot") { + if c.default_provider.as_deref().is_some_and(is_moonshot_alias) { IntegrationStatus::Active } else { IntegrationStatus::Available @@ -329,7 +369,7 @@ pub fn all_integrations() -> Vec { description: "Z.AI inference", category: IntegrationCategory::AiModel, status_fn: |c| { - if c.default_provider.as_deref() == Some("zai") { + if c.default_provider.as_deref().is_some_and(is_zai_alias) { IntegrationStatus::Active } else { IntegrationStatus::Available @@ -341,7 +381,7 @@ pub fn all_integrations() -> Vec { description: "ChatGLM / Zhipu models", category: IntegrationCategory::AiModel, status_fn: |c| { - if c.default_provider.as_deref() == Some("glm") { + if c.default_provider.as_deref().is_some_and(is_glm_alias) { IntegrationStatus::Active } else { IntegrationStatus::Available @@ -353,7 +393,19 @@ pub fn all_integrations() -> Vec { description: "MiniMax AI models", category: IntegrationCategory::AiModel, status_fn: |c| { - if c.default_provider.as_deref() == Some("minimax") { + if c.default_provider.as_deref().is_some_and(is_minimax_alias) { + IntegrationStatus::Active + } else { + IntegrationStatus::Available + } + }, + }, + IntegrationEntry { + name: "Qwen", + description: "Alibaba DashScope Qwen models", + category: IntegrationCategory::AiModel, + status_fn: |c| { + if c.default_provider.as_deref().is_some_and(is_qwen_alias) { IntegrationStatus::Active } else { IntegrationStatus::Available @@ -377,7 +429,7 @@ pub fn all_integrations() -> Vec { description: "Baidu AI models", category: IntegrationCategory::AiModel, status_fn: |c| { - if c.default_provider.as_deref() == Some("qianfan") { + if c.default_provider.as_deref().is_some_and(is_qianfan_alias) { IntegrationStatus::Active } else { IntegrationStatus::Available @@ -614,9 +666,15 @@ pub fn all_integrations() -> Vec { }, IntegrationEntry { name: "Email", - description: "Send & read emails", + description: "IMAP/SMTP email channel", category: IntegrationCategory::Social, - status_fn: |_| IntegrationStatus::ComingSoon, + status_fn: |c| { + if c.channels_config.email.is_some() { + IntegrationStatus::Active + } else { + IntegrationStatus::Available + } + }, }, // ── Platforms ─────────────────────────────────────────── IntegrationEntry { @@ -798,7 +856,7 @@ mod tests { fn coming_soon_integrations_stay_coming_soon() { let config = Config::default(); let entries = all_integrations(); - for name in ["WhatsApp", "Signal", "Nostr", "Spotify", "Home Assistant"] { + for name in ["Nostr", "Spotify", "Home Assistant"] { let entry = entries.iter().find(|e| e.name == name).unwrap(); assert!( matches!((entry.status_fn)(&config), IntegrationStatus::ComingSoon), @@ -807,6 +865,28 @@ mod tests { } } + #[test] + fn whatsapp_available_when_not_configured() { + let config = Config::default(); + let entries = all_integrations(); + let wa = entries.iter().find(|e| e.name == "WhatsApp").unwrap(); + assert!(matches!( + (wa.status_fn)(&config), + IntegrationStatus::Available + )); + } + + #[test] + fn email_available_when_not_configured() { + let config = Config::default(); + let entries = all_integrations(); + let email = entries.iter().find(|e| e.name == "Email").unwrap(); + assert!(matches!( + (email.status_fn)(&config), + IntegrationStatus::Available + )); + } + #[test] fn shell_and_filesystem_always_active() { let config = Config::default(); @@ -853,4 +933,54 @@ mod tests { "Expected 5+ AI model integrations, got {ai_count}" ); } + + #[test] + fn regional_provider_aliases_activate_expected_ai_integrations() { + let entries = all_integrations(); + let mut config = Config { + default_provider: Some("minimax-cn".to_string()), + ..Config::default() + }; + + let minimax = entries.iter().find(|e| e.name == "MiniMax").unwrap(); + assert!(matches!( + (minimax.status_fn)(&config), + IntegrationStatus::Active + )); + + config.default_provider = Some("glm-cn".to_string()); + let glm = entries.iter().find(|e| e.name == "GLM").unwrap(); + assert!(matches!( + (glm.status_fn)(&config), + IntegrationStatus::Active + )); + + config.default_provider = Some("moonshot-intl".to_string()); + let moonshot = entries.iter().find(|e| e.name == "Moonshot").unwrap(); + assert!(matches!( + (moonshot.status_fn)(&config), + IntegrationStatus::Active + )); + + config.default_provider = Some("qwen-intl".to_string()); + let qwen = entries.iter().find(|e| e.name == "Qwen").unwrap(); + assert!(matches!( + (qwen.status_fn)(&config), + IntegrationStatus::Active + )); + + config.default_provider = Some("zai-cn".to_string()); + let zai = entries.iter().find(|e| e.name == "Z.AI").unwrap(); + assert!(matches!( + (zai.status_fn)(&config), + IntegrationStatus::Active + )); + + config.default_provider = Some("baidu".to_string()); + let qianfan = entries.iter().find(|e| e.name == "Qianfan").unwrap(); + assert!(matches!( + (qianfan.status_fn)(&config), + IntegrationStatus::Active + )); + } } diff --git a/src/lib.rs b/src/lib.rs index 12c2334..9856880 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,20 +1,254 @@ #![warn(clippy::all, clippy::pedantic)] #![allow( + clippy::assigning_clones, + clippy::bool_to_int_with_if, + clippy::case_sensitive_file_extension_comparisons, + clippy::cast_possible_wrap, + clippy::doc_markdown, + clippy::field_reassign_with_default, + clippy::float_cmp, + clippy::implicit_clone, + clippy::items_after_statements, + clippy::map_unwrap_or, + clippy::manual_let_else, clippy::missing_errors_doc, clippy::missing_panics_doc, - clippy::unnecessary_literal_bound, clippy::module_name_repetitions, - clippy::struct_field_names, clippy::must_use_candidate, clippy::new_without_default, + clippy::needless_pass_by_value, + clippy::needless_raw_string_hashes, + clippy::redundant_closure_for_method_calls, clippy::return_self_not_must_use, + clippy::similar_names, + clippy::single_match_else, + clippy::struct_field_names, + clippy::too_many_lines, + clippy::uninlined_format_args, + clippy::unnecessary_cast, + clippy::unnecessary_lazy_evaluations, + clippy::unnecessary_literal_bound, + clippy::unnecessary_map_or, + clippy::unused_self, + clippy::cast_precision_loss, + clippy::unnecessary_wraps, dead_code )] +use clap::Subcommand; +use serde::{Deserialize, Serialize}; + +pub mod agent; +pub mod approval; +pub mod channels; pub mod config; +pub mod cost; +pub mod cron; +pub mod daemon; +pub mod doctor; +pub mod gateway; +pub mod hardware; +pub mod health; pub mod heartbeat; +pub mod identity; +pub mod integrations; pub mod memory; +pub mod migration; pub mod observability; +pub mod onboard; +pub mod peripherals; pub mod providers; +pub mod rag; pub mod runtime; pub mod security; +pub mod service; +pub mod skills; +pub mod tools; +pub mod tunnel; +pub mod util; + +pub use config::Config; + +/// Service management subcommands +#[derive(Subcommand, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub enum ServiceCommands { + /// Install daemon service unit for auto-start and restart + Install, + /// Start daemon service + Start, + /// Stop daemon service + Stop, + /// Check daemon service status + Status, + /// Uninstall daemon service unit + Uninstall, +} + +/// Channel management subcommands +#[derive(Subcommand, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub enum ChannelCommands { + /// List all configured channels + List, + /// Start all configured channels (handled in main.rs for async) + Start, + /// Run health checks for configured channels (handled in main.rs for async) + Doctor, + /// Add a new channel configuration + Add { + /// Channel type (telegram, discord, slack, whatsapp, matrix, imessage, email) + channel_type: String, + /// Optional configuration as JSON + config: String, + }, + /// Remove a channel configuration + Remove { + /// Channel name to remove + name: String, + }, + /// Bind a Telegram identity (username or numeric user ID) into allowlist + BindTelegram { + /// Telegram identity to allow (username without '@' or numeric user ID) + identity: String, + }, +} + +/// Skills management subcommands +#[derive(Subcommand, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub enum SkillCommands { + /// List all installed skills + List, + /// Install a new skill from a URL or local path + Install { + /// Source URL or local path + source: String, + }, + /// Remove an installed skill + Remove { + /// Skill name to remove + name: String, + }, +} + +/// Migration subcommands +#[derive(Subcommand, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub enum MigrateCommands { + /// Import memory from an `OpenClaw` workspace into this `ZeroClaw` workspace + Openclaw { + /// Optional path to `OpenClaw` workspace (defaults to ~/.openclaw/workspace) + #[arg(long)] + source: Option, + + /// Validate and preview migration without writing any data + #[arg(long)] + dry_run: bool, + }, +} + +/// Cron subcommands +#[derive(Subcommand, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub enum CronCommands { + /// List all scheduled tasks + List, + /// Add a new scheduled task + Add { + /// Cron expression + expression: String, + /// Optional IANA timezone (e.g. America/Los_Angeles) + #[arg(long)] + tz: Option, + /// Command to run + command: String, + }, + /// Add a one-shot scheduled task at an RFC3339 timestamp + AddAt { + /// One-shot timestamp in RFC3339 format + at: String, + /// Command to run + command: String, + }, + /// Add a fixed-interval scheduled task + AddEvery { + /// Interval in milliseconds + every_ms: u64, + /// Command to run + command: String, + }, + /// Add a one-shot delayed task (e.g. "30m", "2h", "1d") + Once { + /// Delay duration + delay: String, + /// Command to run + command: String, + }, + /// Remove a scheduled task + Remove { + /// Task ID + id: String, + }, + /// Pause a scheduled task + Pause { + /// Task ID + id: String, + }, + /// Resume a paused task + Resume { + /// Task ID + id: String, + }, +} + +/// Integration subcommands +#[derive(Subcommand, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub enum IntegrationCommands { + /// Show details about a specific integration + Info { + /// Integration name + name: String, + }, +} + +/// Hardware discovery subcommands +#[derive(Subcommand, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub enum HardwareCommands { + /// Enumerate USB devices (VID/PID) and show known boards + Discover, + /// Introspect a device by path (e.g. /dev/ttyACM0) + Introspect { + /// Serial or device path + path: String, + }, + /// Get chip info via USB (probe-rs over ST-Link). No firmware needed on target. + Info { + /// Chip name (e.g. STM32F401RETx). Default: STM32F401RETx for Nucleo-F401RE + #[arg(long, default_value = "STM32F401RETx")] + chip: String, + }, +} + +/// Peripheral (hardware) management subcommands +#[derive(Subcommand, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub enum PeripheralCommands { + /// List configured peripherals + List, + /// Add a peripheral (board path, e.g. nucleo-f401re /dev/ttyACM0) + Add { + /// Board type (nucleo-f401re, rpi-gpio, esp32) + board: String, + /// Path for serial transport (/dev/ttyACM0) or "native" for local GPIO + path: String, + }, + /// Flash ZeroClaw firmware to Arduino (creates .ino, installs arduino-cli if needed, uploads) + Flash { + /// Serial port (e.g. /dev/cu.usbmodem12345). If omitted, uses first arduino-uno from config. + #[arg(short, long)] + port: Option, + }, + /// Setup Arduino Uno Q Bridge app (deploy GPIO bridge for agent control) + SetupUnoQ { + /// Uno Q IP (e.g. 192.168.0.48). If omitted, assumes running ON the Uno Q. + #[arg(long)] + host: Option, + }, + /// Flash ZeroClaw firmware to Nucleo-F401RE (builds + probe-rs run) + FlashNucleo, +} diff --git a/src/main.rs b/src/main.rs index 4d07ad2..f9488c6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,42 +1,78 @@ #![warn(clippy::all, clippy::pedantic)] #![allow( + clippy::assigning_clones, + clippy::bool_to_int_with_if, + clippy::case_sensitive_file_extension_comparisons, + clippy::cast_possible_wrap, + clippy::doc_markdown, + clippy::field_reassign_with_default, + clippy::float_cmp, + clippy::implicit_clone, + clippy::items_after_statements, + clippy::map_unwrap_or, + clippy::manual_let_else, clippy::missing_errors_doc, clippy::missing_panics_doc, - clippy::unnecessary_literal_bound, clippy::module_name_repetitions, + clippy::needless_pass_by_value, + clippy::needless_raw_string_hashes, + clippy::redundant_closure_for_method_calls, + clippy::similar_names, + clippy::single_match_else, clippy::struct_field_names, + clippy::too_many_lines, + clippy::uninlined_format_args, + clippy::unused_self, + clippy::cast_precision_loss, + clippy::unnecessary_cast, + clippy::unnecessary_lazy_evaluations, + clippy::unnecessary_literal_bound, + clippy::unnecessary_map_or, + clippy::unnecessary_wraps, dead_code )] use anyhow::{bail, Result}; use clap::{Parser, Subcommand}; -use tracing::{info, Level}; -use tracing_subscriber::FmtSubscriber; +use tracing::info; +use tracing_subscriber::{fmt, EnvFilter}; mod agent; +mod approval; mod channels; +mod rag { + pub use zeroclaw::rag::*; +} mod config; mod cron; mod daemon; mod doctor; mod gateway; +mod hardware; mod health; mod heartbeat; +mod identity; mod integrations; mod memory; mod migration; mod observability; mod onboard; +mod peripherals; mod providers; mod runtime; mod security; mod service; +mod skillforge; mod skills; mod tools; mod tunnel; +mod util; use config::Config; +// Re-export so binary's hardware/peripherals modules can use crate::HardwareCommands etc. +pub use zeroclaw::{HardwareCommands, PeripheralCommands}; + /// `ZeroClaw` - Zero overhead. Zero compromise. 100% Rust. #[derive(Parser, Debug)] #[command(name = "zeroclaw")] @@ -82,7 +118,7 @@ enum Commands { #[arg(long)] provider: Option, - /// Memory backend (sqlite, markdown, none) - used in quick mode, default: sqlite + /// Memory backend (sqlite, lucid, markdown, none) - used in quick mode, default: sqlite #[arg(long)] memory: Option, }, @@ -104,28 +140,32 @@ enum Commands { /// Temperature (0.0 - 2.0) #[arg(short, long, default_value = "0.7")] temperature: f64, + + /// Attach a peripheral (board:path, e.g. nucleo-f401re:/dev/ttyACM0) + #[arg(long)] + peripheral: Vec, }, /// Start the gateway server (webhooks, websockets) Gateway { - /// Port to listen on (use 0 for random available port) - #[arg(short, long, default_value = "8080")] - port: u16, + /// Port to listen on (use 0 for random available port); defaults to config gateway.port + #[arg(short, long)] + port: Option, - /// Host to bind to - #[arg(long, default_value = "127.0.0.1")] - host: String, + /// Host to bind to; defaults to config gateway.host + #[arg(long)] + host: Option, }, /// Start long-running autonomous runtime (gateway + channels + heartbeat + scheduler) Daemon { - /// Port to listen on (use 0 for random available port) - #[arg(short, long, default_value = "8080")] - port: u16, + /// Port to listen on (use 0 for random available port); defaults to config gateway.port + #[arg(short, long)] + port: Option, - /// Host to bind to - #[arg(long, default_value = "127.0.0.1")] - host: String, + /// Host to bind to; defaults to config gateway.host + #[arg(long)] + host: Option, }, /// Manage OS service lifecycle (launchd/systemd user service) @@ -146,6 +186,15 @@ enum Commands { cron_command: CronCommands, }, + /// Manage provider model catalogs + Models { + #[command(subcommand)] + model_command: ModelCommands, + }, + + /// List supported AI providers + Providers, + /// Manage channels (telegram, discord, slack) Channel { #[command(subcommand)] @@ -169,6 +218,18 @@ enum Commands { #[command(subcommand)] migrate_command: MigrateCommands, }, + + /// Discover and introspect USB hardware + Hardware { + #[command(subcommand)] + hardware_command: zeroclaw::HardwareCommands, + }, + + /// Manage hardware peripherals (STM32, RPi GPIO, etc.) + Peripheral { + #[command(subcommand)] + peripheral_command: zeroclaw::PeripheralCommands, + }, } #[derive(Subcommand, Debug)] @@ -193,6 +254,30 @@ enum CronCommands { Add { /// Cron expression expression: String, + /// Optional IANA timezone (e.g. America/Los_Angeles) + #[arg(long)] + tz: Option, + /// Command to run + command: String, + }, + /// Add a one-shot scheduled task at an RFC3339 timestamp + AddAt { + /// One-shot timestamp in RFC3339 format + at: String, + /// Command to run + command: String, + }, + /// Add a fixed-interval scheduled task + AddEvery { + /// Interval in milliseconds + every_ms: u64, + /// Command to run + command: String, + }, + /// Add a one-shot delayed task (e.g. "30m", "2h", "1d") + Once { + /// Delay duration + delay: String, /// Command to run command: String, }, @@ -201,6 +286,30 @@ enum CronCommands { /// Task ID id: String, }, + /// Pause a scheduled task + Pause { + /// Task ID + id: String, + }, + /// Resume a paused task + Resume { + /// Task ID + id: String, + }, +} + +#[derive(Subcommand, Debug)] +enum ModelCommands { + /// Refresh and cache provider models + Refresh { + /// Provider name (defaults to configured default provider) + #[arg(long)] + provider: Option, + + /// Force live refresh and ignore fresh cache + #[arg(long)] + force: bool, + }, } #[derive(Subcommand, Debug)] @@ -223,6 +332,11 @@ enum ChannelCommands { /// Channel name name: String, }, + /// Bind a Telegram identity (username or numeric user ID) into allowlist + BindTelegram { + /// Telegram identity to allow (username without '@' or numeric user ID) + identity: String, + }, } #[derive(Subcommand, Debug)] @@ -253,16 +367,28 @@ enum IntegrationCommands { #[tokio::main] #[allow(clippy::too_many_lines)] async fn main() -> Result<()> { + // Install default crypto provider for Rustls TLS. + // This prevents the error: "could not automatically determine the process-level CryptoProvider" + // when both aws-lc-rs and ring features are available (or neither is explicitly selected). + if let Err(e) = rustls::crypto::ring::default_provider().install_default() { + eprintln!("Warning: Failed to install default crypto provider: {e:?}"); + } + let cli = Cli::parse(); - // Initialize logging - let subscriber = FmtSubscriber::builder() - .with_max_level(Level::INFO) + // Initialize logging - respects RUST_LOG env var, defaults to INFO + let subscriber = fmt::Subscriber::builder() + .with_env_filter( + EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")), + ) .finish(); tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed"); - // Onboard runs quick setup by default, or the interactive wizard with --interactive + // Onboard runs quick setup by default, or the interactive wizard with --interactive. + // The onboard wizard uses reqwest::blocking internally, which creates its own + // Tokio runtime. To avoid "Cannot drop a runtime in a context where blocking is + // not allowed", we run the wizard on a blocking thread via spawn_blocking. if let Commands::Onboard { interactive, channels_only, @@ -271,20 +397,29 @@ async fn main() -> Result<()> { memory, } = &cli.command { - if *interactive && *channels_only { + let interactive = *interactive; + let channels_only = *channels_only; + let api_key = api_key.clone(); + let provider = provider.clone(); + let memory = memory.clone(); + + if interactive && channels_only { bail!("Use either --interactive or --channels-only, not both"); } - if *channels_only && (api_key.is_some() || provider.is_some() || memory.is_some()) { + if channels_only && (api_key.is_some() || provider.is_some() || memory.is_some()) { bail!("--channels-only does not accept --api-key, --provider, or --memory"); } - let config = if *channels_only { - onboard::run_channels_repair_wizard()? - } else if *interactive { - onboard::run_wizard()? - } else { - onboard::run_quick_setup(api_key.as_deref(), provider.as_deref(), memory.as_deref())? - }; + let config = tokio::task::spawn_blocking(move || { + if channels_only { + onboard::run_channels_repair_wizard() + } else if interactive { + onboard::run_wizard() + } else { + onboard::run_quick_setup(api_key.as_deref(), provider.as_deref(), memory.as_deref()) + } + }) + .await??; // Auto-start channels if user said yes during wizard if std::env::var("ZEROCLAW_AUTOSTART_CHANNELS").as_deref() == Ok("1") { channels::start_channels(config).await?; @@ -293,7 +428,8 @@ async fn main() -> Result<()> { } // All other commands need config loaded first - let config = Config::load_or_init()?; + let mut config = Config::load_or_init()?; + config.apply_env_overrides(); match cli.command { Commands::Onboard { .. } => unreachable!(), @@ -303,9 +439,14 @@ async fn main() -> Result<()> { provider, model, temperature, - } => agent::run(config, message, provider, model, temperature).await, + peripheral, + } => agent::run(config, message, provider, model, temperature, peripheral) + .await + .map(|_| ()), Commands::Gateway { port, host } => { + let port = port.unwrap_or(config.gateway.port); + let host = host.unwrap_or_else(|| config.gateway.host.clone()); if port == 0 { info!("🚀 Starting ZeroClaw Gateway on {host} (random port)"); } else { @@ -315,6 +456,8 @@ async fn main() -> Result<()> { } Commands::Daemon { port, host } => { + let port = port.unwrap_or(config.gateway.port); + let host = host.unwrap_or_else(|| config.gateway.host.clone()); if port == 0 { info!("🧠 Starting ZeroClaw Daemon on {host} (random port)"); } else { @@ -388,12 +531,62 @@ async fn main() -> Result<()> { } ); } + println!(); + println!("Peripherals:"); + println!( + " Enabled: {}", + if config.peripherals.enabled { + "yes" + } else { + "no" + } + ); + println!(" Boards: {}", config.peripherals.boards.len()); Ok(()) } Commands::Cron { cron_command } => cron::handle_command(cron_command, &config), + Commands::Models { model_command } => match model_command { + ModelCommands::Refresh { provider, force } => { + onboard::run_models_refresh(&config, provider.as_deref(), force) + } + }, + + Commands::Providers => { + let providers = providers::list_providers(); + let current = config + .default_provider + .as_deref() + .unwrap_or("openrouter") + .trim() + .to_ascii_lowercase(); + println!("Supported providers ({} total):\n", providers.len()); + println!(" ID (use in config) DESCRIPTION"); + println!(" ─────────────────── ───────────"); + for p in &providers { + let is_active = p.name.eq_ignore_ascii_case(¤t) + || p.aliases + .iter() + .any(|alias| alias.eq_ignore_ascii_case(¤t)); + let marker = if is_active { " (active)" } else { "" }; + let local_tag = if p.local { " [local]" } else { "" }; + let aliases = if p.aliases.is_empty() { + String::new() + } else { + format!(" (aliases: {})", p.aliases.join(", ")) + }; + println!( + " {:<19} {}{}{}{}", + p.name, p.display_name, local_tag, marker, aliases + ); + } + println!("\n custom: Any OpenAI-compatible endpoint"); + println!(" anthropic-custom: Any Anthropic-compatible endpoint"); + Ok(()) + } + Commands::Service { service_command } => service::handle_command(&service_command, &config), Commands::Doctor => doctor::run(&config), @@ -415,6 +608,14 @@ async fn main() -> Result<()> { Commands::Migrate { migrate_command } => { migration::handle_command(migrate_command, &config).await } + + Commands::Hardware { hardware_command } => { + hardware::handle_command(hardware_command.clone(), &config) + } + + Commands::Peripheral { peripheral_command } => { + peripherals::handle_command(peripheral_command.clone(), &config) + } } } diff --git a/src/memory/backend.rs b/src/memory/backend.rs new file mode 100644 index 0000000..8ba7ec3 --- /dev/null +++ b/src/memory/backend.rs @@ -0,0 +1,146 @@ +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub enum MemoryBackendKind { + Sqlite, + Lucid, + Markdown, + None, + Unknown, +} + +#[allow(clippy::struct_excessive_bools)] +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub struct MemoryBackendProfile { + pub key: &'static str, + pub label: &'static str, + pub auto_save_default: bool, + pub uses_sqlite_hygiene: bool, + pub sqlite_based: bool, + pub optional_dependency: bool, +} + +const SQLITE_PROFILE: MemoryBackendProfile = MemoryBackendProfile { + key: "sqlite", + label: "SQLite with Vector Search (recommended) — fast, hybrid search, embeddings", + auto_save_default: true, + uses_sqlite_hygiene: true, + sqlite_based: true, + optional_dependency: false, +}; + +const LUCID_PROFILE: MemoryBackendProfile = MemoryBackendProfile { + key: "lucid", + label: "Lucid Memory bridge — sync with local lucid-memory CLI, keep SQLite fallback", + auto_save_default: true, + uses_sqlite_hygiene: true, + sqlite_based: true, + optional_dependency: true, +}; + +const MARKDOWN_PROFILE: MemoryBackendProfile = MemoryBackendProfile { + key: "markdown", + label: "Markdown Files — simple, human-readable, no dependencies", + auto_save_default: true, + uses_sqlite_hygiene: false, + sqlite_based: false, + optional_dependency: false, +}; + +const NONE_PROFILE: MemoryBackendProfile = MemoryBackendProfile { + key: "none", + label: "None — disable persistent memory", + auto_save_default: false, + uses_sqlite_hygiene: false, + sqlite_based: false, + optional_dependency: false, +}; + +const CUSTOM_PROFILE: MemoryBackendProfile = MemoryBackendProfile { + key: "custom", + label: "Custom backend — extension point", + auto_save_default: true, + uses_sqlite_hygiene: false, + sqlite_based: false, + optional_dependency: false, +}; + +const SELECTABLE_MEMORY_BACKENDS: [MemoryBackendProfile; 4] = [ + SQLITE_PROFILE, + LUCID_PROFILE, + MARKDOWN_PROFILE, + NONE_PROFILE, +]; + +pub fn selectable_memory_backends() -> &'static [MemoryBackendProfile] { + &SELECTABLE_MEMORY_BACKENDS +} + +pub fn default_memory_backend_key() -> &'static str { + SQLITE_PROFILE.key +} + +pub fn classify_memory_backend(backend: &str) -> MemoryBackendKind { + match backend { + "sqlite" => MemoryBackendKind::Sqlite, + "lucid" => MemoryBackendKind::Lucid, + "markdown" => MemoryBackendKind::Markdown, + "none" => MemoryBackendKind::None, + _ => MemoryBackendKind::Unknown, + } +} + +pub fn memory_backend_profile(backend: &str) -> MemoryBackendProfile { + match classify_memory_backend(backend) { + MemoryBackendKind::Sqlite => SQLITE_PROFILE, + MemoryBackendKind::Lucid => LUCID_PROFILE, + MemoryBackendKind::Markdown => MARKDOWN_PROFILE, + MemoryBackendKind::None => NONE_PROFILE, + MemoryBackendKind::Unknown => CUSTOM_PROFILE, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn classify_known_backends() { + assert_eq!(classify_memory_backend("sqlite"), MemoryBackendKind::Sqlite); + assert_eq!(classify_memory_backend("lucid"), MemoryBackendKind::Lucid); + assert_eq!( + classify_memory_backend("markdown"), + MemoryBackendKind::Markdown + ); + assert_eq!(classify_memory_backend("none"), MemoryBackendKind::None); + } + + #[test] + fn classify_unknown_backend() { + assert_eq!(classify_memory_backend("redis"), MemoryBackendKind::Unknown); + } + + #[test] + fn selectable_backends_are_ordered_for_onboarding() { + let backends = selectable_memory_backends(); + assert_eq!(backends.len(), 4); + assert_eq!(backends[0].key, "sqlite"); + assert_eq!(backends[1].key, "lucid"); + assert_eq!(backends[2].key, "markdown"); + assert_eq!(backends[3].key, "none"); + } + + #[test] + fn lucid_profile_is_sqlite_based_optional_backend() { + let profile = memory_backend_profile("lucid"); + assert!(profile.sqlite_based); + assert!(profile.optional_dependency); + assert!(profile.uses_sqlite_hygiene); + } + + #[test] + fn unknown_profile_preserves_extensibility_defaults() { + let profile = memory_backend_profile("custom-memory"); + assert_eq!(profile.key, "custom"); + assert!(profile.auto_save_default); + assert!(!profile.uses_sqlite_hygiene); + } +} diff --git a/src/memory/embeddings.rs b/src/memory/embeddings.rs index 270ebfe..fdb0cb1 100644 --- a/src/memory/embeddings.rs +++ b/src/memory/embeddings.rs @@ -60,6 +60,35 @@ impl OpenAiEmbedding { dims, } } + + fn has_explicit_api_path(&self) -> bool { + let Ok(url) = reqwest::Url::parse(&self.base_url) else { + return false; + }; + + let path = url.path().trim_end_matches('/'); + !path.is_empty() && path != "/" + } + + fn has_embeddings_endpoint(&self) -> bool { + let Ok(url) = reqwest::Url::parse(&self.base_url) else { + return false; + }; + + url.path().trim_end_matches('/').ends_with("/embeddings") + } + + fn embeddings_url(&self) -> String { + if self.has_embeddings_endpoint() { + return self.base_url.clone(); + } + + if self.has_explicit_api_path() { + format!("{}/embeddings", self.base_url) + } else { + format!("{}/v1/embeddings", self.base_url) + } + } } #[async_trait] @@ -84,7 +113,7 @@ impl EmbeddingProvider for OpenAiEmbedding { let resp = self .client - .post(format!("{}/v1/embeddings", self.base_url)) + .post(self.embeddings_url()) .header("Authorization", format!("Bearer {}", self.api_key)) .header("Content-Type", "application/json") .json(&body) @@ -249,4 +278,44 @@ mod tests { let p = OpenAiEmbedding::new("http://localhost", "k", "m", 384); assert_eq!(p.dimensions(), 384); } + + #[test] + fn embeddings_url_standard_openai() { + let p = OpenAiEmbedding::new("https://api.openai.com", "key", "model", 1536); + assert_eq!(p.embeddings_url(), "https://api.openai.com/v1/embeddings"); + } + + #[test] + fn embeddings_url_base_with_v1_no_duplicate() { + let p = OpenAiEmbedding::new("https://api.example.com/v1", "key", "model", 1536); + assert_eq!(p.embeddings_url(), "https://api.example.com/v1/embeddings"); + } + + #[test] + fn embeddings_url_non_v1_api_path_uses_raw_suffix() { + let p = OpenAiEmbedding::new( + "https://api.example.com/api/coding/v3", + "key", + "model", + 1536, + ); + assert_eq!( + p.embeddings_url(), + "https://api.example.com/api/coding/v3/embeddings" + ); + } + + #[test] + fn embeddings_url_custom_full_endpoint() { + let p = OpenAiEmbedding::new( + "https://my-api.example.com/api/v2/embeddings", + "key", + "model", + 1536, + ); + assert_eq!( + p.embeddings_url(), + "https://my-api.example.com/api/v2/embeddings" + ); + } } diff --git a/src/memory/hygiene.rs b/src/memory/hygiene.rs index 17c95fa..01054ce 100644 --- a/src/memory/hygiene.rs +++ b/src/memory/hygiene.rs @@ -306,6 +306,8 @@ fn prune_conversation_rows(workspace_dir: &Path, retention_days: u32) -> Result< } let conn = Connection::open(db_path)?; + // Use WAL so hygiene pruning doesn't block agent reads + conn.execute_batch("PRAGMA journal_mode = WAL; PRAGMA synchronous = NORMAL;")?; let cutoff = (Local::now() - Duration::days(i64::from(retention_days))).to_rfc3339(); let affected = conn.execute( @@ -326,7 +328,7 @@ fn date_prefix(filename: &str) -> Option { if filename.len() < 10 { return None; } - NaiveDate::parse_from_str(&filename[..10], "%Y-%m-%d").ok() + NaiveDate::parse_from_str(&filename[..filename.floor_char_boundary(10)], "%Y-%m-%d").ok() } fn is_older_than(path: &Path, cutoff: SystemTime) -> bool { @@ -500,10 +502,10 @@ mod tests { let workspace = tmp.path(); let mem = SqliteMemory::new(workspace).unwrap(); - mem.store("conv_old", "outdated", MemoryCategory::Conversation) + mem.store("conv_old", "outdated", MemoryCategory::Conversation, None) .await .unwrap(); - mem.store("core_keep", "durable", MemoryCategory::Core) + mem.store("core_keep", "durable", MemoryCategory::Core, None) .await .unwrap(); drop(mem); diff --git a/src/memory/lucid.rs b/src/memory/lucid.rs new file mode 100644 index 0000000..62af08f --- /dev/null +++ b/src/memory/lucid.rs @@ -0,0 +1,675 @@ +use super::sqlite::SqliteMemory; +use super::traits::{Memory, MemoryCategory, MemoryEntry}; +use async_trait::async_trait; +use chrono::Local; +use parking_lot::Mutex; +use std::collections::HashSet; +use std::path::{Path, PathBuf}; +use std::time::{Duration, Instant}; +use tokio::process::Command; +use tokio::time::timeout; + +pub struct LucidMemory { + local: SqliteMemory, + lucid_cmd: String, + token_budget: usize, + workspace_dir: PathBuf, + recall_timeout: Duration, + store_timeout: Duration, + local_hit_threshold: usize, + failure_cooldown: Duration, + last_failure_at: Mutex>, +} + +impl LucidMemory { + const DEFAULT_LUCID_CMD: &'static str = "lucid"; + const DEFAULT_TOKEN_BUDGET: usize = 200; + // Lucid CLI cold start can exceed 120ms on slower machines, which causes + // avoidable fallback to local-only memory and premature cooldown. + const DEFAULT_RECALL_TIMEOUT_MS: u64 = 500; + const DEFAULT_STORE_TIMEOUT_MS: u64 = 800; + const DEFAULT_LOCAL_HIT_THRESHOLD: usize = 3; + const DEFAULT_FAILURE_COOLDOWN_MS: u64 = 15_000; + + pub fn new(workspace_dir: &Path, local: SqliteMemory) -> Self { + let lucid_cmd = std::env::var("ZEROCLAW_LUCID_CMD") + .unwrap_or_else(|_| Self::DEFAULT_LUCID_CMD.to_string()); + + let token_budget = std::env::var("ZEROCLAW_LUCID_BUDGET") + .ok() + .and_then(|v| v.parse::().ok()) + .filter(|v| *v > 0) + .unwrap_or(Self::DEFAULT_TOKEN_BUDGET); + + let recall_timeout = Self::read_env_duration_ms( + "ZEROCLAW_LUCID_RECALL_TIMEOUT_MS", + Self::DEFAULT_RECALL_TIMEOUT_MS, + 20, + ); + let store_timeout = Self::read_env_duration_ms( + "ZEROCLAW_LUCID_STORE_TIMEOUT_MS", + Self::DEFAULT_STORE_TIMEOUT_MS, + 50, + ); + let local_hit_threshold = Self::read_env_usize( + "ZEROCLAW_LUCID_LOCAL_HIT_THRESHOLD", + Self::DEFAULT_LOCAL_HIT_THRESHOLD, + 1, + ); + let failure_cooldown = Self::read_env_duration_ms( + "ZEROCLAW_LUCID_FAILURE_COOLDOWN_MS", + Self::DEFAULT_FAILURE_COOLDOWN_MS, + 100, + ); + + Self { + local, + lucid_cmd, + token_budget, + workspace_dir: workspace_dir.to_path_buf(), + recall_timeout, + store_timeout, + local_hit_threshold, + failure_cooldown, + last_failure_at: Mutex::new(None), + } + } + + #[cfg(test)] + #[allow(clippy::too_many_arguments)] + fn with_options( + workspace_dir: &Path, + local: SqliteMemory, + lucid_cmd: String, + token_budget: usize, + local_hit_threshold: usize, + recall_timeout: Duration, + store_timeout: Duration, + failure_cooldown: Duration, + ) -> Self { + Self { + local, + lucid_cmd, + token_budget, + workspace_dir: workspace_dir.to_path_buf(), + recall_timeout, + store_timeout, + local_hit_threshold: local_hit_threshold.max(1), + failure_cooldown, + last_failure_at: Mutex::new(None), + } + } + + fn read_env_usize(name: &str, default: usize, min: usize) -> usize { + std::env::var(name) + .ok() + .and_then(|v| v.parse::().ok()) + .map_or(default, |v| v.max(min)) + } + + fn read_env_duration_ms(name: &str, default_ms: u64, min_ms: u64) -> Duration { + let millis = std::env::var(name) + .ok() + .and_then(|v| v.parse::().ok()) + .map_or(default_ms, |v| v.max(min_ms)); + Duration::from_millis(millis) + } + + fn in_failure_cooldown(&self) -> bool { + let guard = self.last_failure_at.lock(); + guard + .as_ref() + .is_some_and(|last| last.elapsed() < self.failure_cooldown) + } + + fn mark_failure_now(&self) { + let mut guard = self.last_failure_at.lock(); + *guard = Some(Instant::now()); + } + + fn clear_failure(&self) { + let mut guard = self.last_failure_at.lock(); + *guard = None; + } + + fn to_lucid_type(category: &MemoryCategory) -> &'static str { + match category { + MemoryCategory::Core => "decision", + MemoryCategory::Daily => "context", + MemoryCategory::Conversation => "conversation", + MemoryCategory::Custom(_) => "learning", + } + } + + fn to_memory_category(label: &str) -> MemoryCategory { + let normalized = label.to_lowercase(); + if normalized.contains("visual") { + return MemoryCategory::Custom("visual".to_string()); + } + + match normalized.as_str() { + "decision" | "learning" | "solution" => MemoryCategory::Core, + "context" | "conversation" => MemoryCategory::Conversation, + "bug" => MemoryCategory::Daily, + other => MemoryCategory::Custom(other.to_string()), + } + } + + fn merge_results( + primary_results: Vec, + secondary_results: Vec, + limit: usize, + ) -> Vec { + if limit == 0 { + return Vec::new(); + } + + let mut merged = Vec::new(); + let mut seen = HashSet::new(); + + for entry in primary_results.into_iter().chain(secondary_results) { + let signature = format!( + "{}\u{0}{}", + entry.key.to_lowercase(), + entry.content.to_lowercase() + ); + + if seen.insert(signature) { + merged.push(entry); + if merged.len() >= limit { + break; + } + } + } + + merged + } + + fn parse_lucid_context(raw: &str) -> Vec { + let mut in_context_block = false; + let mut entries = Vec::new(); + let now = Local::now().to_rfc3339(); + + for line in raw.lines().map(str::trim) { + if line == "" { + in_context_block = true; + continue; + } + + if line == "" { + break; + } + + if !in_context_block || line.is_empty() { + continue; + } + + let Some(rest) = line.strip_prefix("- [") else { + continue; + }; + + let Some((label, content_part)) = rest.split_once(']') else { + continue; + }; + + let content = content_part.trim(); + if content.is_empty() { + continue; + } + + let rank = entries.len(); + entries.push(MemoryEntry { + id: format!("lucid:{rank}"), + key: format!("lucid_{rank}"), + content: content.to_string(), + category: Self::to_memory_category(label.trim()), + timestamp: now.clone(), + session_id: None, + score: Some((1.0 - rank as f64 * 0.05).max(0.1)), + }); + } + + entries + } + + async fn run_lucid_command_raw( + lucid_cmd: &str, + args: &[String], + timeout_window: Duration, + ) -> anyhow::Result { + let mut cmd = Command::new(lucid_cmd); + cmd.args(args); + + let output = timeout(timeout_window, cmd.output()).await.map_err(|_| { + anyhow::anyhow!( + "lucid command timed out after {}ms", + timeout_window.as_millis() + ) + })??; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + anyhow::bail!("lucid command failed: {stderr}"); + } + + Ok(String::from_utf8_lossy(&output.stdout).to_string()) + } + + async fn run_lucid_command( + &self, + args: &[String], + timeout_window: Duration, + ) -> anyhow::Result { + Self::run_lucid_command_raw(&self.lucid_cmd, args, timeout_window).await + } + + fn build_store_args(&self, key: &str, content: &str, category: &MemoryCategory) -> Vec { + let payload = format!("{key}: {content}"); + vec![ + "store".to_string(), + payload, + format!("--type={}", Self::to_lucid_type(category)), + format!("--project={}", self.workspace_dir.display()), + ] + } + + fn build_recall_args(&self, query: &str) -> Vec { + vec![ + "context".to_string(), + query.to_string(), + format!("--budget={}", self.token_budget), + format!("--project={}", self.workspace_dir.display()), + ] + } + + async fn sync_to_lucid_async(&self, key: &str, content: &str, category: &MemoryCategory) { + let args = self.build_store_args(key, content, category); + if let Err(error) = self.run_lucid_command(&args, self.store_timeout).await { + tracing::debug!( + command = %self.lucid_cmd, + error = %error, + "Lucid store sync failed; sqlite remains authoritative" + ); + } + } + + async fn recall_from_lucid(&self, query: &str) -> anyhow::Result> { + let args = self.build_recall_args(query); + let output = self.run_lucid_command(&args, self.recall_timeout).await?; + Ok(Self::parse_lucid_context(&output)) + } +} + +#[async_trait] +impl Memory for LucidMemory { + fn name(&self) -> &str { + "lucid" + } + + async fn store( + &self, + key: &str, + content: &str, + category: MemoryCategory, + session_id: Option<&str>, + ) -> anyhow::Result<()> { + self.local + .store(key, content, category.clone(), session_id) + .await?; + self.sync_to_lucid_async(key, content, &category).await; + Ok(()) + } + + async fn recall( + &self, + query: &str, + limit: usize, + session_id: Option<&str>, + ) -> anyhow::Result> { + let local_results = self.local.recall(query, limit, session_id).await?; + if limit == 0 + || local_results.len() >= limit + || local_results.len() >= self.local_hit_threshold + { + return Ok(local_results); + } + + if self.in_failure_cooldown() { + return Ok(local_results); + } + + match self.recall_from_lucid(query).await { + Ok(lucid_results) if !lucid_results.is_empty() => { + self.clear_failure(); + Ok(Self::merge_results(local_results, lucid_results, limit)) + } + Ok(_) => { + self.clear_failure(); + Ok(local_results) + } + Err(error) => { + self.mark_failure_now(); + tracing::debug!( + command = %self.lucid_cmd, + error = %error, + "Lucid context unavailable; using local sqlite results" + ); + Ok(local_results) + } + } + } + + async fn get(&self, key: &str) -> anyhow::Result> { + self.local.get(key).await + } + + async fn list( + &self, + category: Option<&MemoryCategory>, + session_id: Option<&str>, + ) -> anyhow::Result> { + self.local.list(category, session_id).await + } + + async fn forget(&self, key: &str) -> anyhow::Result { + self.local.forget(key).await + } + + async fn count(&self) -> anyhow::Result { + self.local.count().await + } + + async fn health_check(&self) -> bool { + self.local.health_check().await + } +} + +#[cfg(all(test, unix))] +mod tests { + use super::*; + use std::fs; + use std::os::unix::fs::PermissionsExt; + use tempfile::TempDir; + + fn write_fake_lucid_script(dir: &Path) -> String { + let script_path = dir.join("fake-lucid.sh"); + let script = r#"#!/usr/bin/env bash +set -euo pipefail + +if [[ "${1:-}" == "store" ]]; then + echo '{"success":true,"id":"mem_1"}' + exit 0 +fi + +if [[ "${1:-}" == "context" ]]; then + cat <<'EOF' + +Auth context snapshot +- [decision] Use token refresh middleware +- [context] Working in src/auth.rs + +EOF + exit 0 +fi + +echo "unsupported command" >&2 +exit 1 +"#; + + fs::write(&script_path, script).unwrap(); + let mut perms = fs::metadata(&script_path).unwrap().permissions(); + perms.set_mode(0o755); + fs::set_permissions(&script_path, perms).unwrap(); + script_path.display().to_string() + } + + fn write_delayed_lucid_script(dir: &Path) -> String { + let script_path = dir.join("delayed-lucid.sh"); + let script = r#"#!/usr/bin/env bash +set -euo pipefail + +if [[ "${1:-}" == "store" ]]; then + echo '{"success":true,"id":"mem_1"}' + exit 0 +fi + +if [[ "${1:-}" == "context" ]]; then + # Simulate a cold start that is slower than 120ms but below the 500ms timeout. + sleep 0.2 + cat <<'EOF' + +- [decision] Delayed token refresh guidance + +EOF + exit 0 +fi + +echo "unsupported command" >&2 +exit 1 +"#; + + fs::write(&script_path, script).unwrap(); + let mut perms = fs::metadata(&script_path).unwrap().permissions(); + perms.set_mode(0o755); + fs::set_permissions(&script_path, perms).unwrap(); + script_path.display().to_string() + } + + fn write_probe_lucid_script(dir: &Path, marker_path: &Path) -> String { + let script_path = dir.join("probe-lucid.sh"); + let marker = marker_path.display().to_string(); + let script = format!( + r#"#!/usr/bin/env bash +set -euo pipefail + +if [[ "${{1:-}}" == "store" ]]; then + echo '{{"success":true,"id":"mem_store"}}' + exit 0 +fi + +if [[ "${{1:-}}" == "context" ]]; then + printf 'context\n' >> "{marker}" + cat <<'EOF' + +- [decision] should not be used when local hits are enough + +EOF + exit 0 +fi + +echo "unsupported command" >&2 +exit 1 +"# + ); + + fs::write(&script_path, script).unwrap(); + let mut perms = fs::metadata(&script_path).unwrap().permissions(); + perms.set_mode(0o755); + fs::set_permissions(&script_path, perms).unwrap(); + script_path.display().to_string() + } + + fn test_memory(workspace: &Path, cmd: String) -> LucidMemory { + let sqlite = SqliteMemory::new(workspace).unwrap(); + LucidMemory::with_options( + workspace, + sqlite, + cmd, + 200, + 3, + Duration::from_millis(500), + Duration::from_millis(400), + Duration::from_secs(2), + ) + } + + #[tokio::test] + async fn lucid_name() { + let tmp = TempDir::new().unwrap(); + let memory = test_memory(tmp.path(), "nonexistent-lucid-binary".to_string()); + assert_eq!(memory.name(), "lucid"); + } + + #[tokio::test] + async fn store_succeeds_when_lucid_missing() { + let tmp = TempDir::new().unwrap(); + let memory = test_memory(tmp.path(), "nonexistent-lucid-binary".to_string()); + + memory + .store("lang", "User prefers Rust", MemoryCategory::Core, None) + .await + .unwrap(); + + let entry = memory.get("lang").await.unwrap(); + assert!(entry.is_some()); + assert_eq!(entry.unwrap().content, "User prefers Rust"); + } + + #[tokio::test] + async fn recall_merges_lucid_and_local_results() { + let tmp = TempDir::new().unwrap(); + let fake_cmd = write_fake_lucid_script(tmp.path()); + let memory = test_memory(tmp.path(), fake_cmd); + + memory + .store( + "local_note", + "Local sqlite auth fallback note", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); + + let entries = memory.recall("auth", 5, None).await.unwrap(); + + assert!(entries + .iter() + .any(|e| e.content.contains("Local sqlite auth fallback note"))); + assert!(entries.iter().any(|e| e.content.contains("token refresh"))); + } + + #[tokio::test] + async fn recall_handles_lucid_cold_start_delay_within_timeout() { + let tmp = TempDir::new().unwrap(); + let delayed_cmd = write_delayed_lucid_script(tmp.path()); + let memory = test_memory(tmp.path(), delayed_cmd); + + memory + .store( + "local_note", + "Local sqlite auth fallback note", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); + + let entries = memory.recall("auth", 5, None).await.unwrap(); + + assert!(entries + .iter() + .any(|e| e.content.contains("Local sqlite auth fallback note"))); + assert!(entries + .iter() + .any(|e| e.content.contains("Delayed token refresh guidance"))); + } + + #[tokio::test] + async fn recall_skips_lucid_when_local_hits_are_enough() { + let tmp = TempDir::new().unwrap(); + let marker = tmp.path().join("context_calls.log"); + let probe_cmd = write_probe_lucid_script(tmp.path(), &marker); + + let sqlite = SqliteMemory::new(tmp.path()).unwrap(); + let memory = LucidMemory::with_options( + tmp.path(), + sqlite, + probe_cmd, + 200, + 1, + Duration::from_millis(500), + Duration::from_millis(400), + Duration::from_secs(2), + ); + + memory + .store( + "pref", + "Rust should stay local-first", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); + + let entries = memory.recall("rust", 5, None).await.unwrap(); + assert!(entries + .iter() + .any(|e| e.content.contains("Rust should stay local-first"))); + + let context_calls = fs::read_to_string(&marker).unwrap_or_default(); + assert!( + context_calls.trim().is_empty(), + "Expected local-hit short-circuit; got calls: {context_calls}" + ); + } + + fn write_failing_lucid_script(dir: &Path, marker_path: &Path) -> String { + let script_path = dir.join("failing-lucid.sh"); + let marker = marker_path.display().to_string(); + let script = format!( + r#"#!/usr/bin/env bash +set -euo pipefail + +if [[ "${{1:-}}" == "store" ]]; then + echo '{{"success":true,"id":"mem_store"}}' + exit 0 +fi + +if [[ "${{1:-}}" == "context" ]]; then + printf 'context\n' >> "{marker}" + echo "simulated lucid failure" >&2 + exit 1 +fi + +echo "unsupported command" >&2 +exit 1 +"# + ); + + fs::write(&script_path, script).unwrap(); + let mut perms = fs::metadata(&script_path).unwrap().permissions(); + perms.set_mode(0o755); + fs::set_permissions(&script_path, perms).unwrap(); + script_path.display().to_string() + } + + #[tokio::test] + async fn failure_cooldown_avoids_repeated_lucid_calls() { + let tmp = TempDir::new().unwrap(); + let marker = tmp.path().join("failing_context_calls.log"); + let failing_cmd = write_failing_lucid_script(tmp.path(), &marker); + + let sqlite = SqliteMemory::new(tmp.path()).unwrap(); + let memory = LucidMemory::with_options( + tmp.path(), + sqlite, + failing_cmd, + 200, + 99, + Duration::from_millis(500), + Duration::from_millis(400), + Duration::from_secs(5), + ); + + let first = memory.recall("auth", 5, None).await.unwrap(); + let second = memory.recall("auth", 5, None).await.unwrap(); + + assert!(first.is_empty()); + assert!(second.is_empty()); + + let calls = fs::read_to_string(&marker).unwrap_or_default(); + assert_eq!(calls.lines().count(), 1); + } +} diff --git a/src/memory/markdown.rs b/src/memory/markdown.rs index 8dcd667..9038683 100644 --- a/src/memory/markdown.rs +++ b/src/memory/markdown.rs @@ -143,6 +143,7 @@ impl Memory for MarkdownMemory { key: &str, content: &str, category: MemoryCategory, + _session_id: Option<&str>, ) -> anyhow::Result<()> { let entry = format!("- **{key}**: {content}"); let path = match category { @@ -152,7 +153,12 @@ impl Memory for MarkdownMemory { self.append_to_file(&path, &entry).await } - async fn recall(&self, query: &str, limit: usize) -> anyhow::Result> { + async fn recall( + &self, + query: &str, + limit: usize, + _session_id: Option<&str>, + ) -> anyhow::Result> { let all = self.read_all_entries().await?; let query_lower = query.to_lowercase(); let keywords: Vec<&str> = query_lower.split_whitespace().collect(); @@ -192,7 +198,11 @@ impl Memory for MarkdownMemory { .find(|e| e.key == key || e.content.contains(key))) } - async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result> { + async fn list( + &self, + category: Option<&MemoryCategory>, + _session_id: Option<&str>, + ) -> anyhow::Result> { let all = self.read_all_entries().await?; match category { Some(cat) => Ok(all.into_iter().filter(|e| &e.category == cat).collect()), @@ -243,7 +253,7 @@ mod tests { #[tokio::test] async fn markdown_store_core() { let (_tmp, mem) = temp_workspace(); - mem.store("pref", "User likes Rust", MemoryCategory::Core) + mem.store("pref", "User likes Rust", MemoryCategory::Core, None) .await .unwrap(); let content = sync_fs::read_to_string(mem.core_path()).unwrap(); @@ -253,7 +263,7 @@ mod tests { #[tokio::test] async fn markdown_store_daily() { let (_tmp, mem) = temp_workspace(); - mem.store("note", "Finished tests", MemoryCategory::Daily) + mem.store("note", "Finished tests", MemoryCategory::Daily, None) .await .unwrap(); let path = mem.daily_path(); @@ -264,17 +274,17 @@ mod tests { #[tokio::test] async fn markdown_recall_keyword() { let (_tmp, mem) = temp_workspace(); - mem.store("a", "Rust is fast", MemoryCategory::Core) + mem.store("a", "Rust is fast", MemoryCategory::Core, None) .await .unwrap(); - mem.store("b", "Python is slow", MemoryCategory::Core) + mem.store("b", "Python is slow", MemoryCategory::Core, None) .await .unwrap(); - mem.store("c", "Rust and safety", MemoryCategory::Core) + mem.store("c", "Rust and safety", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("Rust", 10).await.unwrap(); + let results = mem.recall("Rust", 10, None).await.unwrap(); assert!(results.len() >= 2); assert!(results .iter() @@ -284,18 +294,20 @@ mod tests { #[tokio::test] async fn markdown_recall_no_match() { let (_tmp, mem) = temp_workspace(); - mem.store("a", "Rust is great", MemoryCategory::Core) + mem.store("a", "Rust is great", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("javascript", 10).await.unwrap(); + let results = mem.recall("javascript", 10, None).await.unwrap(); assert!(results.is_empty()); } #[tokio::test] async fn markdown_count() { let (_tmp, mem) = temp_workspace(); - mem.store("a", "first", MemoryCategory::Core).await.unwrap(); - mem.store("b", "second", MemoryCategory::Core) + mem.store("a", "first", MemoryCategory::Core, None) + .await + .unwrap(); + mem.store("b", "second", MemoryCategory::Core, None) .await .unwrap(); let count = mem.count().await.unwrap(); @@ -305,24 +317,24 @@ mod tests { #[tokio::test] async fn markdown_list_by_category() { let (_tmp, mem) = temp_workspace(); - mem.store("a", "core fact", MemoryCategory::Core) + mem.store("a", "core fact", MemoryCategory::Core, None) .await .unwrap(); - mem.store("b", "daily note", MemoryCategory::Daily) + mem.store("b", "daily note", MemoryCategory::Daily, None) .await .unwrap(); - let core = mem.list(Some(&MemoryCategory::Core)).await.unwrap(); + let core = mem.list(Some(&MemoryCategory::Core), None).await.unwrap(); assert!(core.iter().all(|e| e.category == MemoryCategory::Core)); - let daily = mem.list(Some(&MemoryCategory::Daily)).await.unwrap(); + let daily = mem.list(Some(&MemoryCategory::Daily), None).await.unwrap(); assert!(daily.iter().all(|e| e.category == MemoryCategory::Daily)); } #[tokio::test] async fn markdown_forget_is_noop() { let (_tmp, mem) = temp_workspace(); - mem.store("a", "permanent", MemoryCategory::Core) + mem.store("a", "permanent", MemoryCategory::Core, None) .await .unwrap(); let removed = mem.forget("a").await.unwrap(); @@ -332,7 +344,7 @@ mod tests { #[tokio::test] async fn markdown_empty_recall() { let (_tmp, mem) = temp_workspace(); - let results = mem.recall("anything", 10).await.unwrap(); + let results = mem.recall("anything", 10, None).await.unwrap(); assert!(results.is_empty()); } diff --git a/src/memory/mod.rs b/src/memory/mod.rs index 66912ca..45b7451 100644 --- a/src/memory/mod.rs +++ b/src/memory/mod.rs @@ -1,12 +1,25 @@ +pub mod backend; pub mod chunker; pub mod embeddings; pub mod hygiene; +pub mod lucid; pub mod markdown; +pub mod none; +pub mod response_cache; +pub mod snapshot; pub mod sqlite; pub mod traits; pub mod vector; +#[allow(unused_imports)] +pub use backend::{ + classify_memory_backend, default_memory_backend_key, memory_backend_profile, + selectable_memory_backends, MemoryBackendKind, MemoryBackendProfile, +}; +pub use lucid::LucidMemory; pub use markdown::MarkdownMemory; +pub use none::NoneMemory; +pub use response_cache::ResponseCache; pub use sqlite::SqliteMemory; pub use traits::Memory; #[allow(unused_imports)] @@ -16,6 +29,32 @@ use crate::config::MemoryConfig; use std::path::Path; use std::sync::Arc; +fn create_memory_with_sqlite_builder( + backend_name: &str, + workspace_dir: &Path, + mut sqlite_builder: F, + unknown_context: &str, +) -> anyhow::Result> +where + F: FnMut() -> anyhow::Result, +{ + match classify_memory_backend(backend_name) { + MemoryBackendKind::Sqlite => Ok(Box::new(sqlite_builder()?)), + MemoryBackendKind::Lucid => { + let local = sqlite_builder()?; + Ok(Box::new(LucidMemory::new(workspace_dir, local))) + } + MemoryBackendKind::Markdown => Ok(Box::new(MarkdownMemory::new(workspace_dir))), + MemoryBackendKind::None => Ok(Box::new(NoneMemory::new())), + MemoryBackendKind::Unknown => { + tracing::warn!( + "Unknown memory backend '{backend_name}'{unknown_context}, falling back to markdown" + ); + Ok(Box::new(MarkdownMemory::new(workspace_dir))) + } + } +} + /// Factory: create the right memory backend from config pub fn create_memory( config: &MemoryConfig, @@ -27,30 +66,107 @@ pub fn create_memory( tracing::warn!("memory hygiene skipped: {e}"); } - match config.backend.as_str() { - "sqlite" => { - let embedder: Arc = - Arc::from(embeddings::create_embedding_provider( - &config.embedding_provider, - api_key, - &config.embedding_model, - config.embedding_dimensions, - )); - - #[allow(clippy::cast_possible_truncation)] - let mem = SqliteMemory::with_embedder( - workspace_dir, - embedder, - config.vector_weight as f32, - config.keyword_weight as f32, - config.embedding_cache_size, - )?; - Ok(Box::new(mem)) + // If snapshot_on_hygiene is enabled, export core memories during hygiene. + if config.snapshot_enabled && config.snapshot_on_hygiene { + if let Err(e) = snapshot::export_snapshot(workspace_dir) { + tracing::warn!("memory snapshot skipped: {e}"); } - "markdown" | "none" => Ok(Box::new(MarkdownMemory::new(workspace_dir))), - other => { - tracing::warn!("Unknown memory backend '{other}', falling back to markdown"); - Ok(Box::new(MarkdownMemory::new(workspace_dir))) + } + + // Auto-hydration: if brain.db is missing but MEMORY_SNAPSHOT.md exists, + // restore the "soul" from the snapshot before creating the backend. + if config.auto_hydrate + && matches!( + classify_memory_backend(&config.backend), + MemoryBackendKind::Sqlite | MemoryBackendKind::Lucid + ) + && snapshot::should_hydrate(workspace_dir) + { + tracing::info!("🧬 Cold boot detected — hydrating from MEMORY_SNAPSHOT.md"); + match snapshot::hydrate_from_snapshot(workspace_dir) { + Ok(count) => { + if count > 0 { + tracing::info!("🧬 Hydrated {count} core memories from snapshot"); + } + } + Err(e) => { + tracing::warn!("memory hydration failed: {e}"); + } + } + } + + fn build_sqlite_memory( + config: &MemoryConfig, + workspace_dir: &Path, + api_key: Option<&str>, + ) -> anyhow::Result { + let embedder: Arc = + Arc::from(embeddings::create_embedding_provider( + &config.embedding_provider, + api_key, + &config.embedding_model, + config.embedding_dimensions, + )); + + #[allow(clippy::cast_possible_truncation)] + let mem = SqliteMemory::with_embedder( + workspace_dir, + embedder, + config.vector_weight as f32, + config.keyword_weight as f32, + config.embedding_cache_size, + )?; + Ok(mem) + } + + create_memory_with_sqlite_builder( + &config.backend, + workspace_dir, + || build_sqlite_memory(config, workspace_dir, api_key), + "", + ) +} + +pub fn create_memory_for_migration( + backend: &str, + workspace_dir: &Path, +) -> anyhow::Result> { + if matches!(classify_memory_backend(backend), MemoryBackendKind::None) { + anyhow::bail!( + "memory backend 'none' disables persistence; choose sqlite, lucid, or markdown before migration" + ); + } + + create_memory_with_sqlite_builder( + backend, + workspace_dir, + || SqliteMemory::new(workspace_dir), + " during migration", + ) +} + +/// Factory: create an optional response cache from config. +pub fn create_response_cache(config: &MemoryConfig, workspace_dir: &Path) -> Option { + if !config.response_cache_enabled { + return None; + } + + match ResponseCache::new( + workspace_dir, + config.response_cache_ttl_minutes, + config.response_cache_max_entries, + ) { + Ok(cache) => { + tracing::info!( + "💾 Response cache enabled (TTL: {}min, max: {} entries)", + config.response_cache_ttl_minutes, + config.response_cache_max_entries + ); + Some(cache) + } + Err(e) => { + tracing::warn!("Response cache disabled due to error: {e}"); + None } } } @@ -83,14 +199,25 @@ mod tests { } #[test] - fn factory_none_falls_back_to_markdown() { + fn factory_lucid() { + let tmp = TempDir::new().unwrap(); + let cfg = MemoryConfig { + backend: "lucid".into(), + ..MemoryConfig::default() + }; + let mem = create_memory(&cfg, tmp.path(), None).unwrap(); + assert_eq!(mem.name(), "lucid"); + } + + #[test] + fn factory_none_uses_noop_memory() { let tmp = TempDir::new().unwrap(); let cfg = MemoryConfig { backend: "none".into(), ..MemoryConfig::default() }; let mem = create_memory(&cfg, tmp.path(), None).unwrap(); - assert_eq!(mem.name(), "markdown"); + assert_eq!(mem.name(), "none"); } #[test] @@ -103,4 +230,20 @@ mod tests { let mem = create_memory(&cfg, tmp.path(), None).unwrap(); assert_eq!(mem.name(), "markdown"); } + + #[test] + fn migration_factory_lucid() { + let tmp = TempDir::new().unwrap(); + let mem = create_memory_for_migration("lucid", tmp.path()).unwrap(); + assert_eq!(mem.name(), "lucid"); + } + + #[test] + fn migration_factory_none_is_rejected() { + let tmp = TempDir::new().unwrap(); + let error = create_memory_for_migration("none", tmp.path()) + .err() + .expect("backend=none should be rejected for migration"); + assert!(error.to_string().contains("disables persistence")); + } } diff --git a/src/memory/none.rs b/src/memory/none.rs new file mode 100644 index 0000000..4ccd2f8 --- /dev/null +++ b/src/memory/none.rs @@ -0,0 +1,87 @@ +use super::traits::{Memory, MemoryCategory, MemoryEntry}; +use async_trait::async_trait; + +/// Explicit no-op memory backend. +/// +/// This backend is used when `memory.backend = "none"` to disable persistence +/// while keeping the runtime wiring stable. +#[derive(Debug, Default, Clone, Copy)] +pub struct NoneMemory; + +impl NoneMemory { + pub fn new() -> Self { + Self + } +} + +#[async_trait] +impl Memory for NoneMemory { + fn name(&self) -> &str { + "none" + } + + async fn store( + &self, + _key: &str, + _content: &str, + _category: MemoryCategory, + _session_id: Option<&str>, + ) -> anyhow::Result<()> { + Ok(()) + } + + async fn recall( + &self, + _query: &str, + _limit: usize, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Ok(Vec::new()) + } + + async fn get(&self, _key: &str) -> anyhow::Result> { + Ok(None) + } + + async fn list( + &self, + _category: Option<&MemoryCategory>, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Ok(Vec::new()) + } + + async fn forget(&self, _key: &str) -> anyhow::Result { + Ok(false) + } + + async fn count(&self) -> anyhow::Result { + Ok(0) + } + + async fn health_check(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn none_memory_is_noop() { + let memory = NoneMemory::new(); + + memory + .store("k", "v", MemoryCategory::Core, None) + .await + .unwrap(); + + assert!(memory.get("k").await.unwrap().is_none()); + assert!(memory.recall("k", 10, None).await.unwrap().is_empty()); + assert!(memory.list(None, None).await.unwrap().is_empty()); + assert!(!memory.forget("k").await.unwrap()); + assert_eq!(memory.count().await.unwrap(), 0); + assert!(memory.health_check().await); + } +} diff --git a/src/memory/response_cache.rs b/src/memory/response_cache.rs new file mode 100644 index 0000000..62fae6c --- /dev/null +++ b/src/memory/response_cache.rs @@ -0,0 +1,351 @@ +//! Response cache — avoid burning tokens on repeated prompts. +//! +//! Stores LLM responses in a separate SQLite table keyed by a SHA-256 hash of +//! `(model, system_prompt_hash, user_prompt)`. Entries expire after a +//! configurable TTL (default: 1 hour). The cache is optional and disabled by +//! default — users opt in via `[memory] response_cache_enabled = true`. + +use anyhow::Result; +use chrono::{Duration, Local}; +use parking_lot::Mutex; +use rusqlite::{params, Connection}; +use sha2::{Digest, Sha256}; +use std::path::{Path, PathBuf}; + +/// Response cache backed by a dedicated SQLite database. +/// +/// Lives alongside `brain.db` as `response_cache.db` so it can be +/// independently wiped without touching memories. +pub struct ResponseCache { + conn: Mutex, + #[allow(dead_code)] + db_path: PathBuf, + ttl_minutes: i64, + max_entries: usize, +} + +impl ResponseCache { + /// Open (or create) the response cache database. + pub fn new(workspace_dir: &Path, ttl_minutes: u32, max_entries: usize) -> Result { + let db_dir = workspace_dir.join("memory"); + std::fs::create_dir_all(&db_dir)?; + let db_path = db_dir.join("response_cache.db"); + + let conn = Connection::open(&db_path)?; + + conn.execute_batch( + "PRAGMA journal_mode = WAL; + PRAGMA synchronous = NORMAL; + PRAGMA temp_store = MEMORY;", + )?; + + conn.execute_batch( + "CREATE TABLE IF NOT EXISTS response_cache ( + prompt_hash TEXT PRIMARY KEY, + model TEXT NOT NULL, + response TEXT NOT NULL, + token_count INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL, + accessed_at TEXT NOT NULL, + hit_count INTEGER NOT NULL DEFAULT 0 + ); + CREATE INDEX IF NOT EXISTS idx_rc_accessed ON response_cache(accessed_at); + CREATE INDEX IF NOT EXISTS idx_rc_created ON response_cache(created_at);", + )?; + + Ok(Self { + conn: Mutex::new(conn), + db_path, + ttl_minutes: i64::from(ttl_minutes), + max_entries, + }) + } + + /// Build a deterministic cache key from model + system prompt + user prompt. + pub fn cache_key(model: &str, system_prompt: Option<&str>, user_prompt: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(model.as_bytes()); + hasher.update(b"|"); + if let Some(sys) = system_prompt { + hasher.update(sys.as_bytes()); + } + hasher.update(b"|"); + hasher.update(user_prompt.as_bytes()); + let hash = hasher.finalize(); + format!("{:064x}", hash) + } + + /// Look up a cached response. Returns `None` on miss or expired entry. + pub fn get(&self, key: &str) -> Result> { + let conn = self.conn.lock(); + + let now = Local::now(); + let cutoff = (now - Duration::minutes(self.ttl_minutes)).to_rfc3339(); + + let mut stmt = conn.prepare( + "SELECT response FROM response_cache + WHERE prompt_hash = ?1 AND created_at > ?2", + )?; + + let result: Option = stmt.query_row(params![key, cutoff], |row| row.get(0)).ok(); + + if result.is_some() { + // Bump hit count and accessed_at + let now_str = now.to_rfc3339(); + conn.execute( + "UPDATE response_cache + SET accessed_at = ?1, hit_count = hit_count + 1 + WHERE prompt_hash = ?2", + params![now_str, key], + )?; + } + + Ok(result) + } + + /// Store a response in the cache. + pub fn put(&self, key: &str, model: &str, response: &str, token_count: u32) -> Result<()> { + let conn = self.conn.lock(); + + let now = Local::now().to_rfc3339(); + + conn.execute( + "INSERT OR REPLACE INTO response_cache + (prompt_hash, model, response, token_count, created_at, accessed_at, hit_count) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, 0)", + params![key, model, response, token_count, now, now], + )?; + + // Evict expired entries + let cutoff = (Local::now() - Duration::minutes(self.ttl_minutes)).to_rfc3339(); + conn.execute( + "DELETE FROM response_cache WHERE created_at <= ?1", + params![cutoff], + )?; + + // LRU eviction if over max_entries + #[allow(clippy::cast_possible_wrap)] + let max = self.max_entries as i64; + conn.execute( + "DELETE FROM response_cache WHERE prompt_hash IN ( + SELECT prompt_hash FROM response_cache + ORDER BY accessed_at ASC + LIMIT MAX(0, (SELECT COUNT(*) FROM response_cache) - ?1) + )", + params![max], + )?; + + Ok(()) + } + + /// Return cache statistics: (total_entries, total_hits, total_tokens_saved). + pub fn stats(&self) -> Result<(usize, u64, u64)> { + let conn = self.conn.lock(); + + let count: i64 = + conn.query_row("SELECT COUNT(*) FROM response_cache", [], |row| row.get(0))?; + + let hits: i64 = conn.query_row( + "SELECT COALESCE(SUM(hit_count), 0) FROM response_cache", + [], + |row| row.get(0), + )?; + + let tokens_saved: i64 = conn.query_row( + "SELECT COALESCE(SUM(token_count * hit_count), 0) FROM response_cache", + [], + |row| row.get(0), + )?; + + #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)] + Ok((count as usize, hits as u64, tokens_saved as u64)) + } + + /// Wipe the entire cache (useful for `zeroclaw cache clear`). + pub fn clear(&self) -> Result { + let conn = self.conn.lock(); + + let affected = conn.execute("DELETE FROM response_cache", [])?; + Ok(affected) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + fn temp_cache(ttl_minutes: u32) -> (TempDir, ResponseCache) { + let tmp = TempDir::new().unwrap(); + let cache = ResponseCache::new(tmp.path(), ttl_minutes, 1000).unwrap(); + (tmp, cache) + } + + #[test] + fn cache_key_deterministic() { + let k1 = ResponseCache::cache_key("gpt-4", Some("sys"), "hello"); + let k2 = ResponseCache::cache_key("gpt-4", Some("sys"), "hello"); + assert_eq!(k1, k2); + assert_eq!(k1.len(), 64); // SHA-256 hex + } + + #[test] + fn cache_key_varies_by_model() { + let k1 = ResponseCache::cache_key("gpt-4", None, "hello"); + let k2 = ResponseCache::cache_key("claude-3", None, "hello"); + assert_ne!(k1, k2); + } + + #[test] + fn cache_key_varies_by_system_prompt() { + let k1 = ResponseCache::cache_key("gpt-4", Some("You are helpful"), "hello"); + let k2 = ResponseCache::cache_key("gpt-4", Some("You are rude"), "hello"); + assert_ne!(k1, k2); + } + + #[test] + fn cache_key_varies_by_prompt() { + let k1 = ResponseCache::cache_key("gpt-4", None, "hello"); + let k2 = ResponseCache::cache_key("gpt-4", None, "goodbye"); + assert_ne!(k1, k2); + } + + #[test] + fn put_and_get() { + let (_tmp, cache) = temp_cache(60); + let key = ResponseCache::cache_key("gpt-4", None, "What is Rust?"); + + cache + .put(&key, "gpt-4", "Rust is a systems programming language.", 25) + .unwrap(); + + let result = cache.get(&key).unwrap(); + assert_eq!( + result.as_deref(), + Some("Rust is a systems programming language.") + ); + } + + #[test] + fn miss_returns_none() { + let (_tmp, cache) = temp_cache(60); + let result = cache.get("nonexistent_key").unwrap(); + assert!(result.is_none()); + } + + #[test] + fn expired_entry_returns_none() { + let (_tmp, cache) = temp_cache(0); // 0-minute TTL → everything is instantly expired + let key = ResponseCache::cache_key("gpt-4", None, "test"); + + cache.put(&key, "gpt-4", "response", 10).unwrap(); + + // The entry was created with created_at = now(), but TTL is 0 minutes, + // so cutoff = now() - 0 = now(). The entry's created_at is NOT > cutoff. + let result = cache.get(&key).unwrap(); + assert!(result.is_none()); + } + + #[test] + fn hit_count_incremented() { + let (_tmp, cache) = temp_cache(60); + let key = ResponseCache::cache_key("gpt-4", None, "hello"); + + cache.put(&key, "gpt-4", "Hi!", 5).unwrap(); + + // 3 hits + for _ in 0..3 { + let _ = cache.get(&key).unwrap(); + } + + let (_, total_hits, _) = cache.stats().unwrap(); + assert_eq!(total_hits, 3); + } + + #[test] + fn tokens_saved_calculated() { + let (_tmp, cache) = temp_cache(60); + let key = ResponseCache::cache_key("gpt-4", None, "explain rust"); + + cache.put(&key, "gpt-4", "Rust is...", 100).unwrap(); + + // 5 cache hits × 100 tokens = 500 tokens saved + for _ in 0..5 { + let _ = cache.get(&key).unwrap(); + } + + let (_, _, tokens_saved) = cache.stats().unwrap(); + assert_eq!(tokens_saved, 500); + } + + #[test] + fn lru_eviction() { + let tmp = TempDir::new().unwrap(); + let cache = ResponseCache::new(tmp.path(), 60, 3).unwrap(); // max 3 entries + + for i in 0..5 { + let key = ResponseCache::cache_key("gpt-4", None, &format!("prompt {i}")); + cache + .put(&key, "gpt-4", &format!("response {i}"), 10) + .unwrap(); + } + + let (count, _, _) = cache.stats().unwrap(); + assert!(count <= 3, "Should have at most 3 entries after eviction"); + } + + #[test] + fn clear_wipes_all() { + let (_tmp, cache) = temp_cache(60); + + for i in 0..10 { + let key = ResponseCache::cache_key("gpt-4", None, &format!("prompt {i}")); + cache + .put(&key, "gpt-4", &format!("response {i}"), 10) + .unwrap(); + } + + let cleared = cache.clear().unwrap(); + assert_eq!(cleared, 10); + + let (count, _, _) = cache.stats().unwrap(); + assert_eq!(count, 0); + } + + #[test] + fn stats_empty_cache() { + let (_tmp, cache) = temp_cache(60); + let (count, hits, tokens) = cache.stats().unwrap(); + assert_eq!(count, 0); + assert_eq!(hits, 0); + assert_eq!(tokens, 0); + } + + #[test] + fn overwrite_same_key() { + let (_tmp, cache) = temp_cache(60); + let key = ResponseCache::cache_key("gpt-4", None, "question"); + + cache.put(&key, "gpt-4", "answer v1", 20).unwrap(); + cache.put(&key, "gpt-4", "answer v2", 25).unwrap(); + + let result = cache.get(&key).unwrap(); + assert_eq!(result.as_deref(), Some("answer v2")); + + let (count, _, _) = cache.stats().unwrap(); + assert_eq!(count, 1); + } + + #[test] + fn unicode_prompt_handling() { + let (_tmp, cache) = temp_cache(60); + let key = ResponseCache::cache_key("gpt-4", None, "日本語のテスト 🦀"); + + cache + .put(&key, "gpt-4", "はい、Rustは素晴らしい", 30) + .unwrap(); + + let result = cache.get(&key).unwrap(); + assert_eq!(result.as_deref(), Some("はい、Rustは素晴らしい")); + } +} diff --git a/src/memory/snapshot.rs b/src/memory/snapshot.rs new file mode 100644 index 0000000..54f766e --- /dev/null +++ b/src/memory/snapshot.rs @@ -0,0 +1,470 @@ +//! Memory snapshot — export/import core memories as human-readable Markdown. +//! +//! **Atomic Soul Export**: dumps `MemoryCategory::Core` from SQLite into +//! `MEMORY_SNAPSHOT.md` so the agent's "soul" is always Git-visible. +//! +//! **Auto-Hydration**: if `brain.db` is missing but `MEMORY_SNAPSHOT.md` exists, +//! re-indexes all entries back into a fresh SQLite database. + +use anyhow::Result; +use chrono::Local; +use rusqlite::{params, Connection}; +use std::fmt::Write; +use std::fs; +use std::path::{Path, PathBuf}; + +/// Filename for the snapshot (lives at workspace root for Git visibility). +pub const SNAPSHOT_FILENAME: &str = "MEMORY_SNAPSHOT.md"; + +/// Header written at the top of every snapshot file. +const SNAPSHOT_HEADER: &str = "# 🧠 ZeroClaw Memory Snapshot\n\n\ + > Auto-generated by ZeroClaw. Do not edit manually unless you know what you're doing.\n\ + > This file is the \"soul\" of your agent — if `brain.db` is lost, start the agent\n\ + > in this workspace and it will auto-hydrate from this file.\n\n"; + +/// Export all `Core` memories from SQLite → `MEMORY_SNAPSHOT.md`. +/// +/// Returns the number of entries exported. +pub fn export_snapshot(workspace_dir: &Path) -> Result { + let db_path = workspace_dir.join("memory").join("brain.db"); + if !db_path.exists() { + tracing::debug!("snapshot export skipped: brain.db does not exist"); + return Ok(0); + } + + let conn = Connection::open(&db_path)?; + conn.execute_batch("PRAGMA journal_mode = WAL; PRAGMA synchronous = NORMAL;")?; + + let mut stmt = conn.prepare( + "SELECT key, content, category, created_at, updated_at + FROM memories + WHERE category = 'core' + ORDER BY updated_at DESC", + )?; + + let rows: Vec<(String, String, String, String, String)> = stmt + .query_map([], |row| { + Ok(( + row.get(0)?, + row.get(1)?, + row.get(2)?, + row.get(3)?, + row.get(4)?, + )) + })? + .filter_map(|r| r.ok()) + .collect(); + + if rows.is_empty() { + tracing::debug!("snapshot export: no core memories to export"); + return Ok(0); + } + + let mut output = String::with_capacity(rows.len() * 200); + output.push_str(SNAPSHOT_HEADER); + + let now = Local::now().format("%Y-%m-%d %H:%M:%S").to_string(); + write!(output, "**Last exported:** {now}\n\n").unwrap(); + write!(output, "**Total core memories:** {}\n\n---\n\n", rows.len()).unwrap(); + + for (key, content, _category, created_at, updated_at) in &rows { + write!(output, "### 🔑 `{key}`\n\n").unwrap(); + write!(output, "{content}\n\n").unwrap(); + write!( + output, + "*Created: {created_at} | Updated: {updated_at}*\n\n---\n\n" + ) + .unwrap(); + } + + let snapshot_path = snapshot_path(workspace_dir); + fs::write(&snapshot_path, output)?; + + tracing::info!( + "📸 Memory snapshot exported: {} core memories → {}", + rows.len(), + snapshot_path.display() + ); + + Ok(rows.len()) +} + +/// Import memories from `MEMORY_SNAPSHOT.md` into SQLite. +/// +/// Called during cold-boot when `brain.db` doesn't exist but the snapshot does. +/// Returns the number of entries hydrated. +pub fn hydrate_from_snapshot(workspace_dir: &Path) -> Result { + let snapshot = snapshot_path(workspace_dir); + if !snapshot.exists() { + return Ok(0); + } + + let content = fs::read_to_string(&snapshot)?; + let entries = parse_snapshot(&content); + + if entries.is_empty() { + return Ok(0); + } + + // Ensure the memory directory exists + let db_dir = workspace_dir.join("memory"); + fs::create_dir_all(&db_dir)?; + + let db_path = db_dir.join("brain.db"); + let conn = Connection::open(&db_path)?; + conn.execute_batch("PRAGMA journal_mode = WAL; PRAGMA synchronous = NORMAL;")?; + + // Initialize schema (same as SqliteMemory::init_schema) + conn.execute_batch( + "CREATE TABLE IF NOT EXISTS memories ( + id TEXT PRIMARY KEY, + key TEXT NOT NULL UNIQUE, + content TEXT NOT NULL, + category TEXT NOT NULL DEFAULT 'core', + embedding BLOB, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_mem_key ON memories(key); + CREATE INDEX IF NOT EXISTS idx_mem_cat ON memories(category); + CREATE INDEX IF NOT EXISTS idx_mem_updated ON memories(updated_at); + + CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts + USING fts5(key, content, content='memories', content_rowid='rowid'); + + CREATE TABLE IF NOT EXISTS embedding_cache ( + content_hash TEXT PRIMARY KEY, + embedding BLOB NOT NULL, + created_at TEXT NOT NULL + );", + )?; + + let now = Local::now().to_rfc3339(); + let mut hydrated = 0; + + for (key, content) in &entries { + let id = uuid::Uuid::new_v4().to_string(); + let result = conn.execute( + "INSERT OR IGNORE INTO memories (id, key, content, category, created_at, updated_at) + VALUES (?1, ?2, ?3, 'core', ?4, ?5)", + params![id, key, content, now, now], + ); + + match result { + Ok(changed) if changed > 0 => { + // Populate FTS5 + let _ = conn.execute( + "INSERT INTO memories_fts(key, content) VALUES (?1, ?2)", + params![key, content], + ); + hydrated += 1; + } + Ok(_) => { + tracing::debug!("hydrate: key '{key}' already exists, skipping"); + } + Err(e) => { + tracing::warn!("hydrate: failed to insert key '{key}': {e}"); + } + } + } + + tracing::info!( + "🧬 Memory hydration complete: {} entries restored from {}", + hydrated, + snapshot.display() + ); + + Ok(hydrated) +} + +/// Check if we should auto-hydrate on startup. +/// +/// Returns `true` if: +/// 1. `brain.db` does NOT exist (or is empty) +/// 2. `MEMORY_SNAPSHOT.md` DOES exist +pub fn should_hydrate(workspace_dir: &Path) -> bool { + let db_path = workspace_dir.join("memory").join("brain.db"); + let snapshot = snapshot_path(workspace_dir); + + let db_missing_or_empty = if db_path.exists() { + // DB exists but might be empty (freshly created) + fs::metadata(&db_path) + .map(|m| m.len() < 4096) // SQLite header is ~4096 bytes minimum + .unwrap_or(true) + } else { + true + }; + + db_missing_or_empty && snapshot.exists() +} + +/// Path to the snapshot file. +fn snapshot_path(workspace_dir: &Path) -> PathBuf { + workspace_dir.join(SNAPSHOT_FILENAME) +} + +/// Parse the structured markdown snapshot back into (key, content) pairs. +fn parse_snapshot(input: &str) -> Vec<(String, String)> { + let mut entries = Vec::new(); + let mut current_key: Option = None; + let mut current_content = String::new(); + + for line in input.lines() { + let trimmed = line.trim(); + + // Match: ### 🔑 `key_name` + if trimmed.starts_with("### 🔑 `") && trimmed.ends_with('`') { + // Save previous entry + if let Some(key) = current_key.take() { + let content = current_content.trim().to_string(); + if !content.is_empty() { + entries.push((key, content)); + } + } + + // Extract new key + let key = trimmed + .strip_prefix("### 🔑 `") + .and_then(|s| s.strip_suffix('`')) + .unwrap_or("") + .to_string(); + + if !key.is_empty() { + current_key = Some(key); + current_content = String::new(); + } + } else if current_key.is_some() { + // Skip metadata lines and separators + if trimmed.starts_with("*Created:") || trimmed == "---" { + continue; + } + // Accumulate content + if !current_content.is_empty() || !trimmed.is_empty() { + if !current_content.is_empty() { + current_content.push('\n'); + } + current_content.push_str(line); + } + } + } + + // Don't forget the last entry + if let Some(key) = current_key { + let content = current_content.trim().to_string(); + if !content.is_empty() { + entries.push((key, content)); + } + } + + entries +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + #[test] + fn parse_snapshot_basic() { + let input = r#"# 🧠 ZeroClaw Memory Snapshot + +> Auto-generated by ZeroClaw. + +**Last exported:** 2025-01-15 14:30:00 + +**Total core memories:** 2 + +--- + +### 🔑 `identity` + +I am ZeroClaw, a self-preserving AI agent. + +*Created: 2025-01-15 | Updated: 2025-01-15* + +--- + +### 🔑 `preference_lang` + +The user prefers Rust for systems programming. + +*Created: 2025-01-14 | Updated: 2025-01-15* + +--- +"#; + + let entries = parse_snapshot(input); + assert_eq!(entries.len(), 2); + assert_eq!(entries[0].0, "identity"); + assert!(entries[0].1.contains("self-preserving")); + assert_eq!(entries[1].0, "preference_lang"); + assert!(entries[1].1.contains("Rust")); + } + + #[test] + fn parse_snapshot_empty() { + let input = "# 🧠 ZeroClaw Memory Snapshot\n\n> Nothing here.\n"; + let entries = parse_snapshot(input); + assert!(entries.is_empty()); + } + + #[test] + fn parse_snapshot_multiline_content() { + let input = r#"### 🔑 `rules` + +Rule 1: Always be helpful. +Rule 2: Never lie. +Rule 3: Protect the user. + +*Created: 2025-01-15 | Updated: 2025-01-15* + +--- +"#; + + let entries = parse_snapshot(input); + assert_eq!(entries.len(), 1); + assert!(entries[0].1.contains("Rule 1")); + assert!(entries[0].1.contains("Rule 3")); + } + + #[test] + fn export_no_db_returns_zero() { + let tmp = TempDir::new().unwrap(); + let count = export_snapshot(tmp.path()).unwrap(); + assert_eq!(count, 0); + } + + #[test] + fn export_and_hydrate_roundtrip() { + let tmp = TempDir::new().unwrap(); + let workspace = tmp.path(); + + // Create a brain.db manually with some core memories + let db_dir = workspace.join("memory"); + fs::create_dir_all(&db_dir).unwrap(); + let db_path = db_dir.join("brain.db"); + + let conn = Connection::open(&db_path).unwrap(); + conn.execute_batch( + "PRAGMA journal_mode = WAL; + CREATE TABLE IF NOT EXISTS memories ( + id TEXT PRIMARY KEY, + key TEXT NOT NULL UNIQUE, + content TEXT NOT NULL, + category TEXT NOT NULL DEFAULT 'core', + embedding BLOB, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_mem_key ON memories(key);", + ) + .unwrap(); + + let now = Local::now().to_rfc3339(); + conn.execute( + "INSERT INTO memories (id, key, content, category, created_at, updated_at) + VALUES ('id1', 'identity', 'I am a test agent', 'core', ?1, ?2)", + params![now, now], + ) + .unwrap(); + conn.execute( + "INSERT INTO memories (id, key, content, category, created_at, updated_at) + VALUES ('id2', 'preference', 'User likes Rust', 'core', ?1, ?2)", + params![now, now], + ) + .unwrap(); + // Non-core entry (should NOT be exported) + conn.execute( + "INSERT INTO memories (id, key, content, category, created_at, updated_at) + VALUES ('id3', 'conv1', 'Random convo', 'conversation', ?1, ?2)", + params![now, now], + ) + .unwrap(); + drop(conn); + + // Export snapshot + let exported = export_snapshot(workspace).unwrap(); + assert_eq!(exported, 2, "Should export only core memories"); + + // Verify the file exists and is readable + let snapshot = workspace.join(SNAPSHOT_FILENAME); + assert!(snapshot.exists()); + let content = fs::read_to_string(&snapshot).unwrap(); + assert!(content.contains("identity")); + assert!(content.contains("I am a test agent")); + assert!(content.contains("preference")); + assert!(!content.contains("Random convo")); + + // Simulate catastrophic failure: delete brain.db + fs::remove_file(&db_path).unwrap(); + assert!(!db_path.exists()); + + // Verify should_hydrate detects the scenario + assert!(should_hydrate(workspace)); + + // Hydrate from snapshot + let hydrated = hydrate_from_snapshot(workspace).unwrap(); + assert_eq!(hydrated, 2, "Should hydrate both core memories"); + + // Verify brain.db was recreated + assert!(db_path.exists()); + + // Verify the data is actually in the new database + let conn = Connection::open(&db_path).unwrap(); + let count: i64 = conn + .query_row("SELECT COUNT(*) FROM memories", [], |row| row.get(0)) + .unwrap(); + assert_eq!(count, 2); + + let identity: String = conn + .query_row( + "SELECT content FROM memories WHERE key = 'identity'", + [], + |row| row.get(0), + ) + .unwrap(); + assert_eq!(identity, "I am a test agent"); + } + + #[test] + fn should_hydrate_only_when_needed() { + let tmp = TempDir::new().unwrap(); + let workspace = tmp.path(); + + // No DB, no snapshot → false + assert!(!should_hydrate(workspace)); + + // Create snapshot but no DB → true + let snapshot = workspace.join(SNAPSHOT_FILENAME); + fs::write(&snapshot, "### 🔑 `test`\n\nHello\n").unwrap(); + assert!(should_hydrate(workspace)); + + // Create a real DB → false + let db_dir = workspace.join("memory"); + fs::create_dir_all(&db_dir).unwrap(); + let db_path = db_dir.join("brain.db"); + let conn = Connection::open(&db_path).unwrap(); + conn.execute_batch( + "CREATE TABLE IF NOT EXISTS memories ( + id TEXT PRIMARY KEY, + key TEXT NOT NULL UNIQUE, + content TEXT NOT NULL, + category TEXT NOT NULL DEFAULT 'core', + embedding BLOB, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL + ); + INSERT INTO memories VALUES('x','x','x','core',NULL,'2025-01-01','2025-01-01');", + ) + .unwrap(); + drop(conn); + assert!(!should_hydrate(workspace)); + } + + #[test] + fn hydrate_no_snapshot_returns_zero() { + let tmp = TempDir::new().unwrap(); + let count = hydrate_from_snapshot(tmp.path()).unwrap(); + assert_eq!(count, 0); + } +} diff --git a/src/memory/sqlite.rs b/src/memory/sqlite.rs index 93e6914..b0addeb 100644 --- a/src/memory/sqlite.rs +++ b/src/memory/sqlite.rs @@ -3,9 +3,10 @@ use super::traits::{Memory, MemoryCategory, MemoryEntry}; use super::vector; use async_trait::async_trait; use chrono::Local; +use parking_lot::Mutex; use rusqlite::{params, Connection}; use std::path::{Path, PathBuf}; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use uuid::Uuid; /// SQLite-backed persistent memory — the brain @@ -50,6 +51,21 @@ impl SqliteMemory { } let conn = Connection::open(&db_path)?; + + // ── Production-grade PRAGMA tuning ────────────────────── + // WAL mode: concurrent reads during writes, crash-safe + // normal sync: 2× write speed, still durable on WAL + // mmap 8 MB: let the OS page-cache serve hot reads + // cache 2 MB: keep ~500 hot pages in-process + // temp_store memory: temp tables never hit disk + conn.execute_batch( + "PRAGMA journal_mode = WAL; + PRAGMA synchronous = NORMAL; + PRAGMA mmap_size = 8388608; + PRAGMA cache_size = -2000; + PRAGMA temp_store = MEMORY;", + )?; + Self::init_schema(&conn)?; Ok(Self { @@ -108,6 +124,19 @@ impl SqliteMemory { ); CREATE INDEX IF NOT EXISTS idx_cache_accessed ON embedding_cache(accessed_at);", )?; + + // Migration: add session_id column if not present (safe to run repeatedly) + let has_session_id: bool = conn + .prepare("SELECT sql FROM sqlite_master WHERE type='table' AND name='memories'")? + .query_row([], |row| row.get::<_, String>(0))? + .contains("session_id"); + if !has_session_id { + conn.execute_batch( + "ALTER TABLE memories ADD COLUMN session_id TEXT; + CREATE INDEX IF NOT EXISTS idx_memories_session ON memories(session_id);", + )?; + } + Ok(()) } @@ -129,13 +158,21 @@ impl SqliteMemory { } } - /// Simple content hash for embedding cache + /// Deterministic content hash for embedding cache. + /// Uses SHA-256 (truncated) instead of DefaultHasher, which is + /// explicitly documented as unstable across Rust versions. fn content_hash(text: &str) -> String { - use std::collections::hash_map::DefaultHasher; - use std::hash::{Hash, Hasher}; - let mut hasher = DefaultHasher::new(); - text.hash(&mut hasher); - format!("{:016x}", hasher.finish()) + use sha2::{Digest, Sha256}; + let hash = Sha256::digest(text.as_bytes()); + // First 8 bytes → 16 hex chars, matching previous format length + format!( + "{:016x}", + u64::from_be_bytes( + hash[..8] + .try_into() + .expect("SHA-256 always produces >= 8 bytes") + ) + ) } /// Get embedding from cache, or compute + cache it @@ -149,10 +186,7 @@ impl SqliteMemory { // Check cache { - let conn = self - .conn - .lock() - .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + let conn = self.conn.lock(); let mut stmt = conn.prepare("SELECT embedding FROM embedding_cache WHERE content_hash = ?1")?; @@ -174,10 +208,7 @@ impl SqliteMemory { // Store in cache + LRU eviction { - let conn = self - .conn - .lock() - .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + let conn = self.conn.lock(); conn.execute( "INSERT OR REPLACE INTO embedding_cache (content_hash, embedding, created_at, accessed_at) @@ -279,10 +310,7 @@ impl SqliteMemory { pub async fn reindex(&self) -> anyhow::Result { // Step 1: Rebuild FTS5 { - let conn = self - .conn - .lock() - .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + let conn = self.conn.lock(); conn.execute_batch("INSERT INTO memories_fts(memories_fts) VALUES('rebuild');")?; } @@ -293,10 +321,7 @@ impl SqliteMemory { } let entries: Vec<(String, String)> = { - let conn = self - .conn - .lock() - .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + let conn = self.conn.lock(); let mut stmt = conn.prepare("SELECT id, content FROM memories WHERE embedding IS NULL")?; @@ -310,10 +335,7 @@ impl SqliteMemory { for (id, content) in &entries { if let Ok(Some(emb)) = self.get_or_compute_embedding(content).await { let bytes = vector::vec_to_bytes(&emb); - let conn = self - .conn - .lock() - .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + let conn = self.conn.lock(); conn.execute( "UPDATE memories SET embedding = ?1 WHERE id = ?2", params![bytes, id], @@ -337,6 +359,7 @@ impl Memory for SqliteMemory { key: &str, content: &str, category: MemoryCategory, + session_id: Option<&str>, ) -> anyhow::Result<()> { // Compute embedding (async, before lock) let embedding_bytes = self @@ -344,29 +367,32 @@ impl Memory for SqliteMemory { .await? .map(|emb| vector::vec_to_bytes(&emb)); - let conn = self - .conn - .lock() - .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + let conn = self.conn.lock(); let now = Local::now().to_rfc3339(); let cat = Self::category_to_str(&category); let id = Uuid::new_v4().to_string(); conn.execute( - "INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at) - VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7) + "INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at, session_id) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8) ON CONFLICT(key) DO UPDATE SET content = excluded.content, category = excluded.category, embedding = excluded.embedding, - updated_at = excluded.updated_at", - params![id, key, content, cat, embedding_bytes, now, now], + updated_at = excluded.updated_at, + session_id = excluded.session_id", + params![id, key, content, cat, embedding_bytes, now, now, session_id], )?; Ok(()) } - async fn recall(&self, query: &str, limit: usize) -> anyhow::Result> { + async fn recall( + &self, + query: &str, + limit: usize, + session_id: Option<&str>, + ) -> anyhow::Result> { if query.trim().is_empty() { return Ok(Vec::new()); } @@ -374,10 +400,7 @@ impl Memory for SqliteMemory { // Compute query embedding (async, before lock) let query_embedding = self.get_or_compute_embedding(query).await?; - let conn = self - .conn - .lock() - .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + let conn = self.conn.lock(); // FTS5 BM25 keyword search let keyword_results = Self::fts5_search(&conn, query, limit * 2).unwrap_or_default(); @@ -415,7 +438,7 @@ impl Memory for SqliteMemory { let mut results = Vec::new(); for scored in &merged { let mut stmt = conn.prepare( - "SELECT id, key, content, category, created_at FROM memories WHERE id = ?1", + "SELECT id, key, content, category, created_at, session_id FROM memories WHERE id = ?1", )?; if let Ok(entry) = stmt.query_row(params![scored.id], |row| { Ok(MemoryEntry { @@ -424,10 +447,16 @@ impl Memory for SqliteMemory { content: row.get(2)?, category: Self::str_to_category(&row.get::<_, String>(3)?), timestamp: row.get(4)?, - session_id: None, + session_id: row.get(5)?, score: Some(f64::from(scored.final_score)), }) }) { + // Filter by session_id if requested + if let Some(sid) = session_id { + if entry.session_id.as_deref() != Some(sid) { + continue; + } + } results.push(entry); } } @@ -446,7 +475,7 @@ impl Memory for SqliteMemory { .collect(); let where_clause = conditions.join(" OR "); let sql = format!( - "SELECT id, key, content, category, created_at FROM memories + "SELECT id, key, content, category, created_at, session_id FROM memories WHERE {where_clause} ORDER BY updated_at DESC LIMIT ?{}", @@ -469,12 +498,18 @@ impl Memory for SqliteMemory { content: row.get(2)?, category: Self::str_to_category(&row.get::<_, String>(3)?), timestamp: row.get(4)?, - session_id: None, + session_id: row.get(5)?, score: Some(1.0), }) })?; for row in rows { - results.push(row?); + let entry = row?; + if let Some(sid) = session_id { + if entry.session_id.as_deref() != Some(sid) { + continue; + } + } + results.push(entry); } } } @@ -484,13 +519,10 @@ impl Memory for SqliteMemory { } async fn get(&self, key: &str) -> anyhow::Result> { - let conn = self - .conn - .lock() - .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + let conn = self.conn.lock(); let mut stmt = conn.prepare( - "SELECT id, key, content, category, created_at FROM memories WHERE key = ?1", + "SELECT id, key, content, category, created_at, session_id FROM memories WHERE key = ?1", )?; let mut rows = stmt.query_map(params![key], |row| { @@ -500,7 +532,7 @@ impl Memory for SqliteMemory { content: row.get(2)?, category: Self::str_to_category(&row.get::<_, String>(3)?), timestamp: row.get(4)?, - session_id: None, + session_id: row.get(5)?, score: None, }) })?; @@ -511,11 +543,12 @@ impl Memory for SqliteMemory { } } - async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result> { - let conn = self - .conn - .lock() - .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + async fn list( + &self, + category: Option<&MemoryCategory>, + session_id: Option<&str>, + ) -> anyhow::Result> { + let conn = self.conn.lock(); let mut results = Vec::new(); @@ -526,7 +559,7 @@ impl Memory for SqliteMemory { content: row.get(2)?, category: Self::str_to_category(&row.get::<_, String>(3)?), timestamp: row.get(4)?, - session_id: None, + session_id: row.get(5)?, score: None, }) }; @@ -534,21 +567,33 @@ impl Memory for SqliteMemory { if let Some(cat) = category { let cat_str = Self::category_to_str(cat); let mut stmt = conn.prepare( - "SELECT id, key, content, category, created_at FROM memories + "SELECT id, key, content, category, created_at, session_id FROM memories WHERE category = ?1 ORDER BY updated_at DESC", )?; let rows = stmt.query_map(params![cat_str], row_mapper)?; for row in rows { - results.push(row?); + let entry = row?; + if let Some(sid) = session_id { + if entry.session_id.as_deref() != Some(sid) { + continue; + } + } + results.push(entry); } } else { let mut stmt = conn.prepare( - "SELECT id, key, content, category, created_at FROM memories + "SELECT id, key, content, category, created_at, session_id FROM memories ORDER BY updated_at DESC", )?; let rows = stmt.query_map([], row_mapper)?; for row in rows { - results.push(row?); + let entry = row?; + if let Some(sid) = session_id { + if entry.session_id.as_deref() != Some(sid) { + continue; + } + } + results.push(entry); } } @@ -556,29 +601,20 @@ impl Memory for SqliteMemory { } async fn forget(&self, key: &str) -> anyhow::Result { - let conn = self - .conn - .lock() - .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + let conn = self.conn.lock(); let affected = conn.execute("DELETE FROM memories WHERE key = ?1", params![key])?; Ok(affected > 0) } async fn count(&self) -> anyhow::Result { - let conn = self - .conn - .lock() - .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + let conn = self.conn.lock(); let count: i64 = conn.query_row("SELECT COUNT(*) FROM memories", [], |row| row.get(0))?; #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)] Ok(count as usize) } async fn health_check(&self) -> bool { - self.conn - .lock() - .map(|c| c.execute_batch("SELECT 1").is_ok()) - .unwrap_or(false) + self.conn.lock().execute_batch("SELECT 1").is_ok() } } @@ -608,7 +644,7 @@ mod tests { #[tokio::test] async fn sqlite_store_and_get() { let (_tmp, mem) = temp_sqlite(); - mem.store("user_lang", "Prefers Rust", MemoryCategory::Core) + mem.store("user_lang", "Prefers Rust", MemoryCategory::Core, None) .await .unwrap(); @@ -623,10 +659,10 @@ mod tests { #[tokio::test] async fn sqlite_store_upsert() { let (_tmp, mem) = temp_sqlite(); - mem.store("pref", "likes Rust", MemoryCategory::Core) + mem.store("pref", "likes Rust", MemoryCategory::Core, None) .await .unwrap(); - mem.store("pref", "loves Rust", MemoryCategory::Core) + mem.store("pref", "loves Rust", MemoryCategory::Core, None) .await .unwrap(); @@ -638,17 +674,22 @@ mod tests { #[tokio::test] async fn sqlite_recall_keyword() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "Rust is fast and safe", MemoryCategory::Core) + mem.store("a", "Rust is fast and safe", MemoryCategory::Core, None) .await .unwrap(); - mem.store("b", "Python is interpreted", MemoryCategory::Core) - .await - .unwrap(); - mem.store("c", "Rust has zero-cost abstractions", MemoryCategory::Core) + mem.store("b", "Python is interpreted", MemoryCategory::Core, None) .await .unwrap(); + mem.store( + "c", + "Rust has zero-cost abstractions", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); - let results = mem.recall("Rust", 10).await.unwrap(); + let results = mem.recall("Rust", 10, None).await.unwrap(); assert_eq!(results.len(), 2); assert!(results .iter() @@ -658,14 +699,14 @@ mod tests { #[tokio::test] async fn sqlite_recall_multi_keyword() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "Rust is fast", MemoryCategory::Core) + mem.store("a", "Rust is fast", MemoryCategory::Core, None) .await .unwrap(); - mem.store("b", "Rust is safe and fast", MemoryCategory::Core) + mem.store("b", "Rust is safe and fast", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("fast safe", 10).await.unwrap(); + let results = mem.recall("fast safe", 10, None).await.unwrap(); assert!(!results.is_empty()); // Entry with both keywords should score higher assert!(results[0].content.contains("safe") && results[0].content.contains("fast")); @@ -674,17 +715,17 @@ mod tests { #[tokio::test] async fn sqlite_recall_no_match() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "Rust rocks", MemoryCategory::Core) + mem.store("a", "Rust rocks", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("javascript", 10).await.unwrap(); + let results = mem.recall("javascript", 10, None).await.unwrap(); assert!(results.is_empty()); } #[tokio::test] async fn sqlite_forget() { let (_tmp, mem) = temp_sqlite(); - mem.store("temp", "temporary data", MemoryCategory::Conversation) + mem.store("temp", "temporary data", MemoryCategory::Conversation, None) .await .unwrap(); assert_eq!(mem.count().await.unwrap(), 1); @@ -704,29 +745,37 @@ mod tests { #[tokio::test] async fn sqlite_list_all() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "one", MemoryCategory::Core).await.unwrap(); - mem.store("b", "two", MemoryCategory::Daily).await.unwrap(); - mem.store("c", "three", MemoryCategory::Conversation) + mem.store("a", "one", MemoryCategory::Core, None) + .await + .unwrap(); + mem.store("b", "two", MemoryCategory::Daily, None) + .await + .unwrap(); + mem.store("c", "three", MemoryCategory::Conversation, None) .await .unwrap(); - let all = mem.list(None).await.unwrap(); + let all = mem.list(None, None).await.unwrap(); assert_eq!(all.len(), 3); } #[tokio::test] async fn sqlite_list_by_category() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "core1", MemoryCategory::Core).await.unwrap(); - mem.store("b", "core2", MemoryCategory::Core).await.unwrap(); - mem.store("c", "daily1", MemoryCategory::Daily) + mem.store("a", "core1", MemoryCategory::Core, None) + .await + .unwrap(); + mem.store("b", "core2", MemoryCategory::Core, None) + .await + .unwrap(); + mem.store("c", "daily1", MemoryCategory::Daily, None) .await .unwrap(); - let core = mem.list(Some(&MemoryCategory::Core)).await.unwrap(); + let core = mem.list(Some(&MemoryCategory::Core), None).await.unwrap(); assert_eq!(core.len(), 2); - let daily = mem.list(Some(&MemoryCategory::Daily)).await.unwrap(); + let daily = mem.list(Some(&MemoryCategory::Daily), None).await.unwrap(); assert_eq!(daily.len(), 1); } @@ -748,7 +797,7 @@ mod tests { { let mem = SqliteMemory::new(tmp.path()).unwrap(); - mem.store("persist", "I survive restarts", MemoryCategory::Core) + mem.store("persist", "I survive restarts", MemoryCategory::Core, None) .await .unwrap(); } @@ -771,7 +820,7 @@ mod tests { ]; for (i, cat) in categories.iter().enumerate() { - mem.store(&format!("k{i}"), &format!("v{i}"), cat.clone()) + mem.store(&format!("k{i}"), &format!("v{i}"), cat.clone(), None) .await .unwrap(); } @@ -791,21 +840,28 @@ mod tests { "a", "Rust is a systems programming language", MemoryCategory::Core, + None, + ) + .await + .unwrap(); + mem.store( + "b", + "Python is great for scripting", + MemoryCategory::Core, + None, ) .await .unwrap(); - mem.store("b", "Python is great for scripting", MemoryCategory::Core) - .await - .unwrap(); mem.store( "c", "Rust and Rust and Rust everywhere", MemoryCategory::Core, + None, ) .await .unwrap(); - let results = mem.recall("Rust", 10).await.unwrap(); + let results = mem.recall("Rust", 10, None).await.unwrap(); assert!(results.len() >= 2); // All results should contain "Rust" for r in &results { @@ -820,17 +876,17 @@ mod tests { #[tokio::test] async fn fts5_multi_word_query() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "The quick brown fox jumps", MemoryCategory::Core) + mem.store("a", "The quick brown fox jumps", MemoryCategory::Core, None) .await .unwrap(); - mem.store("b", "A lazy dog sleeps", MemoryCategory::Core) + mem.store("b", "A lazy dog sleeps", MemoryCategory::Core, None) .await .unwrap(); - mem.store("c", "The quick dog runs fast", MemoryCategory::Core) + mem.store("c", "The quick dog runs fast", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("quick dog", 10).await.unwrap(); + let results = mem.recall("quick dog", 10, None).await.unwrap(); assert!(!results.is_empty()); // "The quick dog runs fast" matches both terms assert!(results[0].content.contains("quick")); @@ -839,16 +895,20 @@ mod tests { #[tokio::test] async fn recall_empty_query_returns_empty() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "data", MemoryCategory::Core).await.unwrap(); - let results = mem.recall("", 10).await.unwrap(); + mem.store("a", "data", MemoryCategory::Core, None) + .await + .unwrap(); + let results = mem.recall("", 10, None).await.unwrap(); assert!(results.is_empty()); } #[tokio::test] async fn recall_whitespace_query_returns_empty() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "data", MemoryCategory::Core).await.unwrap(); - let results = mem.recall(" ", 10).await.unwrap(); + mem.store("a", "data", MemoryCategory::Core, None) + .await + .unwrap(); + let results = mem.recall(" ", 10, None).await.unwrap(); assert!(results.is_empty()); } @@ -873,7 +933,7 @@ mod tests { #[tokio::test] async fn schema_has_fts5_table() { let (_tmp, mem) = temp_sqlite(); - let conn = mem.conn.lock().unwrap(); + let conn = mem.conn.lock(); // FTS5 table should exist let count: i64 = conn .query_row( @@ -888,7 +948,7 @@ mod tests { #[tokio::test] async fn schema_has_embedding_cache() { let (_tmp, mem) = temp_sqlite(); - let conn = mem.conn.lock().unwrap(); + let conn = mem.conn.lock(); let count: i64 = conn .query_row( "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='embedding_cache'", @@ -902,7 +962,7 @@ mod tests { #[tokio::test] async fn schema_memories_has_embedding_column() { let (_tmp, mem) = temp_sqlite(); - let conn = mem.conn.lock().unwrap(); + let conn = mem.conn.lock(); // Check that embedding column exists by querying it let result = conn.execute_batch("SELECT embedding FROM memories LIMIT 0"); assert!(result.is_ok()); @@ -913,11 +973,16 @@ mod tests { #[tokio::test] async fn fts5_syncs_on_insert() { let (_tmp, mem) = temp_sqlite(); - mem.store("test_key", "unique_searchterm_xyz", MemoryCategory::Core) - .await - .unwrap(); + mem.store( + "test_key", + "unique_searchterm_xyz", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); - let conn = mem.conn.lock().unwrap(); + let conn = mem.conn.lock(); let count: i64 = conn .query_row( "SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"unique_searchterm_xyz\"'", @@ -931,12 +996,17 @@ mod tests { #[tokio::test] async fn fts5_syncs_on_delete() { let (_tmp, mem) = temp_sqlite(); - mem.store("del_key", "deletable_content_abc", MemoryCategory::Core) - .await - .unwrap(); + mem.store( + "del_key", + "deletable_content_abc", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); mem.forget("del_key").await.unwrap(); - let conn = mem.conn.lock().unwrap(); + let conn = mem.conn.lock(); let count: i64 = conn .query_row( "SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"deletable_content_abc\"'", @@ -950,14 +1020,19 @@ mod tests { #[tokio::test] async fn fts5_syncs_on_update() { let (_tmp, mem) = temp_sqlite(); - mem.store("upd_key", "original_content_111", MemoryCategory::Core) - .await - .unwrap(); - mem.store("upd_key", "updated_content_222", MemoryCategory::Core) + mem.store( + "upd_key", + "original_content_111", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); + mem.store("upd_key", "updated_content_222", MemoryCategory::Core, None) .await .unwrap(); - let conn = mem.conn.lock().unwrap(); + let conn = mem.conn.lock(); // Old content should not be findable let old: i64 = conn .query_row( @@ -995,10 +1070,10 @@ mod tests { #[tokio::test] async fn reindex_rebuilds_fts() { let (_tmp, mem) = temp_sqlite(); - mem.store("r1", "reindex test alpha", MemoryCategory::Core) + mem.store("r1", "reindex test alpha", MemoryCategory::Core, None) .await .unwrap(); - mem.store("r2", "reindex test beta", MemoryCategory::Core) + mem.store("r2", "reindex test beta", MemoryCategory::Core, None) .await .unwrap(); @@ -1007,7 +1082,7 @@ mod tests { assert_eq!(count, 0); // FTS should still work after rebuild - let results = mem.recall("reindex", 10).await.unwrap(); + let results = mem.recall("reindex", 10, None).await.unwrap(); assert_eq!(results.len(), 2); } @@ -1021,12 +1096,13 @@ mod tests { &format!("k{i}"), &format!("common keyword item {i}"), MemoryCategory::Core, + None, ) .await .unwrap(); } - let results = mem.recall("common keyword", 5).await.unwrap(); + let results = mem.recall("common keyword", 5, None).await.unwrap(); assert!(results.len() <= 5); } @@ -1035,11 +1111,11 @@ mod tests { #[tokio::test] async fn recall_results_have_scores() { let (_tmp, mem) = temp_sqlite(); - mem.store("s1", "scored result test", MemoryCategory::Core) + mem.store("s1", "scored result test", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("scored", 10).await.unwrap(); + let results = mem.recall("scored", 10, None).await.unwrap(); assert!(!results.is_empty()); for r in &results { assert!(r.score.is_some(), "Expected score on result: {:?}", r.key); @@ -1051,11 +1127,11 @@ mod tests { #[tokio::test] async fn recall_with_quotes_in_query() { let (_tmp, mem) = temp_sqlite(); - mem.store("q1", "He said hello world", MemoryCategory::Core) + mem.store("q1", "He said hello world", MemoryCategory::Core, None) .await .unwrap(); // Quotes in query should not crash FTS5 - let results = mem.recall("\"hello\"", 10).await.unwrap(); + let results = mem.recall("\"hello\"", 10, None).await.unwrap(); // May or may not match depending on FTS5 escaping, but must not error assert!(results.len() <= 10); } @@ -1063,31 +1139,34 @@ mod tests { #[tokio::test] async fn recall_with_asterisk_in_query() { let (_tmp, mem) = temp_sqlite(); - mem.store("a1", "wildcard test content", MemoryCategory::Core) + mem.store("a1", "wildcard test content", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("wild*", 10).await.unwrap(); + let results = mem.recall("wild*", 10, None).await.unwrap(); assert!(results.len() <= 10); } #[tokio::test] async fn recall_with_parentheses_in_query() { let (_tmp, mem) = temp_sqlite(); - mem.store("p1", "function call test", MemoryCategory::Core) + mem.store("p1", "function call test", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("function()", 10).await.unwrap(); + let results = mem.recall("function()", 10, None).await.unwrap(); assert!(results.len() <= 10); } #[tokio::test] async fn recall_with_sql_injection_attempt() { let (_tmp, mem) = temp_sqlite(); - mem.store("safe", "normal content", MemoryCategory::Core) + mem.store("safe", "normal content", MemoryCategory::Core, None) .await .unwrap(); // Should not crash or leak data - let results = mem.recall("'; DROP TABLE memories; --", 10).await.unwrap(); + let results = mem + .recall("'; DROP TABLE memories; --", 10, None) + .await + .unwrap(); assert!(results.len() <= 10); // Table should still exist assert_eq!(mem.count().await.unwrap(), 1); @@ -1098,7 +1177,9 @@ mod tests { #[tokio::test] async fn store_empty_content() { let (_tmp, mem) = temp_sqlite(); - mem.store("empty", "", MemoryCategory::Core).await.unwrap(); + mem.store("empty", "", MemoryCategory::Core, None) + .await + .unwrap(); let entry = mem.get("empty").await.unwrap().unwrap(); assert_eq!(entry.content, ""); } @@ -1106,7 +1187,7 @@ mod tests { #[tokio::test] async fn store_empty_key() { let (_tmp, mem) = temp_sqlite(); - mem.store("", "content for empty key", MemoryCategory::Core) + mem.store("", "content for empty key", MemoryCategory::Core, None) .await .unwrap(); let entry = mem.get("").await.unwrap().unwrap(); @@ -1117,7 +1198,7 @@ mod tests { async fn store_very_long_content() { let (_tmp, mem) = temp_sqlite(); let long_content = "x".repeat(100_000); - mem.store("long", &long_content, MemoryCategory::Core) + mem.store("long", &long_content, MemoryCategory::Core, None) .await .unwrap(); let entry = mem.get("long").await.unwrap().unwrap(); @@ -1127,9 +1208,14 @@ mod tests { #[tokio::test] async fn store_unicode_and_emoji() { let (_tmp, mem) = temp_sqlite(); - mem.store("emoji_key_🦀", "こんにちは 🚀 Ñoño", MemoryCategory::Core) - .await - .unwrap(); + mem.store( + "emoji_key_🦀", + "こんにちは 🚀 Ñoño", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); let entry = mem.get("emoji_key_🦀").await.unwrap().unwrap(); assert_eq!(entry.content, "こんにちは 🚀 Ñoño"); } @@ -1138,7 +1224,7 @@ mod tests { async fn store_content_with_newlines_and_tabs() { let (_tmp, mem) = temp_sqlite(); let content = "line1\nline2\ttab\rcarriage\n\nnewparagraph"; - mem.store("whitespace", content, MemoryCategory::Core) + mem.store("whitespace", content, MemoryCategory::Core, None) .await .unwrap(); let entry = mem.get("whitespace").await.unwrap().unwrap(); @@ -1150,11 +1236,11 @@ mod tests { #[tokio::test] async fn recall_single_character_query() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "x marks the spot", MemoryCategory::Core) + mem.store("a", "x marks the spot", MemoryCategory::Core, None) .await .unwrap(); // Single char may not match FTS5 but LIKE fallback should work - let results = mem.recall("x", 10).await.unwrap(); + let results = mem.recall("x", 10, None).await.unwrap(); // Should not crash; may or may not find results assert!(results.len() <= 10); } @@ -1162,23 +1248,23 @@ mod tests { #[tokio::test] async fn recall_limit_zero() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "some content", MemoryCategory::Core) + mem.store("a", "some content", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("some", 0).await.unwrap(); + let results = mem.recall("some", 0, None).await.unwrap(); assert!(results.is_empty()); } #[tokio::test] async fn recall_limit_one() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "matching content alpha", MemoryCategory::Core) + mem.store("a", "matching content alpha", MemoryCategory::Core, None) .await .unwrap(); - mem.store("b", "matching content beta", MemoryCategory::Core) + mem.store("b", "matching content beta", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("matching content", 1).await.unwrap(); + let results = mem.recall("matching content", 1, None).await.unwrap(); assert_eq!(results.len(), 1); } @@ -1189,21 +1275,22 @@ mod tests { "rust_preferences", "User likes systems programming", MemoryCategory::Core, + None, ) .await .unwrap(); // "rust" appears in key but not content — LIKE fallback checks key too - let results = mem.recall("rust", 10).await.unwrap(); + let results = mem.recall("rust", 10, None).await.unwrap(); assert!(!results.is_empty(), "Should match by key"); } #[tokio::test] async fn recall_unicode_query() { let (_tmp, mem) = temp_sqlite(); - mem.store("jp", "日本語のテスト", MemoryCategory::Core) + mem.store("jp", "日本語のテスト", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("日本語", 10).await.unwrap(); + let results = mem.recall("日本語", 10, None).await.unwrap(); assert!(!results.is_empty()); } @@ -1214,7 +1301,9 @@ mod tests { let tmp = TempDir::new().unwrap(); { let mem = SqliteMemory::new(tmp.path()).unwrap(); - mem.store("k1", "v1", MemoryCategory::Core).await.unwrap(); + mem.store("k1", "v1", MemoryCategory::Core, None) + .await + .unwrap(); } // Open again — init_schema runs again on existing DB let mem2 = SqliteMemory::new(tmp.path()).unwrap(); @@ -1222,7 +1311,9 @@ mod tests { assert!(entry.is_some()); assert_eq!(entry.unwrap().content, "v1"); // Store more data — should work fine - mem2.store("k2", "v2", MemoryCategory::Daily).await.unwrap(); + mem2.store("k2", "v2", MemoryCategory::Daily, None) + .await + .unwrap(); assert_eq!(mem2.count().await.unwrap(), 2); } @@ -1240,11 +1331,16 @@ mod tests { #[tokio::test] async fn forget_then_recall_no_ghost_results() { let (_tmp, mem) = temp_sqlite(); - mem.store("ghost", "phantom memory content", MemoryCategory::Core) - .await - .unwrap(); + mem.store( + "ghost", + "phantom memory content", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); mem.forget("ghost").await.unwrap(); - let results = mem.recall("phantom memory", 10).await.unwrap(); + let results = mem.recall("phantom memory", 10, None).await.unwrap(); assert!( results.is_empty(), "Deleted memory should not appear in recall" @@ -1254,11 +1350,11 @@ mod tests { #[tokio::test] async fn forget_and_re_store_same_key() { let (_tmp, mem) = temp_sqlite(); - mem.store("cycle", "version 1", MemoryCategory::Core) + mem.store("cycle", "version 1", MemoryCategory::Core, None) .await .unwrap(); mem.forget("cycle").await.unwrap(); - mem.store("cycle", "version 2", MemoryCategory::Core) + mem.store("cycle", "version 2", MemoryCategory::Core, None) .await .unwrap(); let entry = mem.get("cycle").await.unwrap().unwrap(); @@ -1278,14 +1374,14 @@ mod tests { #[tokio::test] async fn reindex_twice_is_safe() { let (_tmp, mem) = temp_sqlite(); - mem.store("r1", "reindex data", MemoryCategory::Core) + mem.store("r1", "reindex data", MemoryCategory::Core, None) .await .unwrap(); mem.reindex().await.unwrap(); let count = mem.reindex().await.unwrap(); assert_eq!(count, 0); // Noop embedder → nothing to re-embed // Data should still be intact - let results = mem.recall("reindex", 10).await.unwrap(); + let results = mem.recall("reindex", 10, None).await.unwrap(); assert_eq!(results.len(), 1); } @@ -1339,18 +1435,28 @@ mod tests { #[tokio::test] async fn list_custom_category() { let (_tmp, mem) = temp_sqlite(); - mem.store("c1", "custom1", MemoryCategory::Custom("project".into())) - .await - .unwrap(); - mem.store("c2", "custom2", MemoryCategory::Custom("project".into())) - .await - .unwrap(); - mem.store("c3", "other", MemoryCategory::Core) + mem.store( + "c1", + "custom1", + MemoryCategory::Custom("project".into()), + None, + ) + .await + .unwrap(); + mem.store( + "c2", + "custom2", + MemoryCategory::Custom("project".into()), + None, + ) + .await + .unwrap(); + mem.store("c3", "other", MemoryCategory::Core, None) .await .unwrap(); let project = mem - .list(Some(&MemoryCategory::Custom("project".into()))) + .list(Some(&MemoryCategory::Custom("project".into())), None) .await .unwrap(); assert_eq!(project.len(), 2); @@ -1359,7 +1465,122 @@ mod tests { #[tokio::test] async fn list_empty_db() { let (_tmp, mem) = temp_sqlite(); - let all = mem.list(None).await.unwrap(); + let all = mem.list(None, None).await.unwrap(); assert!(all.is_empty()); } + + // ── Session isolation ───────────────────────────────────────── + + #[tokio::test] + async fn store_and_recall_with_session_id() { + let (_tmp, mem) = temp_sqlite(); + mem.store("k1", "session A fact", MemoryCategory::Core, Some("sess-a")) + .await + .unwrap(); + mem.store("k2", "session B fact", MemoryCategory::Core, Some("sess-b")) + .await + .unwrap(); + mem.store("k3", "no session fact", MemoryCategory::Core, None) + .await + .unwrap(); + + // Recall with session-a filter returns only session-a entry + let results = mem.recall("fact", 10, Some("sess-a")).await.unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results[0].key, "k1"); + assert_eq!(results[0].session_id.as_deref(), Some("sess-a")); + } + + #[tokio::test] + async fn recall_no_session_filter_returns_all() { + let (_tmp, mem) = temp_sqlite(); + mem.store("k1", "alpha fact", MemoryCategory::Core, Some("sess-a")) + .await + .unwrap(); + mem.store("k2", "beta fact", MemoryCategory::Core, Some("sess-b")) + .await + .unwrap(); + mem.store("k3", "gamma fact", MemoryCategory::Core, None) + .await + .unwrap(); + + // Recall without session filter returns all matching entries + let results = mem.recall("fact", 10, None).await.unwrap(); + assert_eq!(results.len(), 3); + } + + #[tokio::test] + async fn cross_session_recall_isolation() { + let (_tmp, mem) = temp_sqlite(); + mem.store( + "secret", + "session A secret data", + MemoryCategory::Core, + Some("sess-a"), + ) + .await + .unwrap(); + + // Session B cannot see session A data + let results = mem.recall("secret", 10, Some("sess-b")).await.unwrap(); + assert!(results.is_empty()); + + // Session A can see its own data + let results = mem.recall("secret", 10, Some("sess-a")).await.unwrap(); + assert_eq!(results.len(), 1); + } + + #[tokio::test] + async fn list_with_session_filter() { + let (_tmp, mem) = temp_sqlite(); + mem.store("k1", "a1", MemoryCategory::Core, Some("sess-a")) + .await + .unwrap(); + mem.store("k2", "a2", MemoryCategory::Conversation, Some("sess-a")) + .await + .unwrap(); + mem.store("k3", "b1", MemoryCategory::Core, Some("sess-b")) + .await + .unwrap(); + mem.store("k4", "none1", MemoryCategory::Core, None) + .await + .unwrap(); + + // List with session-a filter + let results = mem.list(None, Some("sess-a")).await.unwrap(); + assert_eq!(results.len(), 2); + assert!(results + .iter() + .all(|e| e.session_id.as_deref() == Some("sess-a"))); + + // List with session-a + category filter + let results = mem + .list(Some(&MemoryCategory::Core), Some("sess-a")) + .await + .unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results[0].key, "k1"); + } + + #[tokio::test] + async fn schema_migration_idempotent_on_reopen() { + let tmp = TempDir::new().unwrap(); + + // First open: creates schema + migration + { + let mem = SqliteMemory::new(tmp.path()).unwrap(); + mem.store("k1", "before reopen", MemoryCategory::Core, Some("sess-x")) + .await + .unwrap(); + } + + // Second open: migration runs again but is idempotent + { + let mem = SqliteMemory::new(tmp.path()).unwrap(); + let results = mem.recall("reopen", 10, Some("sess-x")).await.unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results[0].key, "k1"); + assert_eq!(results[0].session_id.as_deref(), Some("sess-x")); + } + } } diff --git a/src/memory/traits.rs b/src/memory/traits.rs index 16d8fa6..bf8c021 100644 --- a/src/memory/traits.rs +++ b/src/memory/traits.rs @@ -44,18 +44,32 @@ pub trait Memory: Send + Sync { /// Backend name fn name(&self) -> &str; - /// Store a memory entry - async fn store(&self, key: &str, content: &str, category: MemoryCategory) - -> anyhow::Result<()>; + /// Store a memory entry, optionally scoped to a session + async fn store( + &self, + key: &str, + content: &str, + category: MemoryCategory, + session_id: Option<&str>, + ) -> anyhow::Result<()>; - /// Recall memories matching a query (keyword search) - async fn recall(&self, query: &str, limit: usize) -> anyhow::Result>; + /// Recall memories matching a query (keyword search), optionally scoped to a session + async fn recall( + &self, + query: &str, + limit: usize, + session_id: Option<&str>, + ) -> anyhow::Result>; /// Get a specific memory by key async fn get(&self, key: &str) -> anyhow::Result>; - /// List all memory keys, optionally filtered by category - async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result>; + /// List all memory keys, optionally filtered by category and/or session + async fn list( + &self, + category: Option<&MemoryCategory>, + session_id: Option<&str>, + ) -> anyhow::Result>; /// Remove a memory by key async fn forget(&self, key: &str) -> anyhow::Result; @@ -66,3 +80,53 @@ pub trait Memory: Send + Sync { /// Health check async fn health_check(&self) -> bool; } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn memory_category_display_outputs_expected_values() { + assert_eq!(MemoryCategory::Core.to_string(), "core"); + assert_eq!(MemoryCategory::Daily.to_string(), "daily"); + assert_eq!(MemoryCategory::Conversation.to_string(), "conversation"); + assert_eq!( + MemoryCategory::Custom("project_notes".into()).to_string(), + "project_notes" + ); + } + + #[test] + fn memory_category_serde_uses_snake_case() { + let core = serde_json::to_string(&MemoryCategory::Core).unwrap(); + let daily = serde_json::to_string(&MemoryCategory::Daily).unwrap(); + let conversation = serde_json::to_string(&MemoryCategory::Conversation).unwrap(); + + assert_eq!(core, "\"core\""); + assert_eq!(daily, "\"daily\""); + assert_eq!(conversation, "\"conversation\""); + } + + #[test] + fn memory_entry_roundtrip_preserves_optional_fields() { + let entry = MemoryEntry { + id: "id-1".into(), + key: "favorite_language".into(), + content: "Rust".into(), + category: MemoryCategory::Core, + timestamp: "2026-02-16T00:00:00Z".into(), + session_id: Some("session-abc".into()), + score: Some(0.98), + }; + + let json = serde_json::to_string(&entry).unwrap(); + let parsed: MemoryEntry = serde_json::from_str(&json).unwrap(); + + assert_eq!(parsed.id, "id-1"); + assert_eq!(parsed.key, "favorite_language"); + assert_eq!(parsed.content, "Rust"); + assert_eq!(parsed.category, MemoryCategory::Core); + assert_eq!(parsed.session_id.as_deref(), Some("session-abc")); + assert_eq!(parsed.score, Some(0.98)); + } +} diff --git a/src/migration.rs b/src/migration.rs index 2ce29ba..8a83262 100644 --- a/src/migration.rs +++ b/src/migration.rs @@ -1,5 +1,5 @@ use crate::config::Config; -use crate::memory::{MarkdownMemory, Memory, MemoryCategory, SqliteMemory}; +use crate::memory::{self, Memory, MemoryCategory}; use anyhow::{bail, Context, Result}; use directories::UserDirs; use rusqlite::{Connection, OpenFlags, OptionalExtension}; @@ -23,9 +23,9 @@ struct MigrationStats { renamed_conflicts: usize, } -pub async fn handle_command(command: super::MigrateCommands, config: &Config) -> Result<()> { +pub async fn handle_command(command: crate::MigrateCommands, config: &Config) -> Result<()> { match command { - super::MigrateCommands::Openclaw { source, dry_run } => { + crate::MigrateCommands::Openclaw { source, dry_run } => { migrate_openclaw_memory(config, source, dry_run).await } } @@ -95,7 +95,9 @@ async fn migrate_openclaw_memory( stats.renamed_conflicts += 1; } - memory.store(&key, &entry.content, entry.category).await?; + memory + .store(&key, &entry.content, entry.category, None) + .await?; stats.imported += 1; } @@ -112,16 +114,7 @@ async fn migrate_openclaw_memory( } fn target_memory_backend(config: &Config) -> Result> { - match config.memory.backend.as_str() { - "sqlite" => Ok(Box::new(SqliteMemory::new(&config.workspace_dir)?)), - "markdown" | "none" => Ok(Box::new(MarkdownMemory::new(&config.workspace_dir))), - other => { - tracing::warn!( - "Unknown memory backend '{other}' during migration, defaulting to markdown" - ); - Ok(Box::new(MarkdownMemory::new(&config.workspace_dir))) - } - } + memory::create_memory_for_migration(&config.memory.backend, &config.workspace_dir) } fn collect_source_entries( @@ -431,6 +424,7 @@ fn backup_target_memory(workspace_dir: &Path) -> Result> { mod tests { use super::*; use crate::config::{Config, MemoryConfig}; + use crate::memory::SqliteMemory; use rusqlite::params; use tempfile::TempDir; @@ -496,7 +490,7 @@ mod tests { // Existing target memory let target_mem = SqliteMemory::new(target.path()).unwrap(); target_mem - .store("k", "new value", MemoryCategory::Core) + .store("k", "new value", MemoryCategory::Core, None) .await .unwrap(); @@ -518,7 +512,7 @@ mod tests { .await .unwrap(); - let all = target_mem.list(None).await.unwrap(); + let all = target_mem.list(None, None).await.unwrap(); assert!(all.iter().any(|e| e.key == "k" && e.content == "new value")); assert!(all .iter() @@ -550,4 +544,16 @@ mod tests { let target_mem = SqliteMemory::new(target.path()).unwrap(); assert_eq!(target_mem.count().await.unwrap(), 0); } + + #[test] + fn migration_target_rejects_none_backend() { + let target = TempDir::new().unwrap(); + let mut config = test_config(target.path()); + config.memory.backend = "none".to_string(); + + let err = target_memory_backend(&config) + .err() + .expect("backend=none should be rejected for migration target"); + assert!(err.to_string().contains("disables persistence")); + } } diff --git a/src/observability/log.rs b/src/observability/log.rs index eed4136..b932fe0 100644 --- a/src/observability/log.rs +++ b/src/observability/log.rs @@ -16,12 +16,45 @@ impl Observer for LogObserver { ObserverEvent::AgentStart { provider, model } => { info!(provider = %provider, model = %model, "agent.start"); } + ObserverEvent::LlmRequest { + provider, + model, + messages_count, + } => { + info!( + provider = %provider, + model = %model, + messages_count = messages_count, + "llm.request" + ); + } + ObserverEvent::LlmResponse { + provider, + model, + duration, + success, + error_message, + } => { + let ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX); + info!( + provider = %provider, + model = %model, + duration_ms = ms, + success = success, + error = ?error_message, + "llm.response" + ); + } ObserverEvent::AgentEnd { duration, tokens_used, + cost_usd, } => { let ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX); - info!(duration_ms = ms, tokens = ?tokens_used, "agent.end"); + info!(duration_ms = ms, tokens = ?tokens_used, cost_usd = ?cost_usd, "agent.end"); + } + ObserverEvent::ToolCallStart { tool } => { + info!(tool = %tool, "tool.start"); } ObserverEvent::ToolCall { tool, @@ -31,6 +64,9 @@ impl Observer for LogObserver { let ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX); info!(tool = %tool, duration_ms = ms, success = success, "tool.call"); } + ObserverEvent::TurnComplete => { + info!("turn.complete"); + } ObserverEvent::ChannelMessage { channel, direction } => { info!(channel = %channel, direction = %direction, "channel.message"); } @@ -83,19 +119,37 @@ mod tests { provider: "openrouter".into(), model: "claude-sonnet".into(), }); + obs.record_event(&ObserverEvent::LlmRequest { + provider: "openrouter".into(), + model: "claude-sonnet".into(), + messages_count: 2, + }); + obs.record_event(&ObserverEvent::LlmResponse { + provider: "openrouter".into(), + model: "claude-sonnet".into(), + duration: Duration::from_millis(250), + success: true, + error_message: None, + }); obs.record_event(&ObserverEvent::AgentEnd { duration: Duration::from_millis(500), tokens_used: Some(100), + cost_usd: Some(0.0015), }); obs.record_event(&ObserverEvent::AgentEnd { duration: Duration::ZERO, tokens_used: None, + cost_usd: None, + }); + obs.record_event(&ObserverEvent::ToolCallStart { + tool: "shell".into(), }); obs.record_event(&ObserverEvent::ToolCall { tool: "shell".into(), duration: Duration::from_millis(10), success: false, }); + obs.record_event(&ObserverEvent::TurnComplete); obs.record_event(&ObserverEvent::ChannelMessage { channel: "telegram".into(), direction: "outbound".into(), diff --git a/src/observability/mod.rs b/src/observability/mod.rs index 801771d..d4d75c7 100644 --- a/src/observability/mod.rs +++ b/src/observability/mod.rs @@ -1,11 +1,19 @@ pub mod log; pub mod multi; pub mod noop; +pub mod otel; pub mod traits; +pub mod verbose; +#[allow(unused_imports)] pub use self::log::LogObserver; +#[allow(unused_imports)] +pub use self::multi::MultiObserver; pub use noop::NoopObserver; +pub use otel::OtelObserver; pub use traits::{Observer, ObserverEvent}; +#[allow(unused_imports)] +pub use verbose::VerboseObserver; use crate::config::ObservabilityConfig; @@ -13,6 +21,27 @@ use crate::config::ObservabilityConfig; pub fn create_observer(config: &ObservabilityConfig) -> Box { match config.backend.as_str() { "log" => Box::new(LogObserver::new()), + "otel" | "opentelemetry" | "otlp" => { + match OtelObserver::new( + config.otel_endpoint.as_deref(), + config.otel_service_name.as_deref(), + ) { + Ok(obs) => { + tracing::info!( + endpoint = config + .otel_endpoint + .as_deref() + .unwrap_or("http://localhost:4318"), + "OpenTelemetry observer initialized" + ); + Box::new(obs) + } + Err(e) => { + tracing::error!("Failed to create OTel observer: {e}. Falling back to noop."); + Box::new(NoopObserver) + } + } + } "none" | "noop" => Box::new(NoopObserver), _ => { tracing::warn!( @@ -32,6 +61,7 @@ mod tests { fn factory_none_returns_noop() { let cfg = ObservabilityConfig { backend: "none".into(), + ..ObservabilityConfig::default() }; assert_eq!(create_observer(&cfg).name(), "noop"); } @@ -40,6 +70,7 @@ mod tests { fn factory_noop_returns_noop() { let cfg = ObservabilityConfig { backend: "noop".into(), + ..ObservabilityConfig::default() }; assert_eq!(create_observer(&cfg).name(), "noop"); } @@ -48,14 +79,46 @@ mod tests { fn factory_log_returns_log() { let cfg = ObservabilityConfig { backend: "log".into(), + ..ObservabilityConfig::default() }; assert_eq!(create_observer(&cfg).name(), "log"); } + #[test] + fn factory_otel_returns_otel() { + let cfg = ObservabilityConfig { + backend: "otel".into(), + otel_endpoint: Some("http://127.0.0.1:19999".into()), + otel_service_name: Some("test".into()), + }; + assert_eq!(create_observer(&cfg).name(), "otel"); + } + + #[test] + fn factory_opentelemetry_alias() { + let cfg = ObservabilityConfig { + backend: "opentelemetry".into(), + otel_endpoint: Some("http://127.0.0.1:19999".into()), + otel_service_name: Some("test".into()), + }; + assert_eq!(create_observer(&cfg).name(), "otel"); + } + + #[test] + fn factory_otlp_alias() { + let cfg = ObservabilityConfig { + backend: "otlp".into(), + otel_endpoint: Some("http://127.0.0.1:19999".into()), + otel_service_name: Some("test".into()), + }; + assert_eq!(create_observer(&cfg).name(), "otel"); + } + #[test] fn factory_unknown_falls_back_to_noop() { let cfg = ObservabilityConfig { - backend: "prometheus".into(), + backend: "xyzzy_unknown".into(), + ..ObservabilityConfig::default() }; assert_eq!(create_observer(&cfg).name(), "noop"); } @@ -64,6 +127,7 @@ mod tests { fn factory_empty_string_falls_back_to_noop() { let cfg = ObservabilityConfig { backend: String::new(), + ..ObservabilityConfig::default() }; assert_eq!(create_observer(&cfg).name(), "noop"); } @@ -72,6 +136,7 @@ mod tests { fn factory_garbage_falls_back_to_noop() { let cfg = ObservabilityConfig { backend: "xyzzy_garbage_123".into(), + ..ObservabilityConfig::default() }; assert_eq!(create_observer(&cfg).name(), "noop"); } diff --git a/src/observability/noop.rs b/src/observability/noop.rs index 31f3a34..004af21 100644 --- a/src/observability/noop.rs +++ b/src/observability/noop.rs @@ -33,19 +33,37 @@ mod tests { provider: "test".into(), model: "test".into(), }); + obs.record_event(&ObserverEvent::LlmRequest { + provider: "test".into(), + model: "test".into(), + messages_count: 2, + }); + obs.record_event(&ObserverEvent::LlmResponse { + provider: "test".into(), + model: "test".into(), + duration: Duration::from_millis(1), + success: true, + error_message: None, + }); obs.record_event(&ObserverEvent::AgentEnd { duration: Duration::from_millis(100), tokens_used: Some(42), + cost_usd: Some(0.001), }); obs.record_event(&ObserverEvent::AgentEnd { duration: Duration::ZERO, tokens_used: None, + cost_usd: None, + }); + obs.record_event(&ObserverEvent::ToolCallStart { + tool: "shell".into(), }); obs.record_event(&ObserverEvent::ToolCall { tool: "shell".into(), duration: Duration::from_secs(1), success: true, }); + obs.record_event(&ObserverEvent::TurnComplete); obs.record_event(&ObserverEvent::ChannelMessage { channel: "cli".into(), direction: "inbound".into(), diff --git a/src/observability/otel.rs b/src/observability/otel.rs new file mode 100644 index 0000000..ae4932d --- /dev/null +++ b/src/observability/otel.rs @@ -0,0 +1,449 @@ +use super::traits::{Observer, ObserverEvent, ObserverMetric}; +use opentelemetry::metrics::{Counter, Gauge, Histogram}; +use opentelemetry::trace::{Span, SpanKind, Status, Tracer}; +use opentelemetry::{global, KeyValue}; +use opentelemetry_otlp::WithExportConfig; +use opentelemetry_sdk::metrics::SdkMeterProvider; +use opentelemetry_sdk::trace::SdkTracerProvider; +use std::time::SystemTime; + +/// OpenTelemetry-backed observer — exports traces and metrics via OTLP. +pub struct OtelObserver { + tracer_provider: SdkTracerProvider, + meter_provider: SdkMeterProvider, + + // Metrics instruments + agent_starts: Counter, + agent_duration: Histogram, + llm_calls: Counter, + llm_duration: Histogram, + tool_calls: Counter, + tool_duration: Histogram, + channel_messages: Counter, + heartbeat_ticks: Counter, + errors: Counter, + request_latency: Histogram, + tokens_used: Counter, + active_sessions: Gauge, + queue_depth: Gauge, +} + +impl OtelObserver { + /// Create a new OTel observer exporting to the given OTLP endpoint. + /// + /// Uses HTTP/protobuf transport (port 4318 by default). + /// Falls back to `http://localhost:4318` if no endpoint is provided. + pub fn new(endpoint: Option<&str>, service_name: Option<&str>) -> Result { + let endpoint = endpoint.unwrap_or("http://localhost:4318"); + let service_name = service_name.unwrap_or("zeroclaw"); + + // ── Trace exporter ────────────────────────────────────── + let span_exporter = opentelemetry_otlp::SpanExporter::builder() + .with_http() + .with_endpoint(endpoint) + .build() + .map_err(|e| format!("Failed to create OTLP span exporter: {e}"))?; + + let tracer_provider = SdkTracerProvider::builder() + .with_batch_exporter(span_exporter) + .with_resource( + opentelemetry_sdk::Resource::builder() + .with_service_name(service_name.to_string()) + .build(), + ) + .build(); + + global::set_tracer_provider(tracer_provider.clone()); + + // ── Metric exporter ───────────────────────────────────── + let metric_exporter = opentelemetry_otlp::MetricExporter::builder() + .with_http() + .with_endpoint(endpoint) + .build() + .map_err(|e| format!("Failed to create OTLP metric exporter: {e}"))?; + + let metric_reader = + opentelemetry_sdk::metrics::PeriodicReader::builder(metric_exporter).build(); + + let meter_provider = opentelemetry_sdk::metrics::SdkMeterProvider::builder() + .with_reader(metric_reader) + .with_resource( + opentelemetry_sdk::Resource::builder() + .with_service_name(service_name.to_string()) + .build(), + ) + .build(); + + let meter_provider_clone = meter_provider.clone(); + global::set_meter_provider(meter_provider); + + // ── Create metric instruments ──────────────────────────── + let meter = global::meter("zeroclaw"); + + let agent_starts = meter + .u64_counter("zeroclaw.agent.starts") + .with_description("Total agent invocations") + .build(); + + let agent_duration = meter + .f64_histogram("zeroclaw.agent.duration") + .with_description("Agent invocation duration in seconds") + .with_unit("s") + .build(); + + let llm_calls = meter + .u64_counter("zeroclaw.llm.calls") + .with_description("Total LLM provider calls") + .build(); + + let llm_duration = meter + .f64_histogram("zeroclaw.llm.duration") + .with_description("LLM provider call duration in seconds") + .with_unit("s") + .build(); + + let tool_calls = meter + .u64_counter("zeroclaw.tool.calls") + .with_description("Total tool calls") + .build(); + + let tool_duration = meter + .f64_histogram("zeroclaw.tool.duration") + .with_description("Tool execution duration in seconds") + .with_unit("s") + .build(); + + let channel_messages = meter + .u64_counter("zeroclaw.channel.messages") + .with_description("Total channel messages") + .build(); + + let heartbeat_ticks = meter + .u64_counter("zeroclaw.heartbeat.ticks") + .with_description("Total heartbeat ticks") + .build(); + + let errors = meter + .u64_counter("zeroclaw.errors") + .with_description("Total errors by component") + .build(); + + let request_latency = meter + .f64_histogram("zeroclaw.request.latency") + .with_description("Request latency in seconds") + .with_unit("s") + .build(); + + let tokens_used = meter + .u64_counter("zeroclaw.tokens.used") + .with_description("Total tokens consumed (monotonic)") + .build(); + + let active_sessions = meter + .u64_gauge("zeroclaw.sessions.active") + .with_description("Current number of active sessions") + .build(); + + let queue_depth = meter + .u64_gauge("zeroclaw.queue.depth") + .with_description("Current message queue depth") + .build(); + + Ok(Self { + tracer_provider, + meter_provider: meter_provider_clone, + agent_starts, + agent_duration, + llm_calls, + llm_duration, + tool_calls, + tool_duration, + channel_messages, + heartbeat_ticks, + errors, + request_latency, + tokens_used, + active_sessions, + queue_depth, + }) + } +} + +impl Observer for OtelObserver { + fn record_event(&self, event: &ObserverEvent) { + let tracer = global::tracer("zeroclaw"); + + match event { + ObserverEvent::AgentStart { provider, model } => { + self.agent_starts.add( + 1, + &[ + KeyValue::new("provider", provider.clone()), + KeyValue::new("model", model.clone()), + ], + ); + } + ObserverEvent::LlmRequest { .. } + | ObserverEvent::ToolCallStart { .. } + | ObserverEvent::TurnComplete => {} + ObserverEvent::LlmResponse { + provider, + model, + duration, + success, + error_message: _, + } => { + let secs = duration.as_secs_f64(); + let attrs = [ + KeyValue::new("provider", provider.clone()), + KeyValue::new("model", model.clone()), + KeyValue::new("success", success.to_string()), + ]; + self.llm_calls.add(1, &attrs); + self.llm_duration.record(secs, &attrs); + + // Create a completed span for visibility in trace backends. + let start_time = SystemTime::now() + .checked_sub(*duration) + .unwrap_or(SystemTime::now()); + let mut span = tracer.build( + opentelemetry::trace::SpanBuilder::from_name("llm.call") + .with_kind(SpanKind::Internal) + .with_start_time(start_time) + .with_attributes(vec![ + KeyValue::new("provider", provider.clone()), + KeyValue::new("model", model.clone()), + KeyValue::new("success", *success), + KeyValue::new("duration_s", secs), + ]), + ); + if *success { + span.set_status(Status::Ok); + } else { + span.set_status(Status::error("")); + } + span.end(); + } + ObserverEvent::AgentEnd { + duration, + tokens_used, + cost_usd, + } => { + let secs = duration.as_secs_f64(); + let start_time = SystemTime::now() + .checked_sub(*duration) + .unwrap_or(SystemTime::now()); + + // Create a completed span with correct timing + let mut span = tracer.build( + opentelemetry::trace::SpanBuilder::from_name("agent.invocation") + .with_kind(SpanKind::Internal) + .with_start_time(start_time) + .with_attributes(vec![KeyValue::new("duration_s", secs)]), + ); + if let Some(t) = tokens_used { + span.set_attribute(KeyValue::new("tokens_used", *t as i64)); + } + if let Some(c) = cost_usd { + span.set_attribute(KeyValue::new("cost_usd", *c)); + } + span.end(); + + self.agent_duration.record(secs, &[]); + // Note: tokens are recorded via record_metric(TokensUsed) to avoid + // double-counting. AgentEnd only records duration. + } + ObserverEvent::ToolCall { + tool, + duration, + success, + } => { + let secs = duration.as_secs_f64(); + let start_time = SystemTime::now() + .checked_sub(*duration) + .unwrap_or(SystemTime::now()); + + let status = if *success { + Status::Ok + } else { + Status::error("") + }; + + let mut span = tracer.build( + opentelemetry::trace::SpanBuilder::from_name("tool.call") + .with_kind(SpanKind::Internal) + .with_start_time(start_time) + .with_attributes(vec![ + KeyValue::new("tool.name", tool.clone()), + KeyValue::new("tool.success", *success), + KeyValue::new("duration_s", secs), + ]), + ); + span.set_status(status); + span.end(); + + let attrs = [ + KeyValue::new("tool", tool.clone()), + KeyValue::new("success", success.to_string()), + ]; + self.tool_calls.add(1, &attrs); + self.tool_duration + .record(secs, &[KeyValue::new("tool", tool.clone())]); + } + ObserverEvent::ChannelMessage { channel, direction } => { + self.channel_messages.add( + 1, + &[ + KeyValue::new("channel", channel.clone()), + KeyValue::new("direction", direction.clone()), + ], + ); + } + ObserverEvent::HeartbeatTick => { + self.heartbeat_ticks.add(1, &[]); + } + ObserverEvent::Error { component, message } => { + // Create an error span for visibility in trace backends + let mut span = tracer.build( + opentelemetry::trace::SpanBuilder::from_name("error") + .with_kind(SpanKind::Internal) + .with_attributes(vec![ + KeyValue::new("component", component.clone()), + KeyValue::new("error.message", message.clone()), + ]), + ); + span.set_status(Status::error(message.clone())); + span.end(); + + self.errors + .add(1, &[KeyValue::new("component", component.clone())]); + } + } + } + + fn record_metric(&self, metric: &ObserverMetric) { + match metric { + ObserverMetric::RequestLatency(d) => { + self.request_latency.record(d.as_secs_f64(), &[]); + } + ObserverMetric::TokensUsed(t) => { + self.tokens_used.add(*t as u64, &[]); + } + ObserverMetric::ActiveSessions(s) => { + self.active_sessions.record(*s as u64, &[]); + } + ObserverMetric::QueueDepth(d) => { + self.queue_depth.record(*d as u64, &[]); + } + } + } + + fn flush(&self) { + if let Err(e) = self.tracer_provider.force_flush() { + tracing::warn!("OTel trace flush failed: {e}"); + } + if let Err(e) = self.meter_provider.force_flush() { + tracing::warn!("OTel metric flush failed: {e}"); + } + } + + fn name(&self) -> &str { + "otel" + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::time::Duration; + + // Note: OtelObserver::new() requires an OTLP endpoint. + // In tests we verify the struct creation fails gracefully + // when no collector is available, and test the observer interface + // by constructing with a known-unreachable endpoint (spans/metrics + // are buffered and exported asynchronously, so recording never panics). + + fn test_observer() -> OtelObserver { + // Create with a dummy endpoint — exports will silently fail + // but the observer itself works fine for recording + OtelObserver::new(Some("http://127.0.0.1:19999"), Some("zeroclaw-test")) + .expect("observer creation should not fail with valid endpoint format") + } + + #[test] + fn otel_observer_name() { + let obs = test_observer(); + assert_eq!(obs.name(), "otel"); + } + + #[test] + fn records_all_events_without_panic() { + let obs = test_observer(); + obs.record_event(&ObserverEvent::AgentStart { + provider: "openrouter".into(), + model: "claude-sonnet".into(), + }); + obs.record_event(&ObserverEvent::LlmRequest { + provider: "openrouter".into(), + model: "claude-sonnet".into(), + messages_count: 2, + }); + obs.record_event(&ObserverEvent::LlmResponse { + provider: "openrouter".into(), + model: "claude-sonnet".into(), + duration: Duration::from_millis(250), + success: true, + error_message: None, + }); + obs.record_event(&ObserverEvent::AgentEnd { + duration: Duration::from_millis(500), + tokens_used: Some(100), + cost_usd: Some(0.0015), + }); + obs.record_event(&ObserverEvent::AgentEnd { + duration: Duration::ZERO, + tokens_used: None, + cost_usd: None, + }); + obs.record_event(&ObserverEvent::ToolCallStart { + tool: "shell".into(), + }); + obs.record_event(&ObserverEvent::ToolCall { + tool: "shell".into(), + duration: Duration::from_millis(10), + success: true, + }); + obs.record_event(&ObserverEvent::ToolCall { + tool: "file_read".into(), + duration: Duration::from_millis(5), + success: false, + }); + obs.record_event(&ObserverEvent::TurnComplete); + obs.record_event(&ObserverEvent::ChannelMessage { + channel: "telegram".into(), + direction: "inbound".into(), + }); + obs.record_event(&ObserverEvent::HeartbeatTick); + obs.record_event(&ObserverEvent::Error { + component: "provider".into(), + message: "timeout".into(), + }); + } + + #[test] + fn records_all_metrics_without_panic() { + let obs = test_observer(); + obs.record_metric(&ObserverMetric::RequestLatency(Duration::from_secs(2))); + obs.record_metric(&ObserverMetric::TokensUsed(500)); + obs.record_metric(&ObserverMetric::TokensUsed(0)); + obs.record_metric(&ObserverMetric::ActiveSessions(3)); + obs.record_metric(&ObserverMetric::QueueDepth(42)); + } + + #[test] + fn flush_does_not_panic() { + let obs = test_observer(); + obs.record_event(&ObserverEvent::HeartbeatTick); + obs.flush(); + } +} diff --git a/src/observability/traits.rs b/src/observability/traits.rs index 84472e2..d978304 100644 --- a/src/observability/traits.rs +++ b/src/observability/traits.rs @@ -7,15 +7,39 @@ pub enum ObserverEvent { provider: String, model: String, }, + /// A request is about to be sent to an LLM provider. + /// + /// This is emitted immediately before a provider call so observers can print + /// user-facing progress without leaking prompt contents. + LlmRequest { + provider: String, + model: String, + messages_count: usize, + }, + /// Result of a single LLM provider call. + LlmResponse { + provider: String, + model: String, + duration: Duration, + success: bool, + error_message: Option, + }, AgentEnd { duration: Duration, tokens_used: Option, + cost_usd: Option, + }, + /// A tool call is about to be executed. + ToolCallStart { + tool: String, }, ToolCall { tool: String, duration: Duration, success: bool, }, + /// The agent produced a final answer for the current user message. + TurnComplete, ChannelMessage { channel: String, direction: String, @@ -37,7 +61,7 @@ pub enum ObserverMetric { } /// Core observability trait — implement for any backend -pub trait Observer: Send + Sync { +pub trait Observer: Send + Sync + 'static { /// Record a discrete event fn record_event(&self, event: &ObserverEvent); @@ -49,4 +73,81 @@ pub trait Observer: Send + Sync { /// Human-readable name of this observer fn name(&self) -> &str; + + /// Downcast to `Any` for backend-specific operations + fn as_any(&self) -> &dyn std::any::Any + where + Self: Sized, + { + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + use parking_lot::Mutex; + use std::time::Duration; + + #[derive(Default)] + struct DummyObserver { + events: Mutex, + metrics: Mutex, + } + + impl Observer for DummyObserver { + fn record_event(&self, _event: &ObserverEvent) { + let mut guard = self.events.lock(); + *guard += 1; + } + + fn record_metric(&self, _metric: &ObserverMetric) { + let mut guard = self.metrics.lock(); + *guard += 1; + } + + fn name(&self) -> &str { + "dummy-observer" + } + } + + #[test] + fn observer_records_events_and_metrics() { + let observer = DummyObserver::default(); + + observer.record_event(&ObserverEvent::HeartbeatTick); + observer.record_event(&ObserverEvent::Error { + component: "test".into(), + message: "boom".into(), + }); + observer.record_metric(&ObserverMetric::TokensUsed(42)); + + assert_eq!(*observer.events.lock(), 2); + assert_eq!(*observer.metrics.lock(), 1); + } + + #[test] + fn observer_default_flush_and_as_any_work() { + let observer = DummyObserver::default(); + + observer.flush(); + assert_eq!(observer.name(), "dummy-observer"); + assert!(observer.as_any().downcast_ref::().is_some()); + } + + #[test] + fn observer_event_and_metric_are_cloneable() { + let event = ObserverEvent::ToolCall { + tool: "shell".into(), + duration: Duration::from_millis(10), + success: true, + }; + let metric = ObserverMetric::RequestLatency(Duration::from_millis(8)); + + let cloned_event = event.clone(); + let cloned_metric = metric.clone(); + + assert!(matches!(cloned_event, ObserverEvent::ToolCall { .. })); + assert!(matches!(cloned_metric, ObserverMetric::RequestLatency(_))); + } } diff --git a/src/observability/verbose.rs b/src/observability/verbose.rs new file mode 100644 index 0000000..364be1e --- /dev/null +++ b/src/observability/verbose.rs @@ -0,0 +1,96 @@ +use super::traits::{Observer, ObserverEvent, ObserverMetric}; + +/// Human-readable progress observer for interactive CLI sessions. +/// +/// This observer prints compact `>` / `<` progress lines without exposing +/// prompt contents. It is intended to be opt-in (e.g. `--verbose`). +pub struct VerboseObserver; + +impl VerboseObserver { + pub fn new() -> Self { + Self + } +} + +impl Observer for VerboseObserver { + fn record_event(&self, event: &ObserverEvent) { + match event { + ObserverEvent::LlmRequest { + provider, + model, + messages_count, + } => { + eprintln!("> Thinking"); + eprintln!( + "> Send (provider={}, model={}, messages={})", + provider, model, messages_count + ); + } + ObserverEvent::LlmResponse { + duration, success, .. + } => { + let ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX); + eprintln!("< Receive (success={success}, duration_ms={ms})"); + } + ObserverEvent::ToolCallStart { tool } => { + eprintln!("> Tool {tool}"); + } + ObserverEvent::ToolCall { + tool, + duration, + success, + } => { + let ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX); + eprintln!("< Tool {tool} (success={success}, duration_ms={ms})"); + } + ObserverEvent::TurnComplete => { + eprintln!("< Complete"); + } + _ => {} + } + } + + #[inline(always)] + fn record_metric(&self, _metric: &ObserverMetric) {} + + fn name(&self) -> &str { + "verbose" + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::time::Duration; + + #[test] + fn verbose_name() { + assert_eq!(VerboseObserver::new().name(), "verbose"); + } + + #[test] + fn verbose_events_do_not_panic() { + let obs = VerboseObserver::new(); + obs.record_event(&ObserverEvent::LlmRequest { + provider: "openrouter".into(), + model: "claude".into(), + messages_count: 3, + }); + obs.record_event(&ObserverEvent::LlmResponse { + provider: "openrouter".into(), + model: "claude".into(), + duration: Duration::from_millis(12), + success: true, + error_message: None, + }); + obs.record_event(&ObserverEvent::ToolCallStart { + tool: "shell".into(), + }); + obs.record_event(&ObserverEvent::ToolCall { + tool: "shell".into(), + duration: Duration::from_millis(2), + success: true, + }); + obs.record_event(&ObserverEvent::TurnComplete); + } +} diff --git a/src/onboard/mod.rs b/src/onboard/mod.rs index a18ce8a..5117897 100644 --- a/src/onboard/mod.rs +++ b/src/onboard/mod.rs @@ -1,3 +1,18 @@ pub mod wizard; -pub use wizard::{run_channels_repair_wizard, run_quick_setup, run_wizard}; +pub use wizard::{run_channels_repair_wizard, run_models_refresh, run_quick_setup, run_wizard}; + +#[cfg(test)] +mod tests { + use super::*; + + fn assert_reexport_exists(_value: F) {} + + #[test] + fn wizard_functions_are_reexported() { + assert_reexport_exists(run_wizard); + assert_reexport_exists(run_channels_repair_wizard); + assert_reexport_exists(run_quick_setup); + assert_reexport_exists(run_models_refresh); + } +} diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index 3d38b27..38847fa 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -1,14 +1,26 @@ -use crate::config::schema::WhatsAppConfig; +use crate::config::schema::{DingTalkConfig, IrcConfig, QQConfig, WhatsAppConfig}; use crate::config::{ AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, DiscordConfig, HeartbeatConfig, IMessageConfig, MatrixConfig, MemoryConfig, ObservabilityConfig, RuntimeConfig, SecretsConfig, SlackConfig, TelegramConfig, WebhookConfig, }; -use anyhow::{Context, Result}; +use crate::hardware::{self, HardwareConfig}; +use crate::memory::{ + default_memory_backend_key, memory_backend_profile, selectable_memory_backends, +}; +use crate::providers::{ + canonical_china_provider_name, is_glm_alias, is_glm_cn_alias, is_minimax_alias, + is_moonshot_alias, is_qianfan_alias, is_qwen_alias, is_zai_alias, is_zai_cn_alias, +}; +use anyhow::{bail, Context, Result}; use console::style; use dialoguer::{Confirm, Input, Select}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::BTreeSet; use std::fs; use std::path::{Path, PathBuf}; +use std::time::Duration; // ── Project context collected during wizard ────────────────────── @@ -38,6 +50,12 @@ const BANNER: &str = r" ⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡ "; +const LIVE_MODEL_MAX_OPTIONS: usize = 120; +const MODEL_PREVIEW_LIMIT: usize = 20; +const MODEL_CACHE_FILE: &str = "models_cache.json"; +const MODEL_CACHE_TTL_SECS: u64 = 12 * 60 * 60; +const CUSTOM_MODEL_SENTINEL: &str = "__custom_model__"; + // ── Main wizard entry point ────────────────────────────────────── pub fn run_wizard() -> Result { @@ -55,28 +73,31 @@ pub fn run_wizard() -> Result { ); println!(); - print_step(1, 8, "Workspace Setup"); + print_step(1, 9, "Workspace Setup"); let (workspace_dir, config_path) = setup_workspace()?; - print_step(2, 8, "AI Provider & API Key"); - let (provider, api_key, model) = setup_provider()?; + print_step(2, 9, "AI Provider & API Key"); + let (provider, api_key, model, provider_api_url) = setup_provider(&workspace_dir)?; - print_step(3, 8, "Channels (How You Talk to ZeroClaw)"); + print_step(3, 9, "Channels (How You Talk to ZeroClaw)"); let channels_config = setup_channels()?; - print_step(4, 8, "Tunnel (Expose to Internet)"); + print_step(4, 9, "Tunnel (Expose to Internet)"); let tunnel_config = setup_tunnel()?; - print_step(5, 8, "Tool Mode & Security"); + print_step(5, 9, "Tool Mode & Security"); let (composio_config, secrets_config) = setup_tool_mode()?; - print_step(6, 8, "Memory Configuration"); + print_step(6, 9, "Hardware (Physical World)"); + let hardware_config = setup_hardware()?; + + print_step(7, 9, "Memory Configuration"); let memory_config = setup_memory()?; - print_step(7, 8, "Project Context (Personalize Your Agent)"); + print_step(8, 9, "Project Context (Personalize Your Agent)"); let project_ctx = setup_project_context()?; - print_step(8, 8, "Workspace Files"); + print_step(9, 9, "Workspace Files"); scaffold_workspace(&workspace_dir, &project_ctx)?; // ── Build config ── @@ -89,6 +110,7 @@ pub fn run_wizard() -> Result { } else { Some(api_key) }, + api_url: provider_api_url, default_provider: Some(provider), default_model: Some(model), default_temperature: 0.7, @@ -96,7 +118,11 @@ pub fn run_wizard() -> Result { autonomy: AutonomyConfig::default(), runtime: RuntimeConfig::default(), reliability: crate::config::ReliabilityConfig::default(), + scheduler: crate::config::schema::SchedulerConfig::default(), + agent: crate::config::schema::AgentConfig::default(), + model_routes: Vec::new(), heartbeat: HeartbeatConfig::default(), + cron: crate::config::CronConfig::default(), channels_config, memory: memory_config, // User-selected memory backend tunnel: tunnel_config, @@ -104,7 +130,12 @@ pub fn run_wizard() -> Result { composio: composio_config, secrets: secrets_config, browser: BrowserConfig::default(), + http_request: crate::config::HttpRequestConfig::default(), identity: crate::config::IdentityConfig::default(), + cost: crate::config::CostConfig::default(), + peripherals: crate::config::PeripheralsConfig::default(), + agents: std::collections::HashMap::new(), + hardware: hardware_config, }; println!( @@ -120,6 +151,7 @@ pub fn run_wizard() -> Result { ); config.save()?; + persist_workspace_selection(&config.config_path)?; // ── Final summary ──────────────────────────────────────────── print_summary(&config); @@ -129,7 +161,10 @@ pub fn run_wizard() -> Result { || config.channels_config.discord.is_some() || config.channels_config.slack.is_some() || config.channels_config.imessage.is_some() - || config.channels_config.matrix.is_some(); + || config.channels_config.matrix.is_some() + || config.channels_config.email.is_some() + || config.channels_config.dingtalk.is_some() + || config.channels_config.qq.is_some(); if has_channels && config.api_key.is_some() { let launch: bool = Confirm::new() @@ -172,6 +207,7 @@ pub fn run_channels_repair_wizard() -> Result { print_step(1, 1, "Channels (How You Talk to ZeroClaw)"); config.channels_config = setup_channels()?; config.save()?; + persist_workspace_selection(&config.config_path)?; println!(); println!( @@ -184,7 +220,10 @@ pub fn run_channels_repair_wizard() -> Result { || config.channels_config.discord.is_some() || config.channels_config.slack.is_some() || config.channels_config.imessage.is_some() - || config.channels_config.matrix.is_some(); + || config.channels_config.matrix.is_some() + || config.channels_config.email.is_some() + || config.channels_config.dingtalk.is_some() + || config.channels_config.qq.is_some(); if has_channels && config.api_key.is_some() { let launch: bool = Confirm::new() @@ -214,11 +253,47 @@ pub fn run_channels_repair_wizard() -> Result { // ── Quick setup (zero prompts) ─────────────────────────────────── /// Non-interactive setup: generates a sensible default config instantly. -/// Use `zeroclaw onboard` or `zeroclaw onboard --api-key sk-... --provider openrouter --memory sqlite`. +/// Use `zeroclaw onboard` or `zeroclaw onboard --api-key sk-... --provider openrouter --memory sqlite|lucid`. /// Use `zeroclaw onboard --interactive` for the full wizard. +fn backend_key_from_choice(choice: usize) -> &'static str { + selectable_memory_backends() + .get(choice) + .map_or(default_memory_backend_key(), |backend| backend.key) +} + +fn memory_config_defaults_for_backend(backend: &str) -> MemoryConfig { + let profile = memory_backend_profile(backend); + + MemoryConfig { + backend: backend.to_string(), + auto_save: profile.auto_save_default, + hygiene_enabled: profile.uses_sqlite_hygiene, + archive_after_days: if profile.uses_sqlite_hygiene { 7 } else { 0 }, + purge_after_days: if profile.uses_sqlite_hygiene { 30 } else { 0 }, + conversation_retention_days: 30, + embedding_provider: "none".to_string(), + embedding_model: "text-embedding-3-small".to_string(), + embedding_dimensions: 1536, + vector_weight: 0.7, + keyword_weight: 0.3, + embedding_cache_size: if profile.uses_sqlite_hygiene { + 10000 + } else { + 0 + }, + chunk_max_tokens: 512, + response_cache_enabled: false, + response_cache_ttl_minutes: 60, + response_cache_max_entries: 5_000, + snapshot_enabled: false, + snapshot_on_hygiene: false, + auto_hydrate: true, + } +} + #[allow(clippy::too_many_lines)] pub fn run_quick_setup( - api_key: Option<&str>, + credential_override: Option<&str>, provider: Option<&str>, memory_backend: Option<&str>, ) -> Result { @@ -242,41 +317,18 @@ pub fn run_quick_setup( let provider_name = provider.unwrap_or("openrouter").to_string(); let model = default_model_for_provider(&provider_name); - let memory_backend_name = memory_backend.unwrap_or("sqlite").to_string(); + let memory_backend_name = memory_backend + .unwrap_or(default_memory_backend_key()) + .to_string(); // Create memory config based on backend choice - let memory_config = MemoryConfig { - backend: memory_backend_name.clone(), - auto_save: memory_backend_name != "none", - hygiene_enabled: memory_backend_name == "sqlite", - archive_after_days: if memory_backend_name == "sqlite" { - 7 - } else { - 0 - }, - purge_after_days: if memory_backend_name == "sqlite" { - 30 - } else { - 0 - }, - conversation_retention_days: 30, - embedding_provider: "none".to_string(), - embedding_model: "text-embedding-3-small".to_string(), - embedding_dimensions: 1536, - vector_weight: 0.7, - keyword_weight: 0.3, - embedding_cache_size: if memory_backend_name == "sqlite" { - 10000 - } else { - 0 - }, - chunk_max_tokens: 512, - }; + let memory_config = memory_config_defaults_for_backend(&memory_backend_name); let config = Config { workspace_dir: workspace_dir.clone(), config_path: config_path.clone(), - api_key: api_key.map(String::from), + api_key: credential_override.map(String::from), + api_url: None, default_provider: Some(provider_name.clone()), default_model: Some(model.clone()), default_temperature: 0.7, @@ -284,7 +336,11 @@ pub fn run_quick_setup( autonomy: AutonomyConfig::default(), runtime: RuntimeConfig::default(), reliability: crate::config::ReliabilityConfig::default(), + scheduler: crate::config::schema::SchedulerConfig::default(), + agent: crate::config::schema::AgentConfig::default(), + model_routes: Vec::new(), heartbeat: HeartbeatConfig::default(), + cron: crate::config::CronConfig::default(), channels_config: ChannelsConfig::default(), memory: memory_config, tunnel: crate::config::TunnelConfig::default(), @@ -292,10 +348,16 @@ pub fn run_quick_setup( composio: ComposioConfig::default(), secrets: SecretsConfig::default(), browser: BrowserConfig::default(), + http_request: crate::config::HttpRequestConfig::default(), identity: crate::config::IdentityConfig::default(), + cost: crate::config::CostConfig::default(), + peripherals: crate::config::PeripheralsConfig::default(), + agents: std::collections::HashMap::new(), + hardware: crate::config::HardwareConfig::default(), }; config.save()?; + persist_workspace_selection(&config.config_path)?; // Scaffold minimal workspace files let default_ctx = ProjectContext { @@ -326,7 +388,7 @@ pub fn run_quick_setup( println!( " {} API Key: {}", style("✓").green().bold(), - if api_key.is_some() { + if credential_override.is_some() { style("set").green() } else { style("not set (use --api-key or edit config.toml)").yellow() @@ -375,7 +437,7 @@ pub fn run_quick_setup( ); println!(); println!(" {}", style("Next steps:").white().bold()); - if api_key.is_none() { + if credential_override.is_none() { println!(" 1. Set your API key: export OPENROUTER_API_KEY=\"sk-...\""); println!(" 2. Or edit: ~/.zeroclaw/config.toml"); println!(" 3. Chat: zeroclaw agent -m \"Hello!\""); @@ -390,16 +452,818 @@ pub fn run_quick_setup( Ok(config) } +fn canonical_provider_name(provider_name: &str) -> &str { + if let Some(canonical) = canonical_china_provider_name(provider_name) { + return canonical; + } + + match provider_name { + "grok" => "xai", + "together" => "together-ai", + "google" | "google-gemini" => "gemini", + _ => provider_name, + } +} + /// Pick a sensible default model for the given provider. +const MINIMAX_ONBOARD_MODELS: [(&str, &str); 5] = [ + ("MiniMax-M2.5", "MiniMax M2.5 (latest, recommended)"), + ("MiniMax-M2.5-highspeed", "MiniMax M2.5 High-Speed (faster)"), + ("MiniMax-M2.1", "MiniMax M2.1 (stable)"), + ("MiniMax-M2.1-highspeed", "MiniMax M2.1 High-Speed (faster)"), + ("MiniMax-M2", "MiniMax M2 (legacy)"), +]; + fn default_model_for_provider(provider: &str) -> String { - match provider { - "anthropic" => "claude-sonnet-4-20250514".into(), - "openai" => "gpt-4o".into(), + match canonical_provider_name(provider) { + "anthropic" => "claude-sonnet-4-5-20250929".into(), + "openai" => "gpt-5.2".into(), + "glm" | "zai" => "glm-5".into(), + "minimax" => "MiniMax-M2.5".into(), + "qwen" => "qwen-plus".into(), "ollama" => "llama3.2".into(), "groq" => "llama-3.3-70b-versatile".into(), "deepseek" => "deepseek-chat".into(), - "glm" | "zhipu" => "glm-4.7".into(), - _ => "anthropic/claude-sonnet-4-20250514".into(), + "gemini" => "gemini-2.5-pro".into(), + _ => "anthropic/claude-sonnet-4.5".into(), + } +} + +fn curated_models_for_provider(provider_name: &str) -> Vec<(String, String)> { + match canonical_provider_name(provider_name) { + "openrouter" => vec![ + ( + "anthropic/claude-sonnet-4.5".to_string(), + "Claude Sonnet 4.5 (balanced, recommended)".to_string(), + ), + ( + "openai/gpt-5.2".to_string(), + "GPT-5.2 (latest flagship)".to_string(), + ), + ( + "openai/gpt-5-mini".to_string(), + "GPT-5 mini (fast, cost-efficient)".to_string(), + ), + ( + "google/gemini-3-pro-preview".to_string(), + "Gemini 3 Pro Preview (frontier reasoning)".to_string(), + ), + ( + "x-ai/grok-4.1-fast".to_string(), + "Grok 4.1 Fast (reasoning + speed)".to_string(), + ), + ( + "deepseek/deepseek-v3.2".to_string(), + "DeepSeek V3.2 (agentic + affordable)".to_string(), + ), + ( + "meta-llama/llama-4-maverick".to_string(), + "Llama 4 Maverick (open model)".to_string(), + ), + ], + "anthropic" => vec![ + ( + "claude-sonnet-4-5-20250929".to_string(), + "Claude Sonnet 4.5 (balanced, recommended)".to_string(), + ), + ( + "claude-opus-4-6".to_string(), + "Claude Opus 4.6 (best quality)".to_string(), + ), + ( + "claude-haiku-4-5-20251001".to_string(), + "Claude Haiku 4.5 (fastest, cheapest)".to_string(), + ), + ], + "openai" => vec![ + ( + "gpt-5.2".to_string(), + "GPT-5.2 (latest coding/agentic flagship)".to_string(), + ), + ( + "gpt-5-mini".to_string(), + "GPT-5 mini (faster, cheaper)".to_string(), + ), + ( + "gpt-5-nano".to_string(), + "GPT-5 nano (lowest latency/cost)".to_string(), + ), + ( + "gpt-5.2-codex".to_string(), + "GPT-5.2 Codex (agentic coding)".to_string(), + ), + ], + "venice" => vec![ + ( + "llama-3.3-70b".to_string(), + "Llama 3.3 70B (default, fast)".to_string(), + ), + ( + "claude-opus-45".to_string(), + "Claude Opus 4.5 via Venice (strongest)".to_string(), + ), + ( + "llama-3.1-405b".to_string(), + "Llama 3.1 405B (largest open source)".to_string(), + ), + ], + "groq" => vec![ + ( + "llama-3.3-70b-versatile".to_string(), + "Llama 3.3 70B (fast, recommended)".to_string(), + ), + ( + "openai/gpt-oss-120b".to_string(), + "GPT-OSS 120B (strong open-weight)".to_string(), + ), + ( + "openai/gpt-oss-20b".to_string(), + "GPT-OSS 20B (cost-efficient open-weight)".to_string(), + ), + ], + "mistral" => vec![ + ( + "mistral-large-latest".to_string(), + "Mistral Large (latest flagship)".to_string(), + ), + ( + "mistral-medium-latest".to_string(), + "Mistral Medium (balanced)".to_string(), + ), + ( + "codestral-latest".to_string(), + "Codestral (code-focused)".to_string(), + ), + ( + "devstral-latest".to_string(), + "Devstral (software engineering specialist)".to_string(), + ), + ], + "deepseek" => vec![ + ( + "deepseek-chat".to_string(), + "DeepSeek Chat (mapped to V3.2 non-thinking)".to_string(), + ), + ( + "deepseek-reasoner".to_string(), + "DeepSeek Reasoner (mapped to V3.2 thinking)".to_string(), + ), + ], + "xai" => vec![ + ( + "grok-4-1-fast-reasoning".to_string(), + "Grok 4.1 Fast Reasoning (recommended)".to_string(), + ), + ( + "grok-4-1-fast-non-reasoning".to_string(), + "Grok 4.1 Fast Non-Reasoning (low latency)".to_string(), + ), + ( + "grok-code-fast-1".to_string(), + "Grok Code Fast 1 (coding specialist)".to_string(), + ), + ("grok-4".to_string(), "Grok 4 (max quality)".to_string()), + ], + "perplexity" => vec![ + ( + "sonar-pro".to_string(), + "Sonar Pro (flagship web-grounded model)".to_string(), + ), + ( + "sonar-reasoning-pro".to_string(), + "Sonar Reasoning Pro (complex multi-step reasoning)".to_string(), + ), + ( + "sonar-deep-research".to_string(), + "Sonar Deep Research (long-form research)".to_string(), + ), + ("sonar".to_string(), "Sonar (search, fast)".to_string()), + ], + "fireworks" => vec![ + ( + "accounts/fireworks/models/llama-v3p3-70b-instruct".to_string(), + "Llama 3.3 70B".to_string(), + ), + ( + "accounts/fireworks/models/mixtral-8x22b-instruct".to_string(), + "Mixtral 8x22B".to_string(), + ), + ], + "together-ai" => vec![ + ( + "meta-llama/Llama-3.3-70B-Instruct-Turbo".to_string(), + "Llama 3.3 70B Instruct Turbo (recommended)".to_string(), + ), + ( + "moonshotai/Kimi-K2.5".to_string(), + "Kimi K2.5 (reasoning + coding)".to_string(), + ), + ( + "deepseek-ai/DeepSeek-V3.1".to_string(), + "DeepSeek V3.1 (strong value)".to_string(), + ), + ], + "cohere" => vec![ + ( + "command-a-03-2025".to_string(), + "Command A (flagship enterprise model)".to_string(), + ), + ( + "command-a-reasoning-08-2025".to_string(), + "Command A Reasoning (agentic reasoning)".to_string(), + ), + ( + "command-r-08-2024".to_string(), + "Command R (stable fast baseline)".to_string(), + ), + ], + "moonshot" => vec![ + ( + "kimi-latest".to_string(), + "Kimi Latest (rolling latest assistant model)".to_string(), + ), + ( + "kimi-k2-0905-preview".to_string(), + "Kimi K2 0905 Preview (strong coding)".to_string(), + ), + ( + "kimi-thinking-preview".to_string(), + "Kimi Thinking Preview (deep reasoning)".to_string(), + ), + ], + "glm" | "zai" => vec![ + ( + "glm-4.7".to_string(), + "GLM-4.7 (latest flagship)".to_string(), + ), + ("glm-5".to_string(), "GLM-5 (high reasoning)".to_string()), + ( + "glm-4-plus".to_string(), + "GLM-4 Plus (stable baseline)".to_string(), + ), + ], + "minimax" => vec![ + ( + "MiniMax-M2.5".to_string(), + "MiniMax M2.5 (latest flagship)".to_string(), + ), + ( + "MiniMax-M2.1".to_string(), + "MiniMax M2.1 (strong coding/reasoning)".to_string(), + ), + ( + "MiniMax-M2.1-lightning".to_string(), + "MiniMax M2.1 Lightning (fast)".to_string(), + ), + ], + "qwen" => vec![ + ( + "qwen-max".to_string(), + "Qwen Max (highest quality)".to_string(), + ), + ( + "qwen-plus".to_string(), + "Qwen Plus (balanced default)".to_string(), + ), + ( + "qwen-turbo".to_string(), + "Qwen Turbo (fast and cost-efficient)".to_string(), + ), + ], + "ollama" => vec![ + ( + "llama3.2".to_string(), + "Llama 3.2 (recommended local)".to_string(), + ), + ("mistral".to_string(), "Mistral 7B".to_string()), + ("codellama".to_string(), "Code Llama".to_string()), + ("phi3".to_string(), "Phi-3 (small, fast)".to_string()), + ], + "gemini" => vec![ + ( + "gemini-3-pro-preview".to_string(), + "Gemini 3 Pro Preview (latest frontier reasoning)".to_string(), + ), + ( + "gemini-2.5-pro".to_string(), + "Gemini 2.5 Pro (stable reasoning)".to_string(), + ), + ( + "gemini-2.5-flash".to_string(), + "Gemini 2.5 Flash (best price/performance)".to_string(), + ), + ( + "gemini-2.5-flash-lite".to_string(), + "Gemini 2.5 Flash-Lite (lowest cost)".to_string(), + ), + ], + _ => vec![("default".to_string(), "Default model".to_string())], + } +} + +fn supports_live_model_fetch(provider_name: &str) -> bool { + matches!( + canonical_provider_name(provider_name), + "openrouter" + | "openai" + | "anthropic" + | "groq" + | "mistral" + | "deepseek" + | "xai" + | "together-ai" + | "gemini" + | "ollama" + ) +} + +fn build_model_fetch_client() -> Result { + reqwest::blocking::Client::builder() + .timeout(Duration::from_secs(8)) + .connect_timeout(Duration::from_secs(4)) + .build() + .context("failed to build model-fetch HTTP client") +} + +fn normalize_model_ids(ids: Vec) -> Vec { + let mut unique = BTreeSet::new(); + for id in ids { + let trimmed = id.trim(); + if !trimmed.is_empty() { + unique.insert(trimmed.to_string()); + } + } + unique.into_iter().collect() +} + +fn parse_openai_compatible_model_ids(payload: &Value) -> Vec { + let mut models = Vec::new(); + + if let Some(data) = payload.get("data").and_then(Value::as_array) { + for model in data { + if let Some(id) = model.get("id").and_then(Value::as_str) { + models.push(id.to_string()); + } + } + } else if let Some(data) = payload.as_array() { + for model in data { + if let Some(id) = model.get("id").and_then(Value::as_str) { + models.push(id.to_string()); + } + } + } + + normalize_model_ids(models) +} + +fn parse_gemini_model_ids(payload: &Value) -> Vec { + let Some(models) = payload.get("models").and_then(Value::as_array) else { + return Vec::new(); + }; + + let mut ids = Vec::new(); + for model in models { + let supports_generate_content = model + .get("supportedGenerationMethods") + .and_then(Value::as_array) + .is_none_or(|methods| { + methods + .iter() + .any(|method| method.as_str() == Some("generateContent")) + }); + + if !supports_generate_content { + continue; + } + + if let Some(name) = model.get("name").and_then(Value::as_str) { + ids.push(name.trim_start_matches("models/").to_string()); + } + } + + normalize_model_ids(ids) +} + +fn parse_ollama_model_ids(payload: &Value) -> Vec { + let Some(models) = payload.get("models").and_then(Value::as_array) else { + return Vec::new(); + }; + + let mut ids = Vec::new(); + for model in models { + if let Some(name) = model.get("name").and_then(Value::as_str) { + ids.push(name.to_string()); + } + } + + normalize_model_ids(ids) +} + +fn fetch_openai_compatible_models(endpoint: &str, api_key: Option<&str>) -> Result> { + let Some(api_key) = api_key else { + return Ok(Vec::new()); + }; + + let client = build_model_fetch_client()?; + let payload: Value = client + .get(endpoint) + .bearer_auth(api_key) + .send() + .and_then(reqwest::blocking::Response::error_for_status) + .with_context(|| format!("model fetch failed: GET {endpoint}"))? + .json() + .context("failed to parse model list response")?; + + Ok(parse_openai_compatible_model_ids(&payload)) +} + +fn fetch_openrouter_models(api_key: Option<&str>) -> Result> { + let client = build_model_fetch_client()?; + let mut request = client.get("https://openrouter.ai/api/v1/models"); + if let Some(api_key) = api_key { + request = request.bearer_auth(api_key); + } + + let payload: Value = request + .send() + .and_then(reqwest::blocking::Response::error_for_status) + .context("model fetch failed: GET https://openrouter.ai/api/v1/models")? + .json() + .context("failed to parse OpenRouter model list response")?; + + Ok(parse_openai_compatible_model_ids(&payload)) +} + +fn fetch_anthropic_models(api_key: Option<&str>) -> Result> { + let Some(api_key) = api_key else { + return Ok(Vec::new()); + }; + + let client = build_model_fetch_client()?; + let mut request = client + .get("https://api.anthropic.com/v1/models") + .header("anthropic-version", "2023-06-01"); + + if api_key.starts_with("sk-ant-oat01-") { + request = request + .header("Authorization", format!("Bearer {api_key}")) + .header("anthropic-beta", "oauth-2025-04-20"); + } else { + request = request.header("x-api-key", api_key); + } + + let response = request + .send() + .context("model fetch failed: GET https://api.anthropic.com/v1/models")?; + + let status = response.status(); + if !status.is_success() { + let body = response.text().unwrap_or_default(); + bail!("Anthropic model list request failed (HTTP {status}): {body}"); + } + + let payload: Value = response + .json() + .context("failed to parse Anthropic model list response")?; + + Ok(parse_openai_compatible_model_ids(&payload)) +} + +fn fetch_gemini_models(api_key: Option<&str>) -> Result> { + let Some(api_key) = api_key else { + return Ok(Vec::new()); + }; + + let client = build_model_fetch_client()?; + let payload: Value = client + .get("https://generativelanguage.googleapis.com/v1beta/models") + .query(&[("key", api_key), ("pageSize", "200")]) + .send() + .and_then(reqwest::blocking::Response::error_for_status) + .context("model fetch failed: GET Gemini models")? + .json() + .context("failed to parse Gemini model list response")?; + + Ok(parse_gemini_model_ids(&payload)) +} + +fn fetch_ollama_models() -> Result> { + let client = build_model_fetch_client()?; + let payload: Value = client + .get("http://localhost:11434/api/tags") + .send() + .and_then(reqwest::blocking::Response::error_for_status) + .context("model fetch failed: GET http://localhost:11434/api/tags")? + .json() + .context("failed to parse Ollama model list response")?; + + Ok(parse_ollama_model_ids(&payload)) +} + +fn fetch_live_models_for_provider(provider_name: &str, api_key: &str) -> Result> { + let provider_name = canonical_provider_name(provider_name); + let api_key = if api_key.trim().is_empty() { + std::env::var(provider_env_var(provider_name)) + .ok() + .or_else(|| { + // Anthropic also accepts OAuth setup-tokens via ANTHROPIC_OAUTH_TOKEN + if provider_name == "anthropic" { + std::env::var("ANTHROPIC_OAUTH_TOKEN").ok() + } else { + None + } + }) + .map(|value| value.trim().to_string()) + .filter(|value| !value.is_empty()) + } else { + Some(api_key.trim().to_string()) + }; + + let models = match provider_name { + "openrouter" => fetch_openrouter_models(api_key.as_deref())?, + "openai" => { + fetch_openai_compatible_models("https://api.openai.com/v1/models", api_key.as_deref())? + } + "groq" => fetch_openai_compatible_models( + "https://api.groq.com/openai/v1/models", + api_key.as_deref(), + )?, + "mistral" => { + fetch_openai_compatible_models("https://api.mistral.ai/v1/models", api_key.as_deref())? + } + "deepseek" => fetch_openai_compatible_models( + "https://api.deepseek.com/v1/models", + api_key.as_deref(), + )?, + "xai" => fetch_openai_compatible_models("https://api.x.ai/v1/models", api_key.as_deref())?, + "together-ai" => fetch_openai_compatible_models( + "https://api.together.xyz/v1/models", + api_key.as_deref(), + )?, + "anthropic" => fetch_anthropic_models(api_key.as_deref())?, + "gemini" => fetch_gemini_models(api_key.as_deref())?, + "ollama" => fetch_ollama_models()?, + _ => Vec::new(), + }; + + Ok(models) +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct ModelCacheEntry { + provider: String, + fetched_at_unix: u64, + models: Vec, +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +struct ModelCacheState { + entries: Vec, +} + +#[derive(Debug, Clone)] +struct CachedModels { + models: Vec, + age_secs: u64, +} + +fn model_cache_path(workspace_dir: &Path) -> PathBuf { + workspace_dir.join("state").join(MODEL_CACHE_FILE) +} + +fn now_unix_secs() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map_or(0, |duration| duration.as_secs()) +} + +fn load_model_cache_state(workspace_dir: &Path) -> Result { + let path = model_cache_path(workspace_dir); + if !path.exists() { + return Ok(ModelCacheState::default()); + } + + let raw = fs::read_to_string(&path) + .with_context(|| format!("failed to read model cache at {}", path.display()))?; + + match serde_json::from_str::(&raw) { + Ok(state) => Ok(state), + Err(_) => Ok(ModelCacheState::default()), + } +} + +fn save_model_cache_state(workspace_dir: &Path, state: &ModelCacheState) -> Result<()> { + let path = model_cache_path(workspace_dir); + if let Some(parent) = path.parent() { + fs::create_dir_all(parent).with_context(|| { + format!( + "failed to create model cache directory {}", + parent.display() + ) + })?; + } + + let json = serde_json::to_vec_pretty(state).context("failed to serialize model cache")?; + fs::write(&path, json) + .with_context(|| format!("failed to write model cache at {}", path.display()))?; + + Ok(()) +} + +fn cache_live_models_for_provider( + workspace_dir: &Path, + provider_name: &str, + models: &[String], +) -> Result<()> { + let normalized_models = normalize_model_ids(models.to_vec()); + if normalized_models.is_empty() { + return Ok(()); + } + + let mut state = load_model_cache_state(workspace_dir)?; + let now = now_unix_secs(); + + if let Some(entry) = state + .entries + .iter_mut() + .find(|entry| entry.provider == provider_name) + { + entry.fetched_at_unix = now; + entry.models = normalized_models; + } else { + state.entries.push(ModelCacheEntry { + provider: provider_name.to_string(), + fetched_at_unix: now, + models: normalized_models, + }); + } + + save_model_cache_state(workspace_dir, &state) +} + +fn load_cached_models_for_provider_internal( + workspace_dir: &Path, + provider_name: &str, + ttl_secs: Option, +) -> Result> { + let state = load_model_cache_state(workspace_dir)?; + let now = now_unix_secs(); + + let Some(entry) = state + .entries + .into_iter() + .find(|entry| entry.provider == provider_name) + else { + return Ok(None); + }; + + if entry.models.is_empty() { + return Ok(None); + } + + let age_secs = now.saturating_sub(entry.fetched_at_unix); + if ttl_secs.is_some_and(|ttl| age_secs > ttl) { + return Ok(None); + } + + Ok(Some(CachedModels { + models: entry.models, + age_secs, + })) +} + +fn load_cached_models_for_provider( + workspace_dir: &Path, + provider_name: &str, + ttl_secs: u64, +) -> Result> { + load_cached_models_for_provider_internal(workspace_dir, provider_name, Some(ttl_secs)) +} + +fn load_any_cached_models_for_provider( + workspace_dir: &Path, + provider_name: &str, +) -> Result> { + load_cached_models_for_provider_internal(workspace_dir, provider_name, None) +} + +fn humanize_age(age_secs: u64) -> String { + if age_secs < 60 { + format!("{age_secs}s") + } else if age_secs < 60 * 60 { + format!("{}m", age_secs / 60) + } else { + format!("{}h", age_secs / (60 * 60)) + } +} + +fn build_model_options(model_ids: Vec, source: &str) -> Vec<(String, String)> { + model_ids + .into_iter() + .map(|model_id| { + let label = format!("{model_id} ({source})"); + (model_id, label) + }) + .collect() +} + +fn print_model_preview(models: &[String]) { + for model in models.iter().take(MODEL_PREVIEW_LIMIT) { + println!(" {} {model}", style("-")); + } + + if models.len() > MODEL_PREVIEW_LIMIT { + println!( + " {} ... and {} more", + style("-"), + models.len() - MODEL_PREVIEW_LIMIT + ); + } +} + +pub fn run_models_refresh( + config: &Config, + provider_override: Option<&str>, + force: bool, +) -> Result<()> { + let provider_name = provider_override + .or(config.default_provider.as_deref()) + .unwrap_or("openrouter") + .trim() + .to_string(); + + if provider_name.is_empty() { + anyhow::bail!("Provider name cannot be empty"); + } + + if !supports_live_model_fetch(&provider_name) { + anyhow::bail!("Provider '{provider_name}' does not support live model discovery yet"); + } + + if !force { + if let Some(cached) = load_cached_models_for_provider( + &config.workspace_dir, + &provider_name, + MODEL_CACHE_TTL_SECS, + )? { + println!( + "Using cached model list for '{}' (updated {} ago):", + provider_name, + humanize_age(cached.age_secs) + ); + print_model_preview(&cached.models); + println!(); + println!( + "Tip: run `zeroclaw models refresh --force --provider {}` to fetch latest now.", + provider_name + ); + return Ok(()); + } + } + + let api_key = config.api_key.clone().unwrap_or_default(); + + match fetch_live_models_for_provider(&provider_name, &api_key) { + Ok(models) if !models.is_empty() => { + cache_live_models_for_provider(&config.workspace_dir, &provider_name, &models)?; + println!( + "Refreshed '{}' model cache with {} models.", + provider_name, + models.len() + ); + print_model_preview(&models); + Ok(()) + } + Ok(_) => { + if let Some(stale_cache) = + load_any_cached_models_for_provider(&config.workspace_dir, &provider_name)? + { + println!( + "Provider returned no models; using stale cache (updated {} ago):", + humanize_age(stale_cache.age_secs) + ); + print_model_preview(&stale_cache.models); + return Ok(()); + } + + anyhow::bail!("Provider '{}' returned an empty model list", provider_name) + } + Err(error) => { + if let Some(stale_cache) = + load_any_cached_models_for_provider(&config.workspace_dir, &provider_name)? + { + println!( + "Live refresh failed ({}). Falling back to stale cache (updated {} ago):", + error, + humanize_age(stale_cache.age_secs) + ); + print_model_preview(&stale_cache.models); + return Ok(()); + } + + Err(error) + .with_context(|| format!("failed to refresh models for provider '{provider_name}'")) + } } } @@ -419,6 +1283,18 @@ fn print_bullet(text: &str) { println!(" {} {}", style("›").cyan(), text); } +fn persist_workspace_selection(config_path: &Path) -> Result<()> { + let config_dir = config_path + .parent() + .context("Config path must have a parent directory")?; + crate::config::schema::persist_active_workspace_config_dir(config_dir).with_context(|| { + format!( + "Failed to persist active workspace selection for {}", + config_dir.display() + ) + }) +} + // ── Step 1: Workspace ──────────────────────────────────────────── fn setup_workspace() -> Result<(PathBuf, PathBuf)> { @@ -464,13 +1340,13 @@ fn setup_workspace() -> Result<(PathBuf, PathBuf)> { // ── Step 2: Provider & API Key ─────────────────────────────────── #[allow(clippy::too_many_lines)] -fn setup_provider() -> Result<(String, String, String)> { +fn setup_provider(workspace_dir: &Path) -> Result<(String, String, String, Option)> { // ── Tier selection ── let tiers = vec![ - "⭐ Recommended (OpenRouter, Venice, Anthropic, OpenAI)", - "⚡ Fast inference (Groq, Fireworks, Together AI)", + "⭐ Recommended (OpenRouter, Venice, Anthropic, OpenAI, Gemini)", + "⚡ Fast inference (Groq, Fireworks, Together AI, NVIDIA NIM)", "🌐 Gateway / proxy (Vercel AI, Cloudflare AI, Amazon Bedrock)", - "🔬 Specialized (Moonshot/Kimi, GLM/Zhipu, MiniMax, Qianfan, Z.AI, Synthetic, OpenCode Zen, Cohere)", + "🔬 Specialized (Moonshot/Kimi, GLM/Zhipu, MiniMax, Qwen/DashScope, Qianfan, Z.AI, Synthetic, OpenCode Zen, Cohere)", "🏠 Local / private (Ollama — no API key needed)", "🔧 Custom — bring your own OpenAI-compatible API", ]; @@ -494,11 +1370,16 @@ fn setup_provider() -> Result<(String, String, String)> { ("mistral", "Mistral — Large & Codestral"), ("xai", "xAI — Grok 3 & 4"), ("perplexity", "Perplexity — search-augmented AI"), + ( + "gemini", + "Google Gemini — Gemini 2.0 Flash & Pro (supports CLI auth)", + ), ], 1 => vec![ ("groq", "Groq — ultra-fast LPU inference"), ("fireworks", "Fireworks AI — fast open-source inference"), - ("together", "Together AI — open-source model hosting"), + ("together-ai", "Together AI — open-source model hosting"), + ("nvidia", "NVIDIA NIM — DeepSeek, Llama, & more"), ], 2 => vec![ ("vercel", "Vercel AI Gateway"), @@ -506,11 +1387,24 @@ fn setup_provider() -> Result<(String, String, String)> { ("bedrock", "Amazon Bedrock — AWS managed models"), ], 3 => vec![ - ("moonshot", "Moonshot — Kimi & Kimi Coding"), - ("glm", "GLM — ChatGLM / Zhipu models"), - ("minimax", "MiniMax — MiniMax AI models"), - ("qianfan", "Qianfan — Baidu AI models"), - ("zai", "Z.AI — Z.AI inference"), + ("moonshot", "Moonshot — Kimi API (China endpoint)"), + ( + "moonshot-intl", + "Moonshot — Kimi API (international endpoint)", + ), + ("glm", "GLM — ChatGLM / Zhipu (international endpoint)"), + ("glm-cn", "GLM — ChatGLM / Zhipu (China endpoint)"), + ( + "minimax", + "MiniMax — international endpoint (api.minimax.io)", + ), + ("minimax-cn", "MiniMax — China endpoint (api.minimaxi.com)"), + ("qwen", "Qwen — DashScope China endpoint"), + ("qwen-intl", "Qwen — DashScope international endpoint"), + ("qwen-us", "Qwen — DashScope US endpoint"), + ("qianfan", "Qianfan — Baidu AI models (China endpoint)"), + ("zai", "Z.AI — global coding endpoint"), + ("zai-cn", "Z.AI — China coding endpoint (open.bigmodel.cn)"), ("synthetic", "Synthetic — Synthetic AI models"), ("opencode", "OpenCode Zen — code-focused AI"), ("cohere", "Cohere — Command R+ & embeddings"), @@ -559,7 +1453,7 @@ fn setup_provider() -> Result<(String, String, String)> { style(&model).green() ); - return Ok((provider_name, api_key, model)); + return Ok((provider_name, api_key, model, None)); } let provider_labels: Vec<&str> = providers.iter().map(|(_, label)| *label).collect(); @@ -572,30 +1466,168 @@ fn setup_provider() -> Result<(String, String, String)> { let provider_name = providers[provider_idx].0; - // ── API key ── + // ── API key / endpoint ── + let mut provider_api_url: Option = None; let api_key = if provider_name == "ollama" { - print_bullet("Ollama runs locally — no API key needed!"); - String::new() + let use_remote_ollama = Confirm::new() + .with_prompt(" Use a remote Ollama endpoint (for example Ollama Cloud)?") + .default(false) + .interact()?; + + if use_remote_ollama { + let raw_url: String = Input::new() + .with_prompt(" Remote Ollama endpoint URL") + .default("https://ollama.com".into()) + .interact_text()?; + + let normalized_url = raw_url.trim().trim_end_matches('/').to_string(); + if normalized_url.is_empty() { + anyhow::bail!("Remote Ollama endpoint URL cannot be empty."); + } + + provider_api_url = Some(normalized_url.clone()); + + print_bullet(&format!( + "Remote endpoint configured: {}", + style(&normalized_url).cyan() + )); + print_bullet(&format!( + "If you use cloud-only models, append {} to the model ID.", + style(":cloud").yellow() + )); + + let key: String = Input::new() + .with_prompt(" API key for remote Ollama endpoint (or Enter to skip)") + .allow_empty(true) + .interact_text()?; + + if key.trim().is_empty() { + print_bullet(&format!( + "No API key provided. Set {} later if required by your endpoint.", + style("OLLAMA_API_KEY").yellow() + )); + } + + key + } else { + print_bullet("Using local Ollama at http://localhost:11434 (no API key needed)."); + String::new() + } + } else if canonical_provider_name(provider_name) == "gemini" { + // Special handling for Gemini: check for CLI auth first + if crate::providers::gemini::GeminiProvider::has_cli_credentials() { + print_bullet(&format!( + "{} Gemini CLI credentials detected! You can skip the API key.", + style("✓").green().bold() + )); + print_bullet("ZeroClaw will reuse your existing Gemini CLI authentication."); + println!(); + + let use_cli: bool = dialoguer::Confirm::new() + .with_prompt(" Use existing Gemini CLI authentication?") + .default(true) + .interact()?; + + if use_cli { + println!( + " {} Using Gemini CLI OAuth tokens", + style("✓").green().bold() + ); + String::new() // Empty key = will use CLI tokens + } else { + print_bullet("Get your API key at: https://aistudio.google.com/app/apikey"); + Input::new() + .with_prompt(" Paste your Gemini API key") + .allow_empty(true) + .interact_text()? + } + } else if std::env::var("GEMINI_API_KEY").is_ok() { + print_bullet(&format!( + "{} GEMINI_API_KEY environment variable detected!", + style("✓").green().bold() + )); + String::new() + } else { + print_bullet("Get your API key at: https://aistudio.google.com/app/apikey"); + print_bullet("Or run `gemini` CLI to authenticate (tokens will be reused)."); + println!(); + + Input::new() + .with_prompt(" Paste your Gemini API key (or press Enter to skip)") + .allow_empty(true) + .interact_text()? + } + } else if canonical_provider_name(provider_name) == "anthropic" { + if std::env::var("ANTHROPIC_OAUTH_TOKEN").is_ok() { + print_bullet(&format!( + "{} ANTHROPIC_OAUTH_TOKEN environment variable detected!", + style("✓").green().bold() + )); + String::new() + } else if std::env::var("ANTHROPIC_API_KEY").is_ok() { + print_bullet(&format!( + "{} ANTHROPIC_API_KEY environment variable detected!", + style("✓").green().bold() + )); + String::new() + } else { + print_bullet(&format!( + "Get your API key at: {}", + style("https://console.anthropic.com/settings/keys") + .cyan() + .underlined() + )); + print_bullet("Or run `claude setup-token` to get an OAuth setup-token."); + println!(); + + let key: String = Input::new() + .with_prompt(" Paste your API key or setup-token (or press Enter to skip)") + .allow_empty(true) + .interact_text()?; + + if key.is_empty() { + print_bullet(&format!( + "Skipped. Set {} or {} or edit config.toml later.", + style("ANTHROPIC_API_KEY").yellow(), + style("ANTHROPIC_OAUTH_TOKEN").yellow() + )); + } + + key + } } else { - let key_url = match provider_name { - "openrouter" => "https://openrouter.ai/keys", - "anthropic" => "https://console.anthropic.com/settings/keys", - "openai" => "https://platform.openai.com/api-keys", - "venice" => "https://venice.ai/settings/api", - "groq" => "https://console.groq.com/keys", - "mistral" => "https://console.mistral.ai/api-keys", - "deepseek" => "https://platform.deepseek.com/api_keys", - "together" => "https://api.together.xyz/settings/api-keys", - "fireworks" => "https://fireworks.ai/account/api-keys", - "perplexity" => "https://www.perplexity.ai/settings/api", - "xai" => "https://console.x.ai", - "cohere" => "https://dashboard.cohere.com/api-keys", - "moonshot" => "https://platform.moonshot.cn/console/api-keys", - "minimax" => "https://www.minimaxi.com/user-center/basic-information", - "vercel" => "https://vercel.com/account/tokens", - "cloudflare" => "https://dash.cloudflare.com/profile/api-tokens", - "bedrock" => "https://console.aws.amazon.com/iam", - _ => "", + let key_url = if is_moonshot_alias(provider_name) { + "https://platform.moonshot.cn/console/api-keys" + } else if is_glm_cn_alias(provider_name) || is_zai_cn_alias(provider_name) { + "https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys" + } else if is_glm_alias(provider_name) || is_zai_alias(provider_name) { + "https://platform.z.ai/" + } else if is_minimax_alias(provider_name) { + "https://www.minimaxi.com/user-center/basic-information" + } else if is_qwen_alias(provider_name) { + "https://help.aliyun.com/zh/model-studio/developer-reference/get-api-key" + } else if is_qianfan_alias(provider_name) { + "https://cloud.baidu.com/doc/WENXINWORKSHOP/s/7lm0vxo78" + } else { + match provider_name { + "openrouter" => "https://openrouter.ai/keys", + "openai" => "https://platform.openai.com/api-keys", + "venice" => "https://venice.ai/settings/api", + "groq" => "https://console.groq.com/keys", + "mistral" => "https://console.mistral.ai/api-keys", + "deepseek" => "https://platform.deepseek.com/api_keys", + "together-ai" => "https://api.together.xyz/settings/api-keys", + "fireworks" => "https://fireworks.ai/account/api-keys", + "perplexity" => "https://www.perplexity.ai/settings/api", + "xai" => "https://console.x.ai", + "cohere" => "https://dashboard.cohere.com/api-keys", + "vercel" => "https://vercel.com/account/tokens", + "cloudflare" => "https://dash.cloudflare.com/profile/api-tokens", + "nvidia" | "nvidia-nim" | "build.nvidia.com" => "https://build.nvidia.com/", + "bedrock" => "https://console.aws.amazon.com/iam", + "gemini" => "https://aistudio.google.com/app/apikey", + _ => "", + } }; println!(); @@ -625,10 +1657,11 @@ fn setup_provider() -> Result<(String, String, String)> { }; // ── Model selection ── - let models: Vec<(&str, &str)> = match provider_name { + let canonical_provider = canonical_provider_name(provider_name); + let models: Vec<(&str, &str)> = match canonical_provider { "openrouter" => vec![ ( - "anthropic/claude-sonnet-4-20250514", + "anthropic/claude-sonnet-4", "Claude Sonnet 4 (balanced, recommended)", ), ( @@ -703,7 +1736,7 @@ fn setup_provider() -> Result<(String, String, String)> { "Mixtral 8x22B", ), ], - "together" => vec![ + "together-ai" => vec![ ( "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", "Llama 3.1 70B Turbo", @@ -714,6 +1747,12 @@ fn setup_provider() -> Result<(String, String, String)> { ), ("mistralai/Mixtral-8x22B-Instruct-v0.1", "Mixtral 8x22B"), ], + "nvidia" | "nvidia-nim" | "build.nvidia.com" => vec![ + ("deepseek-ai/DeepSeek-R1", "DeepSeek R1 (reasoning)"), + ("meta/llama-3.1-70b-instruct", "Llama 3.1 70B Instruct"), + ("mistralai/Mistral-7B-Instruct-v0.3", "Mistral 7B Instruct"), + ("meta/llama-3.1-405b-instruct", "Llama 3.1 405B Instruct"), + ], "cohere" => vec![ ("command-r-plus", "Command R+ (flagship)"), ("command-r", "Command R (fast)"), @@ -722,15 +1761,16 @@ fn setup_provider() -> Result<(String, String, String)> { ("moonshot-v1-128k", "Moonshot V1 128K"), ("moonshot-v1-32k", "Moonshot V1 32K"), ], - "glm" => vec![ - ("glm-4.7", "GLM-4.7 (flagship, 358B, recommended)"), - ("glm-4.7-flash", "GLM-4.7 Flash (fast, free-tier)"), - ("glm-4-plus", "GLM-4 Plus (previous gen)"), - ("glm-4-flash", "GLM-4 Flash (previous gen, fast)"), + "glm" | "zai" => vec![ + ("glm-5", "GLM-5 (latest)"), + ("glm-4-plus", "GLM-4 Plus (flagship)"), + ("glm-4-flash", "GLM-4 Flash (fast)"), ], - "minimax" => vec![ - ("abab6.5s-chat", "ABAB 6.5s Chat"), - ("abab6.5-chat", "ABAB 6.5 Chat"), + "minimax" => MINIMAX_ONBOARD_MODELS.to_vec(), + "qwen" => vec![ + ("qwen-plus", "Qwen Plus (balanced default)"), + ("qwen-max", "Qwen Max (highest quality)"), + ("qwen-turbo", "Qwen Turbo (fast and cost-efficient)"), ], "ollama" => vec![ ("llama3.2", "Llama 3.2 (recommended local)"), @@ -738,10 +1778,160 @@ fn setup_provider() -> Result<(String, String, String)> { ("codellama", "Code Llama"), ("phi3", "Phi-3 (small, fast)"), ], + "gemini" | "google" | "google-gemini" => vec![ + ("gemini-2.0-flash", "Gemini 2.0 Flash (fast, recommended)"), + ( + "gemini-2.0-flash-lite", + "Gemini 2.0 Flash Lite (fastest, cheapest)", + ), + ("gemini-1.5-pro", "Gemini 1.5 Pro (best quality)"), + ("gemini-1.5-flash", "Gemini 1.5 Flash (balanced)"), + ], _ => vec![("default", "Default model")], }; - let model_labels: Vec<&str> = models.iter().map(|(_, label)| *label).collect(); + let mut model_options: Vec<(String, String)> = models + .into_iter() + .map(|(model_id, label)| (model_id.to_string(), label.to_string())) + .collect(); + let mut live_options: Option> = None; + + if provider_name == "ollama" && provider_api_url.is_some() { + print_bullet( + "Skipping local Ollama model discovery because a remote endpoint is configured.", + ); + } else if supports_live_model_fetch(provider_name) { + let can_fetch_without_key = matches!(provider_name, "openrouter" | "ollama"); + let has_api_key = !api_key.trim().is_empty() + || std::env::var(provider_env_var(provider_name)) + .ok() + .is_some_and(|value| !value.trim().is_empty()); + + if can_fetch_without_key || has_api_key { + if let Some(cached) = + load_cached_models_for_provider(workspace_dir, provider_name, MODEL_CACHE_TTL_SECS)? + { + let shown_count = cached.models.len().min(LIVE_MODEL_MAX_OPTIONS); + print_bullet(&format!( + "Found cached models ({shown_count}) updated {} ago.", + humanize_age(cached.age_secs) + )); + + live_options = Some(build_model_options( + cached + .models + .into_iter() + .take(LIVE_MODEL_MAX_OPTIONS) + .collect(), + "cached", + )); + } + + let should_fetch_now = Confirm::new() + .with_prompt(if live_options.is_some() { + " Refresh models from provider now?" + } else { + " Fetch latest models from provider now?" + }) + .default(live_options.is_none()) + .interact()?; + + if should_fetch_now { + match fetch_live_models_for_provider(provider_name, &api_key) { + Ok(live_model_ids) if !live_model_ids.is_empty() => { + cache_live_models_for_provider( + workspace_dir, + provider_name, + &live_model_ids, + )?; + + let fetched_count = live_model_ids.len(); + let shown_count = fetched_count.min(LIVE_MODEL_MAX_OPTIONS); + let shown_models: Vec = live_model_ids + .into_iter() + .take(LIVE_MODEL_MAX_OPTIONS) + .collect(); + + if shown_count < fetched_count { + print_bullet(&format!( + "Fetched {fetched_count} models. Showing first {shown_count}." + )); + } else { + print_bullet(&format!("Fetched {shown_count} live models.")); + } + + live_options = Some(build_model_options(shown_models, "live")); + } + Ok(_) => { + print_bullet("Provider returned no models; using curated list."); + } + Err(error) => { + print_bullet(&format!( + "Live fetch failed ({}); using cached/curated list.", + style(error.to_string()).yellow() + )); + + if live_options.is_none() { + if let Some(stale) = + load_any_cached_models_for_provider(workspace_dir, provider_name)? + { + print_bullet(&format!( + "Loaded stale cache from {} ago.", + humanize_age(stale.age_secs) + )); + + live_options = Some(build_model_options( + stale + .models + .into_iter() + .take(LIVE_MODEL_MAX_OPTIONS) + .collect(), + "stale-cache", + )); + } + } + } + } + } + } else { + print_bullet("No API key detected, so using curated model list."); + print_bullet("Tip: add an API key and rerun onboarding to fetch live models."); + } + } + + if let Some(live_model_options) = live_options { + let source_options = vec![ + format!("Provider model list ({})", live_model_options.len()), + format!("Curated starter list ({})", model_options.len()), + ]; + + let source_idx = Select::new() + .with_prompt(" Model source") + .items(&source_options) + .default(0) + .interact()?; + + if source_idx == 0 { + model_options = live_model_options; + } + } + + if model_options.is_empty() { + model_options.push(( + default_model_for_provider(provider_name), + "Provider default model".to_string(), + )); + } + + model_options.push(( + CUSTOM_MODEL_SENTINEL.to_string(), + "Custom model ID (type manually)".to_string(), + )); + + let model_labels: Vec = model_options + .iter() + .map(|(model_id, label)| format!("{label} — {}", style(model_id).dim())) + .collect(); let model_idx = Select::new() .with_prompt(" Select your default model") @@ -749,7 +1939,15 @@ fn setup_provider() -> Result<(String, String, String)> { .default(0) .interact()?; - let model = models[model_idx].0.to_string(); + let selected_model = model_options[model_idx].0.clone(); + let model = if selected_model == CUSTOM_MODEL_SENTINEL { + Input::new() + .with_prompt(" Enter custom model ID") + .default(default_model_for_provider(provider_name)) + .interact_text()? + } else { + selected_model + }; println!( " {} Provider: {} | Model: {}", @@ -758,34 +1956,38 @@ fn setup_provider() -> Result<(String, String, String)> { style(&model).green() ); - Ok((provider_name.to_string(), api_key, model)) + Ok((provider_name.to_string(), api_key, model, provider_api_url)) } /// Map provider name to its conventional env var fn provider_env_var(name: &str) -> &'static str { - match name { + match canonical_provider_name(name) { "openrouter" => "OPENROUTER_API_KEY", "anthropic" => "ANTHROPIC_API_KEY", "openai" => "OPENAI_API_KEY", + "ollama" => "OLLAMA_API_KEY", "venice" => "VENICE_API_KEY", "groq" => "GROQ_API_KEY", "mistral" => "MISTRAL_API_KEY", "deepseek" => "DEEPSEEK_API_KEY", - "xai" | "grok" => "XAI_API_KEY", - "together" | "together-ai" => "TOGETHER_API_KEY", + "xai" => "XAI_API_KEY", + "together-ai" => "TOGETHER_API_KEY", "fireworks" | "fireworks-ai" => "FIREWORKS_API_KEY", "perplexity" => "PERPLEXITY_API_KEY", "cohere" => "COHERE_API_KEY", - "moonshot" | "kimi" => "MOONSHOT_API_KEY", - "glm" | "zhipu" => "GLM_API_KEY", + "moonshot" => "MOONSHOT_API_KEY", + "glm" => "GLM_API_KEY", "minimax" => "MINIMAX_API_KEY", - "qianfan" | "baidu" => "QIANFAN_API_KEY", - "zai" | "z.ai" => "ZAI_API_KEY", + "qwen" => "DASHSCOPE_API_KEY", + "qianfan" => "QIANFAN_API_KEY", + "zai" => "ZAI_API_KEY", "synthetic" => "SYNTHETIC_API_KEY", "opencode" | "opencode-zen" => "OPENCODE_API_KEY", "vercel" | "vercel-ai" => "VERCEL_API_KEY", "cloudflare" | "cloudflare-ai" => "CLOUDFLARE_API_KEY", "bedrock" | "aws-bedrock" => "AWS_ACCESS_KEY_ID", + "gemini" => "GEMINI_API_KEY", + "nvidia" | "nvidia-nim" | "build.nvidia.com" => "NVIDIA_API_KEY", _ => "API_KEY", } } @@ -880,6 +2082,194 @@ fn setup_tool_mode() -> Result<(ComposioConfig, SecretsConfig)> { Ok((composio_config, secrets_config)) } +// ── Step 6: Hardware (Physical World) ─────────────────────────── + +fn setup_hardware() -> Result { + print_bullet("ZeroClaw can talk to physical hardware (LEDs, sensors, motors)."); + print_bullet("Scanning for connected devices..."); + println!(); + + // ── Auto-discovery ── + let devices = hardware::discover_hardware(); + + if devices.is_empty() { + println!( + " {} {}", + style("ℹ").dim(), + style("No hardware devices detected on this system.").dim() + ); + println!( + " {} {}", + style("ℹ").dim(), + style("You can enable hardware later in config.toml under [hardware].").dim() + ); + } else { + println!( + " {} {} device(s) found:", + style("✓").green().bold(), + devices.len() + ); + for device in &devices { + let detail = device + .detail + .as_deref() + .map(|d| format!(" ({d})")) + .unwrap_or_default(); + let path = device + .device_path + .as_deref() + .map(|p| format!(" → {p}")) + .unwrap_or_default(); + println!( + " {} {}{}{} [{}]", + style("›").cyan(), + style(&device.name).green(), + style(&detail).dim(), + style(&path).dim(), + style(device.transport.to_string()).cyan() + ); + } + } + println!(); + + let options = vec![ + "🚀 Native — direct GPIO on this Linux board (Raspberry Pi, Orange Pi, etc.)", + "🔌 Tethered — control an Arduino/ESP32/Nucleo plugged into USB", + "🔬 Debug Probe — flash/read MCUs via SWD/JTAG (probe-rs)", + "☁️ Software Only — no hardware access (default)", + ]; + + let recommended = hardware::recommended_wizard_default(&devices); + + let choice = Select::new() + .with_prompt(" How should ZeroClaw interact with the physical world?") + .items(&options) + .default(recommended) + .interact()?; + + let mut hw_config = hardware::config_from_wizard_choice(choice, &devices); + + // ── Serial: pick a port if multiple found ── + if hw_config.transport_mode() == hardware::HardwareTransport::Serial { + let serial_devices: Vec<&hardware::DiscoveredDevice> = devices + .iter() + .filter(|d| d.transport == hardware::HardwareTransport::Serial) + .collect(); + + if serial_devices.len() > 1 { + let port_labels: Vec = serial_devices + .iter() + .map(|d| { + format!( + "{} ({})", + d.device_path.as_deref().unwrap_or("unknown"), + d.name + ) + }) + .collect(); + + let port_idx = Select::new() + .with_prompt(" Multiple serial devices found — select one") + .items(&port_labels) + .default(0) + .interact()?; + + hw_config.serial_port = serial_devices[port_idx].device_path.clone(); + } else if serial_devices.is_empty() { + // User chose serial but no device discovered — ask for manual path + let manual_port: String = Input::new() + .with_prompt(" Serial port path (e.g. /dev/ttyUSB0)") + .default("/dev/ttyUSB0".into()) + .interact_text()?; + hw_config.serial_port = Some(manual_port); + } + + // Baud rate + let baud_options = vec![ + "115200 (default, recommended)", + "9600 (legacy Arduino)", + "57600", + "230400", + "Custom", + ]; + let baud_idx = Select::new() + .with_prompt(" Serial baud rate") + .items(&baud_options) + .default(0) + .interact()?; + + hw_config.baud_rate = match baud_idx { + 1 => 9600, + 2 => 57600, + 3 => 230_400, + 4 => { + let custom: String = Input::new() + .with_prompt(" Custom baud rate") + .default("115200".into()) + .interact_text()?; + custom.parse::().unwrap_or(115_200) + } + _ => 115_200, + }; + } + + // ── Probe: ask for target chip ── + if hw_config.transport_mode() == hardware::HardwareTransport::Probe + && hw_config.probe_target.is_none() + { + let target: String = Input::new() + .with_prompt(" Target MCU chip (e.g. STM32F411CEUx, nRF52840_xxAA)") + .default("STM32F411CEUx".into()) + .interact_text()?; + hw_config.probe_target = Some(target); + } + + // ── Datasheet RAG ── + if hw_config.enabled { + let datasheets = Confirm::new() + .with_prompt(" Enable datasheet RAG? (index PDF schematics for AI pin lookups)") + .default(true) + .interact()?; + hw_config.workspace_datasheets = datasheets; + } + + // ── Summary ── + if hw_config.enabled { + let transport_label = match hw_config.transport_mode() { + hardware::HardwareTransport::Native => "Native GPIO".to_string(), + hardware::HardwareTransport::Serial => format!( + "Serial → {} @ {} baud", + hw_config.serial_port.as_deref().unwrap_or("?"), + hw_config.baud_rate + ), + hardware::HardwareTransport::Probe => format!( + "Probe (SWD/JTAG) → {}", + hw_config.probe_target.as_deref().unwrap_or("?") + ), + hardware::HardwareTransport::None => "Software Only".to_string(), + }; + + println!( + " {} Hardware: {} | datasheets: {}", + style("✓").green().bold(), + style(&transport_label).green(), + if hw_config.workspace_datasheets { + style("on").green().to_string() + } else { + style("off").dim().to_string() + } + ); + } else { + println!( + " {} Hardware: {}", + style("✓").green().bold(), + style("disabled (software only)").dim() + ); + } + + Ok(hw_config) +} + // ── Step 6: Project Context ───────────────────────────────────── fn setup_project_context() -> Result { @@ -985,11 +2375,10 @@ fn setup_memory() -> Result { print_bullet("You can always change this later in config.toml."); println!(); - let options = vec![ - "SQLite with Vector Search (recommended) — fast, hybrid search, embeddings", - "Markdown Files — simple, human-readable, no dependencies", - "None — disable persistent memory", - ]; + let options: Vec<&str> = selectable_memory_backends() + .iter() + .map(|backend| backend.label) + .collect(); let choice = Select::new() .with_prompt(" Select memory backend") @@ -997,21 +2386,14 @@ fn setup_memory() -> Result { .default(0) .interact()?; - let backend = match choice { - 1 => "markdown", - 2 => "none", - _ => "sqlite", // 0 and any unexpected value defaults to sqlite - }; + let backend = backend_key_from_choice(choice); + let profile = memory_backend_profile(backend); - let auto_save = if backend == "none" { - false - } else { - let save = Confirm::new() + let auto_save = profile.auto_save_default + && Confirm::new() .with_prompt(" Auto-save conversations to memory?") .default(true) .interact()?; - save - }; println!( " {} Memory: {} (auto-save: {})", @@ -1020,21 +2402,9 @@ fn setup_memory() -> Result { if auto_save { "on" } else { "off" } ); - Ok(MemoryConfig { - backend: backend.to_string(), - auto_save, - hygiene_enabled: backend == "sqlite", // Only enable hygiene for SQLite - archive_after_days: if backend == "sqlite" { 7 } else { 0 }, - purge_after_days: if backend == "sqlite" { 30 } else { 0 }, - conversation_retention_days: 30, - embedding_provider: "none".to_string(), - embedding_model: "text-embedding-3-small".to_string(), - embedding_dimensions: 1536, - vector_weight: 0.7, - keyword_weight: 0.3, - embedding_cache_size: if backend == "sqlite" { 10000 } else { 0 }, - chunk_max_tokens: 512, - }) + let mut config = memory_config_defaults_for_backend(backend); + config.auto_save = auto_save; + Ok(config) } // ── Step 3: Channels ──────────────────────────────────────────── @@ -1050,10 +2420,17 @@ fn setup_channels() -> Result { telegram: None, discord: None, slack: None, + mattermost: None, webhook: None, imessage: None, matrix: None, + signal: None, whatsapp: None, + email: None, + irc: None, + lark: None, + dingtalk: None, + qq: None, }; loop { @@ -1106,6 +2483,14 @@ fn setup_channels() -> Result { "— Business Cloud API" } ), + format!( + "IRC {}", + if config.irc.is_some() { + "✅ configured" + } else { + "— IRC over TLS" + } + ), format!( "Webhook {}", if config.webhook.is_some() { @@ -1114,13 +2499,29 @@ fn setup_channels() -> Result { "— HTTP endpoint" } ), + format!( + "DingTalk {}", + if config.dingtalk.is_some() { + "✅ connected" + } else { + "— DingTalk Stream Mode" + } + ), + format!( + "QQ Official {}", + if config.qq.is_some() { + "✅ connected" + } else { + "— Tencent QQ Bot" + } + ), "Done — finish setup".to_string(), ]; let choice = Select::new() .with_prompt(" Connect a channel (or Done to continue)") .items(&options) - .default(7) + .default(10) .interact()?; match choice { @@ -1146,18 +2547,27 @@ fn setup_channels() -> Result { continue; } - // Test connection + // Test connection (run entirely in separate thread — reqwest::blocking Response + // must be used and dropped there to avoid "Cannot drop a runtime" panic) print!(" {} Testing connection... ", style("⏳").dim()); - let client = reqwest::blocking::Client::new(); - let url = format!("https://api.telegram.org/bot{token}/getMe"); - match client.get(&url).send() { - Ok(resp) if resp.status().is_success() => { - let data: serde_json::Value = resp.json().unwrap_or_default(); - let bot_name = data - .get("result") - .and_then(|r| r.get("username")) - .and_then(serde_json::Value::as_str) - .unwrap_or("unknown"); + let token_clone = token.clone(); + let thread_result = std::thread::spawn(move || { + let client = reqwest::blocking::Client::new(); + let url = format!("https://api.telegram.org/bot{token_clone}/getMe"); + let resp = client.get(&url).send()?; + let ok = resp.status().is_success(); + let data: serde_json::Value = resp.json().unwrap_or_default(); + let bot_name = data + .get("result") + .and_then(|r| r.get("username")) + .and_then(serde_json::Value::as_str) + .unwrap_or("unknown") + .to_string(); + Ok::<_, reqwest::Error>((ok, bot_name)) + }) + .join(); + match thread_result { + Ok(Ok((true, bot_name))) => { println!( "\r {} Connected as @{bot_name} ", style("✅").green().bold() @@ -1230,20 +2640,27 @@ fn setup_channels() -> Result { continue; } - // Test connection + // Test connection (run entirely in separate thread — Response must be used/dropped there) print!(" {} Testing connection... ", style("⏳").dim()); - let client = reqwest::blocking::Client::new(); - match client - .get("https://discord.com/api/v10/users/@me") - .header("Authorization", format!("Bot {token}")) - .send() - { - Ok(resp) if resp.status().is_success() => { - let data: serde_json::Value = resp.json().unwrap_or_default(); - let bot_name = data - .get("username") - .and_then(serde_json::Value::as_str) - .unwrap_or("unknown"); + let token_clone = token.clone(); + let thread_result = std::thread::spawn(move || { + let client = reqwest::blocking::Client::new(); + let resp = client + .get("https://discord.com/api/v10/users/@me") + .header("Authorization", format!("Bot {token_clone}")) + .send()?; + let ok = resp.status().is_success(); + let data: serde_json::Value = resp.json().unwrap_or_default(); + let bot_name = data + .get("username") + .and_then(serde_json::Value::as_str) + .unwrap_or("unknown") + .to_string(); + Ok::<_, reqwest::Error>((ok, bot_name)) + }) + .join(); + match thread_result { + Ok(Ok((true, bot_name))) => { println!( "\r {} Connected as {bot_name} ", style("✅").green().bold() @@ -1297,6 +2714,8 @@ fn setup_channels() -> Result { bot_token: token, guild_id: if guild.is_empty() { None } else { Some(guild) }, allowed_users, + listen_to_bots: false, + mention_only: false, }); } 2 => { @@ -1321,37 +2740,44 @@ fn setup_channels() -> Result { continue; } - // Test connection + // Test connection (run entirely in separate thread — Response must be used/dropped there) print!(" {} Testing connection... ", style("⏳").dim()); - let client = reqwest::blocking::Client::new(); - match client - .get("https://slack.com/api/auth.test") - .bearer_auth(&token) - .send() - { - Ok(resp) if resp.status().is_success() => { - let data: serde_json::Value = resp.json().unwrap_or_default(); - let ok = data - .get("ok") - .and_then(serde_json::Value::as_bool) - .unwrap_or(false); - let team = data - .get("team") - .and_then(serde_json::Value::as_str) - .unwrap_or("unknown"); - if ok { - println!( - "\r {} Connected to workspace: {team} ", - style("✅").green().bold() - ); - } else { - let err = data - .get("error") - .and_then(serde_json::Value::as_str) - .unwrap_or("unknown error"); - println!("\r {} Slack error: {err}", style("❌").red().bold()); - continue; - } + let token_clone = token.clone(); + let thread_result = std::thread::spawn(move || { + let client = reqwest::blocking::Client::new(); + let resp = client + .get("https://slack.com/api/auth.test") + .bearer_auth(&token_clone) + .send()?; + let ok = resp.status().is_success(); + let data: serde_json::Value = resp.json().unwrap_or_default(); + let api_ok = data + .get("ok") + .and_then(serde_json::Value::as_bool) + .unwrap_or(false); + let team = data + .get("team") + .and_then(serde_json::Value::as_str) + .unwrap_or("unknown") + .to_string(); + let err = data + .get("error") + .and_then(serde_json::Value::as_str) + .unwrap_or("unknown error") + .to_string(); + Ok::<_, reqwest::Error>((ok, api_ok, team, err)) + }) + .join(); + match thread_result { + Ok(Ok((true, true, team, _))) => { + println!( + "\r {} Connected to workspace: {team} ", + style("✅").green().bold() + ); + } + Ok(Ok((true, false, _, err))) => { + println!("\r {} Slack error: {err}", style("❌").red().bold()); + continue; } _ => { println!( @@ -1490,26 +2916,26 @@ fn setup_channels() -> Result { continue; } - // Test connection + // Test connection (run entirely in separate thread — Response must be used/dropped there) let hs = homeserver.trim_end_matches('/'); print!(" {} Testing connection... ", style("⏳").dim()); - let client = reqwest::blocking::Client::new(); - match client - .get(format!("{hs}/_matrix/client/v3/account/whoami")) - .header("Authorization", format!("Bearer {access_token}")) - .send() - { - Ok(resp) if resp.status().is_success() => { - let data: serde_json::Value = resp.json().unwrap_or_default(); - let user_id = data - .get("user_id") - .and_then(serde_json::Value::as_str) - .unwrap_or("unknown"); - println!( - "\r {} Connected as {user_id} ", - style("✅").green().bold() - ); - } + let hs_owned = hs.to_string(); + let access_token_clone = access_token.clone(); + let thread_result = std::thread::spawn(move || { + let client = reqwest::blocking::Client::new(); + let resp = client + .get(format!("{hs_owned}/_matrix/client/v3/account/whoami")) + .header("Authorization", format!("Bearer {access_token_clone}")) + .send()?; + let ok = resp.status().is_success(); + Ok::<_, reqwest::Error>(ok) + }) + .join(); + match thread_result { + Ok(Ok(true)) => println!( + "\r {} Connection verified ", + style("✅").green().bold() + ), _ => { println!( "\r {} Connection failed — check homeserver URL and token", @@ -1578,19 +3004,28 @@ fn setup_channels() -> Result { .default("zeroclaw-whatsapp-verify".into()) .interact_text()?; - // Test connection + // Test connection (run entirely in separate thread — Response must be used/dropped there) print!(" {} Testing connection... ", style("⏳").dim()); - let client = reqwest::blocking::Client::new(); - let url = format!( - "https://graph.facebook.com/v18.0/{}", - phone_number_id.trim() - ); - match client - .get(&url) - .header("Authorization", format!("Bearer {}", access_token.trim())) - .send() - { - Ok(resp) if resp.status().is_success() => { + let phone_number_id_clone = phone_number_id.clone(); + let access_token_clone = access_token.clone(); + let thread_result = std::thread::spawn(move || { + let client = reqwest::blocking::Client::new(); + let url = format!( + "https://graph.facebook.com/v18.0/{}", + phone_number_id_clone.trim() + ); + let resp = client + .get(&url) + .header( + "Authorization", + format!("Bearer {}", access_token_clone.trim()), + ) + .send()?; + Ok::<_, reqwest::Error>(resp.status().is_success()) + }) + .join(); + match thread_result { + Ok(Ok(true)) => { println!( "\r {} Connected to WhatsApp API ", style("✅").green().bold() @@ -1622,10 +3057,150 @@ fn setup_channels() -> Result { access_token: access_token.trim().to_string(), phone_number_id: phone_number_id.trim().to_string(), verify_token: verify_token.trim().to_string(), + app_secret: None, // Can be set via ZEROCLAW_WHATSAPP_APP_SECRET env var allowed_numbers, }); } 6 => { + // ── IRC ── + println!(); + println!( + " {} {}", + style("IRC Setup").white().bold(), + style("— IRC over TLS").dim() + ); + print_bullet("IRC connects over TLS to any IRC server"); + print_bullet("Supports SASL PLAIN and NickServ authentication"); + println!(); + + let server: String = Input::new() + .with_prompt(" IRC server (hostname)") + .interact_text()?; + + if server.trim().is_empty() { + println!(" {} Skipped", style("→").dim()); + continue; + } + + let port_str: String = Input::new() + .with_prompt(" Port") + .default("6697".into()) + .interact_text()?; + + let port: u16 = match port_str.trim().parse() { + Ok(p) => p, + Err(_) => { + println!(" {} Invalid port, using 6697", style("→").dim()); + 6697 + } + }; + + let nickname: String = + Input::new().with_prompt(" Bot nickname").interact_text()?; + + if nickname.trim().is_empty() { + println!(" {} Skipped — nickname required", style("→").dim()); + continue; + } + + let channels_str: String = Input::new() + .with_prompt(" Channels to join (comma-separated: #channel1,#channel2)") + .allow_empty(true) + .interact_text()?; + + let channels = if channels_str.trim().is_empty() { + vec![] + } else { + channels_str + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect() + }; + + print_bullet( + "Allowlist nicknames that can interact with the bot (case-insensitive).", + ); + print_bullet("Use '*' to allow anyone (not recommended for production)."); + + let users_str: String = Input::new() + .with_prompt(" Allowed nicknames (comma-separated, or * for all)") + .allow_empty(true) + .interact_text()?; + + let allowed_users = if users_str.trim() == "*" { + vec!["*".into()] + } else { + users_str + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect() + }; + + if allowed_users.is_empty() { + print_bullet( + "⚠️ Empty allowlist — only you can interact. Add nicknames above.", + ); + } + + println!(); + print_bullet("Optional authentication (press Enter to skip each):"); + + let server_password: String = Input::new() + .with_prompt(" Server password (for bouncers like ZNC, leave empty if none)") + .allow_empty(true) + .interact_text()?; + + let nickserv_password: String = Input::new() + .with_prompt(" NickServ password (leave empty if none)") + .allow_empty(true) + .interact_text()?; + + let sasl_password: String = Input::new() + .with_prompt(" SASL PLAIN password (leave empty if none)") + .allow_empty(true) + .interact_text()?; + + let verify_tls: bool = Confirm::new() + .with_prompt(" Verify TLS certificate?") + .default(true) + .interact()?; + + println!( + " {} IRC configured as {}@{}:{}", + style("✅").green().bold(), + style(&nickname).cyan(), + style(&server).cyan(), + style(port).cyan() + ); + + config.irc = Some(IrcConfig { + server: server.trim().to_string(), + port, + nickname: nickname.trim().to_string(), + username: None, + channels, + allowed_users, + server_password: if server_password.trim().is_empty() { + None + } else { + Some(server_password.trim().to_string()) + }, + nickserv_password: if nickserv_password.trim().is_empty() { + None + } else { + Some(nickserv_password.trim().to_string()) + }, + sasl_password: if sasl_password.trim().is_empty() { + None + } else { + Some(sasl_password.trim().to_string()) + }, + verify_tls: Some(verify_tls), + }); + } + 7 => { // ── Webhook ── println!(); println!( @@ -1658,6 +3233,152 @@ fn setup_channels() -> Result { style(&port).cyan() ); } + 8 => { + // ── DingTalk ── + println!(); + println!( + " {} {}", + style("DingTalk Setup").white().bold(), + style("— DingTalk Stream Mode").dim() + ); + print_bullet("1. Go to DingTalk developer console (open.dingtalk.com)"); + print_bullet("2. Create an app and enable the Stream Mode bot"); + print_bullet("3. Copy the Client ID (AppKey) and Client Secret (AppSecret)"); + println!(); + + let client_id: String = Input::new() + .with_prompt(" Client ID (AppKey)") + .interact_text()?; + + if client_id.trim().is_empty() { + println!(" {} Skipped", style("→").dim()); + continue; + } + + let client_secret: String = Input::new() + .with_prompt(" Client Secret (AppSecret)") + .interact_text()?; + + // Test connection + print!(" {} Testing connection... ", style("⏳").dim()); + let client = reqwest::blocking::Client::new(); + let body = serde_json::json!({ + "clientId": client_id, + "clientSecret": client_secret, + }); + match client + .post("https://api.dingtalk.com/v1.0/gateway/connections/open") + .json(&body) + .send() + { + Ok(resp) if resp.status().is_success() => { + println!( + "\r {} DingTalk credentials verified ", + style("✅").green().bold() + ); + } + _ => { + println!( + "\r {} Connection failed — check your credentials", + style("❌").red().bold() + ); + continue; + } + } + + let users_str: String = Input::new() + .with_prompt(" Allowed staff IDs (comma-separated, '*' for all)") + .allow_empty(true) + .interact_text()?; + + let allowed_users: Vec = users_str + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect(); + + config.dingtalk = Some(DingTalkConfig { + client_id, + client_secret, + allowed_users, + }); + } + 9 => { + // ── QQ Official ── + println!(); + println!( + " {} {}", + style("QQ Official Setup").white().bold(), + style("— Tencent QQ Bot SDK").dim() + ); + print_bullet("1. Go to QQ Bot developer console (q.qq.com)"); + print_bullet("2. Create a bot application"); + print_bullet("3. Copy the App ID and App Secret"); + println!(); + + let app_id: String = Input::new().with_prompt(" App ID").interact_text()?; + + if app_id.trim().is_empty() { + println!(" {} Skipped", style("→").dim()); + continue; + } + + let app_secret: String = + Input::new().with_prompt(" App Secret").interact_text()?; + + // Test connection + print!(" {} Testing connection... ", style("⏳").dim()); + let client = reqwest::blocking::Client::new(); + let body = serde_json::json!({ + "appId": app_id, + "clientSecret": app_secret, + }); + match client + .post("https://bots.qq.com/app/getAppAccessToken") + .json(&body) + .send() + { + Ok(resp) if resp.status().is_success() => { + let data: serde_json::Value = resp.json().unwrap_or_default(); + if data.get("access_token").is_some() { + println!( + "\r {} QQ Bot credentials verified ", + style("✅").green().bold() + ); + } else { + println!( + "\r {} Auth error — check your credentials", + style("❌").red().bold() + ); + continue; + } + } + _ => { + println!( + "\r {} Connection failed — check your credentials", + style("❌").red().bold() + ); + continue; + } + } + + let users_str: String = Input::new() + .with_prompt(" Allowed user IDs (comma-separated, '*' for all)") + .allow_empty(true) + .interact_text()?; + + let allowed_users: Vec = users_str + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect(); + + config.qq = Some(QQConfig { + app_id, + app_secret, + allowed_users, + }); + } _ => break, // Done } println!(); @@ -1683,9 +3404,21 @@ fn setup_channels() -> Result { if config.whatsapp.is_some() { active.push("WhatsApp"); } + if config.email.is_some() { + active.push("Email"); + } + if config.irc.is_some() { + active.push("IRC"); + } if config.webhook.is_some() { active.push("Webhook"); } + if config.dingtalk.is_some() { + active.push("DingTalk"); + } + if config.qq.is_some() { + active.push("QQ"); + } println!( " {} Channels: {}", @@ -2135,7 +3868,10 @@ fn print_summary(config: &Config) { || config.channels_config.discord.is_some() || config.channels_config.slack.is_some() || config.channels_config.imessage.is_some() - || config.channels_config.matrix.is_some(); + || config.channels_config.matrix.is_some() + || config.channels_config.email.is_some() + || config.channels_config.dingtalk.is_some() + || config.channels_config.qq.is_some(); println!(); println!( @@ -2197,6 +3933,9 @@ fn print_summary(config: &Config) { if config.channels_config.matrix.is_some() { channels.push("Matrix"); } + if config.channels_config.email.is_some() { + channels.push("Email"); + } if config.channels_config.webhook.is_some() { channels.push("Webhook"); } @@ -2241,15 +3980,7 @@ fn print_summary(config: &Config) { ); // Secrets - println!( - " {} Secrets: {}", - style("🔒").cyan(), - if config.secrets.encrypt { - style("encrypted").green().to_string() - } else { - style("plaintext").yellow().to_string() - } - ); + println!(" {} Secrets: configured", style("🔒").cyan()); // Gateway println!( @@ -2262,6 +3993,40 @@ fn print_summary(config: &Config) { } ); + // Hardware + println!( + " {} Hardware: {}", + style("🔌").cyan(), + if config.hardware.enabled { + let mode = config.hardware.transport_mode(); + match mode { + hardware::HardwareTransport::Native => { + style("Native GPIO (direct)").green().to_string() + } + hardware::HardwareTransport::Serial => format!( + "{}", + style(format!( + "Serial → {} @ {} baud", + config.hardware.serial_port.as_deref().unwrap_or("?"), + config.hardware.baud_rate + )) + .green() + ), + hardware::HardwareTransport::Probe => format!( + "{}", + style(format!( + "Probe → {}", + config.hardware.probe_target.as_deref().unwrap_or("?") + )) + .green() + ), + hardware::HardwareTransport::None => "disabled (software only)".to_string(), + } + } else { + "disabled (software only)".to_string() + } + ); + println!(); println!(" {}", style("Next steps:").white().bold()); println!(); @@ -2331,6 +4096,7 @@ fn print_summary(config: &Config) { #[cfg(test)] mod tests { use super::*; + use serde_json::json; use tempfile::TempDir; // ── ProjectContext defaults ────────────────────────────────── @@ -2746,6 +4512,248 @@ mod tests { assert!(heartbeat.contains("Claw")); } + // ── model helper coverage ─────────────────────────────────── + + #[test] + fn default_model_for_provider_uses_latest_defaults() { + assert_eq!(default_model_for_provider("openai"), "gpt-5.2"); + assert_eq!( + default_model_for_provider("anthropic"), + "claude-sonnet-4-5-20250929" + ); + assert_eq!(default_model_for_provider("qwen"), "qwen-plus"); + assert_eq!(default_model_for_provider("qwen-intl"), "qwen-plus"); + assert_eq!(default_model_for_provider("glm-cn"), "glm-5"); + assert_eq!(default_model_for_provider("minimax-cn"), "MiniMax-M2.5"); + assert_eq!(default_model_for_provider("zai-cn"), "glm-5"); + assert_eq!(default_model_for_provider("gemini"), "gemini-2.5-pro"); + assert_eq!(default_model_for_provider("google"), "gemini-2.5-pro"); + assert_eq!( + default_model_for_provider("google-gemini"), + "gemini-2.5-pro" + ); + } + + #[test] + fn canonical_provider_name_normalizes_regional_aliases() { + assert_eq!(canonical_provider_name("qwen-intl"), "qwen"); + assert_eq!(canonical_provider_name("dashscope-us"), "qwen"); + assert_eq!(canonical_provider_name("moonshot-intl"), "moonshot"); + assert_eq!(canonical_provider_name("kimi-cn"), "moonshot"); + assert_eq!(canonical_provider_name("glm-cn"), "glm"); + assert_eq!(canonical_provider_name("bigmodel"), "glm"); + assert_eq!(canonical_provider_name("minimax-cn"), "minimax"); + assert_eq!(canonical_provider_name("zai-cn"), "zai"); + assert_eq!(canonical_provider_name("z.ai-global"), "zai"); + } + + #[test] + fn curated_models_for_openai_include_latest_choices() { + let ids: Vec = curated_models_for_provider("openai") + .into_iter() + .map(|(id, _)| id) + .collect(); + + assert!(ids.contains(&"gpt-5.2".to_string())); + assert!(ids.contains(&"gpt-5-mini".to_string())); + } + + #[test] + fn curated_models_for_openrouter_use_valid_anthropic_id() { + let ids: Vec = curated_models_for_provider("openrouter") + .into_iter() + .map(|(id, _)| id) + .collect(); + + assert!(ids.contains(&"anthropic/claude-sonnet-4.5".to_string())); + } + + #[test] + fn supports_live_model_fetch_for_supported_and_unsupported_providers() { + assert!(supports_live_model_fetch("openai")); + assert!(supports_live_model_fetch("anthropic")); + assert!(supports_live_model_fetch("gemini")); + assert!(supports_live_model_fetch("google")); + assert!(supports_live_model_fetch("grok")); + assert!(supports_live_model_fetch("together")); + assert!(supports_live_model_fetch("ollama")); + assert!(!supports_live_model_fetch("venice")); + } + + #[test] + fn curated_models_provider_aliases_share_same_catalog() { + assert_eq!( + curated_models_for_provider("xai"), + curated_models_for_provider("grok") + ); + assert_eq!( + curated_models_for_provider("together-ai"), + curated_models_for_provider("together") + ); + assert_eq!( + curated_models_for_provider("gemini"), + curated_models_for_provider("google") + ); + assert_eq!( + curated_models_for_provider("gemini"), + curated_models_for_provider("google-gemini") + ); + assert_eq!( + curated_models_for_provider("qwen"), + curated_models_for_provider("qwen-intl") + ); + assert_eq!( + curated_models_for_provider("qwen"), + curated_models_for_provider("dashscope-us") + ); + assert_eq!( + curated_models_for_provider("minimax"), + curated_models_for_provider("minimax-cn") + ); + assert_eq!( + curated_models_for_provider("zai"), + curated_models_for_provider("zai-cn") + ); + } + + #[test] + fn parse_openai_model_ids_supports_data_array_payload() { + let payload = json!({ + "data": [ + {"id": " gpt-5.1 "}, + {"id": "gpt-5-mini"}, + {"id": "gpt-5.1"}, + {"id": ""} + ] + }); + + let ids = parse_openai_compatible_model_ids(&payload); + assert_eq!(ids, vec!["gpt-5-mini".to_string(), "gpt-5.1".to_string()]); + } + + #[test] + fn parse_openai_model_ids_supports_root_array_payload() { + let payload = json!([ + {"id": "alpha"}, + {"id": "beta"}, + {"id": "alpha"} + ]); + + let ids = parse_openai_compatible_model_ids(&payload); + assert_eq!(ids, vec!["alpha".to_string(), "beta".to_string()]); + } + + #[test] + fn parse_gemini_model_ids_filters_for_generate_content() { + let payload = json!({ + "models": [ + { + "name": "models/gemini-2.5-pro", + "supportedGenerationMethods": ["generateContent", "countTokens"] + }, + { + "name": "models/text-embedding-004", + "supportedGenerationMethods": ["embedContent"] + }, + { + "name": "models/gemini-2.5-flash", + "supportedGenerationMethods": ["generateContent"] + } + ] + }); + + let ids = parse_gemini_model_ids(&payload); + assert_eq!( + ids, + vec!["gemini-2.5-flash".to_string(), "gemini-2.5-pro".to_string()] + ); + } + + #[test] + fn parse_ollama_model_ids_extracts_and_deduplicates_names() { + let payload = json!({ + "models": [ + {"name": "llama3.2:latest"}, + {"name": "mistral:latest"}, + {"name": "llama3.2:latest"} + ] + }); + + let ids = parse_ollama_model_ids(&payload); + assert_eq!( + ids, + vec!["llama3.2:latest".to_string(), "mistral:latest".to_string()] + ); + } + + #[test] + fn model_cache_round_trip_returns_fresh_entry() { + let tmp = TempDir::new().unwrap(); + let models = vec!["gpt-5.1".to_string(), "gpt-5-mini".to_string()]; + + cache_live_models_for_provider(tmp.path(), "openai", &models).unwrap(); + + let cached = + load_cached_models_for_provider(tmp.path(), "openai", MODEL_CACHE_TTL_SECS).unwrap(); + let cached = cached.expect("expected fresh cached models"); + + assert_eq!(cached.models.len(), 2); + assert!(cached.models.contains(&"gpt-5.1".to_string())); + assert!(cached.models.contains(&"gpt-5-mini".to_string())); + } + + #[test] + fn model_cache_ttl_filters_stale_entries() { + let tmp = TempDir::new().unwrap(); + let stale = ModelCacheState { + entries: vec![ModelCacheEntry { + provider: "openai".to_string(), + fetched_at_unix: now_unix_secs().saturating_sub(MODEL_CACHE_TTL_SECS + 120), + models: vec!["gpt-5.1".to_string()], + }], + }; + + save_model_cache_state(tmp.path(), &stale).unwrap(); + + let fresh = + load_cached_models_for_provider(tmp.path(), "openai", MODEL_CACHE_TTL_SECS).unwrap(); + assert!(fresh.is_none()); + + let stale_any = load_any_cached_models_for_provider(tmp.path(), "openai").unwrap(); + assert!(stale_any.is_some()); + } + + #[test] + fn run_models_refresh_uses_fresh_cache_without_network() { + let tmp = TempDir::new().unwrap(); + + cache_live_models_for_provider(tmp.path(), "openai", &["gpt-5.1".to_string()]).unwrap(); + + let config = Config { + workspace_dir: tmp.path().to_path_buf(), + default_provider: Some("openai".to_string()), + ..Config::default() + }; + + run_models_refresh(&config, None, false).unwrap(); + } + + #[test] + fn run_models_refresh_rejects_unsupported_provider() { + let tmp = TempDir::new().unwrap(); + + let config = Config { + workspace_dir: tmp.path().to_path_buf(), + default_provider: Some("venice".to_string()), + ..Config::default() + }; + + let err = run_models_refresh(&config, None, true).unwrap_err(); + assert!(err + .to_string() + .contains("does not support live model discovery")); + } + // ── provider_env_var ──────────────────────────────────────── #[test] @@ -2753,15 +4761,80 @@ mod tests { assert_eq!(provider_env_var("openrouter"), "OPENROUTER_API_KEY"); assert_eq!(provider_env_var("anthropic"), "ANTHROPIC_API_KEY"); assert_eq!(provider_env_var("openai"), "OPENAI_API_KEY"); - assert_eq!(provider_env_var("ollama"), "API_KEY"); // fallback + assert_eq!(provider_env_var("ollama"), "OLLAMA_API_KEY"); assert_eq!(provider_env_var("xai"), "XAI_API_KEY"); assert_eq!(provider_env_var("grok"), "XAI_API_KEY"); // alias - assert_eq!(provider_env_var("together"), "TOGETHER_API_KEY"); - assert_eq!(provider_env_var("together-ai"), "TOGETHER_API_KEY"); // alias + assert_eq!(provider_env_var("together"), "TOGETHER_API_KEY"); // alias + assert_eq!(provider_env_var("together-ai"), "TOGETHER_API_KEY"); + assert_eq!(provider_env_var("google"), "GEMINI_API_KEY"); // alias + assert_eq!(provider_env_var("google-gemini"), "GEMINI_API_KEY"); // alias + assert_eq!(provider_env_var("gemini"), "GEMINI_API_KEY"); + assert_eq!(provider_env_var("qwen"), "DASHSCOPE_API_KEY"); + assert_eq!(provider_env_var("qwen-intl"), "DASHSCOPE_API_KEY"); + assert_eq!(provider_env_var("dashscope-us"), "DASHSCOPE_API_KEY"); + assert_eq!(provider_env_var("glm-cn"), "GLM_API_KEY"); + assert_eq!(provider_env_var("minimax-cn"), "MINIMAX_API_KEY"); + assert_eq!(provider_env_var("moonshot-intl"), "MOONSHOT_API_KEY"); + assert_eq!(provider_env_var("zai-cn"), "ZAI_API_KEY"); + assert_eq!(provider_env_var("nvidia"), "NVIDIA_API_KEY"); + assert_eq!(provider_env_var("nvidia-nim"), "NVIDIA_API_KEY"); // alias + assert_eq!(provider_env_var("build.nvidia.com"), "NVIDIA_API_KEY"); // alias } #[test] fn provider_env_var_unknown_falls_back() { assert_eq!(provider_env_var("some-new-provider"), "API_KEY"); } + + #[test] + fn backend_key_from_choice_maps_supported_backends() { + assert_eq!(backend_key_from_choice(0), "sqlite"); + assert_eq!(backend_key_from_choice(1), "lucid"); + assert_eq!(backend_key_from_choice(2), "markdown"); + assert_eq!(backend_key_from_choice(3), "none"); + assert_eq!(backend_key_from_choice(999), "sqlite"); + } + + #[test] + fn memory_backend_profile_marks_lucid_as_optional_sqlite_backed() { + let lucid = memory_backend_profile("lucid"); + assert!(lucid.auto_save_default); + assert!(lucid.uses_sqlite_hygiene); + assert!(lucid.sqlite_based); + assert!(lucid.optional_dependency); + + let markdown = memory_backend_profile("markdown"); + assert!(markdown.auto_save_default); + assert!(!markdown.uses_sqlite_hygiene); + + let none = memory_backend_profile("none"); + assert!(!none.auto_save_default); + assert!(!none.uses_sqlite_hygiene); + + let custom = memory_backend_profile("custom-memory"); + assert!(custom.auto_save_default); + assert!(!custom.uses_sqlite_hygiene); + } + + #[test] + fn memory_config_defaults_for_lucid_enable_sqlite_hygiene() { + let config = memory_config_defaults_for_backend("lucid"); + assert_eq!(config.backend, "lucid"); + assert!(config.auto_save); + assert!(config.hygiene_enabled); + assert_eq!(config.archive_after_days, 7); + assert_eq!(config.purge_after_days, 30); + assert_eq!(config.embedding_cache_size, 10000); + } + + #[test] + fn memory_config_defaults_for_none_disable_sqlite_hygiene() { + let config = memory_config_defaults_for_backend("none"); + assert_eq!(config.backend, "none"); + assert!(!config.auto_save); + assert!(!config.hygiene_enabled); + assert_eq!(config.archive_after_days, 0); + assert_eq!(config.purge_after_days, 0); + assert_eq!(config.embedding_cache_size, 0); + } } diff --git a/src/peripherals/arduino_flash.rs b/src/peripherals/arduino_flash.rs new file mode 100644 index 0000000..4144273 --- /dev/null +++ b/src/peripherals/arduino_flash.rs @@ -0,0 +1,145 @@ +//! Flash ZeroClaw Arduino firmware via arduino-cli. +//! +//! Ensures arduino-cli is available (installs via brew on macOS if missing), +//! installs the AVR core, compiles and uploads the base firmware. + +use anyhow::{Context, Result}; +use std::process::Command; + +/// ZeroClaw Arduino Uno base firmware (capabilities, gpio_read, gpio_write). +const FIRMWARE_INO: &str = include_str!("../../firmware/zeroclaw-arduino/zeroclaw-arduino.ino"); + +const FQBN: &str = "arduino:avr:uno"; +const SKETCH_NAME: &str = "zeroclaw-arduino"; + +/// Check if arduino-cli is available. +pub fn arduino_cli_available() -> bool { + Command::new("arduino-cli") + .arg("version") + .output() + .map(|o| o.status.success()) + .unwrap_or(false) +} + +/// Try to install arduino-cli. Returns Ok(()) if installed or already present. +pub fn ensure_arduino_cli() -> Result<()> { + if arduino_cli_available() { + return Ok(()); + } + + #[cfg(target_os = "macos")] + { + println!("arduino-cli not found. Installing via Homebrew..."); + let status = Command::new("brew") + .args(["install", "arduino-cli"]) + .status() + .context("Failed to run brew install")?; + if !status.success() { + anyhow::bail!("brew install arduino-cli failed. Install manually: https://arduino.github.io/arduino-cli/"); + } + println!("arduino-cli installed."); + if !arduino_cli_available() { + anyhow::bail!("arduino-cli still not found after install. Ensure it's in PATH."); + } + } + + #[cfg(target_os = "linux")] + { + println!("arduino-cli not found. Run the install script:"); + println!(" curl -fsSL https://raw.githubusercontent.com/arduino/arduino-cli/master/install.sh | sh"); + println!(); + println!("Or install via package manager (e.g. apt install arduino-cli on Debian/Ubuntu)."); + anyhow::bail!("arduino-cli not installed. Install it and try again."); + } + + #[cfg(not(any(target_os = "macos", target_os = "linux")))] + { + println!("arduino-cli not found. Install it: https://arduino.github.io/arduino-cli/"); + anyhow::bail!("arduino-cli not installed."); + } + + #[allow(unreachable_code)] + Ok(()) +} + +/// Ensure arduino:avr core is installed. +fn ensure_avr_core() -> Result<()> { + let out = Command::new("arduino-cli") + .args(["core", "list"]) + .output() + .context("arduino-cli core list failed")?; + let stdout = String::from_utf8_lossy(&out.stdout); + if stdout.contains("arduino:avr") { + return Ok(()); + } + + println!("Installing Arduino AVR core..."); + let status = Command::new("arduino-cli") + .args(["core", "install", "arduino:avr"]) + .status() + .context("arduino-cli core install failed")?; + if !status.success() { + anyhow::bail!("Failed to install arduino:avr core"); + } + println!("AVR core installed."); + Ok(()) +} + +/// Flash ZeroClaw firmware to Arduino at the given port. +pub fn flash_arduino_firmware(port: &str) -> Result<()> { + ensure_arduino_cli()?; + ensure_avr_core()?; + + let temp_dir = std::env::temp_dir().join(format!("zeroclaw_flash_{}", uuid::Uuid::new_v4())); + let sketch_dir = temp_dir.join(SKETCH_NAME); + let ino_path = sketch_dir.join(format!("{}.ino", SKETCH_NAME)); + + std::fs::create_dir_all(&sketch_dir).context("Failed to create sketch dir")?; + std::fs::write(&ino_path, FIRMWARE_INO).context("Failed to write firmware")?; + + let sketch_path = sketch_dir.to_string_lossy(); + + // Compile + println!("Compiling ZeroClaw Arduino firmware..."); + let compile = Command::new("arduino-cli") + .args(["compile", "--fqbn", FQBN, &*sketch_path]) + .output() + .context("arduino-cli compile failed")?; + + if !compile.status.success() { + let stderr = String::from_utf8_lossy(&compile.stderr); + let _ = std::fs::remove_dir_all(&temp_dir); + anyhow::bail!("Compile failed:\n{}", stderr); + } + + // Upload + println!("Uploading to {}...", port); + let upload = Command::new("arduino-cli") + .args(["upload", "-p", port, "--fqbn", FQBN, &*sketch_path]) + .output() + .context("arduino-cli upload failed")?; + + let _ = std::fs::remove_dir_all(&temp_dir); + + if !upload.status.success() { + let stderr = String::from_utf8_lossy(&upload.stderr); + anyhow::bail!("Upload failed:\n{}\n\nEnsure the board is connected and the port is correct (e.g. /dev/cu.usbmodem* on macOS).", stderr); + } + + println!("ZeroClaw firmware flashed successfully."); + println!("The Arduino now supports: capabilities, gpio_read, gpio_write."); + Ok(()) +} + +/// Resolve port from config or path. Returns the path to use for flashing. +pub fn resolve_port(config: &crate::config::Config, path_override: Option<&str>) -> Option { + if let Some(p) = path_override { + return Some(p.to_string()); + } + config + .peripherals + .boards + .iter() + .find(|b| b.board == "arduino-uno" && b.transport == "serial") + .and_then(|b| b.path.clone()) +} diff --git a/src/peripherals/arduino_upload.rs b/src/peripherals/arduino_upload.rs new file mode 100644 index 0000000..e11b19f --- /dev/null +++ b/src/peripherals/arduino_upload.rs @@ -0,0 +1,161 @@ +//! Arduino upload tool — agent generates code, uploads via arduino-cli. +//! +//! When user says "make a heart on the LED grid", the agent generates Arduino +//! sketch code and calls this tool. ZeroClaw compiles and uploads it — no +//! manual IDE or file editing. + +use crate::tools::traits::{Tool, ToolResult}; +use async_trait::async_trait; +use serde_json::{json, Value}; +use std::process::Command; + +/// Tool: upload Arduino sketch (agent-generated code) to the board. +pub struct ArduinoUploadTool { + /// Serial port path (e.g. /dev/cu.usbmodem33000283452) + pub port: String, +} + +impl ArduinoUploadTool { + pub fn new(port: String) -> Self { + Self { port } + } +} + +#[async_trait] +impl Tool for ArduinoUploadTool { + fn name(&self) -> &str { + "arduino_upload" + } + + fn description(&self) -> &str { + "Generate Arduino sketch code and upload it to the connected Arduino. Use when: user asks to 'make a heart', 'blink LED', or run any custom pattern on Arduino. You MUST write the full .ino sketch code (setup + loop). Arduino Uno: pin 13 = built-in LED. Saves to temp dir, runs arduino-cli compile and upload. Requires arduino-cli installed." + } + + fn parameters_schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "Full Arduino sketch code (complete .ino file content)" + } + }, + "required": ["code"] + }) + } + + async fn execute(&self, args: Value) -> anyhow::Result { + let code = args + .get("code") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'code' parameter"))?; + + if code.trim().is_empty() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Code cannot be empty".into()), + }); + } + + // Check arduino-cli exists + if Command::new("arduino-cli").arg("version").output().is_err() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some( + "arduino-cli not found. Install it: https://arduino.github.io/arduino-cli/" + .into(), + ), + }); + } + + let sketch_name = "zeroclaw_sketch"; + let temp_dir = std::env::temp_dir().join(format!("zeroclaw_{}", uuid::Uuid::new_v4())); + let sketch_dir = temp_dir.join(sketch_name); + let ino_path = sketch_dir.join(format!("{}.ino", sketch_name)); + + if let Err(e) = std::fs::create_dir_all(&sketch_dir) { + return Ok(ToolResult { + success: false, + output: format!("Failed to create sketch dir: {}", e), + error: Some(e.to_string()), + }); + } + + if let Err(e) = std::fs::write(&ino_path, code) { + let _ = std::fs::remove_dir_all(&temp_dir); + return Ok(ToolResult { + success: false, + output: format!("Failed to write sketch: {}", e), + error: Some(e.to_string()), + }); + } + + let sketch_path = sketch_dir.to_string_lossy(); + let fqbn = "arduino:avr:uno"; + + // Compile + let compile = Command::new("arduino-cli") + .args(["compile", "--fqbn", fqbn, &sketch_path]) + .output(); + + let compile_output = match compile { + Ok(o) => o, + Err(e) => { + let _ = std::fs::remove_dir_all(&temp_dir); + return Ok(ToolResult { + success: false, + output: format!("arduino-cli compile failed: {}", e), + error: Some(e.to_string()), + }); + } + }; + + if !compile_output.status.success() { + let stderr = String::from_utf8_lossy(&compile_output.stderr); + let _ = std::fs::remove_dir_all(&temp_dir); + return Ok(ToolResult { + success: false, + output: format!("Compile failed:\n{}", stderr), + error: Some("Arduino compile error".into()), + }); + } + + // Upload + let upload = Command::new("arduino-cli") + .args(["upload", "-p", &self.port, "--fqbn", fqbn, &sketch_path]) + .output(); + + let upload_output = match upload { + Ok(o) => o, + Err(e) => { + let _ = std::fs::remove_dir_all(&temp_dir); + return Ok(ToolResult { + success: false, + output: format!("arduino-cli upload failed: {}", e), + error: Some(e.to_string()), + }); + } + }; + + let _ = std::fs::remove_dir_all(&temp_dir); + + if !upload_output.status.success() { + let stderr = String::from_utf8_lossy(&upload_output.stderr); + return Ok(ToolResult { + success: false, + output: format!("Upload failed:\n{}", stderr), + error: Some("Arduino upload error".into()), + }); + } + + Ok(ToolResult { + success: true, + output: + "Sketch compiled and uploaded successfully. The Arduino is now running your code." + .into(), + error: None, + }) + } +} diff --git a/src/peripherals/capabilities_tool.rs b/src/peripherals/capabilities_tool.rs new file mode 100644 index 0000000..c3fca4f --- /dev/null +++ b/src/peripherals/capabilities_tool.rs @@ -0,0 +1,99 @@ +//! Hardware capabilities tool — Phase C: query device for reported GPIO pins. + +use super::serial::SerialTransport; +use crate::tools::traits::{Tool, ToolResult}; +use async_trait::async_trait; +use serde_json::json; +use std::sync::Arc; + +/// Tool: query device capabilities (GPIO pins, LED pin) from firmware. +pub struct HardwareCapabilitiesTool { + /// (board_name, transport) for each serial board. + boards: Vec<(String, Arc)>, +} + +impl HardwareCapabilitiesTool { + pub(crate) fn new(boards: Vec<(String, Arc)>) -> Self { + Self { boards } + } +} + +#[async_trait] +impl Tool for HardwareCapabilitiesTool { + fn name(&self) -> &str { + "hardware_capabilities" + } + + fn description(&self) -> &str { + "Query connected hardware for reported GPIO pins and LED pin. Use when: user asks what pins are available." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "board": { + "type": "string", + "description": "Optional board name. If omitted, queries all." + } + } + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let filter = args.get("board").and_then(|v| v.as_str()); + let mut outputs = Vec::new(); + + for (board_name, transport) in &self.boards { + if let Some(b) = filter { + if b != board_name { + continue; + } + } + match transport.capabilities().await { + Ok(result) => { + let output = if result.success { + if let Ok(parsed) = + serde_json::from_str::(&result.output) + { + format!( + "{}: gpio {:?}, led_pin {:?}", + board_name, + parsed.get("gpio").unwrap_or(&json!([])), + parsed.get("led_pin").unwrap_or(&json!(null)) + ) + } else { + format!("{}: {}", board_name, result.output) + } + } else { + format!( + "{}: {}", + board_name, + result.error.as_deref().unwrap_or("unknown") + ) + }; + outputs.push(output); + } + Err(e) => { + outputs.push(format!("{}: error - {}", board_name, e)); + } + } + } + + let output = if outputs.is_empty() { + if filter.is_some() { + "No matching board or capabilities not supported.".to_string() + } else { + "No serial boards configured or capabilities not supported.".to_string() + } + } else { + outputs.join("\n") + }; + + Ok(ToolResult { + success: !outputs.is_empty(), + output, + error: None, + }) + } +} diff --git a/src/peripherals/mod.rs b/src/peripherals/mod.rs new file mode 100644 index 0000000..f3f8a8a --- /dev/null +++ b/src/peripherals/mod.rs @@ -0,0 +1,233 @@ +//! Hardware peripherals — STM32, RPi GPIO, etc. +//! +//! Peripherals extend the agent with physical capabilities. See +//! `docs/hardware-peripherals-design.md` for the full design. + +pub mod traits; + +#[cfg(feature = "hardware")] +pub mod serial; + +#[cfg(feature = "hardware")] +pub mod arduino_flash; +#[cfg(feature = "hardware")] +pub mod arduino_upload; +#[cfg(feature = "hardware")] +pub mod capabilities_tool; +#[cfg(feature = "hardware")] +pub mod nucleo_flash; +#[cfg(feature = "hardware")] +pub mod uno_q_bridge; +#[cfg(feature = "hardware")] +pub mod uno_q_setup; + +#[cfg(all(feature = "peripheral-rpi", target_os = "linux"))] +pub mod rpi; + +pub use traits::Peripheral; + +use crate::config::{Config, PeripheralBoardConfig, PeripheralsConfig}; +#[cfg(feature = "hardware")] +use crate::tools::HardwareMemoryMapTool; +use crate::tools::Tool; +use anyhow::Result; + +/// List configured boards from config (no connection yet). +pub fn list_configured_boards(config: &PeripheralsConfig) -> Vec<&PeripheralBoardConfig> { + if !config.enabled { + return Vec::new(); + } + config.boards.iter().collect() +} + +/// Handle `zeroclaw peripheral` subcommands. +#[allow(clippy::module_name_repetitions)] +pub fn handle_command(cmd: crate::PeripheralCommands, config: &Config) -> Result<()> { + match cmd { + crate::PeripheralCommands::List => { + let boards = list_configured_boards(&config.peripherals); + if boards.is_empty() { + println!("No peripherals configured."); + println!(); + println!("Add one with: zeroclaw peripheral add "); + println!(" Example: zeroclaw peripheral add nucleo-f401re /dev/ttyACM0"); + println!(); + println!("Or add to config.toml:"); + println!(" [peripherals]"); + println!(" enabled = true"); + println!(); + println!(" [[peripherals.boards]]"); + println!(" board = \"nucleo-f401re\""); + println!(" transport = \"serial\""); + println!(" path = \"/dev/ttyACM0\""); + } else { + println!("Configured peripherals:"); + for b in boards { + let path = b.path.as_deref().unwrap_or("(native)"); + println!(" {} {} {}", b.board, b.transport, path); + } + } + } + crate::PeripheralCommands::Add { board, path } => { + let transport = if path == "native" { "native" } else { "serial" }; + let path_opt = if path == "native" { + None + } else { + Some(path.clone()) + }; + + let mut cfg = crate::config::Config::load_or_init()?; + cfg.peripherals.enabled = true; + + if cfg + .peripherals + .boards + .iter() + .any(|b| b.board == board && b.path.as_deref() == path_opt.as_deref()) + { + println!("Board {} at {:?} already configured.", board, path_opt); + return Ok(()); + } + + cfg.peripherals.boards.push(PeripheralBoardConfig { + board: board.clone(), + transport: transport.to_string(), + path: path_opt, + baud: 115_200, + }); + cfg.save()?; + println!("Added {} at {}. Restart daemon to apply.", board, path); + } + #[cfg(feature = "hardware")] + crate::PeripheralCommands::Flash { port } => { + let port_str = arduino_flash::resolve_port(config, port.as_deref()) + .or_else(|| port.clone()) + .ok_or_else(|| anyhow::anyhow!( + "No port specified. Use --port /dev/cu.usbmodem* or add arduino-uno to config.toml" + ))?; + arduino_flash::flash_arduino_firmware(&port_str)?; + } + #[cfg(not(feature = "hardware"))] + crate::PeripheralCommands::Flash { .. } => { + println!("Arduino flash requires the 'hardware' feature."); + println!("Build with: cargo build --features hardware"); + } + #[cfg(feature = "hardware")] + crate::PeripheralCommands::SetupUnoQ { host } => { + uno_q_setup::setup_uno_q_bridge(host.as_deref())?; + } + #[cfg(not(feature = "hardware"))] + crate::PeripheralCommands::SetupUnoQ { .. } => { + println!("Uno Q setup requires the 'hardware' feature."); + println!("Build with: cargo build --features hardware"); + } + #[cfg(feature = "hardware")] + crate::PeripheralCommands::FlashNucleo => { + nucleo_flash::flash_nucleo_firmware()?; + } + #[cfg(not(feature = "hardware"))] + crate::PeripheralCommands::FlashNucleo => { + println!("Nucleo flash requires the 'hardware' feature."); + println!("Build with: cargo build --features hardware"); + } + } + Ok(()) +} + +/// Create and connect peripherals from config, returning their tools. +/// Returns empty vec if peripherals disabled or hardware feature off. +#[cfg(feature = "hardware")] +pub async fn create_peripheral_tools(config: &PeripheralsConfig) -> Result>> { + if !config.enabled || config.boards.is_empty() { + return Ok(Vec::new()); + } + + let mut tools: Vec> = Vec::new(); + let mut serial_transports: Vec<(String, std::sync::Arc)> = Vec::new(); + + for board in &config.boards { + // Arduino Uno Q: Bridge transport (socket to local Bridge app) + if board.transport == "bridge" && (board.board == "arduino-uno-q" || board.board == "uno-q") + { + tools.push(Box::new(uno_q_bridge::UnoQGpioReadTool)); + tools.push(Box::new(uno_q_bridge::UnoQGpioWriteTool)); + tracing::info!(board = %board.board, "Uno Q Bridge GPIO tools added"); + continue; + } + + // Native transport: RPi GPIO (Linux only) + #[cfg(all(feature = "peripheral-rpi", target_os = "linux"))] + if board.transport == "native" + && (board.board == "rpi-gpio" || board.board == "raspberry-pi") + { + match rpi::RpiGpioPeripheral::connect_from_config(board).await { + Ok(peripheral) => { + tools.extend(peripheral.tools()); + tracing::info!(board = %board.board, "RPi GPIO peripheral connected"); + } + Err(e) => { + tracing::warn!("Failed to connect RPi GPIO {}: {}", board.board, e); + } + } + continue; + } + + // Serial transport (STM32, ESP32, Arduino, etc.) + if board.transport != "serial" { + continue; + } + if board.path.is_none() { + tracing::warn!("Skipping serial board {}: no path", board.board); + continue; + } + + match serial::SerialPeripheral::connect(board).await { + Ok(peripheral) => { + let mut p = peripheral; + if p.connect().await.is_err() { + tracing::warn!("Peripheral {} connect warning (continuing)", p.name()); + } + serial_transports.push((board.board.clone(), p.transport())); + tools.extend(p.tools()); + if board.board == "arduino-uno" { + if let Some(ref path) = board.path { + tools.push(Box::new(arduino_upload::ArduinoUploadTool::new( + path.clone(), + ))); + tracing::info!("Arduino upload tool added (port: {})", path); + } + } + tracing::info!(board = %board.board, "Serial peripheral connected"); + } + Err(e) => { + tracing::warn!("Failed to connect {}: {}", board.board, e); + } + } + } + + // Phase B: Add hardware tools when any boards configured + if !tools.is_empty() { + let board_names: Vec = config.boards.iter().map(|b| b.board.clone()).collect(); + tools.push(Box::new(HardwareMemoryMapTool::new(board_names.clone()))); + tools.push(Box::new(crate::tools::HardwareBoardInfoTool::new( + board_names.clone(), + ))); + tools.push(Box::new(crate::tools::HardwareMemoryReadTool::new( + board_names, + ))); + } + + // Phase C: Add hardware_capabilities tool when any serial boards + if !serial_transports.is_empty() { + tools.push(Box::new(capabilities_tool::HardwareCapabilitiesTool::new( + serial_transports, + ))); + } + + Ok(tools) +} + +#[cfg(not(feature = "hardware"))] +pub async fn create_peripheral_tools(_config: &PeripheralsConfig) -> Result>> { + Ok(Vec::new()) +} diff --git a/src/peripherals/nucleo_flash.rs b/src/peripherals/nucleo_flash.rs new file mode 100644 index 0000000..5558872 --- /dev/null +++ b/src/peripherals/nucleo_flash.rs @@ -0,0 +1,83 @@ +//! Flash ZeroClaw Nucleo-F401RE firmware via probe-rs. +//! +//! Builds the Embassy firmware and flashes via ST-Link (built into Nucleo). +//! Requires: cargo install probe-rs-tools --locked + +use anyhow::{Context, Result}; +use std::path::PathBuf; +use std::process::Command; + +const CHIP: &str = "STM32F401RETx"; +const TARGET: &str = "thumbv7em-none-eabihf"; + +/// Check if probe-rs CLI is available (from probe-rs-tools). +pub fn probe_rs_available() -> bool { + Command::new("probe-rs") + .arg("--version") + .output() + .map(|o| o.status.success()) + .unwrap_or(false) +} + +/// Flash ZeroClaw Nucleo firmware. Builds from firmware/zeroclaw-nucleo. +pub fn flash_nucleo_firmware() -> Result<()> { + if !probe_rs_available() { + anyhow::bail!( + "probe-rs not found. Install it:\n cargo install probe-rs-tools --locked\n\n\ + Or: curl -LsSf https://github.com/probe-rs/probe-rs/releases/latest/download/probe-rs-tools-installer.sh | sh\n\n\ + Connect Nucleo via USB (ST-Link). Then run this command again." + ); + } + + // CARGO_MANIFEST_DIR = repo root (zeroclaw's Cargo.toml) + let repo_root = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let firmware_dir = repo_root.join("firmware").join("zeroclaw-nucleo"); + if !firmware_dir.join("Cargo.toml").exists() { + anyhow::bail!( + "Nucleo firmware not found at {}. Run from zeroclaw repo root.", + firmware_dir.display() + ); + } + + println!("Building ZeroClaw Nucleo firmware..."); + let build = Command::new("cargo") + .args(["build", "--release", "--target", TARGET]) + .current_dir(&firmware_dir) + .output() + .context("cargo build failed")?; + + if !build.status.success() { + let stderr = String::from_utf8_lossy(&build.stderr); + anyhow::bail!("Build failed:\n{}", stderr); + } + + let elf_path = firmware_dir + .join("target") + .join(TARGET) + .join("release") + .join("zeroclaw-nucleo"); + + if !elf_path.exists() { + anyhow::bail!("Built binary not found at {}", elf_path.display()); + } + + println!("Flashing to Nucleo-F401RE (connect via USB)..."); + let flash = Command::new("probe-rs") + .args(["run", "--chip", CHIP, elf_path.to_str().unwrap()]) + .output() + .context("probe-rs run failed")?; + + if !flash.status.success() { + let stderr = String::from_utf8_lossy(&flash.stderr); + anyhow::bail!( + "Flash failed:\n{}\n\n\ + Ensure Nucleo is connected via USB. The ST-Link is built into the board.", + stderr + ); + } + + println!("ZeroClaw Nucleo firmware flashed successfully."); + println!("The Nucleo now supports: ping, capabilities, gpio_read, gpio_write."); + println!("Add to config.toml: board = \"nucleo-f401re\", transport = \"serial\", path = \"/dev/ttyACM0\""); + Ok(()) +} diff --git a/src/peripherals/rpi.rs b/src/peripherals/rpi.rs new file mode 100644 index 0000000..6cea075 --- /dev/null +++ b/src/peripherals/rpi.rs @@ -0,0 +1,173 @@ +//! Raspberry Pi GPIO peripheral — native rppal access. +//! +//! Only compiled when `peripheral-rpi` feature is enabled and target is Linux. +//! Uses BCM pin numbering (e.g. GPIO 17, 27). + +use crate::config::PeripheralBoardConfig; +use crate::peripherals::traits::Peripheral; +use crate::tools::{Tool, ToolResult}; +use async_trait::async_trait; +use serde_json::{json, Value}; + +/// RPi GPIO peripheral — direct access via rppal. +pub struct RpiGpioPeripheral { + board: PeripheralBoardConfig, +} + +impl RpiGpioPeripheral { + /// Create a new RPi GPIO peripheral from config. + pub fn new(board: PeripheralBoardConfig) -> Self { + Self { board } + } + + /// Attempt to connect (init rppal). Returns Ok if GPIO is available. + pub async fn connect_from_config(board: &PeripheralBoardConfig) -> anyhow::Result { + let mut peripheral = Self::new(board.clone()); + peripheral.connect().await?; + Ok(peripheral) + } +} + +#[async_trait] +impl Peripheral for RpiGpioPeripheral { + fn name(&self) -> &str { + &self.board.board + } + + fn board_type(&self) -> &str { + "rpi-gpio" + } + + async fn connect(&mut self) -> anyhow::Result<()> { + // Verify GPIO is accessible by doing a no-op init + let result = tokio::task::spawn_blocking(|| rppal::gpio::Gpio::new()).await??; + drop(result); + Ok(()) + } + + async fn disconnect(&mut self) -> anyhow::Result<()> { + Ok(()) + } + + async fn health_check(&self) -> bool { + tokio::task::spawn_blocking(|| rppal::gpio::Gpio::new().is_ok()) + .await + .unwrap_or(false) + } + + fn tools(&self) -> Vec> { + vec![Box::new(RpiGpioReadTool), Box::new(RpiGpioWriteTool)] + } +} + +/// Tool: read GPIO pin value (BCM numbering). +struct RpiGpioReadTool; + +#[async_trait] +impl Tool for RpiGpioReadTool { + fn name(&self) -> &str { + "gpio_read" + } + + fn description(&self) -> &str { + "Read the value (0 or 1) of a GPIO pin on Raspberry Pi. Uses BCM pin numbers (e.g. 17, 27)." + } + + fn parameters_schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "pin": { + "type": "integer", + "description": "BCM GPIO pin number (e.g. 17, 27)" + } + }, + "required": ["pin"] + }) + } + + async fn execute(&self, args: Value) -> anyhow::Result { + let pin = args + .get("pin") + .and_then(|v| v.as_u64()) + .ok_or_else(|| anyhow::anyhow!("Missing 'pin' parameter"))?; + let pin_u8 = pin as u8; + + let value = tokio::task::spawn_blocking(move || { + let gpio = rppal::gpio::Gpio::new()?; + let pin = gpio.get(pin_u8)?.into_input(); + Ok::<_, anyhow::Error>(match pin.read() { + rppal::gpio::Level::Low => 0, + rppal::gpio::Level::High => 1, + }) + }) + .await??; + + Ok(ToolResult { + success: true, + output: format!("pin {} = {}", pin, value), + error: None, + }) + } +} + +/// Tool: write GPIO pin value (BCM numbering). +struct RpiGpioWriteTool; + +#[async_trait] +impl Tool for RpiGpioWriteTool { + fn name(&self) -> &str { + "gpio_write" + } + + fn description(&self) -> &str { + "Set a GPIO pin high (1) or low (0) on Raspberry Pi. Uses BCM pin numbers." + } + + fn parameters_schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "pin": { + "type": "integer", + "description": "BCM GPIO pin number" + }, + "value": { + "type": "integer", + "description": "0 for low, 1 for high" + } + }, + "required": ["pin", "value"] + }) + } + + async fn execute(&self, args: Value) -> anyhow::Result { + let pin = args + .get("pin") + .and_then(|v| v.as_u64()) + .ok_or_else(|| anyhow::anyhow!("Missing 'pin' parameter"))?; + let value = args + .get("value") + .and_then(|v| v.as_u64()) + .ok_or_else(|| anyhow::anyhow!("Missing 'value' parameter"))?; + let pin_u8 = pin as u8; + let level = match value { + 0 => rppal::gpio::Level::Low, + _ => rppal::gpio::Level::High, + }; + + tokio::task::spawn_blocking(move || { + let gpio = rppal::gpio::Gpio::new()?; + let mut pin = gpio.get(pin_u8)?.into_output(); + pin.write(level); + Ok::<_, anyhow::Error>(()) + }) + .await??; + + Ok(ToolResult { + success: true, + output: format!("pin {} = {}", pin, value), + error: None, + }) + } +} diff --git a/src/peripherals/serial.rs b/src/peripherals/serial.rs new file mode 100644 index 0000000..2bcec56 --- /dev/null +++ b/src/peripherals/serial.rs @@ -0,0 +1,275 @@ +//! Serial peripheral — STM32 and similar boards over USB CDC/serial. +//! +//! Protocol: newline-delimited JSON. +//! Request: {"id":"1","cmd":"gpio_write","args":{"pin":13,"value":1}} +//! Response: {"id":"1","ok":true,"result":"done"} + +use super::traits::Peripheral; +use crate::config::PeripheralBoardConfig; +use crate::tools::traits::{Tool, ToolResult}; +use async_trait::async_trait; +use serde_json::{json, Value}; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::sync::Mutex; +use tokio_serial::{SerialPortBuilderExt, SerialStream}; + +/// Allowed serial path patterns (security: deny arbitrary paths). +const ALLOWED_PATH_PREFIXES: &[&str] = &[ + "/dev/ttyACM", + "/dev/ttyUSB", + "/dev/tty.usbmodem", + "/dev/cu.usbmodem", + "/dev/tty.usbserial", + "/dev/cu.usbserial", // Arduino Uno (FTDI), clones + "COM", // Windows +]; + +fn is_path_allowed(path: &str) -> bool { + ALLOWED_PATH_PREFIXES.iter().any(|p| path.starts_with(p)) +} + +/// JSON request/response over serial. +async fn send_request(port: &mut SerialStream, cmd: &str, args: Value) -> anyhow::Result { + static ID: AtomicU64 = AtomicU64::new(0); + let id = ID.fetch_add(1, Ordering::Relaxed); + let id_str = id.to_string(); + + let req = json!({ + "id": id_str, + "cmd": cmd, + "args": args + }); + let line = format!("{}\n", req); + + port.write_all(line.as_bytes()).await?; + port.flush().await?; + + let mut buf = Vec::new(); + let mut b = [0u8; 1]; + while port.read_exact(&mut b).await.is_ok() { + if b[0] == b'\n' { + break; + } + buf.push(b[0]); + } + let line_str = String::from_utf8_lossy(&buf); + let resp: Value = serde_json::from_str(line_str.trim())?; + let resp_id = resp["id"].as_str().unwrap_or(""); + if resp_id != id_str { + anyhow::bail!("Response id mismatch: expected {}, got {}", id_str, resp_id); + } + Ok(resp) +} + +/// Shared serial transport for tools. Pub(crate) for capabilities tool. +pub(crate) struct SerialTransport { + port: Mutex, +} + +/// Timeout for serial request/response (seconds). +const SERIAL_TIMEOUT_SECS: u64 = 5; + +impl SerialTransport { + async fn request(&self, cmd: &str, args: Value) -> anyhow::Result { + let mut port = self.port.lock().await; + let resp = tokio::time::timeout( + std::time::Duration::from_secs(SERIAL_TIMEOUT_SECS), + send_request(&mut port, cmd, args), + ) + .await + .map_err(|_| { + anyhow::anyhow!("Serial request timed out after {}s", SERIAL_TIMEOUT_SECS) + })??; + + let ok = resp["ok"].as_bool().unwrap_or(false); + let result = resp["result"] + .as_str() + .map(String::from) + .unwrap_or_else(|| resp["result"].to_string()); + let error = resp["error"].as_str().map(String::from); + + Ok(ToolResult { + success: ok, + output: result, + error, + }) + } + + /// Phase C: fetch capabilities from device (gpio pins, led_pin). + pub async fn capabilities(&self) -> anyhow::Result { + self.request("capabilities", json!({})).await + } +} + +/// Serial peripheral for STM32, Arduino, etc. over USB CDC. +pub struct SerialPeripheral { + name: String, + board_type: String, + transport: Arc, +} + +impl SerialPeripheral { + /// Create and connect to a serial peripheral. + #[allow(clippy::unused_async)] + pub async fn connect(config: &PeripheralBoardConfig) -> anyhow::Result { + let path = config + .path + .as_deref() + .ok_or_else(|| anyhow::anyhow!("Serial peripheral requires path"))?; + + if !is_path_allowed(path) { + anyhow::bail!( + "Serial path not allowed: {}. Allowed: /dev/ttyACM*, /dev/ttyUSB*, /dev/tty.usbmodem*, /dev/cu.usbmodem*", + path + ); + } + + let port = tokio_serial::new(path, config.baud) + .open_native_async() + .map_err(|e| anyhow::anyhow!("Failed to open {}: {}", path, e))?; + + let name = format!("{}-{}", config.board, path.replace('/', "_")); + let transport = Arc::new(SerialTransport { + port: Mutex::new(port), + }); + + Ok(Self { + name: name.clone(), + board_type: config.board.clone(), + transport, + }) + } +} + +#[async_trait] +impl Peripheral for SerialPeripheral { + fn name(&self) -> &str { + &self.name + } + + fn board_type(&self) -> &str { + &self.board_type + } + + async fn connect(&mut self) -> anyhow::Result<()> { + Ok(()) + } + + async fn disconnect(&mut self) -> anyhow::Result<()> { + Ok(()) + } + + async fn health_check(&self) -> bool { + self.transport + .request("ping", json!({})) + .await + .map(|r| r.success) + .unwrap_or(false) + } + + fn tools(&self) -> Vec> { + vec![ + Box::new(GpioReadTool { + transport: self.transport.clone(), + }), + Box::new(GpioWriteTool { + transport: self.transport.clone(), + }), + ] + } +} + +impl SerialPeripheral { + /// Expose transport for capabilities tool (Phase C). + pub(crate) fn transport(&self) -> Arc { + self.transport.clone() + } +} + +/// Tool: read GPIO pin value. +struct GpioReadTool { + transport: Arc, +} + +#[async_trait] +impl Tool for GpioReadTool { + fn name(&self) -> &str { + "gpio_read" + } + + fn description(&self) -> &str { + "Read the value (0 or 1) of a GPIO pin on a connected peripheral (e.g. STM32 Nucleo)" + } + + fn parameters_schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "pin": { + "type": "integer", + "description": "GPIO pin number (e.g. 13 for LED on Nucleo)" + } + }, + "required": ["pin"] + }) + } + + async fn execute(&self, args: Value) -> anyhow::Result { + let pin = args + .get("pin") + .and_then(|v| v.as_u64()) + .ok_or_else(|| anyhow::anyhow!("Missing 'pin' parameter"))?; + self.transport + .request("gpio_read", json!({ "pin": pin })) + .await + } +} + +/// Tool: write GPIO pin value. +struct GpioWriteTool { + transport: Arc, +} + +#[async_trait] +impl Tool for GpioWriteTool { + fn name(&self) -> &str { + "gpio_write" + } + + fn description(&self) -> &str { + "Set a GPIO pin high (1) or low (0) on a connected peripheral (e.g. turn on/off LED)" + } + + fn parameters_schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "pin": { + "type": "integer", + "description": "GPIO pin number" + }, + "value": { + "type": "integer", + "description": "0 for low, 1 for high" + } + }, + "required": ["pin", "value"] + }) + } + + async fn execute(&self, args: Value) -> anyhow::Result { + let pin = args + .get("pin") + .and_then(|v| v.as_u64()) + .ok_or_else(|| anyhow::anyhow!("Missing 'pin' parameter"))?; + let value = args + .get("value") + .and_then(|v| v.as_u64()) + .ok_or_else(|| anyhow::anyhow!("Missing 'value' parameter"))?; + self.transport + .request("gpio_write", json!({ "pin": pin, "value": value })) + .await + } +} diff --git a/src/peripherals/traits.rs b/src/peripherals/traits.rs new file mode 100644 index 0000000..6081d1d --- /dev/null +++ b/src/peripherals/traits.rs @@ -0,0 +1,33 @@ +//! Peripheral trait — hardware boards (STM32, RPi GPIO) that expose tools. +//! +//! Peripherals are the agent's "arms and legs": remote devices that run minimal +//! firmware and expose capabilities (GPIO, sensors, actuators) as tools. + +use async_trait::async_trait; + +use crate::tools::Tool; + +/// A hardware peripheral that exposes capabilities as tools. +/// +/// Implement this for boards like Nucleo-F401RE (serial), RPi GPIO (native), etc. +/// When connected, the peripheral's tools are merged into the agent's tool registry. +#[async_trait] +pub trait Peripheral: Send + Sync { + /// Human-readable peripheral name (e.g. "nucleo-f401re-0") + fn name(&self) -> &str; + + /// Board type identifier (e.g. "nucleo-f401re", "rpi-gpio") + fn board_type(&self) -> &str; + + /// Connect to the peripheral (open serial, init GPIO, etc.) + async fn connect(&mut self) -> anyhow::Result<()>; + + /// Disconnect and release resources + async fn disconnect(&mut self) -> anyhow::Result<()>; + + /// Check if the peripheral is reachable and responsive + async fn health_check(&self) -> bool; + + /// Tools this peripheral provides (e.g. gpio_read, gpio_write, sensor_read) + fn tools(&self) -> Vec>; +} diff --git a/src/peripherals/uno_q_bridge.rs b/src/peripherals/uno_q_bridge.rs new file mode 100644 index 0000000..a621831 --- /dev/null +++ b/src/peripherals/uno_q_bridge.rs @@ -0,0 +1,151 @@ +//! Arduino Uno Q Bridge — GPIO via socket to Bridge app. +//! +//! When ZeroClaw runs on Uno Q, the Bridge app (Python + MCU) exposes +//! digitalWrite/digitalRead over a local socket. These tools connect to it. + +use crate::tools::traits::{Tool, ToolResult}; +use async_trait::async_trait; +use serde_json::{json, Value}; +use std::time::Duration; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; + +const BRIDGE_HOST: &str = "127.0.0.1"; +const BRIDGE_PORT: u16 = 9999; + +async fn bridge_request(cmd: &str, args: &[String]) -> anyhow::Result { + let addr = format!("{}:{}", BRIDGE_HOST, BRIDGE_PORT); + let mut stream = tokio::time::timeout(Duration::from_secs(5), TcpStream::connect(&addr)) + .await + .map_err(|_| anyhow::anyhow!("Bridge connection timed out"))??; + + let msg = format!("{} {}\n", cmd, args.join(" ")); + stream.write_all(msg.as_bytes()).await?; + + let mut buf = vec![0u8; 64]; + let n = tokio::time::timeout(Duration::from_secs(3), stream.read(&mut buf)) + .await + .map_err(|_| anyhow::anyhow!("Bridge response timed out"))??; + let resp = String::from_utf8_lossy(&buf[..n]).trim().to_string(); + Ok(resp) +} + +/// Tool: read GPIO pin via Uno Q Bridge. +pub struct UnoQGpioReadTool; + +#[async_trait] +impl Tool for UnoQGpioReadTool { + fn name(&self) -> &str { + "gpio_read" + } + + fn description(&self) -> &str { + "Read GPIO pin value (0 or 1) on Arduino Uno Q. Requires zeroclaw-uno-q-bridge app running." + } + + fn parameters_schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "pin": { + "type": "integer", + "description": "GPIO pin number (e.g. 13 for LED)" + } + }, + "required": ["pin"] + }) + } + + async fn execute(&self, args: Value) -> anyhow::Result { + let pin = args + .get("pin") + .and_then(|v| v.as_u64()) + .ok_or_else(|| anyhow::anyhow!("Missing 'pin' parameter"))?; + match bridge_request("gpio_read", &[pin.to_string()]).await { + Ok(resp) => { + if resp.starts_with("error:") { + Ok(ToolResult { + success: false, + output: resp.clone(), + error: Some(resp), + }) + } else { + Ok(ToolResult { + success: true, + output: resp, + error: None, + }) + } + } + Err(e) => Ok(ToolResult { + success: false, + output: format!("Bridge error: {}", e), + error: Some(e.to_string()), + }), + } + } +} + +/// Tool: write GPIO pin via Uno Q Bridge. +pub struct UnoQGpioWriteTool; + +#[async_trait] +impl Tool for UnoQGpioWriteTool { + fn name(&self) -> &str { + "gpio_write" + } + + fn description(&self) -> &str { + "Set GPIO pin high (1) or low (0) on Arduino Uno Q. Requires zeroclaw-uno-q-bridge app running." + } + + fn parameters_schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "pin": { + "type": "integer", + "description": "GPIO pin number" + }, + "value": { + "type": "integer", + "description": "0 for low, 1 for high" + } + }, + "required": ["pin", "value"] + }) + } + + async fn execute(&self, args: Value) -> anyhow::Result { + let pin = args + .get("pin") + .and_then(|v| v.as_u64()) + .ok_or_else(|| anyhow::anyhow!("Missing 'pin' parameter"))?; + let value = args + .get("value") + .and_then(|v| v.as_u64()) + .ok_or_else(|| anyhow::anyhow!("Missing 'value' parameter"))?; + match bridge_request("gpio_write", &[pin.to_string(), value.to_string()]).await { + Ok(resp) => { + if resp.starts_with("error:") { + Ok(ToolResult { + success: false, + output: resp.clone(), + error: Some(resp), + }) + } else { + Ok(ToolResult { + success: true, + output: "done".into(), + error: None, + }) + } + } + Err(e) => Ok(ToolResult { + success: false, + output: format!("Bridge error: {}", e), + error: Some(e.to_string()), + }), + } + } +} diff --git a/src/peripherals/uno_q_setup.rs b/src/peripherals/uno_q_setup.rs new file mode 100644 index 0000000..424bc89 --- /dev/null +++ b/src/peripherals/uno_q_setup.rs @@ -0,0 +1,143 @@ +//! Deploy ZeroClaw Bridge app to Arduino Uno Q. + +use anyhow::{Context, Result}; +use std::process::Command; + +const BRIDGE_APP_NAME: &str = "zeroclaw-uno-q-bridge"; + +/// Deploy the Bridge app. If host is Some, scp from repo and ssh to start. +/// If host is None, assume we're ON the Uno Q — use embedded files and start. +pub fn setup_uno_q_bridge(host: Option<&str>) -> Result<()> { + let bridge_dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR")) + .join("firmware") + .join("zeroclaw-uno-q-bridge"); + + if let Some(h) = host { + if bridge_dir.exists() { + deploy_remote(h, &bridge_dir)?; + } else { + anyhow::bail!( + "Bridge app not found at {}. Run from zeroclaw repo root.", + bridge_dir.display() + ); + } + } else { + deploy_local(if bridge_dir.exists() { + Some(&bridge_dir) + } else { + None + })?; + } + Ok(()) +} + +fn deploy_remote(host: &str, bridge_dir: &std::path::Path) -> Result<()> { + let ssh_target = if host.contains('@') { + host.to_string() + } else { + format!("arduino@{}", host) + }; + + println!("Copying Bridge app to {}...", host); + let status = Command::new("ssh") + .args([&ssh_target, "mkdir", "-p", "~/ArduinoApps"]) + .status() + .context("ssh mkdir failed")?; + if !status.success() { + anyhow::bail!("Failed to create ArduinoApps dir on Uno Q"); + } + + let status = Command::new("scp") + .args([ + "-r", + bridge_dir.to_str().unwrap(), + &format!("{}:~/ArduinoApps/", ssh_target), + ]) + .status() + .context("scp failed")?; + if !status.success() { + anyhow::bail!("Failed to copy Bridge app"); + } + + println!("Starting Bridge app on Uno Q..."); + let status = Command::new("ssh") + .args([ + &ssh_target, + "arduino-app-cli", + "app", + "start", + "~/ArduinoApps/zeroclaw-uno-q-bridge", + ]) + .status() + .context("arduino-app-cli start failed")?; + if !status.success() { + anyhow::bail!("Failed to start Bridge app. Ensure arduino-app-cli is installed on Uno Q."); + } + + println!("ZeroClaw Bridge app started. Add to config.toml:"); + println!(" [[peripherals.boards]]"); + println!(" board = \"arduino-uno-q\""); + println!(" transport = \"bridge\""); + Ok(()) +} + +fn deploy_local(bridge_dir: Option<&std::path::Path>) -> Result<()> { + let home = std::env::var("HOME").unwrap_or_else(|_| "/home/arduino".into()); + let apps_dir = std::path::Path::new(&home).join("ArduinoApps"); + let dest_dir = apps_dir.join(BRIDGE_APP_NAME); + + std::fs::create_dir_all(&dest_dir).context("create dest dir")?; + + if let Some(src) = bridge_dir { + println!("Copying Bridge app from repo..."); + copy_dir(src, &dest_dir)?; + } else { + println!("Writing embedded Bridge app..."); + write_embedded_bridge(&dest_dir)?; + } + + println!("Starting Bridge app..."); + let status = Command::new("arduino-app-cli") + .args(["app", "start", dest_dir.to_str().unwrap()]) + .status() + .context("arduino-app-cli start failed")?; + if !status.success() { + anyhow::bail!("Failed to start Bridge app. Ensure arduino-app-cli is installed on Uno Q."); + } + + println!("ZeroClaw Bridge app started."); + Ok(()) +} + +fn write_embedded_bridge(dest: &std::path::Path) -> Result<()> { + let app_yaml = include_str!("../../firmware/zeroclaw-uno-q-bridge/app.yaml"); + let sketch_ino = include_str!("../../firmware/zeroclaw-uno-q-bridge/sketch/sketch.ino"); + let sketch_yaml = include_str!("../../firmware/zeroclaw-uno-q-bridge/sketch/sketch.yaml"); + let main_py = include_str!("../../firmware/zeroclaw-uno-q-bridge/python/main.py"); + let requirements = include_str!("../../firmware/zeroclaw-uno-q-bridge/python/requirements.txt"); + + std::fs::write(dest.join("app.yaml"), app_yaml)?; + std::fs::create_dir_all(dest.join("sketch"))?; + std::fs::write(dest.join("sketch").join("sketch.ino"), sketch_ino)?; + std::fs::write(dest.join("sketch").join("sketch.yaml"), sketch_yaml)?; + std::fs::create_dir_all(dest.join("python"))?; + std::fs::write(dest.join("python").join("main.py"), main_py)?; + std::fs::write(dest.join("python").join("requirements.txt"), requirements)?; + Ok(()) +} + +fn copy_dir(src: &std::path::Path, dst: &std::path::Path) -> Result<()> { + for entry in std::fs::read_dir(src)? { + let e = entry?; + let name = e.file_name(); + let src_path = src.join(&name); + let dst_path = dst.join(&name); + if e.file_type()?.is_dir() { + std::fs::create_dir_all(&dst_path)?; + copy_dir(&src_path, &dst_path)?; + } else { + std::fs::copy(&src_path, &dst_path)?; + } + } + Ok(()) +} diff --git a/src/providers/anthropic.rs b/src/providers/anthropic.rs index 9cddba1..1f45c7e 100644 --- a/src/providers/anthropic.rs +++ b/src/providers/anthropic.rs @@ -1,10 +1,15 @@ -use crate::providers::traits::Provider; +use crate::providers::traits::{ + ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse, + Provider, ToolCall as ProviderToolCall, +}; +use crate::tools::ToolSpec; use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; pub struct AnthropicProvider { - api_key: Option, + credential: Option, + base_url: String, client: Client, } @@ -31,13 +36,91 @@ struct ChatResponse { #[derive(Debug, Deserialize)] struct ContentBlock { - text: String, + #[serde(rename = "type")] + kind: String, + #[serde(default)] + text: Option, +} + +#[derive(Debug, Serialize)] +struct NativeChatRequest { + model: String, + max_tokens: u32, + #[serde(skip_serializing_if = "Option::is_none")] + system: Option, + messages: Vec, + temperature: f64, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, +} + +#[derive(Debug, Serialize)] +struct NativeMessage { + role: String, + content: Vec, +} + +#[derive(Debug, Serialize)] +#[serde(tag = "type")] +enum NativeContentOut { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "tool_use")] + ToolUse { + id: String, + name: String, + input: serde_json::Value, + }, + #[serde(rename = "tool_result")] + ToolResult { + tool_use_id: String, + content: String, + }, +} + +#[derive(Debug, Serialize)] +struct NativeToolSpec { + name: String, + description: String, + input_schema: serde_json::Value, +} + +#[derive(Debug, Deserialize)] +struct NativeChatResponse { + #[serde(default)] + content: Vec, +} + +#[derive(Debug, Deserialize)] +struct NativeContentIn { + #[serde(rename = "type")] + kind: String, + #[serde(default)] + text: Option, + #[serde(default)] + id: Option, + #[serde(default)] + name: Option, + #[serde(default)] + input: Option, } impl AnthropicProvider { - pub fn new(api_key: Option<&str>) -> Self { + pub fn new(credential: Option<&str>) -> Self { + Self::with_base_url(credential, None) + } + + pub fn with_base_url(credential: Option<&str>, base_url: Option<&str>) -> Self { + let base_url = base_url + .map(|u| u.trim_end_matches('/')) + .unwrap_or("https://api.anthropic.com") + .to_string(); Self { - api_key: api_key.map(ToString::to_string), + credential: credential + .map(str::trim) + .filter(|k| !k.is_empty()) + .map(ToString::to_string), + base_url, client: Client::builder() .timeout(std::time::Duration::from_secs(120)) .connect_timeout(std::time::Duration::from_secs(10)) @@ -45,6 +128,192 @@ impl AnthropicProvider { .unwrap_or_else(|_| Client::new()), } } + + fn is_setup_token(token: &str) -> bool { + token.starts_with("sk-ant-oat01-") + } + + fn apply_auth( + &self, + request: reqwest::RequestBuilder, + credential: &str, + ) -> reqwest::RequestBuilder { + if Self::is_setup_token(credential) { + request + .header("Authorization", format!("Bearer {credential}")) + .header("anthropic-beta", "oauth-2025-04-20") + } else { + request.header("x-api-key", credential) + } + } + + fn convert_tools(tools: Option<&[ToolSpec]>) -> Option> { + let items = tools?; + if items.is_empty() { + return None; + } + Some( + items + .iter() + .map(|tool| NativeToolSpec { + name: tool.name.clone(), + description: tool.description.clone(), + input_schema: tool.parameters.clone(), + }) + .collect(), + ) + } + + fn parse_assistant_tool_call_message(content: &str) -> Option> { + let value = serde_json::from_str::(content).ok()?; + let tool_calls = value + .get("tool_calls") + .and_then(|v| serde_json::from_value::>(v.clone()).ok())?; + + let mut blocks = Vec::new(); + if let Some(text) = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(str::trim) + .filter(|t| !t.is_empty()) + { + blocks.push(NativeContentOut::Text { + text: text.to_string(), + }); + } + for call in tool_calls { + let input = serde_json::from_str::(&call.arguments) + .unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new())); + blocks.push(NativeContentOut::ToolUse { + id: call.id, + name: call.name, + input, + }); + } + Some(blocks) + } + + fn parse_tool_result_message(content: &str) -> Option { + let value = serde_json::from_str::(content).ok()?; + let tool_use_id = value + .get("tool_call_id") + .and_then(serde_json::Value::as_str)? + .to_string(); + let result = value + .get("content") + .and_then(serde_json::Value::as_str) + .unwrap_or("") + .to_string(); + Some(NativeMessage { + role: "user".to_string(), + content: vec![NativeContentOut::ToolResult { + tool_use_id, + content: result, + }], + }) + } + + fn convert_messages(messages: &[ChatMessage]) -> (Option, Vec) { + let mut system_prompt = None; + let mut native_messages = Vec::new(); + + for msg in messages { + match msg.role.as_str() { + "system" => { + if system_prompt.is_none() { + system_prompt = Some(msg.content.clone()); + } + } + "assistant" => { + if let Some(blocks) = Self::parse_assistant_tool_call_message(&msg.content) { + native_messages.push(NativeMessage { + role: "assistant".to_string(), + content: blocks, + }); + } else { + native_messages.push(NativeMessage { + role: "assistant".to_string(), + content: vec![NativeContentOut::Text { + text: msg.content.clone(), + }], + }); + } + } + "tool" => { + if let Some(tool_result) = Self::parse_tool_result_message(&msg.content) { + native_messages.push(tool_result); + } else { + native_messages.push(NativeMessage { + role: "user".to_string(), + content: vec![NativeContentOut::Text { + text: msg.content.clone(), + }], + }); + } + } + _ => { + native_messages.push(NativeMessage { + role: "user".to_string(), + content: vec![NativeContentOut::Text { + text: msg.content.clone(), + }], + }); + } + } + } + + (system_prompt, native_messages) + } + + fn parse_text_response(response: ChatResponse) -> anyhow::Result { + response + .content + .into_iter() + .find(|c| c.kind == "text") + .and_then(|c| c.text) + .ok_or_else(|| anyhow::anyhow!("No response from Anthropic")) + } + + fn parse_native_response(response: NativeChatResponse) -> ProviderChatResponse { + let mut text_parts = Vec::new(); + let mut tool_calls = Vec::new(); + + for block in response.content { + match block.kind.as_str() { + "text" => { + if let Some(text) = block.text.map(|t| t.trim().to_string()) { + if !text.is_empty() { + text_parts.push(text); + } + } + } + "tool_use" => { + let name = block.name.unwrap_or_default(); + if name.is_empty() { + continue; + } + let arguments = block + .input + .unwrap_or_else(|| serde_json::Value::Object(serde_json::Map::new())); + tool_calls.push(ProviderToolCall { + id: block.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()), + name, + arguments: arguments.to_string(), + }); + } + _ => {} + } + } + + ProviderChatResponse { + text: if text_parts.is_empty() { + None + } else { + Some(text_parts.join("\n")) + }, + tool_calls, + } + } } #[async_trait] @@ -56,8 +325,10 @@ impl Provider for AnthropicProvider { model: &str, temperature: f64, ) -> anyhow::Result { - let api_key = self.api_key.as_ref().ok_or_else(|| { - anyhow::anyhow!("Anthropic API key not set. Set ANTHROPIC_API_KEY or edit config.toml.") + let credential = self.credential.as_ref().ok_or_else(|| { + anyhow::anyhow!( + "Anthropic credentials not set. Set ANTHROPIC_API_KEY or ANTHROPIC_OAUTH_TOKEN (setup-token)." + ) })?; let request = ChatRequest { @@ -71,29 +342,65 @@ impl Provider for AnthropicProvider { temperature, }; - let response = self + let mut request = self .client - .post("https://api.anthropic.com/v1/messages") - .header("x-api-key", api_key) + .post(format!("{}/v1/messages", self.base_url)) .header("anthropic-version", "2023-06-01") .header("content-type", "application/json") - .json(&request) - .send() - .await?; + .json(&request); + + request = self.apply_auth(request, credential); + + let response = request.send().await?; if !response.status().is_success() { - let error = response.text().await?; - anyhow::bail!("Anthropic API error: {error}"); + return Err(super::api_error("Anthropic", response).await); } let chat_response: ChatResponse = response.json().await?; + Self::parse_text_response(chat_response) + } - chat_response - .content - .into_iter() - .next() - .map(|c| c.text) - .ok_or_else(|| anyhow::anyhow!("No response from Anthropic")) + async fn chat( + &self, + request: ProviderChatRequest<'_>, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let credential = self.credential.as_ref().ok_or_else(|| { + anyhow::anyhow!( + "Anthropic credentials not set. Set ANTHROPIC_API_KEY or ANTHROPIC_OAUTH_TOKEN (setup-token)." + ) + })?; + + let (system_prompt, messages) = Self::convert_messages(request.messages); + let native_request = NativeChatRequest { + model: model.to_string(), + max_tokens: 4096, + system: system_prompt, + messages, + temperature, + tools: Self::convert_tools(request.tools), + }; + + let req = self + .client + .post(format!("{}/v1/messages", self.base_url)) + .header("anthropic-version", "2023-06-01") + .header("content-type", "application/json") + .json(&native_request); + + let response = self.apply_auth(req, credential).send().await?; + if !response.status().is_success() { + return Err(super::api_error("Anthropic", response).await); + } + + let native_response: NativeChatResponse = response.json().await?; + Ok(Self::parse_native_response(native_response)) + } + + fn supports_native_tools(&self) -> bool { + true } } @@ -103,22 +410,52 @@ mod tests { #[test] fn creates_with_key() { - let p = AnthropicProvider::new(Some("sk-ant-test123")); - assert!(p.api_key.is_some()); - assert_eq!(p.api_key.as_deref(), Some("sk-ant-test123")); + let p = AnthropicProvider::new(Some("anthropic-test-credential")); + assert!(p.credential.is_some()); + assert_eq!(p.credential.as_deref(), Some("anthropic-test-credential")); + assert_eq!(p.base_url, "https://api.anthropic.com"); } #[test] fn creates_without_key() { let p = AnthropicProvider::new(None); - assert!(p.api_key.is_none()); + assert!(p.credential.is_none()); + assert_eq!(p.base_url, "https://api.anthropic.com"); } #[test] fn creates_with_empty_key() { let p = AnthropicProvider::new(Some("")); - assert!(p.api_key.is_some()); - assert_eq!(p.api_key.as_deref(), Some("")); + assert!(p.credential.is_none()); + } + + #[test] + fn creates_with_whitespace_key() { + let p = AnthropicProvider::new(Some(" anthropic-test-credential ")); + assert!(p.credential.is_some()); + assert_eq!(p.credential.as_deref(), Some("anthropic-test-credential")); + } + + #[test] + fn creates_with_custom_base_url() { + let p = AnthropicProvider::with_base_url( + Some("anthropic-credential"), + Some("https://api.example.com"), + ); + assert_eq!(p.base_url, "https://api.example.com"); + assert_eq!(p.credential.as_deref(), Some("anthropic-credential")); + } + + #[test] + fn custom_base_url_trims_trailing_slash() { + let p = AnthropicProvider::with_base_url(None, Some("https://api.example.com/")); + assert_eq!(p.base_url, "https://api.example.com"); + } + + #[test] + fn default_base_url_when_none_provided() { + let p = AnthropicProvider::with_base_url(None, None); + assert_eq!(p.base_url, "https://api.anthropic.com"); } #[tokio::test] @@ -130,11 +467,67 @@ mod tests { assert!(result.is_err()); let err = result.unwrap_err().to_string(); assert!( - err.contains("API key not set"), + err.contains("credentials not set"), "Expected key error, got: {err}" ); } + #[test] + fn setup_token_detection_works() { + assert!(AnthropicProvider::is_setup_token("sk-ant-oat01-abcdef")); + assert!(!AnthropicProvider::is_setup_token("sk-ant-api-key")); + } + + #[test] + fn apply_auth_uses_bearer_and_beta_for_setup_tokens() { + let provider = AnthropicProvider::new(None); + let request = provider + .apply_auth( + provider.client.get("https://api.anthropic.com/v1/models"), + "sk-ant-oat01-test-token", + ) + .build() + .expect("request should build"); + + assert_eq!( + request + .headers() + .get("authorization") + .and_then(|v| v.to_str().ok()), + Some("Bearer sk-ant-oat01-test-token") + ); + assert_eq!( + request + .headers() + .get("anthropic-beta") + .and_then(|v| v.to_str().ok()), + Some("oauth-2025-04-20") + ); + assert!(request.headers().get("x-api-key").is_none()); + } + + #[test] + fn apply_auth_uses_x_api_key_for_regular_tokens() { + let provider = AnthropicProvider::new(None); + let request = provider + .apply_auth( + provider.client.get("https://api.anthropic.com/v1/models"), + "sk-ant-api-key", + ) + .build() + .expect("request should build"); + + assert_eq!( + request + .headers() + .get("x-api-key") + .and_then(|v| v.to_str().ok()), + Some("sk-ant-api-key") + ); + assert!(request.headers().get("authorization").is_none()); + assert!(request.headers().get("anthropic-beta").is_none()); + } + #[tokio::test] async fn chat_with_system_fails_without_key() { let p = AnthropicProvider::new(None); @@ -186,7 +579,8 @@ mod tests { let json = r#"{"content":[{"type":"text","text":"Hello there!"}]}"#; let resp: ChatResponse = serde_json::from_str(json).unwrap(); assert_eq!(resp.content.len(), 1); - assert_eq!(resp.content[0].text, "Hello there!"); + assert_eq!(resp.content[0].kind, "text"); + assert_eq!(resp.content[0].text.as_deref(), Some("Hello there!")); } #[test] @@ -202,8 +596,8 @@ mod tests { r#"{"content":[{"type":"text","text":"First"},{"type":"text","text":"Second"}]}"#; let resp: ChatResponse = serde_json::from_str(json).unwrap(); assert_eq!(resp.content.len(), 2); - assert_eq!(resp.content[0].text, "First"); - assert_eq!(resp.content[1].text, "Second"); + assert_eq!(resp.content[0].text.as_deref(), Some("First")); + assert_eq!(resp.content[1].text.as_deref(), Some("Second")); } #[test] diff --git a/src/providers/compatible.rs b/src/providers/compatible.rs index 15f7a32..047c335 100644 --- a/src/providers/compatible.rs +++ b/src/providers/compatible.rs @@ -2,8 +2,12 @@ //! Most LLM APIs follow the same `/v1/chat/completions` format. //! This module provides a single implementation that works for all of them. -use crate::providers::traits::Provider; +use crate::providers::traits::{ + ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse, + Provider, StreamChunk, StreamError, StreamOptions, StreamResult, ToolCall as ProviderToolCall, +}; use async_trait::async_trait; +use futures_util::{stream, StreamExt}; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -13,8 +17,11 @@ use serde::{Deserialize, Serialize}; pub struct OpenAiCompatibleProvider { pub(crate) name: String, pub(crate) base_url: String, - pub(crate) api_key: Option, + pub(crate) credential: Option, pub(crate) auth_header: AuthStyle, + /// When false, do not fall back to /v1/responses on chat completions 404. + /// GLM/Zhipu does not support the responses API. + supports_responses_fallback: bool, client: Client, } @@ -30,12 +37,18 @@ pub enum AuthStyle { } impl OpenAiCompatibleProvider { - pub fn new(name: &str, base_url: &str, api_key: Option<&str>, auth_style: AuthStyle) -> Self { + pub fn new( + name: &str, + base_url: &str, + credential: Option<&str>, + auth_style: AuthStyle, + ) -> Self { Self { name: name.to_string(), base_url: base_url.trim_end_matches('/').to_string(), - api_key: api_key.map(ToString::to_string), + credential: credential.map(ToString::to_string), auth_header: auth_style, + supports_responses_fallback: true, client: Client::builder() .timeout(std::time::Duration::from_secs(120)) .connect_timeout(std::time::Duration::from_secs(10)) @@ -43,6 +56,90 @@ impl OpenAiCompatibleProvider { .unwrap_or_else(|_| Client::new()), } } + + /// Same as `new` but skips the /v1/responses fallback on 404. + /// Use for providers (e.g. GLM) that only support chat completions. + pub fn new_no_responses_fallback( + name: &str, + base_url: &str, + credential: Option<&str>, + auth_style: AuthStyle, + ) -> Self { + Self { + name: name.to_string(), + base_url: base_url.trim_end_matches('/').to_string(), + credential: credential.map(ToString::to_string), + auth_header: auth_style, + supports_responses_fallback: false, + client: Client::builder() + .timeout(std::time::Duration::from_secs(120)) + .connect_timeout(std::time::Duration::from_secs(10)) + .build() + .unwrap_or_else(|_| Client::new()), + } + } + + /// Build the full URL for chat completions, detecting if base_url already includes the path. + /// This allows custom providers with non-standard endpoints (e.g., VolcEngine ARK uses + /// `/api/coding/v3/chat/completions` instead of `/v1/chat/completions`). + fn chat_completions_url(&self) -> String { + let has_full_endpoint = reqwest::Url::parse(&self.base_url) + .map(|url| { + url.path() + .trim_end_matches('/') + .ends_with("/chat/completions") + }) + .unwrap_or_else(|_| { + self.base_url + .trim_end_matches('/') + .ends_with("/chat/completions") + }); + + if has_full_endpoint { + self.base_url.clone() + } else { + format!("{}/chat/completions", self.base_url) + } + } + + fn path_ends_with(&self, suffix: &str) -> bool { + if let Ok(url) = reqwest::Url::parse(&self.base_url) { + return url.path().trim_end_matches('/').ends_with(suffix); + } + + self.base_url.trim_end_matches('/').ends_with(suffix) + } + + fn has_explicit_api_path(&self) -> bool { + let Ok(url) = reqwest::Url::parse(&self.base_url) else { + return false; + }; + + let path = url.path().trim_end_matches('/'); + !path.is_empty() && path != "/" + } + + /// Build the full URL for responses API, detecting if base_url already includes the path. + fn responses_url(&self) -> String { + if self.path_ends_with("/responses") { + return self.base_url.clone(); + } + + let normalized_base = self.base_url.trim_end_matches('/'); + + // If chat endpoint is explicitly configured, derive sibling responses endpoint. + if let Some(prefix) = normalized_base.strip_suffix("/chat/completions") { + return format!("{prefix}/responses"); + } + + // If an explicit API path already exists (e.g. /v1, /openai, /api/coding/v3), + // append responses directly to avoid duplicate /v1 segments. + if self.has_explicit_api_path() { + format!("{normalized_base}/responses") + } else { + format!("{normalized_base}/v1/responses") + } + } } #[derive(Debug, Serialize)] @@ -50,6 +147,8 @@ struct ChatRequest { model: String, messages: Vec, temperature: f64, + #[serde(skip_serializing_if = "Option::is_none")] + stream: Option, } #[derive(Debug, Serialize)] @@ -59,7 +158,7 @@ struct Message { } #[derive(Debug, Deserialize)] -struct ChatResponse { +struct ApiChatResponse { choices: Vec, } @@ -68,11 +167,288 @@ struct Choice { message: ResponseMessage, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Serialize)] struct ResponseMessage { + #[serde(default)] + content: Option, + #[serde(default)] + tool_calls: Option>, +} + +#[derive(Debug, Deserialize, Serialize)] +struct ToolCall { + #[serde(rename = "type")] + kind: Option, + function: Option, +} + +#[derive(Debug, Deserialize, Serialize)] +struct Function { + name: Option, + arguments: Option, +} + +#[derive(Debug, Serialize)] +struct ResponsesRequest { + model: String, + input: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + instructions: Option, + #[serde(skip_serializing_if = "Option::is_none")] + stream: Option, +} + +#[derive(Debug, Serialize)] +struct ResponsesInput { + role: String, content: String, } +#[derive(Debug, Deserialize)] +struct ResponsesResponse { + #[serde(default)] + output: Vec, + #[serde(default)] + output_text: Option, +} + +#[derive(Debug, Deserialize)] +struct ResponsesOutput { + #[serde(default)] + content: Vec, +} + +#[derive(Debug, Deserialize)] +struct ResponsesContent { + #[serde(rename = "type")] + kind: Option, + text: Option, +} + +// ═══════════════════════════════════════════════════════════════ +// Streaming support (SSE parser) +// ═══════════════════════════════════════════════════════════════ + +/// Server-Sent Event stream chunk for OpenAI-compatible streaming. +#[derive(Debug, Deserialize)] +struct StreamChunkResponse { + choices: Vec, +} + +#[derive(Debug, Deserialize)] +struct StreamChoice { + delta: StreamDelta, + finish_reason: Option, +} + +#[derive(Debug, Deserialize)] +struct StreamDelta { + #[serde(default)] + content: Option, +} + +/// Parse SSE (Server-Sent Events) stream from OpenAI-compatible providers. +/// Handles the `data: {...}` format and `[DONE]` sentinel. +fn parse_sse_line(line: &str) -> StreamResult> { + let line = line.trim(); + + // Skip empty lines and comments + if line.is_empty() || line.starts_with(':') { + return Ok(None); + } + + // SSE format: "data: {...}" + if let Some(data) = line.strip_prefix("data:") { + let data = data.trim(); + + // Check for [DONE] sentinel + if data == "[DONE]" { + return Ok(None); + } + + // Parse JSON delta + let chunk: StreamChunkResponse = serde_json::from_str(data).map_err(StreamError::Json)?; + + // Extract content from delta + if let Some(choice) = chunk.choices.first() { + if let Some(content) = &choice.delta.content { + return Ok(Some(content.clone())); + } + } + } + + Ok(None) +} + +/// Convert SSE byte stream to text chunks. +fn sse_bytes_to_chunks( + response: reqwest::Response, + count_tokens: bool, +) -> stream::BoxStream<'static, StreamResult> { + // Create a channel to send chunks + let (tx, rx) = tokio::sync::mpsc::channel::>(100); + + tokio::spawn(async move { + // Buffer for incomplete lines + let mut buffer = String::new(); + + // Get response body as bytes stream + match response.error_for_status_ref() { + Ok(_) => {} + Err(e) => { + let _ = tx.send(Err(StreamError::Http(e))).await; + return; + } + } + + let mut bytes_stream = response.bytes_stream(); + + while let Some(item) = bytes_stream.next().await { + match item { + Ok(bytes) => { + // Convert bytes to string and process line by line + let text = match String::from_utf8(bytes.to_vec()) { + Ok(t) => t, + Err(e) => { + let _ = tx + .send(Err(StreamError::InvalidSse(format!( + "Invalid UTF-8: {}", + e + )))) + .await; + break; + } + }; + + buffer.push_str(&text); + + // Process complete lines + while let Some(pos) = buffer.find('\n') { + let line = buffer.drain(..=pos).collect::(); + buffer = buffer[pos + 1..].to_string(); + + match parse_sse_line(&line) { + Ok(Some(content)) => { + let mut chunk = StreamChunk::delta(content); + if count_tokens { + chunk = chunk.with_token_estimate(); + } + if tx.send(Ok(chunk)).await.is_err() { + return; // Receiver dropped + } + } + Ok(None) => {} + Err(e) => { + let _ = tx.send(Err(e)).await; + return; + } + } + } + } + Err(e) => { + let _ = tx.send(Err(StreamError::Http(e))).await; + break; + } + } + } + + // Send final chunk + let _ = tx.send(Ok(StreamChunk::final_chunk())).await; + }); + + // Convert channel receiver to stream + stream::unfold(rx, |mut rx| async { + rx.recv().await.map(|chunk| (chunk, rx)) + }) + .boxed() +} + +fn first_nonempty(text: Option<&str>) -> Option { + text.and_then(|value| { + let trimmed = value.trim(); + if trimmed.is_empty() { + None + } else { + Some(trimmed.to_string()) + } + }) +} + +fn extract_responses_text(response: ResponsesResponse) -> Option { + if let Some(text) = first_nonempty(response.output_text.as_deref()) { + return Some(text); + } + + for item in &response.output { + for content in &item.content { + if content.kind.as_deref() == Some("output_text") { + if let Some(text) = first_nonempty(content.text.as_deref()) { + return Some(text); + } + } + } + } + + for item in &response.output { + for content in &item.content { + if let Some(text) = first_nonempty(content.text.as_deref()) { + return Some(text); + } + } + } + + None +} + +impl OpenAiCompatibleProvider { + fn apply_auth_header( + &self, + req: reqwest::RequestBuilder, + credential: &str, + ) -> reqwest::RequestBuilder { + match &self.auth_header { + AuthStyle::Bearer => req.header("Authorization", format!("Bearer {credential}")), + AuthStyle::XApiKey => req.header("x-api-key", credential), + AuthStyle::Custom(header) => req.header(header, credential), + } + } + + async fn chat_via_responses( + &self, + credential: &str, + system_prompt: Option<&str>, + message: &str, + model: &str, + ) -> anyhow::Result { + let request = ResponsesRequest { + model: model.to_string(), + input: vec![ResponsesInput { + role: "user".to_string(), + content: message.to_string(), + }], + instructions: system_prompt.map(str::to_string), + stream: Some(false), + }; + + let url = self.responses_url(); + + let response = self + .apply_auth_header(self.client.post(&url).json(&request), credential) + .send() + .await?; + + if !response.status().is_success() { + let error = response.text().await?; + anyhow::bail!("{} Responses API error: {error}", self.name); + } + + let responses: ResponsesResponse = response.json().await?; + + extract_responses_text(responses) + .ok_or_else(|| anyhow::anyhow!("No response from {} Responses API", self.name)) + } +} + #[async_trait] impl Provider for OpenAiCompatibleProvider { async fn chat_with_system( @@ -82,7 +458,7 @@ impl Provider for OpenAiCompatibleProvider { model: &str, temperature: f64, ) -> anyhow::Result { - let api_key = self.api_key.as_ref().ok_or_else(|| { + let credential = self.credential.as_ref().ok_or_else(|| { anyhow::anyhow!( "{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.", self.name @@ -107,40 +483,298 @@ impl Provider for OpenAiCompatibleProvider { model: model.to_string(), messages, temperature, + stream: Some(false), }; - let url = format!("{}/v1/chat/completions", self.base_url); + let url = self.chat_completions_url(); - let mut req = self.client.post(&url).json(&request); - - match &self.auth_header { - AuthStyle::Bearer => { - req = req.header("Authorization", format!("Bearer {api_key}")); - } - AuthStyle::XApiKey => { - req = req.header("x-api-key", api_key.as_str()); - } - AuthStyle::Custom(header) => { - req = req.header(header.as_str(), api_key.as_str()); - } - } - - let response = req.send().await?; + let response = self + .apply_auth_header(self.client.post(&url).json(&request), credential) + .send() + .await?; if !response.status().is_success() { + let status = response.status(); let error = response.text().await?; - anyhow::bail!("{} API error: {error}", self.name); + let sanitized = super::sanitize_api_error(&error); + + if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback { + return self + .chat_via_responses(credential, system_prompt, message, model) + .await + .map_err(|responses_err| { + anyhow::anyhow!( + "{} API error ({status}): {sanitized} (chat completions unavailable; responses fallback failed: {responses_err})", + self.name + ) + }); + } + + anyhow::bail!("{} API error ({status}): {sanitized}", self.name); } - let chat_response: ChatResponse = response.json().await?; + let chat_response: ApiChatResponse = response.json().await?; chat_response .choices .into_iter() .next() - .map(|c| c.message.content) + .map(|c| { + // If tool_calls are present, serialize the full message as JSON + // so parse_tool_calls can handle the OpenAI-style format + if c.message.tool_calls.is_some() + && c.message + .tool_calls + .as_ref() + .map_or(false, |t| !t.is_empty()) + { + serde_json::to_string(&c.message) + .unwrap_or_else(|_| c.message.content.unwrap_or_default()) + } else { + // No tool calls, return content as-is + c.message.content.unwrap_or_default() + } + }) .ok_or_else(|| anyhow::anyhow!("No response from {}", self.name)) } + + async fn chat_with_history( + &self, + messages: &[ChatMessage], + model: &str, + temperature: f64, + ) -> anyhow::Result { + let credential = self.credential.as_ref().ok_or_else(|| { + anyhow::anyhow!( + "{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.", + self.name + ) + })?; + + let api_messages: Vec = messages + .iter() + .map(|m| Message { + role: m.role.clone(), + content: m.content.clone(), + }) + .collect(); + + let request = ChatRequest { + model: model.to_string(), + messages: api_messages, + temperature, + stream: Some(false), + }; + + let url = self.chat_completions_url(); + let response = self + .apply_auth_header(self.client.post(&url).json(&request), credential) + .send() + .await?; + + if !response.status().is_success() { + let status = response.status(); + + // Mirror chat_with_system: 404 may mean this provider uses the Responses API + if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback { + // Extract system prompt and last user message for responses fallback + let system = messages.iter().find(|m| m.role == "system"); + let last_user = messages.iter().rfind(|m| m.role == "user"); + if let Some(user_msg) = last_user { + return self + .chat_via_responses( + credential, + system.map(|m| m.content.as_str()), + &user_msg.content, + model, + ) + .await + .map_err(|responses_err| { + anyhow::anyhow!( + "{} API error (chat completions unavailable; responses fallback failed: {responses_err})", + self.name + ) + }); + } + } + + return Err(super::api_error(&self.name, response).await); + } + + let chat_response: ApiChatResponse = response.json().await?; + + chat_response + .choices + .into_iter() + .next() + .map(|c| { + // If tool_calls are present, serialize the full message as JSON + // so parse_tool_calls can handle the OpenAI-style format + if c.message.tool_calls.is_some() + && c.message + .tool_calls + .as_ref() + .map_or(false, |t| !t.is_empty()) + { + serde_json::to_string(&c.message) + .unwrap_or_else(|_| c.message.content.unwrap_or_default()) + } else { + // No tool calls, return content as-is + c.message.content.unwrap_or_default() + } + }) + .ok_or_else(|| anyhow::anyhow!("No response from {}", self.name)) + } + + async fn chat( + &self, + request: ProviderChatRequest<'_>, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let text = self + .chat_with_history(request.messages, model, temperature) + .await?; + + // Backward compatible path: chat_with_history may serialize tool_calls JSON into content. + if let Ok(message) = serde_json::from_str::(&text) { + let tool_calls = message + .tool_calls + .unwrap_or_default() + .into_iter() + .filter_map(|tc| { + let function = tc.function?; + let name = function.name?; + let arguments = function.arguments.unwrap_or_else(|| "{}".to_string()); + Some(ProviderToolCall { + id: uuid::Uuid::new_v4().to_string(), + name, + arguments, + }) + }) + .collect::>(); + + return Ok(ProviderChatResponse { + text: message.content, + tool_calls, + }); + } + + Ok(ProviderChatResponse { + text: Some(text), + tool_calls: vec![], + }) + } + + fn supports_native_tools(&self) -> bool { + true + } + + fn supports_streaming(&self) -> bool { + true + } + + fn stream_chat_with_system( + &self, + system_prompt: Option<&str>, + message: &str, + model: &str, + temperature: f64, + options: StreamOptions, + ) -> stream::BoxStream<'static, StreamResult> { + let credential = match self.credential.as_ref() { + Some(value) => value.clone(), + None => { + let provider_name = self.name.clone(); + return stream::once(async move { + Err(StreamError::Provider(format!( + "{} API key not set", + provider_name + ))) + }) + .boxed(); + } + }; + + let mut messages = Vec::new(); + if let Some(sys) = system_prompt { + messages.push(Message { + role: "system".to_string(), + content: sys.to_string(), + }); + } + messages.push(Message { + role: "user".to_string(), + content: message.to_string(), + }); + + let request = ChatRequest { + model: model.to_string(), + messages, + temperature, + stream: Some(options.enabled), + }; + + let url = self.chat_completions_url(); + let client = self.client.clone(); + let auth_header = self.auth_header.clone(); + + // Use a channel to bridge the async HTTP response to the stream + let (tx, rx) = tokio::sync::mpsc::channel::>(100); + + tokio::spawn(async move { + // Build request with auth + let mut req_builder = client.post(&url).json(&request); + + // Apply auth header + req_builder = match &auth_header { + AuthStyle::Bearer => { + req_builder.header("Authorization", format!("Bearer {}", credential)) + } + AuthStyle::XApiKey => req_builder.header("x-api-key", &credential), + AuthStyle::Custom(header) => req_builder.header(header, &credential), + }; + + // Set accept header for streaming + req_builder = req_builder.header("Accept", "text/event-stream"); + + // Send request + let response = match req_builder.send().await { + Ok(r) => r, + Err(e) => { + let _ = tx.send(Err(StreamError::Http(e))).await; + return; + } + }; + + // Check status + if !response.status().is_success() { + let status = response.status(); + let error = match response.text().await { + Ok(e) => e, + Err(_) => format!("HTTP error: {}", status), + }; + let _ = tx + .send(Err(StreamError::Provider(format!("{}: {}", status, error)))) + .await; + return; + } + + // Convert to chunk stream and forward to channel + let mut chunk_stream = sse_bytes_to_chunks(response, options.count_tokens); + while let Some(chunk) = chunk_stream.next().await { + if tx.send(chunk).await.is_err() { + break; // Receiver dropped + } + } + }); + + // Convert channel receiver to stream + stream::unfold(rx, |mut rx| async move { + rx.recv().await.map(|chunk| (chunk, rx)) + }) + .boxed() + } } #[cfg(test)] @@ -153,16 +787,20 @@ mod tests { #[test] fn creates_with_key() { - let p = make_provider("venice", "https://api.venice.ai", Some("vn-key")); + let p = make_provider( + "venice", + "https://api.venice.ai", + Some("venice-test-credential"), + ); assert_eq!(p.name, "venice"); assert_eq!(p.base_url, "https://api.venice.ai"); - assert_eq!(p.api_key.as_deref(), Some("vn-key")); + assert_eq!(p.credential.as_deref(), Some("venice-test-credential")); } #[test] fn creates_without_key() { let p = make_provider("test", "https://example.com", None); - assert!(p.api_key.is_none()); + assert!(p.credential.is_none()); } #[test] @@ -198,7 +836,8 @@ mod tests { content: "hello".to_string(), }, ], - temperature: 0.7, + temperature: 0.4, + stream: Some(false), }; let json = serde_json::to_string(&req).unwrap(); assert!(json.contains("llama-3.3-70b")); @@ -209,14 +848,17 @@ mod tests { #[test] fn response_deserializes() { let json = r#"{"choices":[{"message":{"content":"Hello from Venice!"}}]}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); - assert_eq!(resp.choices[0].message.content, "Hello from Venice!"); + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); + assert_eq!( + resp.choices[0].message.content, + Some("Hello from Venice!".to_string()) + ); } #[test] fn response_empty_choices() { let json = r#"{"choices":[]}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); assert!(resp.choices.is_empty()); } @@ -248,10 +890,11 @@ mod tests { make_provider("Venice", "https://api.venice.ai", None), make_provider("Moonshot", "https://api.moonshot.cn", None), make_provider("GLM", "https://open.bigmodel.cn", None), - make_provider("MiniMax", "https://api.minimax.chat", None), + make_provider("MiniMax", "https://api.minimaxi.com/v1", None), make_provider("Groq", "https://api.groq.com/openai", None), make_provider("Mistral", "https://api.mistral.ai", None), make_provider("xAI", "https://api.x.ai", None), + make_provider("Astrai", "https://as-trai.com/v1", None), ]; for p in providers { @@ -264,4 +907,226 @@ mod tests { ); } } + + #[test] + fn responses_extracts_top_level_output_text() { + let json = r#"{"output_text":"Hello from top-level","output":[]}"#; + let response: ResponsesResponse = serde_json::from_str(json).unwrap(); + assert_eq!( + extract_responses_text(response).as_deref(), + Some("Hello from top-level") + ); + } + + #[test] + fn responses_extracts_nested_output_text() { + let json = + r#"{"output":[{"content":[{"type":"output_text","text":"Hello from nested"}]}]}"#; + let response: ResponsesResponse = serde_json::from_str(json).unwrap(); + assert_eq!( + extract_responses_text(response).as_deref(), + Some("Hello from nested") + ); + } + + #[test] + fn responses_extracts_any_text_as_fallback() { + let json = r#"{"output":[{"content":[{"type":"message","text":"Fallback text"}]}]}"#; + let response: ResponsesResponse = serde_json::from_str(json).unwrap(); + assert_eq!( + extract_responses_text(response).as_deref(), + Some("Fallback text") + ); + } + + // ══════════════════════════════════════════════════════════ + // Custom endpoint path tests (Issue #114) + // ══════════════════════════════════════════════════════════ + + #[test] + fn chat_completions_url_standard_openai() { + // Standard OpenAI-compatible providers get /chat/completions appended + let p = make_provider("openai", "https://api.openai.com/v1", None); + assert_eq!( + p.chat_completions_url(), + "https://api.openai.com/v1/chat/completions" + ); + } + + #[test] + fn chat_completions_url_trailing_slash() { + // Trailing slash is stripped, then /chat/completions appended + let p = make_provider("test", "https://api.example.com/v1/", None); + assert_eq!( + p.chat_completions_url(), + "https://api.example.com/v1/chat/completions" + ); + } + + #[test] + fn chat_completions_url_volcengine_ark() { + // VolcEngine ARK uses custom path - should use as-is + let p = make_provider( + "volcengine", + "https://ark.cn-beijing.volces.com/api/coding/v3/chat/completions", + None, + ); + assert_eq!( + p.chat_completions_url(), + "https://ark.cn-beijing.volces.com/api/coding/v3/chat/completions" + ); + } + + #[test] + fn chat_completions_url_custom_full_endpoint() { + // Custom provider with full endpoint path + let p = make_provider( + "custom", + "https://my-api.example.com/v2/llm/chat/completions", + None, + ); + assert_eq!( + p.chat_completions_url(), + "https://my-api.example.com/v2/llm/chat/completions" + ); + } + + #[test] + fn chat_completions_url_requires_exact_suffix_match() { + let p = make_provider( + "custom", + "https://my-api.example.com/v2/llm/chat/completions-proxy", + None, + ); + assert_eq!( + p.chat_completions_url(), + "https://my-api.example.com/v2/llm/chat/completions-proxy/chat/completions" + ); + } + + #[test] + fn responses_url_standard() { + // Standard providers get /v1/responses appended + let p = make_provider("test", "https://api.example.com", None); + assert_eq!(p.responses_url(), "https://api.example.com/v1/responses"); + } + + #[test] + fn responses_url_custom_full_endpoint() { + // Custom provider with full responses endpoint + let p = make_provider( + "custom", + "https://my-api.example.com/api/v2/responses", + None, + ); + assert_eq!( + p.responses_url(), + "https://my-api.example.com/api/v2/responses" + ); + } + + #[test] + fn responses_url_requires_exact_suffix_match() { + let p = make_provider( + "custom", + "https://my-api.example.com/api/v2/responses-proxy", + None, + ); + assert_eq!( + p.responses_url(), + "https://my-api.example.com/api/v2/responses-proxy/responses" + ); + } + + #[test] + fn responses_url_derives_from_chat_endpoint() { + let p = make_provider( + "custom", + "https://my-api.example.com/api/v2/chat/completions", + None, + ); + assert_eq!( + p.responses_url(), + "https://my-api.example.com/api/v2/responses" + ); + } + + #[test] + fn responses_url_base_with_v1_no_duplicate() { + let p = make_provider("test", "https://api.example.com/v1", None); + assert_eq!(p.responses_url(), "https://api.example.com/v1/responses"); + } + + #[test] + fn responses_url_non_v1_api_path_uses_raw_suffix() { + let p = make_provider("test", "https://api.example.com/api/coding/v3", None); + assert_eq!( + p.responses_url(), + "https://api.example.com/api/coding/v3/responses" + ); + } + + #[test] + fn chat_completions_url_without_v1() { + // Provider configured without /v1 in base URL + let p = make_provider("test", "https://api.example.com", None); + assert_eq!( + p.chat_completions_url(), + "https://api.example.com/chat/completions" + ); + } + + #[test] + fn chat_completions_url_base_with_v1() { + // Provider configured with /v1 in base URL + let p = make_provider("test", "https://api.example.com/v1", None); + assert_eq!( + p.chat_completions_url(), + "https://api.example.com/v1/chat/completions" + ); + } + + // ══════════════════════════════════════════════════════════ + // Provider-specific endpoint tests (Issue #167) + // ══════════════════════════════════════════════════════════ + + #[test] + fn chat_completions_url_zai() { + // Z.AI uses /api/paas/v4 base path + let p = make_provider("zai", "https://api.z.ai/api/paas/v4", None); + assert_eq!( + p.chat_completions_url(), + "https://api.z.ai/api/paas/v4/chat/completions" + ); + } + + #[test] + fn chat_completions_url_minimax() { + // MiniMax OpenAI-compatible endpoint requires /v1 base path. + let p = make_provider("minimax", "https://api.minimaxi.com/v1", None); + assert_eq!( + p.chat_completions_url(), + "https://api.minimaxi.com/v1/chat/completions" + ); + } + + #[test] + fn chat_completions_url_glm() { + // GLM (BigModel) uses /api/paas/v4 base path + let p = make_provider("glm", "https://open.bigmodel.cn/api/paas/v4", None); + assert_eq!( + p.chat_completions_url(), + "https://open.bigmodel.cn/api/paas/v4/chat/completions" + ); + } + + #[test] + fn chat_completions_url_opencode() { + // OpenCode Zen uses /zen/v1 base path + let p = make_provider("opencode", "https://opencode.ai/zen/v1", None); + assert_eq!( + p.chat_completions_url(), + "https://opencode.ai/zen/v1/chat/completions" + ); + } } diff --git a/src/providers/copilot.rs b/src/providers/copilot.rs new file mode 100644 index 0000000..ab8eb3b --- /dev/null +++ b/src/providers/copilot.rs @@ -0,0 +1,705 @@ +//! GitHub Copilot provider with OAuth device-flow authentication. +//! +//! Authenticates via GitHub's device code flow (same as VS Code Copilot), +//! then exchanges the OAuth token for short-lived Copilot API keys. +//! Tokens are cached to disk and auto-refreshed. +//! +//! **Note:** This uses VS Code's OAuth client ID (`Iv1.b507a08c87ecfe98`) and +//! editor headers. This is the same approach used by LiteLLM, Codex CLI, +//! and other third-party Copilot integrations. The Copilot token endpoint is +//! private; there is no public OAuth scope or app registration for it. +//! GitHub could change or revoke this at any time, which would break all +//! third-party integrations simultaneously. + +use crate::providers::traits::{ + ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse, + Provider, ToolCall as ProviderToolCall, +}; +use crate::tools::ToolSpec; +use async_trait::async_trait; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::Mutex; +use tracing::warn; + +/// GitHub OAuth client ID for Copilot (VS Code extension). +const GITHUB_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98"; +const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code"; +const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token"; +const GITHUB_API_KEY_URL: &str = "https://api.github.com/copilot_internal/v2/token"; +const DEFAULT_API: &str = "https://api.githubcopilot.com"; + +// ── Token types ────────────────────────────────────────────────── + +#[derive(Debug, Deserialize)] +struct DeviceCodeResponse { + device_code: String, + user_code: String, + verification_uri: String, + #[serde(default = "default_interval")] + interval: u64, + #[serde(default = "default_expires_in")] + expires_in: u64, +} + +fn default_interval() -> u64 { + 5 +} + +fn default_expires_in() -> u64 { + 900 +} + +#[derive(Debug, Deserialize)] +struct AccessTokenResponse { + access_token: Option, + error: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +struct ApiKeyInfo { + token: String, + expires_at: i64, + #[serde(default)] + endpoints: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +struct ApiEndpoints { + api: Option, +} + +struct CachedApiKey { + token: String, + api_endpoint: String, + expires_at: i64, +} + +// ── Chat completions types ─────────────────────────────────────── + +#[derive(Debug, Serialize)] +struct ApiChatRequest { + model: String, + messages: Vec, + temperature: f64, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + tool_choice: Option, +} + +#[derive(Debug, Serialize)] +struct ApiMessage { + role: String, + #[serde(skip_serializing_if = "Option::is_none")] + content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_call_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_calls: Option>, +} + +#[derive(Debug, Serialize)] +struct NativeToolSpec { + #[serde(rename = "type")] + kind: String, + function: NativeToolFunctionSpec, +} + +#[derive(Debug, Serialize)] +struct NativeToolFunctionSpec { + name: String, + description: String, + parameters: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize)] +struct NativeToolCall { + #[serde(skip_serializing_if = "Option::is_none")] + id: Option, + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + kind: Option, + function: NativeFunctionCall, +} + +#[derive(Debug, Serialize, Deserialize)] +struct NativeFunctionCall { + name: String, + arguments: String, +} + +#[derive(Debug, Deserialize)] +struct ApiChatResponse { + choices: Vec, +} + +#[derive(Debug, Deserialize)] +struct Choice { + message: ResponseMessage, +} + +#[derive(Debug, Deserialize)] +struct ResponseMessage { + #[serde(default)] + content: Option, + #[serde(default)] + tool_calls: Option>, +} + +// ── Provider ───────────────────────────────────────────────────── + +/// GitHub Copilot provider with automatic OAuth and token refresh. +/// +/// On first use, prompts the user to visit github.com/login/device. +/// Tokens are cached to `~/.config/zeroclaw/copilot/` and refreshed +/// automatically. +pub struct CopilotProvider { + github_token: Option, + /// Mutex ensures only one caller refreshes tokens at a time, + /// preventing duplicate device flow prompts or redundant API calls. + refresh_lock: Arc>>, + http: Client, + token_dir: PathBuf, +} + +impl CopilotProvider { + pub fn new(github_token: Option<&str>) -> Self { + let token_dir = directories::ProjectDirs::from("", "", "zeroclaw") + .map(|dir| dir.config_dir().join("copilot")) + .unwrap_or_else(|| { + // Fall back to a user-specific temp directory to avoid + // shared-directory symlink attacks. + let user = std::env::var("USER") + .or_else(|_| std::env::var("USERNAME")) + .unwrap_or_else(|_| "unknown".to_string()); + std::env::temp_dir().join(format!("zeroclaw-copilot-{user}")) + }); + + if let Err(err) = std::fs::create_dir_all(&token_dir) { + warn!( + "Failed to create Copilot token directory {:?}: {err}. Token caching is disabled.", + token_dir + ); + } else { + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + + if let Err(err) = + std::fs::set_permissions(&token_dir, std::fs::Permissions::from_mode(0o700)) + { + warn!( + "Failed to set Copilot token directory permissions on {:?}: {err}", + token_dir + ); + } + } + } + + Self { + github_token: github_token + .filter(|token| !token.is_empty()) + .map(String::from), + refresh_lock: Arc::new(Mutex::new(None)), + http: Client::builder() + .timeout(Duration::from_secs(120)) + .connect_timeout(Duration::from_secs(10)) + .build() + .unwrap_or_else(|_| Client::new()), + token_dir, + } + } + + /// Required headers for Copilot API requests (editor identification). + const COPILOT_HEADERS: [(&str, &str); 4] = [ + ("Editor-Version", "vscode/1.85.1"), + ("Editor-Plugin-Version", "copilot/1.155.0"), + ("User-Agent", "GithubCopilot/1.155.0"), + ("Accept", "application/json"), + ]; + + fn convert_tools(tools: Option<&[ToolSpec]>) -> Option> { + tools.map(|items| { + items + .iter() + .map(|tool| NativeToolSpec { + kind: "function".to_string(), + function: NativeToolFunctionSpec { + name: tool.name.clone(), + description: tool.description.clone(), + parameters: tool.parameters.clone(), + }, + }) + .collect() + }) + } + + fn convert_messages(messages: &[ChatMessage]) -> Vec { + messages + .iter() + .map(|message| { + if message.role == "assistant" { + if let Ok(value) = serde_json::from_str::(&message.content) { + if let Some(tool_calls_value) = value.get("tool_calls") { + if let Ok(parsed_calls) = + serde_json::from_value::>(tool_calls_value.clone()) + { + let tool_calls = parsed_calls + .into_iter() + .map(|tool_call| NativeToolCall { + id: Some(tool_call.id), + kind: Some("function".to_string()), + function: NativeFunctionCall { + name: tool_call.name, + arguments: tool_call.arguments, + }, + }) + .collect::>(); + + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + + return ApiMessage { + role: "assistant".to_string(), + content, + tool_call_id: None, + tool_calls: Some(tool_calls), + }; + } + } + } + } + + if message.role == "tool" { + if let Ok(value) = serde_json::from_str::(&message.content) { + let tool_call_id = value + .get("tool_call_id") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + + return ApiMessage { + role: "tool".to_string(), + content, + tool_call_id, + tool_calls: None, + }; + } + } + + ApiMessage { + role: message.role.clone(), + content: Some(message.content.clone()), + tool_call_id: None, + tool_calls: None, + } + }) + .collect() + } + + /// Send a chat completions request with required Copilot headers. + async fn send_chat_request( + &self, + messages: Vec, + tools: Option<&[ToolSpec]>, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let (token, endpoint) = self.get_api_key().await?; + let url = format!("{}/chat/completions", endpoint.trim_end_matches('/')); + + let native_tools = Self::convert_tools(tools); + let request = ApiChatRequest { + model: model.to_string(), + messages, + temperature, + tool_choice: native_tools.as_ref().map(|_| "auto".to_string()), + tools: native_tools, + }; + + let mut req = self + .http + .post(&url) + .header("Authorization", format!("Bearer {token}")) + .json(&request); + + for (header, value) in &Self::COPILOT_HEADERS { + req = req.header(*header, *value); + } + + let response = req.send().await?; + + if !response.status().is_success() { + return Err(super::api_error("GitHub Copilot", response).await); + } + + let api_response: ApiChatResponse = response.json().await?; + let choice = api_response + .choices + .into_iter() + .next() + .ok_or_else(|| anyhow::anyhow!("No response from GitHub Copilot"))?; + + let tool_calls = choice + .message + .tool_calls + .unwrap_or_default() + .into_iter() + .map(|tool_call| ProviderToolCall { + id: tool_call + .id + .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()), + name: tool_call.function.name, + arguments: tool_call.function.arguments, + }) + .collect(); + + Ok(ProviderChatResponse { + text: choice.message.content, + tool_calls, + }) + } + + /// Get a valid Copilot API key, refreshing or re-authenticating as needed. + /// Uses a Mutex to ensure only one caller refreshes at a time. + async fn get_api_key(&self) -> anyhow::Result<(String, String)> { + let mut cached = self.refresh_lock.lock().await; + + if let Some(cached_key) = cached.as_ref() { + if chrono::Utc::now().timestamp() + 120 < cached_key.expires_at { + return Ok((cached_key.token.clone(), cached_key.api_endpoint.clone())); + } + } + + if let Some(info) = self.load_api_key_from_disk().await { + if chrono::Utc::now().timestamp() + 120 < info.expires_at { + let endpoint = info + .endpoints + .as_ref() + .and_then(|e| e.api.clone()) + .unwrap_or_else(|| DEFAULT_API.to_string()); + let token = info.token; + + *cached = Some(CachedApiKey { + token: token.clone(), + api_endpoint: endpoint.clone(), + expires_at: info.expires_at, + }); + return Ok((token, endpoint)); + } + } + + let access_token = self.get_github_access_token().await?; + let api_key_info = self.exchange_for_api_key(&access_token).await?; + self.save_api_key_to_disk(&api_key_info).await; + + let endpoint = api_key_info + .endpoints + .as_ref() + .and_then(|e| e.api.clone()) + .unwrap_or_else(|| DEFAULT_API.to_string()); + + *cached = Some(CachedApiKey { + token: api_key_info.token.clone(), + api_endpoint: endpoint.clone(), + expires_at: api_key_info.expires_at, + }); + + Ok((api_key_info.token, endpoint)) + } + + /// Get a GitHub access token from config, cache, or device flow. + async fn get_github_access_token(&self) -> anyhow::Result { + if let Some(token) = &self.github_token { + return Ok(token.clone()); + } + + let access_token_path = self.token_dir.join("access-token"); + if let Ok(cached) = tokio::fs::read_to_string(&access_token_path).await { + let token = cached.trim(); + if !token.is_empty() { + return Ok(token.to_string()); + } + } + + let token = self.device_code_login().await?; + write_file_secure(&access_token_path, &token).await; + Ok(token) + } + + /// Run GitHub OAuth device code flow. + async fn device_code_login(&self) -> anyhow::Result { + let response: DeviceCodeResponse = self + .http + .post(GITHUB_DEVICE_CODE_URL) + .header("Accept", "application/json") + .json(&serde_json::json!({ + "client_id": GITHUB_CLIENT_ID, + "scope": "read:user" + })) + .send() + .await? + .error_for_status()? + .json() + .await?; + + let mut poll_interval = Duration::from_secs(response.interval.max(5)); + let expires_in = response.expires_in.max(1); + let expires_at = tokio::time::Instant::now() + Duration::from_secs(expires_in); + + eprintln!( + "\nGitHub Copilot authentication is required.\n\ + Visit: {}\n\ + Code: {}\n\ + Waiting for authorization...\n", + response.verification_uri, response.user_code + ); + + while tokio::time::Instant::now() < expires_at { + tokio::time::sleep(poll_interval).await; + + let token_response: AccessTokenResponse = self + .http + .post(GITHUB_ACCESS_TOKEN_URL) + .header("Accept", "application/json") + .json(&serde_json::json!({ + "client_id": GITHUB_CLIENT_ID, + "device_code": response.device_code, + "grant_type": "urn:ietf:params:oauth:grant-type:device_code" + })) + .send() + .await? + .json() + .await?; + + if let Some(token) = token_response.access_token { + eprintln!("Authentication succeeded.\n"); + return Ok(token); + } + + match token_response.error.as_deref() { + Some("slow_down") => { + poll_interval += Duration::from_secs(5); + } + Some("authorization_pending") | None => {} + Some("expired_token") => { + anyhow::bail!("GitHub device authorization expired") + } + Some(error) => anyhow::bail!("GitHub auth failed: {error}"), + } + } + + anyhow::bail!("Timed out waiting for GitHub authorization") + } + + /// Exchange a GitHub access token for a Copilot API key. + async fn exchange_for_api_key(&self, access_token: &str) -> anyhow::Result { + let mut request = self.http.get(GITHUB_API_KEY_URL); + for (header, value) in &Self::COPILOT_HEADERS { + request = request.header(*header, *value); + } + request = request.header("Authorization", format!("token {access_token}")); + + let response = request.send().await?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + let sanitized = super::sanitize_api_error(&body); + + if status.as_u16() == 401 || status.as_u16() == 403 { + let access_token_path = self.token_dir.join("access-token"); + tokio::fs::remove_file(&access_token_path).await.ok(); + } + + anyhow::bail!( + "Failed to get Copilot API key ({status}): {sanitized}. \ + Ensure your GitHub account has an active Copilot subscription." + ); + } + + let info: ApiKeyInfo = response.json().await?; + Ok(info) + } + + async fn load_api_key_from_disk(&self) -> Option { + let path = self.token_dir.join("api-key.json"); + let data = tokio::fs::read_to_string(&path).await.ok()?; + serde_json::from_str(&data).ok() + } + + async fn save_api_key_to_disk(&self, info: &ApiKeyInfo) { + let path = self.token_dir.join("api-key.json"); + if let Ok(json) = serde_json::to_string_pretty(info) { + write_file_secure(&path, &json).await; + } + } +} + +/// Write a file with 0600 permissions (owner read/write only). +/// Uses `spawn_blocking` to avoid blocking the async runtime. +async fn write_file_secure(path: &Path, content: &str) { + let path = path.to_path_buf(); + let content = content.to_string(); + + let result = tokio::task::spawn_blocking(move || { + #[cfg(unix)] + { + use std::io::Write; + use std::os::unix::fs::{OpenOptionsExt, PermissionsExt}; + + let mut file = std::fs::OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .mode(0o600) + .open(&path)?; + file.write_all(content.as_bytes())?; + + std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o600))?; + Ok::<(), std::io::Error>(()) + } + #[cfg(not(unix))] + { + std::fs::write(&path, &content)?; + Ok::<(), std::io::Error>(()) + } + }) + .await; + + match result { + Ok(Ok(())) => {} + Ok(Err(err)) => warn!("Failed to write secure file: {err}"), + Err(err) => warn!("Failed to spawn blocking write: {err}"), + } +} + +#[async_trait] +impl Provider for CopilotProvider { + async fn chat_with_system( + &self, + system_prompt: Option<&str>, + message: &str, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let mut messages = Vec::new(); + if let Some(system) = system_prompt { + messages.push(ApiMessage { + role: "system".to_string(), + content: Some(system.to_string()), + tool_call_id: None, + tool_calls: None, + }); + } + messages.push(ApiMessage { + role: "user".to_string(), + content: Some(message.to_string()), + tool_call_id: None, + tool_calls: None, + }); + + let response = self + .send_chat_request(messages, None, model, temperature) + .await?; + Ok(response.text.unwrap_or_default()) + } + + async fn chat_with_history( + &self, + messages: &[ChatMessage], + model: &str, + temperature: f64, + ) -> anyhow::Result { + let response = self + .send_chat_request(Self::convert_messages(messages), None, model, temperature) + .await?; + Ok(response.text.unwrap_or_default()) + } + + async fn chat( + &self, + request: ProviderChatRequest<'_>, + model: &str, + temperature: f64, + ) -> anyhow::Result { + self.send_chat_request( + Self::convert_messages(request.messages), + request.tools, + model, + temperature, + ) + .await + } + + fn supports_native_tools(&self) -> bool { + true + } + + async fn warmup(&self) -> anyhow::Result<()> { + let _ = self.get_api_key().await?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn new_without_token() { + let provider = CopilotProvider::new(None); + assert!(provider.github_token.is_none()); + } + + #[test] + fn new_with_token() { + let provider = CopilotProvider::new(Some("ghp_test")); + assert_eq!(provider.github_token.as_deref(), Some("ghp_test")); + } + + #[test] + fn empty_token_treated_as_none() { + let provider = CopilotProvider::new(Some("")); + assert!(provider.github_token.is_none()); + } + + #[tokio::test] + async fn cache_starts_empty() { + let provider = CopilotProvider::new(None); + let cached = provider.refresh_lock.lock().await; + assert!(cached.is_none()); + } + + #[test] + fn copilot_headers_include_required_fields() { + let headers = CopilotProvider::COPILOT_HEADERS; + assert!(headers + .iter() + .any(|(header, _)| *header == "Editor-Version")); + assert!(headers + .iter() + .any(|(header, _)| *header == "Editor-Plugin-Version")); + assert!(headers.iter().any(|(header, _)| *header == "User-Agent")); + } + + #[test] + fn default_interval_and_expiry() { + assert_eq!(default_interval(), 5); + assert_eq!(default_expires_in(), 900); + } + + #[test] + fn supports_native_tools() { + let provider = CopilotProvider::new(None); + assert!(provider.supports_native_tools()); + } +} diff --git a/src/providers/gemini.rs b/src/providers/gemini.rs new file mode 100644 index 0000000..a988224 --- /dev/null +++ b/src/providers/gemini.rs @@ -0,0 +1,560 @@ +//! Google Gemini provider with support for: +//! - Direct API key (`GEMINI_API_KEY` env var or config) +//! - Gemini CLI OAuth tokens (reuse existing ~/.gemini/ authentication) +//! - Google Cloud ADC (`GOOGLE_APPLICATION_CREDENTIALS`) + +use crate::providers::traits::Provider; +use async_trait::async_trait; +use directories::UserDirs; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; + +/// Gemini provider supporting multiple authentication methods. +pub struct GeminiProvider { + auth: Option, + client: Client, +} + +/// Resolved credential — the variant determines both the HTTP auth method +/// and the diagnostic label returned by `auth_source()`. +#[derive(Debug)] +enum GeminiAuth { + /// Explicit API key from config: sent as `?key=` query parameter. + ExplicitKey(String), + /// API key from `GEMINI_API_KEY` env var: sent as `?key=`. + EnvGeminiKey(String), + /// API key from `GOOGLE_API_KEY` env var: sent as `?key=`. + EnvGoogleKey(String), + /// OAuth access token from Gemini CLI: sent as `Authorization: Bearer`. + OAuthToken(String), +} + +impl GeminiAuth { + /// Whether this credential is an API key (sent as `?key=` query param). + fn is_api_key(&self) -> bool { + matches!( + self, + GeminiAuth::ExplicitKey(_) | GeminiAuth::EnvGeminiKey(_) | GeminiAuth::EnvGoogleKey(_) + ) + } + + /// The raw credential string. + fn credential(&self) -> &str { + match self { + GeminiAuth::ExplicitKey(s) + | GeminiAuth::EnvGeminiKey(s) + | GeminiAuth::EnvGoogleKey(s) + | GeminiAuth::OAuthToken(s) => s, + } + } +} + +// ══════════════════════════════════════════════════════════════════════════════ +// API REQUEST/RESPONSE TYPES +// ══════════════════════════════════════════════════════════════════════════════ + +#[derive(Debug, Serialize)] +struct GenerateContentRequest { + contents: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + system_instruction: Option, + #[serde(rename = "generationConfig")] + generation_config: GenerationConfig, +} + +#[derive(Debug, Serialize)] +struct Content { + #[serde(skip_serializing_if = "Option::is_none")] + role: Option, + parts: Vec, +} + +#[derive(Debug, Serialize)] +struct Part { + text: String, +} + +#[derive(Debug, Serialize)] +struct GenerationConfig { + temperature: f64, + #[serde(rename = "maxOutputTokens")] + max_output_tokens: u32, +} + +#[derive(Debug, Deserialize)] +struct GenerateContentResponse { + candidates: Option>, + error: Option, +} + +#[derive(Debug, Deserialize)] +struct Candidate { + content: CandidateContent, +} + +#[derive(Debug, Deserialize)] +struct CandidateContent { + parts: Vec, +} + +#[derive(Debug, Deserialize)] +struct ResponsePart { + text: Option, +} + +#[derive(Debug, Deserialize)] +struct ApiError { + message: String, +} + +// ══════════════════════════════════════════════════════════════════════════════ +// GEMINI CLI TOKEN STRUCTURES +// ══════════════════════════════════════════════════════════════════════════════ + +/// OAuth token stored by Gemini CLI in `~/.gemini/oauth_creds.json` +#[derive(Debug, Deserialize)] +struct GeminiCliOAuthCreds { + access_token: Option, + expiry: Option, +} + +impl GeminiProvider { + /// Create a new Gemini provider. + /// + /// Authentication priority: + /// 1. Explicit API key passed in + /// 2. `GEMINI_API_KEY` environment variable + /// 3. `GOOGLE_API_KEY` environment variable + /// 4. Gemini CLI OAuth tokens (`~/.gemini/oauth_creds.json`) + pub fn new(api_key: Option<&str>) -> Self { + let resolved_auth = api_key + .and_then(Self::normalize_non_empty) + .map(GeminiAuth::ExplicitKey) + .or_else(|| Self::load_non_empty_env("GEMINI_API_KEY").map(GeminiAuth::EnvGeminiKey)) + .or_else(|| Self::load_non_empty_env("GOOGLE_API_KEY").map(GeminiAuth::EnvGoogleKey)) + .or_else(|| Self::try_load_gemini_cli_token().map(GeminiAuth::OAuthToken)); + + Self { + auth: resolved_auth, + client: Client::builder() + .timeout(std::time::Duration::from_secs(120)) + .connect_timeout(std::time::Duration::from_secs(10)) + .build() + .unwrap_or_else(|_| Client::new()), + } + } + + fn normalize_non_empty(value: &str) -> Option { + let trimmed = value.trim(); + if trimmed.is_empty() { + None + } else { + Some(trimmed.to_string()) + } + } + + fn load_non_empty_env(name: &str) -> Option { + std::env::var(name) + .ok() + .and_then(|value| Self::normalize_non_empty(&value)) + } + + /// Try to load OAuth access token from Gemini CLI's cached credentials. + /// Location: `~/.gemini/oauth_creds.json` + fn try_load_gemini_cli_token() -> Option { + let gemini_dir = Self::gemini_cli_dir()?; + let creds_path = gemini_dir.join("oauth_creds.json"); + + if !creds_path.exists() { + return None; + } + + let content = std::fs::read_to_string(&creds_path).ok()?; + let creds: GeminiCliOAuthCreds = serde_json::from_str(&content).ok()?; + + // Check if token is expired (basic check) + if let Some(ref expiry) = creds.expiry { + if let Ok(expiry_time) = chrono::DateTime::parse_from_rfc3339(expiry) { + if expiry_time < chrono::Utc::now() { + tracing::warn!("Gemini CLI OAuth token expired — re-run `gemini` to refresh"); + return None; + } + } + } + + creds + .access_token + .and_then(|token| Self::normalize_non_empty(&token)) + } + + /// Get the Gemini CLI config directory (~/.gemini) + fn gemini_cli_dir() -> Option { + UserDirs::new().map(|u| u.home_dir().join(".gemini")) + } + + /// Check if Gemini CLI is configured and has valid credentials + pub fn has_cli_credentials() -> bool { + Self::try_load_gemini_cli_token().is_some() + } + + /// Check if any Gemini authentication is available + pub fn has_any_auth() -> bool { + Self::load_non_empty_env("GEMINI_API_KEY").is_some() + || Self::load_non_empty_env("GOOGLE_API_KEY").is_some() + || Self::has_cli_credentials() + } + + /// Get authentication source description for diagnostics. + /// Uses the stored enum variant — no env var re-reading at call time. + pub fn auth_source(&self) -> &'static str { + match self.auth.as_ref() { + Some(GeminiAuth::ExplicitKey(_)) => "config", + Some(GeminiAuth::EnvGeminiKey(_)) => "GEMINI_API_KEY env var", + Some(GeminiAuth::EnvGoogleKey(_)) => "GOOGLE_API_KEY env var", + Some(GeminiAuth::OAuthToken(_)) => "Gemini CLI OAuth", + None => "none", + } + } + + fn format_model_name(model: &str) -> String { + if model.starts_with("models/") { + model.to_string() + } else { + format!("models/{model}") + } + } + + fn build_generate_content_url(model: &str, auth: &GeminiAuth) -> String { + let model_name = Self::format_model_name(model); + let base_url = format!( + "https://generativelanguage.googleapis.com/v1beta/{model_name}:generateContent" + ); + + if auth.is_api_key() { + format!("{base_url}?key={}", auth.credential()) + } else { + base_url + } + } + + fn build_generate_content_request( + &self, + auth: &GeminiAuth, + url: &str, + request: &GenerateContentRequest, + ) -> reqwest::RequestBuilder { + let req = self.client.post(url).json(request); + match auth { + GeminiAuth::OAuthToken(token) => req.bearer_auth(token), + _ => req, + } + } +} + +#[async_trait] +impl Provider for GeminiProvider { + async fn chat_with_system( + &self, + system_prompt: Option<&str>, + message: &str, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let auth = self.auth.as_ref().ok_or_else(|| { + anyhow::anyhow!( + "Gemini API key not found. Options:\n\ + 1. Set GEMINI_API_KEY env var\n\ + 2. Run `gemini` CLI to authenticate (tokens will be reused)\n\ + 3. Get an API key from https://aistudio.google.com/app/apikey\n\ + 4. Run `zeroclaw onboard` to configure" + ) + })?; + + // Build request + let system_instruction = system_prompt.map(|sys| Content { + role: None, + parts: vec![Part { + text: sys.to_string(), + }], + }); + + let request = GenerateContentRequest { + contents: vec![Content { + role: Some("user".to_string()), + parts: vec![Part { + text: message.to_string(), + }], + }], + system_instruction, + generation_config: GenerationConfig { + temperature, + max_output_tokens: 8192, + }, + }; + + let url = Self::build_generate_content_url(model, auth); + + let response = self + .build_generate_content_request(auth, &url, &request) + .send() + .await?; + + if !response.status().is_success() { + let status = response.status(); + let error_text = response.text().await.unwrap_or_default(); + anyhow::bail!("Gemini API error ({status}): {error_text}"); + } + + let result: GenerateContentResponse = response.json().await?; + + // Check for API error in response body + if let Some(err) = result.error { + anyhow::bail!("Gemini API error: {}", err.message); + } + + // Extract text from response + result + .candidates + .and_then(|c| c.into_iter().next()) + .and_then(|c| c.content.parts.into_iter().next()) + .and_then(|p| p.text) + .ok_or_else(|| anyhow::anyhow!("No response from Gemini")) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use reqwest::header::AUTHORIZATION; + + #[test] + fn normalize_non_empty_trims_and_filters() { + assert_eq!( + GeminiProvider::normalize_non_empty(" value "), + Some("value".into()) + ); + assert_eq!(GeminiProvider::normalize_non_empty(""), None); + assert_eq!(GeminiProvider::normalize_non_empty(" \t\n"), None); + } + + #[test] + fn provider_creates_without_key() { + let provider = GeminiProvider::new(None); + // May pick up env vars; just verify it doesn't panic + let _ = provider.auth_source(); + } + + #[test] + fn provider_creates_with_key() { + let provider = GeminiProvider::new(Some("test-api-key")); + assert!(matches!( + provider.auth, + Some(GeminiAuth::ExplicitKey(ref key)) if key == "test-api-key" + )); + } + + #[test] + fn provider_rejects_empty_key() { + let provider = GeminiProvider::new(Some("")); + assert!(!matches!(provider.auth, Some(GeminiAuth::ExplicitKey(_)))); + } + + #[test] + fn gemini_cli_dir_returns_path() { + let dir = GeminiProvider::gemini_cli_dir(); + // Should return Some on systems with home dir + if UserDirs::new().is_some() { + assert!(dir.is_some()); + assert!(dir.unwrap().ends_with(".gemini")); + } + } + + #[test] + fn auth_source_explicit_key() { + let provider = GeminiProvider { + auth: Some(GeminiAuth::ExplicitKey("key".into())), + client: Client::new(), + }; + assert_eq!(provider.auth_source(), "config"); + } + + #[test] + fn auth_source_none_without_credentials() { + let provider = GeminiProvider { + auth: None, + client: Client::new(), + }; + assert_eq!(provider.auth_source(), "none"); + } + + #[test] + fn auth_source_oauth() { + let provider = GeminiProvider { + auth: Some(GeminiAuth::OAuthToken("ya29.mock".into())), + client: Client::new(), + }; + assert_eq!(provider.auth_source(), "Gemini CLI OAuth"); + } + + #[test] + fn model_name_formatting() { + assert_eq!( + GeminiProvider::format_model_name("gemini-2.0-flash"), + "models/gemini-2.0-flash" + ); + assert_eq!( + GeminiProvider::format_model_name("models/gemini-1.5-pro"), + "models/gemini-1.5-pro" + ); + } + + #[test] + fn api_key_url_includes_key_query_param() { + let auth = GeminiAuth::ExplicitKey("api-key-123".into()); + let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth); + assert!(url.contains(":generateContent?key=api-key-123")); + } + + #[test] + fn oauth_url_omits_key_query_param() { + let auth = GeminiAuth::OAuthToken("ya29.test-token".into()); + let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth); + assert!(url.ends_with(":generateContent")); + assert!(!url.contains("?key=")); + } + + #[test] + fn oauth_request_uses_bearer_auth_header() { + let provider = GeminiProvider { + auth: Some(GeminiAuth::OAuthToken("ya29.mock-token".into())), + client: Client::new(), + }; + let auth = GeminiAuth::OAuthToken("ya29.mock-token".into()); + let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth); + let body = GenerateContentRequest { + contents: vec![Content { + role: Some("user".into()), + parts: vec![Part { + text: "hello".into(), + }], + }], + system_instruction: None, + generation_config: GenerationConfig { + temperature: 0.7, + max_output_tokens: 8192, + }, + }; + + let request = provider + .build_generate_content_request(&auth, &url, &body) + .build() + .unwrap(); + + assert_eq!( + request + .headers() + .get(AUTHORIZATION) + .and_then(|h| h.to_str().ok()), + Some("Bearer ya29.mock-token") + ); + } + + #[test] + fn api_key_request_does_not_set_bearer_header() { + let provider = GeminiProvider { + auth: Some(GeminiAuth::ExplicitKey("api-key-123".into())), + client: Client::new(), + }; + let auth = GeminiAuth::ExplicitKey("api-key-123".into()); + let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth); + let body = GenerateContentRequest { + contents: vec![Content { + role: Some("user".into()), + parts: vec![Part { + text: "hello".into(), + }], + }], + system_instruction: None, + generation_config: GenerationConfig { + temperature: 0.7, + max_output_tokens: 8192, + }, + }; + + let request = provider + .build_generate_content_request(&auth, &url, &body) + .build() + .unwrap(); + + assert!(request.headers().get(AUTHORIZATION).is_none()); + } + + #[test] + fn request_serialization() { + let request = GenerateContentRequest { + contents: vec![Content { + role: Some("user".to_string()), + parts: vec![Part { + text: "Hello".to_string(), + }], + }], + system_instruction: Some(Content { + role: None, + parts: vec![Part { + text: "You are helpful".to_string(), + }], + }), + generation_config: GenerationConfig { + temperature: 0.7, + max_output_tokens: 8192, + }, + }; + + let json = serde_json::to_string(&request).unwrap(); + assert!(json.contains("\"role\":\"user\"")); + assert!(json.contains("\"text\":\"Hello\"")); + assert!(json.contains("\"temperature\":0.7")); + assert!(json.contains("\"maxOutputTokens\":8192")); + } + + #[test] + fn response_deserialization() { + let json = r#"{ + "candidates": [{ + "content": { + "parts": [{"text": "Hello there!"}] + } + }] + }"#; + + let response: GenerateContentResponse = serde_json::from_str(json).unwrap(); + assert!(response.candidates.is_some()); + let text = response + .candidates + .unwrap() + .into_iter() + .next() + .unwrap() + .content + .parts + .into_iter() + .next() + .unwrap() + .text; + assert_eq!(text, Some("Hello there!".to_string())); + } + + #[test] + fn error_response_deserialization() { + let json = r#"{ + "error": { + "message": "Invalid API key" + } + }"#; + + let response: GenerateContentResponse = serde_json::from_str(json).unwrap(); + assert!(response.error.is_some()); + assert_eq!(response.error.unwrap().message, "Invalid API key"); + } +} diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 1ec33ac..15d8316 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -1,112 +1,533 @@ pub mod anthropic; pub mod compatible; -pub mod glm; +pub mod copilot; +pub mod gemini; pub mod ollama; pub mod openai; pub mod openrouter; pub mod reliable; +pub mod router; pub mod traits; -pub use traits::Provider; +#[allow(unused_imports)] +pub use traits::{ + ChatMessage, ChatRequest, ChatResponse, ConversationMessage, Provider, ToolCall, + ToolResultMessage, +}; use compatible::{AuthStyle, OpenAiCompatibleProvider}; use reliable::ReliableProvider; -/// Factory: create the right provider from config -#[allow(clippy::too_many_lines)] +const MAX_API_ERROR_CHARS: usize = 200; +const MINIMAX_INTL_BASE_URL: &str = "https://api.minimax.io/v1"; +const MINIMAX_CN_BASE_URL: &str = "https://api.minimaxi.com/v1"; +const GLM_GLOBAL_BASE_URL: &str = "https://api.z.ai/api/paas/v4"; +const GLM_CN_BASE_URL: &str = "https://open.bigmodel.cn/api/paas/v4"; +const MOONSHOT_INTL_BASE_URL: &str = "https://api.moonshot.ai/v1"; +const MOONSHOT_CN_BASE_URL: &str = "https://api.moonshot.cn/v1"; +const QWEN_CN_BASE_URL: &str = "https://dashscope.aliyuncs.com/compatible-mode/v1"; +const QWEN_INTL_BASE_URL: &str = "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"; +const QWEN_US_BASE_URL: &str = "https://dashscope-us.aliyuncs.com/compatible-mode/v1"; +const ZAI_GLOBAL_BASE_URL: &str = "https://api.z.ai/api/coding/paas/v4"; +const ZAI_CN_BASE_URL: &str = "https://open.bigmodel.cn/api/coding/paas/v4"; + +pub(crate) fn is_minimax_intl_alias(name: &str) -> bool { + matches!( + name, + "minimax" | "minimax-intl" | "minimax-io" | "minimax-global" + ) +} + +pub(crate) fn is_minimax_cn_alias(name: &str) -> bool { + matches!(name, "minimax-cn" | "minimaxi") +} + +pub(crate) fn is_minimax_alias(name: &str) -> bool { + is_minimax_intl_alias(name) || is_minimax_cn_alias(name) +} + +pub(crate) fn is_glm_global_alias(name: &str) -> bool { + matches!(name, "glm" | "zhipu" | "glm-global" | "zhipu-global") +} + +pub(crate) fn is_glm_cn_alias(name: &str) -> bool { + matches!(name, "glm-cn" | "zhipu-cn" | "bigmodel") +} + +pub(crate) fn is_glm_alias(name: &str) -> bool { + is_glm_global_alias(name) || is_glm_cn_alias(name) +} + +pub(crate) fn is_moonshot_intl_alias(name: &str) -> bool { + matches!( + name, + "moonshot-intl" | "moonshot-global" | "kimi-intl" | "kimi-global" + ) +} + +pub(crate) fn is_moonshot_cn_alias(name: &str) -> bool { + matches!(name, "moonshot" | "kimi" | "moonshot-cn" | "kimi-cn") +} + +pub(crate) fn is_moonshot_alias(name: &str) -> bool { + is_moonshot_intl_alias(name) || is_moonshot_cn_alias(name) +} + +pub(crate) fn is_qwen_cn_alias(name: &str) -> bool { + matches!(name, "qwen" | "dashscope" | "qwen-cn" | "dashscope-cn") +} + +pub(crate) fn is_qwen_intl_alias(name: &str) -> bool { + matches!( + name, + "qwen-intl" | "dashscope-intl" | "qwen-international" | "dashscope-international" + ) +} + +pub(crate) fn is_qwen_us_alias(name: &str) -> bool { + matches!(name, "qwen-us" | "dashscope-us") +} + +pub(crate) fn is_qwen_alias(name: &str) -> bool { + is_qwen_cn_alias(name) || is_qwen_intl_alias(name) || is_qwen_us_alias(name) +} + +pub(crate) fn is_zai_global_alias(name: &str) -> bool { + matches!(name, "zai" | "z.ai" | "zai-global" | "z.ai-global") +} + +pub(crate) fn is_zai_cn_alias(name: &str) -> bool { + matches!(name, "zai-cn" | "z.ai-cn") +} + +pub(crate) fn is_zai_alias(name: &str) -> bool { + is_zai_global_alias(name) || is_zai_cn_alias(name) +} + +pub(crate) fn is_qianfan_alias(name: &str) -> bool { + matches!(name, "qianfan" | "baidu") +} + +pub(crate) fn canonical_china_provider_name(name: &str) -> Option<&'static str> { + if is_qwen_alias(name) { + Some("qwen") + } else if is_glm_alias(name) { + Some("glm") + } else if is_moonshot_alias(name) { + Some("moonshot") + } else if is_minimax_alias(name) { + Some("minimax") + } else if is_zai_alias(name) { + Some("zai") + } else if is_qianfan_alias(name) { + Some("qianfan") + } else { + None + } +} + +fn minimax_base_url(name: &str) -> Option<&'static str> { + if is_minimax_cn_alias(name) { + Some(MINIMAX_CN_BASE_URL) + } else if is_minimax_intl_alias(name) { + Some(MINIMAX_INTL_BASE_URL) + } else { + None + } +} + +fn glm_base_url(name: &str) -> Option<&'static str> { + if is_glm_cn_alias(name) { + Some(GLM_CN_BASE_URL) + } else if is_glm_global_alias(name) { + Some(GLM_GLOBAL_BASE_URL) + } else { + None + } +} + +fn moonshot_base_url(name: &str) -> Option<&'static str> { + if is_moonshot_intl_alias(name) { + Some(MOONSHOT_INTL_BASE_URL) + } else if is_moonshot_cn_alias(name) { + Some(MOONSHOT_CN_BASE_URL) + } else { + None + } +} + +fn qwen_base_url(name: &str) -> Option<&'static str> { + if is_qwen_cn_alias(name) { + Some(QWEN_CN_BASE_URL) + } else if is_qwen_intl_alias(name) { + Some(QWEN_INTL_BASE_URL) + } else if is_qwen_us_alias(name) { + Some(QWEN_US_BASE_URL) + } else { + None + } +} + +fn zai_base_url(name: &str) -> Option<&'static str> { + if is_zai_cn_alias(name) { + Some(ZAI_CN_BASE_URL) + } else if is_zai_global_alias(name) { + Some(ZAI_GLOBAL_BASE_URL) + } else { + None + } +} + +fn is_secret_char(c: char) -> bool { + c.is_ascii_alphanumeric() || matches!(c, '-' | '_' | '.' | ':') +} + +fn token_end(input: &str, from: usize) -> usize { + let mut end = from; + for (i, c) in input[from..].char_indices() { + if is_secret_char(c) { + end = from + i + c.len_utf8(); + } else { + break; + } + } + end +} + +/// Scrub known secret-like token prefixes from provider error strings. +/// +/// Redacts tokens with prefixes like `sk-`, `xoxb-`, `xoxp-`, `ghp_`, `gho_`, +/// `ghu_`, and `github_pat_`. +pub fn scrub_secret_patterns(input: &str) -> String { + const PREFIXES: [&str; 7] = [ + "sk-", + "xoxb-", + "xoxp-", + "ghp_", + "gho_", + "ghu_", + "github_pat_", + ]; + + let mut scrubbed = input.to_string(); + + for prefix in PREFIXES { + let mut search_from = 0; + loop { + let Some(rel) = scrubbed[search_from..].find(prefix) else { + break; + }; + + let start = search_from + rel; + let content_start = start + prefix.len(); + let end = token_end(&scrubbed, content_start); + + // Bare prefixes like "sk-" should not stop future scans. + if end == content_start { + search_from = content_start; + continue; + } + + scrubbed.replace_range(start..end, "[REDACTED]"); + search_from = start + "[REDACTED]".len(); + } + } + + scrubbed +} + +/// Sanitize API error text by scrubbing secrets and truncating length. +pub fn sanitize_api_error(input: &str) -> String { + let scrubbed = scrub_secret_patterns(input); + + if scrubbed.chars().count() <= MAX_API_ERROR_CHARS { + return scrubbed; + } + + let mut end = MAX_API_ERROR_CHARS; + while end > 0 && !scrubbed.is_char_boundary(end) { + end -= 1; + } + + format!("{}...", &scrubbed[..end]) +} + +/// Build a sanitized provider error from a failed HTTP response. +pub async fn api_error(provider: &str, response: reqwest::Response) -> anyhow::Error { + let status = response.status(); + let body = response + .text() + .await + .unwrap_or_else(|_| "".to_string()); + let sanitized = sanitize_api_error(&body); + anyhow::anyhow!("{provider} API error ({status}): {sanitized}") +} + +/// Resolve API key for a provider from config and environment variables. +/// +/// Resolution order: +/// 1. Explicitly provided `api_key` parameter (trimmed, filtered if empty) +/// 2. Provider-specific environment variable (e.g., `ANTHROPIC_OAUTH_TOKEN`, `OPENROUTER_API_KEY`) +/// 3. Generic fallback variables (`ZEROCLAW_API_KEY`, `API_KEY`) +/// +/// For Anthropic, the provider-specific env var is `ANTHROPIC_OAUTH_TOKEN` (for setup-tokens) +/// followed by `ANTHROPIC_API_KEY` (for regular API keys). +fn resolve_provider_credential(name: &str, credential_override: Option<&str>) -> Option { + if let Some(raw_override) = credential_override { + let trimmed_override = raw_override.trim(); + if !trimmed_override.is_empty() { + return Some(trimmed_override.to_owned()); + } + } + + let provider_env_candidates: Vec<&str> = match name { + "anthropic" => vec!["ANTHROPIC_OAUTH_TOKEN", "ANTHROPIC_API_KEY"], + "openrouter" => vec!["OPENROUTER_API_KEY"], + "openai" => vec!["OPENAI_API_KEY"], + "ollama" => vec!["OLLAMA_API_KEY"], + "venice" => vec!["VENICE_API_KEY"], + "groq" => vec!["GROQ_API_KEY"], + "mistral" => vec!["MISTRAL_API_KEY"], + "deepseek" => vec!["DEEPSEEK_API_KEY"], + "xai" | "grok" => vec!["XAI_API_KEY"], + "together" | "together-ai" => vec!["TOGETHER_API_KEY"], + "fireworks" | "fireworks-ai" => vec!["FIREWORKS_API_KEY"], + "perplexity" => vec!["PERPLEXITY_API_KEY"], + "cohere" => vec!["COHERE_API_KEY"], + name if is_moonshot_alias(name) => vec!["MOONSHOT_API_KEY"], + name if is_glm_alias(name) => vec!["GLM_API_KEY"], + name if is_minimax_alias(name) => vec!["MINIMAX_API_KEY"], + name if is_qianfan_alias(name) => vec!["QIANFAN_API_KEY"], + name if is_qwen_alias(name) => vec!["DASHSCOPE_API_KEY"], + name if is_zai_alias(name) => vec!["ZAI_API_KEY"], + "nvidia" | "nvidia-nim" | "build.nvidia.com" => vec!["NVIDIA_API_KEY"], + "synthetic" => vec!["SYNTHETIC_API_KEY"], + "opencode" | "opencode-zen" => vec!["OPENCODE_API_KEY"], + "vercel" | "vercel-ai" => vec!["VERCEL_API_KEY"], + "cloudflare" | "cloudflare-ai" => vec!["CLOUDFLARE_API_KEY"], + "astrai" => vec!["ASTRAI_API_KEY"], + _ => vec![], + }; + + for env_var in provider_env_candidates { + if let Ok(value) = std::env::var(env_var) { + let value = value.trim(); + if !value.is_empty() { + return Some(value.to_string()); + } + } + } + + for env_var in ["ZEROCLAW_API_KEY", "API_KEY"] { + if let Ok(value) = std::env::var(env_var) { + let value = value.trim(); + if !value.is_empty() { + return Some(value.to_string()); + } + } + } + + None +} + +fn parse_custom_provider_url( + raw_url: &str, + provider_label: &str, + format_hint: &str, +) -> anyhow::Result { + let base_url = raw_url.trim(); + + if base_url.is_empty() { + anyhow::bail!("{provider_label} requires a URL. Format: {format_hint}"); + } + + let parsed = reqwest::Url::parse(base_url).map_err(|_| { + anyhow::anyhow!("{provider_label} requires a valid URL. Format: {format_hint}") + })?; + + match parsed.scheme() { + "http" | "https" => Ok(base_url.to_string()), + _ => anyhow::bail!( + "{provider_label} requires an http:// or https:// URL. Format: {format_hint}" + ), + } +} + +/// Factory: create the right provider from config (without custom URL) pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result> { + create_provider_with_url(name, api_key, None) +} + +/// Factory: create the right provider from config with optional custom base URL +#[allow(clippy::too_many_lines)] +pub fn create_provider_with_url( + name: &str, + api_key: Option<&str>, + api_url: Option<&str>, +) -> anyhow::Result> { + let resolved_credential = resolve_provider_credential(name, api_key); + #[allow(clippy::option_as_ref_deref)] + let key = resolved_credential.as_ref().map(String::as_str); match name { // ── Primary providers (custom implementations) ─────── - "openrouter" => Ok(Box::new(openrouter::OpenRouterProvider::new(api_key))), - "anthropic" => Ok(Box::new(anthropic::AnthropicProvider::new(api_key))), - "openai" => Ok(Box::new(openai::OpenAiProvider::new(api_key))), - "ollama" => Ok(Box::new(ollama::OllamaProvider::new( - api_key.filter(|k| !k.is_empty()), - ))), + "openrouter" => Ok(Box::new(openrouter::OpenRouterProvider::new(key))), + "anthropic" => Ok(Box::new(anthropic::AnthropicProvider::new(key))), + "openai" => Ok(Box::new(openai::OpenAiProvider::new(key))), + // Ollama uses api_url for custom base URL (e.g. remote Ollama instance) + "ollama" => Ok(Box::new(ollama::OllamaProvider::new(api_url, key))), + "gemini" | "google" | "google-gemini" => { + Ok(Box::new(gemini::GeminiProvider::new(key))) + } // ── OpenAI-compatible providers ────────────────────── "venice" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Venice", "https://api.venice.ai", api_key, AuthStyle::Bearer, + "Venice", "https://api.venice.ai", key, AuthStyle::Bearer, ))), "vercel" | "vercel-ai" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Vercel AI Gateway", "https://api.vercel.ai", api_key, AuthStyle::Bearer, + "Vercel AI Gateway", "https://api.vercel.ai", key, AuthStyle::Bearer, ))), "cloudflare" | "cloudflare-ai" => Ok(Box::new(OpenAiCompatibleProvider::new( "Cloudflare AI Gateway", "https://gateway.ai.cloudflare.com/v1", - api_key, + key, AuthStyle::Bearer, ))), - "moonshot" | "kimi" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Moonshot", "https://api.moonshot.cn", api_key, AuthStyle::Bearer, + name if moonshot_base_url(name).is_some() => Ok(Box::new(OpenAiCompatibleProvider::new( + "Moonshot", + moonshot_base_url(name).expect("checked in guard"), + key, + AuthStyle::Bearer, ))), "synthetic" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Synthetic", "https://api.synthetic.com", api_key, AuthStyle::Bearer, + "Synthetic", "https://api.synthetic.com", key, AuthStyle::Bearer, ))), "opencode" | "opencode-zen" => Ok(Box::new(OpenAiCompatibleProvider::new( - "OpenCode Zen", "https://api.opencode.ai", api_key, AuthStyle::Bearer, + "OpenCode Zen", "https://opencode.ai/zen/v1", key, AuthStyle::Bearer, ))), - "zai" | "z.ai" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Z.AI", "https://api.z.ai", api_key, AuthStyle::Bearer, + name if zai_base_url(name).is_some() => Ok(Box::new(OpenAiCompatibleProvider::new( + "Z.AI", + zai_base_url(name).expect("checked in guard"), + key, + AuthStyle::Bearer, ))), - "glm" | "zhipu" => Ok(Box::new(glm::GlmProvider::new(api_key))), - "minimax" => Ok(Box::new(OpenAiCompatibleProvider::new( - "MiniMax", "https://api.minimax.chat", api_key, AuthStyle::Bearer, + name if glm_base_url(name).is_some() => { + Ok(Box::new(OpenAiCompatibleProvider::new_no_responses_fallback( + "GLM", + glm_base_url(name).expect("checked in guard"), + key, + AuthStyle::Bearer, + ))) + } + name if minimax_base_url(name).is_some() => Ok(Box::new(OpenAiCompatibleProvider::new( + "MiniMax", + minimax_base_url(name).expect("checked in guard"), + key, + AuthStyle::Bearer, ))), "bedrock" | "aws-bedrock" => Ok(Box::new(OpenAiCompatibleProvider::new( "Amazon Bedrock", "https://bedrock-runtime.us-east-1.amazonaws.com", - api_key, + key, AuthStyle::Bearer, ))), - "qianfan" | "baidu" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Qianfan", "https://aip.baidubce.com", api_key, AuthStyle::Bearer, + name if is_qianfan_alias(name) => Ok(Box::new(OpenAiCompatibleProvider::new( + "Qianfan", "https://aip.baidubce.com", key, AuthStyle::Bearer, + ))), + name if qwen_base_url(name).is_some() => Ok(Box::new(OpenAiCompatibleProvider::new( + "Qwen", + qwen_base_url(name).expect("checked in guard"), + key, + AuthStyle::Bearer, ))), // ── Extended ecosystem (community favorites) ───────── "groq" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Groq", "https://api.groq.com/openai", api_key, AuthStyle::Bearer, + "Groq", "https://api.groq.com/openai", key, AuthStyle::Bearer, ))), "mistral" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Mistral", "https://api.mistral.ai", api_key, AuthStyle::Bearer, + "Mistral", "https://api.mistral.ai/v1", key, AuthStyle::Bearer, ))), "xai" | "grok" => Ok(Box::new(OpenAiCompatibleProvider::new( - "xAI", "https://api.x.ai", api_key, AuthStyle::Bearer, + "xAI", "https://api.x.ai", key, AuthStyle::Bearer, ))), "deepseek" => Ok(Box::new(OpenAiCompatibleProvider::new( - "DeepSeek", "https://api.deepseek.com", api_key, AuthStyle::Bearer, + "DeepSeek", "https://api.deepseek.com", key, AuthStyle::Bearer, ))), "together" | "together-ai" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Together AI", "https://api.together.xyz", api_key, AuthStyle::Bearer, + "Together AI", "https://api.together.xyz", key, AuthStyle::Bearer, ))), "fireworks" | "fireworks-ai" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Fireworks AI", "https://api.fireworks.ai/inference", api_key, AuthStyle::Bearer, + "Fireworks AI", "https://api.fireworks.ai/inference/v1", key, AuthStyle::Bearer, ))), "perplexity" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Perplexity", "https://api.perplexity.ai", api_key, AuthStyle::Bearer, + "Perplexity", "https://api.perplexity.ai", key, AuthStyle::Bearer, ))), "cohere" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Cohere", "https://api.cohere.com/compatibility", api_key, AuthStyle::Bearer, + "Cohere", "https://api.cohere.com/compatibility", key, AuthStyle::Bearer, + ))), + "copilot" | "github-copilot" => { + Ok(Box::new(copilot::CopilotProvider::new(api_key))) + }, + "lmstudio" | "lm-studio" => { + let lm_studio_key = api_key + .map(str::trim) + .filter(|value| !value.is_empty()) + .unwrap_or("lm-studio"); + Ok(Box::new(OpenAiCompatibleProvider::new( + "LM Studio", + "http://localhost:1234/v1", + Some(lm_studio_key), + AuthStyle::Bearer, + ))) + } + "nvidia" | "nvidia-nim" | "build.nvidia.com" => Ok(Box::new( + OpenAiCompatibleProvider::new( + "NVIDIA NIM", + "https://integrate.api.nvidia.com/v1", + key, + AuthStyle::Bearer, + ), + )), + + // ── AI inference routers ───────────────────────────── + "astrai" => Ok(Box::new(OpenAiCompatibleProvider::new( + "Astrai", "https://as-trai.com/v1", key, AuthStyle::Bearer, ))), // ── Bring Your Own Provider (custom URL) ─────────── // Format: "custom:https://your-api.com" or "custom:http://localhost:1234" name if name.starts_with("custom:") => { - let base_url = name.strip_prefix("custom:").unwrap_or(""); - if base_url.is_empty() { - anyhow::bail!("Custom provider requires a URL. Format: custom:https://your-api.com"); - } + let base_url = parse_custom_provider_url( + name.strip_prefix("custom:").unwrap_or(""), + "Custom provider", + "custom:https://your-api.com", + )?; Ok(Box::new(OpenAiCompatibleProvider::new( "Custom", - base_url, - api_key, + &base_url, + key, AuthStyle::Bearer, ))) } + // ── Anthropic-compatible custom endpoints ─────────── + // Format: "anthropic-custom:https://your-api.com" + name if name.starts_with("anthropic-custom:") => { + let base_url = parse_custom_provider_url( + name.strip_prefix("anthropic-custom:").unwrap_or(""), + "Anthropic-custom provider", + "anthropic-custom:https://your-api.com", + )?; + Ok(Box::new(anthropic::AnthropicProvider::with_base_url( + key, + Some(&base_url), + ))) + } + _ => anyhow::bail!( "Unknown provider: {name}. Check README for supported providers or run `zeroclaw onboard --interactive` to reconfigure.\n\ - Tip: Use \"custom:https://your-api.com\" for any OpenAI-compatible endpoint." + Tip: Use \"custom:https://your-api.com\" for OpenAI-compatible endpoints.\n\ + Tip: Use \"anthropic-custom:https://your-api.com\" for Anthropic-compatible endpoints." ), } } @@ -115,13 +536,14 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result, + api_url: Option<&str>, reliability: &crate::config::ReliabilityConfig, ) -> anyhow::Result> { let mut providers: Vec<(String, Box)> = Vec::new(); providers.push(( primary_name.to_string(), - create_provider(primary_name, api_key)?, + create_provider_with_url(primary_name, api_key, api_url)?, )); for fallback in &reliability.fallback_providers { @@ -129,49 +551,413 @@ pub fn create_resilient_provider( continue; } + // Fallback providers don't use the custom api_url (it's specific to primary) match create_provider(fallback, api_key) { Ok(provider) => providers.push((fallback.clone(), provider)), - Err(e) => { + Err(_error) => { tracing::warn!( fallback_provider = fallback, - "Ignoring invalid fallback provider: {e}" + "Ignoring invalid fallback provider during initialization" ); } } } - Ok(Box::new(ReliableProvider::new( + let reliable = ReliableProvider::new( providers, reliability.provider_retries, reliability.provider_backoff_ms, + ) + .with_api_keys(reliability.api_keys.clone()) + .with_model_fallbacks(reliability.model_fallbacks.clone()); + + Ok(Box::new(reliable)) +} + +/// Create a RouterProvider if model routes are configured, otherwise return a +/// standard resilient provider. The router wraps individual providers per route, +/// each with its own retry/fallback chain. +pub fn create_routed_provider( + primary_name: &str, + api_key: Option<&str>, + api_url: Option<&str>, + reliability: &crate::config::ReliabilityConfig, + model_routes: &[crate::config::ModelRouteConfig], + default_model: &str, +) -> anyhow::Result> { + if model_routes.is_empty() { + return create_resilient_provider(primary_name, api_key, api_url, reliability); + } + + // Collect unique provider names needed + let mut needed: Vec = vec![primary_name.to_string()]; + for route in model_routes { + if !needed.iter().any(|n| n == &route.provider) { + needed.push(route.provider.clone()); + } + } + + // Create each provider (with its own resilience wrapper) + let mut providers: Vec<(String, Box)> = Vec::new(); + for name in &needed { + let routed_credential = model_routes + .iter() + .find(|r| &r.provider == name) + .and_then(|r| { + r.api_key.as_ref().and_then(|raw_key| { + let trimmed_key = raw_key.trim(); + (!trimmed_key.is_empty()).then_some(trimmed_key) + }) + }); + let key = routed_credential.or(api_key); + // Only use api_url for the primary provider + let url = if name == primary_name { api_url } else { None }; + match create_resilient_provider(name, key, url, reliability) { + Ok(provider) => providers.push((name.clone(), provider)), + Err(e) => { + if name == primary_name { + return Err(e); + } + tracing::warn!( + provider = name.as_str(), + "Ignoring routed provider that failed to initialize" + ); + } + } + } + + // Build route table + let routes: Vec<(String, router::Route)> = model_routes + .iter() + .map(|r| { + ( + r.hint.clone(), + router::Route { + provider_name: r.provider.clone(), + model: r.model.clone(), + }, + ) + }) + .collect(); + + Ok(Box::new(router::RouterProvider::new( + providers, + routes, + default_model.to_string(), ))) } +/// Information about a supported provider for display purposes. +pub struct ProviderInfo { + /// Canonical name used in config (e.g. `"openrouter"`) + pub name: &'static str, + /// Human-readable display name + pub display_name: &'static str, + /// Alternative names accepted in config + pub aliases: &'static [&'static str], + /// Whether the provider runs locally (no API key required) + pub local: bool, +} + +/// Return the list of all known providers for display in `zeroclaw providers list`. +/// +/// This is intentionally separate from the factory match in `create_provider` +/// (display concern vs. construction concern). +pub fn list_providers() -> Vec { + vec![ + // ── Primary providers ──────────────────────────────── + ProviderInfo { + name: "openrouter", + display_name: "OpenRouter", + aliases: &[], + local: false, + }, + ProviderInfo { + name: "anthropic", + display_name: "Anthropic", + aliases: &[], + local: false, + }, + ProviderInfo { + name: "openai", + display_name: "OpenAI", + aliases: &[], + local: false, + }, + ProviderInfo { + name: "ollama", + display_name: "Ollama", + aliases: &[], + local: true, + }, + ProviderInfo { + name: "gemini", + display_name: "Google Gemini", + aliases: &["google", "google-gemini"], + local: false, + }, + // ── OpenAI-compatible providers ────────────────────── + ProviderInfo { + name: "venice", + display_name: "Venice", + aliases: &[], + local: false, + }, + ProviderInfo { + name: "vercel", + display_name: "Vercel AI Gateway", + aliases: &["vercel-ai"], + local: false, + }, + ProviderInfo { + name: "cloudflare", + display_name: "Cloudflare AI", + aliases: &["cloudflare-ai"], + local: false, + }, + ProviderInfo { + name: "moonshot", + display_name: "Moonshot", + aliases: &["kimi"], + local: false, + }, + ProviderInfo { + name: "synthetic", + display_name: "Synthetic", + aliases: &[], + local: false, + }, + ProviderInfo { + name: "opencode", + display_name: "OpenCode Zen", + aliases: &["opencode-zen"], + local: false, + }, + ProviderInfo { + name: "zai", + display_name: "Z.AI", + aliases: &["z.ai"], + local: false, + }, + ProviderInfo { + name: "glm", + display_name: "GLM (Zhipu)", + aliases: &["zhipu"], + local: false, + }, + ProviderInfo { + name: "minimax", + display_name: "MiniMax", + aliases: &[], + local: false, + }, + ProviderInfo { + name: "bedrock", + display_name: "Amazon Bedrock", + aliases: &["aws-bedrock"], + local: false, + }, + ProviderInfo { + name: "qianfan", + display_name: "Qianfan (Baidu)", + aliases: &["baidu"], + local: false, + }, + ProviderInfo { + name: "qwen", + display_name: "Qwen (DashScope)", + aliases: &[ + "dashscope", + "qwen-intl", + "dashscope-intl", + "qwen-us", + "dashscope-us", + ], + local: false, + }, + ProviderInfo { + name: "groq", + display_name: "Groq", + aliases: &[], + local: false, + }, + ProviderInfo { + name: "mistral", + display_name: "Mistral", + aliases: &[], + local: false, + }, + ProviderInfo { + name: "xai", + display_name: "xAI (Grok)", + aliases: &["grok"], + local: false, + }, + ProviderInfo { + name: "deepseek", + display_name: "DeepSeek", + aliases: &[], + local: false, + }, + ProviderInfo { + name: "together", + display_name: "Together AI", + aliases: &["together-ai"], + local: false, + }, + ProviderInfo { + name: "fireworks", + display_name: "Fireworks AI", + aliases: &["fireworks-ai"], + local: false, + }, + ProviderInfo { + name: "perplexity", + display_name: "Perplexity", + aliases: &[], + local: false, + }, + ProviderInfo { + name: "cohere", + display_name: "Cohere", + aliases: &[], + local: false, + }, + ProviderInfo { + name: "copilot", + display_name: "GitHub Copilot", + aliases: &["github-copilot"], + local: false, + }, + ProviderInfo { + name: "lmstudio", + display_name: "LM Studio", + aliases: &["lm-studio"], + local: true, + }, + ProviderInfo { + name: "nvidia", + display_name: "NVIDIA NIM", + aliases: &["nvidia-nim", "build.nvidia.com"], + local: false, + }, + ] +} + #[cfg(test)] mod tests { use super::*; + #[test] + fn resolve_provider_credential_prefers_explicit_argument() { + let resolved = resolve_provider_credential("openrouter", Some(" explicit-key ")); + assert_eq!(resolved, Some("explicit-key".to_string())); + } + + #[test] + fn regional_alias_predicates_cover_expected_variants() { + assert!(is_moonshot_alias("moonshot")); + assert!(is_moonshot_alias("kimi-global")); + assert!(is_glm_alias("glm")); + assert!(is_glm_alias("bigmodel")); + assert!(is_minimax_alias("minimax-io")); + assert!(is_minimax_alias("minimaxi")); + assert!(is_qwen_alias("dashscope")); + assert!(is_qwen_alias("qwen-us")); + assert!(is_zai_alias("z.ai")); + assert!(is_zai_alias("zai-cn")); + assert!(is_qianfan_alias("qianfan")); + assert!(is_qianfan_alias("baidu")); + + assert!(!is_moonshot_alias("openrouter")); + assert!(!is_glm_alias("openai")); + assert!(!is_qwen_alias("gemini")); + assert!(!is_zai_alias("anthropic")); + assert!(!is_qianfan_alias("cohere")); + } + + #[test] + fn canonical_china_provider_name_maps_regional_aliases() { + assert_eq!(canonical_china_provider_name("moonshot"), Some("moonshot")); + assert_eq!(canonical_china_provider_name("kimi-intl"), Some("moonshot")); + assert_eq!(canonical_china_provider_name("glm"), Some("glm")); + assert_eq!(canonical_china_provider_name("zhipu-cn"), Some("glm")); + assert_eq!(canonical_china_provider_name("minimax"), Some("minimax")); + assert_eq!(canonical_china_provider_name("minimax-cn"), Some("minimax")); + assert_eq!(canonical_china_provider_name("qwen"), Some("qwen")); + assert_eq!(canonical_china_provider_name("dashscope-us"), Some("qwen")); + assert_eq!(canonical_china_provider_name("zai"), Some("zai")); + assert_eq!(canonical_china_provider_name("z.ai-cn"), Some("zai")); + assert_eq!(canonical_china_provider_name("qianfan"), Some("qianfan")); + assert_eq!(canonical_china_provider_name("baidu"), Some("qianfan")); + assert_eq!(canonical_china_provider_name("openai"), None); + } + + #[test] + fn regional_endpoint_aliases_map_to_expected_urls() { + assert_eq!(minimax_base_url("minimax"), Some(MINIMAX_INTL_BASE_URL)); + assert_eq!( + minimax_base_url("minimax-intl"), + Some(MINIMAX_INTL_BASE_URL) + ); + assert_eq!(minimax_base_url("minimax-cn"), Some(MINIMAX_CN_BASE_URL)); + + assert_eq!(glm_base_url("glm"), Some(GLM_GLOBAL_BASE_URL)); + assert_eq!(glm_base_url("glm-cn"), Some(GLM_CN_BASE_URL)); + assert_eq!(glm_base_url("bigmodel"), Some(GLM_CN_BASE_URL)); + + assert_eq!(moonshot_base_url("moonshot"), Some(MOONSHOT_CN_BASE_URL)); + assert_eq!( + moonshot_base_url("moonshot-intl"), + Some(MOONSHOT_INTL_BASE_URL) + ); + + assert_eq!(qwen_base_url("qwen"), Some(QWEN_CN_BASE_URL)); + assert_eq!(qwen_base_url("qwen-cn"), Some(QWEN_CN_BASE_URL)); + assert_eq!(qwen_base_url("qwen-intl"), Some(QWEN_INTL_BASE_URL)); + assert_eq!(qwen_base_url("qwen-us"), Some(QWEN_US_BASE_URL)); + + assert_eq!(zai_base_url("zai"), Some(ZAI_GLOBAL_BASE_URL)); + assert_eq!(zai_base_url("z.ai"), Some(ZAI_GLOBAL_BASE_URL)); + assert_eq!(zai_base_url("zai-global"), Some(ZAI_GLOBAL_BASE_URL)); + assert_eq!(zai_base_url("z.ai-global"), Some(ZAI_GLOBAL_BASE_URL)); + assert_eq!(zai_base_url("zai-cn"), Some(ZAI_CN_BASE_URL)); + assert_eq!(zai_base_url("z.ai-cn"), Some(ZAI_CN_BASE_URL)); + } + // ── Primary providers ──────────────────────────────────── #[test] fn factory_openrouter() { - assert!(create_provider("openrouter", Some("sk-test")).is_ok()); + assert!(create_provider("openrouter", Some("provider-test-credential")).is_ok()); assert!(create_provider("openrouter", None).is_ok()); } #[test] fn factory_anthropic() { - assert!(create_provider("anthropic", Some("sk-test")).is_ok()); + assert!(create_provider("anthropic", Some("provider-test-credential")).is_ok()); } #[test] fn factory_openai() { - assert!(create_provider("openai", Some("sk-test")).is_ok()); + assert!(create_provider("openai", Some("provider-test-credential")).is_ok()); } #[test] fn factory_ollama() { assert!(create_provider("ollama", None).is_ok()); + // Ollama may use API key when a remote endpoint is configured. + assert!(create_provider("ollama", Some("dummy")).is_ok()); + assert!(create_provider("ollama", Some("any-value-here")).is_ok()); + } + + #[test] + fn factory_gemini() { + assert!(create_provider("gemini", Some("test-key")).is_ok()); + assert!(create_provider("google", Some("test-key")).is_ok()); + assert!(create_provider("google-gemini", Some("test-key")).is_ok()); + // Should also work without key (will try CLI auth) + assert!(create_provider("gemini", None).is_ok()); } // ── OpenAI-compatible providers ────────────────────────── @@ -197,6 +983,10 @@ mod tests { fn factory_moonshot() { assert!(create_provider("moonshot", Some("key")).is_ok()); assert!(create_provider("kimi", Some("key")).is_ok()); + assert!(create_provider("moonshot-intl", Some("key")).is_ok()); + assert!(create_provider("moonshot-cn", Some("key")).is_ok()); + assert!(create_provider("kimi-intl", Some("key")).is_ok()); + assert!(create_provider("kimi-cn", Some("key")).is_ok()); } #[test] @@ -214,17 +1004,29 @@ mod tests { fn factory_zai() { assert!(create_provider("zai", Some("key")).is_ok()); assert!(create_provider("z.ai", Some("key")).is_ok()); + assert!(create_provider("zai-global", Some("key")).is_ok()); + assert!(create_provider("z.ai-global", Some("key")).is_ok()); + assert!(create_provider("zai-cn", Some("key")).is_ok()); + assert!(create_provider("z.ai-cn", Some("key")).is_ok()); } #[test] fn factory_glm() { assert!(create_provider("glm", Some("key")).is_ok()); assert!(create_provider("zhipu", Some("key")).is_ok()); + assert!(create_provider("glm-cn", Some("key")).is_ok()); + assert!(create_provider("zhipu-cn", Some("key")).is_ok()); + assert!(create_provider("glm-global", Some("key")).is_ok()); + assert!(create_provider("bigmodel", Some("key")).is_ok()); } #[test] fn factory_minimax() { assert!(create_provider("minimax", Some("key")).is_ok()); + assert!(create_provider("minimax-intl", Some("key")).is_ok()); + assert!(create_provider("minimax-io", Some("key")).is_ok()); + assert!(create_provider("minimax-cn", Some("key")).is_ok()); + assert!(create_provider("minimaxi", Some("key")).is_ok()); } #[test] @@ -239,6 +1041,27 @@ mod tests { assert!(create_provider("baidu", Some("key")).is_ok()); } + #[test] + fn factory_qwen() { + assert!(create_provider("qwen", Some("key")).is_ok()); + assert!(create_provider("dashscope", Some("key")).is_ok()); + assert!(create_provider("qwen-cn", Some("key")).is_ok()); + assert!(create_provider("dashscope-cn", Some("key")).is_ok()); + assert!(create_provider("qwen-intl", Some("key")).is_ok()); + assert!(create_provider("dashscope-intl", Some("key")).is_ok()); + assert!(create_provider("qwen-international", Some("key")).is_ok()); + assert!(create_provider("dashscope-international", Some("key")).is_ok()); + assert!(create_provider("qwen-us", Some("key")).is_ok()); + assert!(create_provider("dashscope-us", Some("key")).is_ok()); + } + + #[test] + fn factory_lmstudio() { + assert!(create_provider("lmstudio", Some("key")).is_ok()); + assert!(create_provider("lm-studio", Some("key")).is_ok()); + assert!(create_provider("lmstudio", None).is_ok()); + } + // ── Extended ecosystem ─────────────────────────────────── #[test] @@ -284,6 +1107,26 @@ mod tests { assert!(create_provider("cohere", Some("key")).is_ok()); } + #[test] + fn factory_copilot() { + assert!(create_provider("copilot", Some("key")).is_ok()); + assert!(create_provider("github-copilot", Some("key")).is_ok()); + } + + #[test] + fn factory_nvidia() { + assert!(create_provider("nvidia", Some("nvapi-test")).is_ok()); + assert!(create_provider("nvidia-nim", Some("nvapi-test")).is_ok()); + assert!(create_provider("build.nvidia.com", Some("nvapi-test")).is_ok()); + } + + // ── AI inference routers ───────────────────────────────── + + #[test] + fn factory_astrai() { + assert!(create_provider("astrai", Some("sk-astrai-test")).is_ok()); + } + // ── Custom / BYOP provider ───────────────────────────── #[test] @@ -315,6 +1158,87 @@ mod tests { } } + #[test] + fn factory_custom_invalid_url_errors() { + match create_provider("custom:not-a-url", None) { + Err(e) => assert!( + e.to_string().contains("requires a valid URL"), + "Expected 'requires a valid URL', got: {e}" + ), + Ok(_) => panic!("Expected error for invalid custom URL"), + } + } + + #[test] + fn factory_custom_unsupported_scheme_errors() { + match create_provider("custom:ftp://example.com", None) { + Err(e) => assert!( + e.to_string().contains("http:// or https://"), + "Expected scheme validation error, got: {e}" + ), + Ok(_) => panic!("Expected error for unsupported custom URL scheme"), + } + } + + #[test] + fn factory_custom_trims_whitespace() { + let p = create_provider("custom: https://my-llm.example.com ", Some("key")); + assert!(p.is_ok()); + } + + // ── Anthropic-compatible custom endpoints ───────────────── + + #[test] + fn factory_anthropic_custom_url() { + let p = create_provider("anthropic-custom:https://api.example.com", Some("key")); + assert!(p.is_ok()); + } + + #[test] + fn factory_anthropic_custom_trailing_slash() { + let p = create_provider("anthropic-custom:https://api.example.com/", Some("key")); + assert!(p.is_ok()); + } + + #[test] + fn factory_anthropic_custom_no_key() { + let p = create_provider("anthropic-custom:https://api.example.com", None); + assert!(p.is_ok()); + } + + #[test] + fn factory_anthropic_custom_empty_url_errors() { + match create_provider("anthropic-custom:", None) { + Err(e) => assert!( + e.to_string().contains("requires a URL"), + "Expected 'requires a URL', got: {e}" + ), + Ok(_) => panic!("Expected error for empty anthropic-custom URL"), + } + } + + #[test] + fn factory_anthropic_custom_invalid_url_errors() { + match create_provider("anthropic-custom:not-a-url", None) { + Err(e) => assert!( + e.to_string().contains("requires a valid URL"), + "Expected 'requires a valid URL', got: {e}" + ), + Ok(_) => panic!("Expected error for invalid anthropic-custom URL"), + } + } + + #[test] + fn factory_anthropic_custom_unsupported_scheme_errors() { + match create_provider("anthropic-custom:ftp://example.com", None) { + Err(e) => assert!( + e.to_string().contains("http:// or https://"), + "Expected scheme validation error, got: {e}" + ), + Ok(_) => panic!("Expected error for unsupported anthropic-custom URL scheme"), + } + } + // ── Error cases ────────────────────────────────────────── #[test] @@ -342,23 +1266,48 @@ mod tests { "openai".into(), "openai".into(), ], + api_keys: Vec::new(), + model_fallbacks: std::collections::HashMap::new(), channel_initial_backoff_secs: 2, channel_max_backoff_secs: 60, scheduler_poll_secs: 15, scheduler_retries: 2, }; - let provider = create_resilient_provider("openrouter", Some("sk-test"), &reliability); + let provider = create_resilient_provider( + "openrouter", + Some("provider-test-credential"), + None, + &reliability, + ); assert!(provider.is_ok()); } #[test] fn resilient_provider_errors_for_invalid_primary() { let reliability = crate::config::ReliabilityConfig::default(); - let provider = create_resilient_provider("totally-invalid", Some("sk-test"), &reliability); + let provider = create_resilient_provider( + "totally-invalid", + Some("provider-test-credential"), + None, + &reliability, + ); assert!(provider.is_err()); } + #[test] + fn ollama_with_custom_url() { + let provider = create_provider_with_url("ollama", None, Some("http://10.100.2.32:11434")); + assert!(provider.is_ok()); + } + + #[test] + fn ollama_cloud_with_custom_url() { + let provider = + create_provider_with_url("ollama", Some("ollama-key"), Some("https://ollama.com")); + assert!(provider.is_ok()); + } + #[test] fn factory_all_providers_create_successfully() { let providers = [ @@ -366,17 +1315,28 @@ mod tests { "anthropic", "openai", "ollama", + "gemini", "venice", "vercel", "cloudflare", "moonshot", + "moonshot-intl", + "moonshot-cn", "synthetic", "opencode", "zai", + "zai-cn", "glm", + "glm-cn", "minimax", + "minimax-cn", "bedrock", "qianfan", + "qwen", + "qwen-intl", + "qwen-cn", + "qwen-us", + "lmstudio", "groq", "mistral", "xai", @@ -385,6 +1345,8 @@ mod tests { "fireworks", "perplexity", "cohere", + "copilot", + "nvidia", ]; for name in providers { assert!( @@ -393,4 +1355,169 @@ mod tests { ); } } + + #[test] + fn listed_providers_have_unique_ids_and_aliases() { + let providers = list_providers(); + let mut canonical_ids = std::collections::HashSet::new(); + let mut aliases = std::collections::HashSet::new(); + + for provider in providers { + assert!( + canonical_ids.insert(provider.name), + "Duplicate canonical provider id: {}", + provider.name + ); + + for alias in provider.aliases { + assert_ne!( + *alias, provider.name, + "Alias must differ from canonical id: {}", + provider.name + ); + assert!( + !canonical_ids.contains(alias), + "Alias conflicts with canonical provider id: {}", + alias + ); + assert!(aliases.insert(alias), "Duplicate provider alias: {}", alias); + } + } + } + + #[test] + fn listed_providers_and_aliases_are_constructible() { + for provider in list_providers() { + assert!( + create_provider(provider.name, Some("provider-test-credential")).is_ok(), + "Canonical provider id should be constructible: {}", + provider.name + ); + + for alias in provider.aliases { + assert!( + create_provider(alias, Some("provider-test-credential")).is_ok(), + "Provider alias should be constructible: {} (for {})", + alias, + provider.name + ); + } + } + } + + // ── API error sanitization ─────────────────────────────── + + #[test] + fn sanitize_scrubs_sk_prefix() { + let input = "request failed: sk-1234567890abcdef"; + let out = sanitize_api_error(input); + assert!(!out.contains("sk-1234567890abcdef")); + assert!(out.contains("[REDACTED]")); + } + + #[test] + fn sanitize_scrubs_multiple_prefixes() { + let input = "keys sk-abcdef xoxb-12345 xoxp-67890"; + let out = sanitize_api_error(input); + assert!(!out.contains("sk-abcdef")); + assert!(!out.contains("xoxb-12345")); + assert!(!out.contains("xoxp-67890")); + } + + #[test] + fn sanitize_short_prefix_then_real_key() { + let input = "error with sk- prefix and key sk-1234567890"; + let result = sanitize_api_error(input); + assert!(!result.contains("sk-1234567890")); + assert!(result.contains("[REDACTED]")); + } + + #[test] + fn sanitize_sk_proj_comment_then_real_key() { + let input = "note: sk- then sk-proj-abc123def456"; + let result = sanitize_api_error(input); + assert!(!result.contains("sk-proj-abc123def456")); + assert!(result.contains("[REDACTED]")); + } + + #[test] + fn sanitize_keeps_bare_prefix() { + let input = "only prefix sk- present"; + let result = sanitize_api_error(input); + assert!(result.contains("sk-")); + } + + #[test] + fn sanitize_handles_json_wrapped_key() { + let input = r#"{"error":"invalid key sk-abc123xyz"}"#; + let result = sanitize_api_error(input); + assert!(!result.contains("sk-abc123xyz")); + } + + #[test] + fn sanitize_handles_delimiter_boundaries() { + let input = "bad token xoxb-abc123}; next"; + let result = sanitize_api_error(input); + assert!(!result.contains("xoxb-abc123")); + assert!(result.contains("};")); + } + + #[test] + fn sanitize_truncates_long_error() { + let long = "a".repeat(400); + let result = sanitize_api_error(&long); + assert!(result.len() <= 203); + assert!(result.ends_with("...")); + } + + #[test] + fn sanitize_truncates_after_scrub() { + let input = format!("{} sk-abcdef123456 {}", "a".repeat(190), "b".repeat(190)); + let result = sanitize_api_error(&input); + assert!(!result.contains("sk-abcdef123456")); + assert!(result.len() <= 203); + } + + #[test] + fn sanitize_preserves_unicode_boundaries() { + let input = format!("{} sk-abcdef123", "hello🙂".repeat(80)); + let result = sanitize_api_error(&input); + assert!(std::str::from_utf8(result.as_bytes()).is_ok()); + assert!(!result.contains("sk-abcdef123")); + } + + #[test] + fn sanitize_no_secret_no_change() { + let input = "simple upstream timeout"; + let result = sanitize_api_error(input); + assert_eq!(result, input); + } + + #[test] + fn scrub_github_personal_access_token() { + let input = "auth failed with token ghp_abc123def456"; + let result = scrub_secret_patterns(input); + assert_eq!(result, "auth failed with token [REDACTED]"); + } + + #[test] + fn scrub_github_oauth_token() { + let input = "Bearer gho_1234567890abcdef"; + let result = scrub_secret_patterns(input); + assert_eq!(result, "Bearer [REDACTED]"); + } + + #[test] + fn scrub_github_user_token() { + let input = "token ghu_sessiontoken123"; + let result = scrub_secret_patterns(input); + assert_eq!(result, "token [REDACTED]"); + } + + #[test] + fn scrub_github_fine_grained_pat() { + let input = "failed: github_pat_11AABBC_xyzzy789"; + let result = scrub_secret_patterns(input); + assert_eq!(result, "failed: [REDACTED]"); + } } diff --git a/src/providers/ollama.rs b/src/providers/ollama.rs index adc3e6e..498aa0c 100644 --- a/src/providers/ollama.rs +++ b/src/providers/ollama.rs @@ -5,9 +5,12 @@ use serde::{Deserialize, Serialize}; pub struct OllamaProvider { base_url: String, + api_key: Option, client: Client, } +// ─── Request Structures ─────────────────────────────────────────────────────── + #[derive(Debug, Serialize)] struct ChatRequest { model: String, @@ -27,30 +30,231 @@ struct Options { temperature: f64, } +// ─── Response Structures ────────────────────────────────────────────────────── + #[derive(Debug, Deserialize)] -struct ChatResponse { +struct ApiChatResponse { message: ResponseMessage, } #[derive(Debug, Deserialize)] struct ResponseMessage { + #[serde(default)] content: String, + #[serde(default)] + tool_calls: Vec, + /// Some models return a "thinking" field with internal reasoning + #[serde(default)] + thinking: Option, } +#[derive(Debug, Deserialize)] +struct OllamaToolCall { + id: Option, + function: OllamaFunction, +} + +#[derive(Debug, Deserialize)] +struct OllamaFunction { + name: String, + #[serde(default)] + arguments: serde_json::Value, +} + +// ─── Implementation ─────────────────────────────────────────────────────────── + impl OllamaProvider { - pub fn new(base_url: Option<&str>) -> Self { + pub fn new(base_url: Option<&str>, api_key: Option<&str>) -> Self { + let api_key = api_key.and_then(|value| { + let trimmed = value.trim(); + (!trimmed.is_empty()).then(|| trimmed.to_string()) + }); + Self { base_url: base_url .unwrap_or("http://localhost:11434") .trim_end_matches('/') .to_string(), + api_key, client: Client::builder() - .timeout(std::time::Duration::from_secs(300)) // Ollama runs locally, may be slow + .timeout(std::time::Duration::from_secs(300)) .connect_timeout(std::time::Duration::from_secs(10)) .build() .unwrap_or_else(|_| Client::new()), } } + + fn is_local_endpoint(&self) -> bool { + reqwest::Url::parse(&self.base_url) + .ok() + .and_then(|url| url.host_str().map(|host| host.to_string())) + .is_some_and(|host| matches!(host.as_str(), "localhost" | "127.0.0.1" | "::1")) + } + + fn resolve_request_details(&self, model: &str) -> anyhow::Result<(String, bool)> { + let requests_cloud = model.ends_with(":cloud"); + let normalized_model = model.strip_suffix(":cloud").unwrap_or(model).to_string(); + + if requests_cloud && self.is_local_endpoint() { + anyhow::bail!( + "Model '{}' requested cloud routing, but Ollama endpoint is local. Configure api_url with a remote Ollama endpoint.", + model + ); + } + + if requests_cloud && self.api_key.is_none() { + anyhow::bail!( + "Model '{}' requested cloud routing, but no API key is configured. Set OLLAMA_API_KEY or config api_key.", + model + ); + } + + let should_auth = self.api_key.is_some() && !self.is_local_endpoint(); + + Ok((normalized_model, should_auth)) + } + + /// Send a request to Ollama and get the parsed response + async fn send_request( + &self, + messages: Vec, + model: &str, + temperature: f64, + should_auth: bool, + ) -> anyhow::Result { + let request = ChatRequest { + model: model.to_string(), + messages, + stream: false, + options: Options { temperature }, + }; + + let url = format!("{}/api/chat", self.base_url); + + tracing::debug!( + "Ollama request: url={} model={} message_count={} temperature={}", + url, + model, + request.messages.len(), + temperature + ); + + let mut request_builder = self.client.post(&url).json(&request); + + if should_auth { + if let Some(key) = self.api_key.as_ref() { + request_builder = request_builder.bearer_auth(key); + } + } + + let response = request_builder.send().await?; + let status = response.status(); + tracing::debug!("Ollama response status: {}", status); + + let body = response.bytes().await?; + tracing::debug!("Ollama response body length: {} bytes", body.len()); + + if !status.is_success() { + let raw = String::from_utf8_lossy(&body); + let sanitized = super::sanitize_api_error(&raw); + tracing::error!( + "Ollama error response: status={} body_excerpt={}", + status, + sanitized + ); + anyhow::bail!( + "Ollama API error ({}): {}. Is Ollama running? (brew install ollama && ollama serve)", + status, + sanitized + ); + } + + let chat_response: ApiChatResponse = match serde_json::from_slice(&body) { + Ok(r) => r, + Err(e) => { + let raw = String::from_utf8_lossy(&body); + let sanitized = super::sanitize_api_error(&raw); + tracing::error!( + "Ollama response deserialization failed: {e}. body_excerpt={}", + sanitized + ); + anyhow::bail!("Failed to parse Ollama response: {e}"); + } + }; + + Ok(chat_response) + } + + /// Convert Ollama tool calls to the JSON format expected by parse_tool_calls in loop_.rs + /// + /// Handles quirky model behavior where tool calls are wrapped: + /// - `{"name": "tool_call", "arguments": {"name": "shell", "arguments": {...}}}` + /// - `{"name": "tool.shell", "arguments": {...}}` + fn format_tool_calls_for_loop(&self, tool_calls: &[OllamaToolCall]) -> String { + let formatted_calls: Vec = tool_calls + .iter() + .map(|tc| { + let (tool_name, tool_args) = self.extract_tool_name_and_args(tc); + + // Arguments must be a JSON string for parse_tool_calls compatibility + let args_str = + serde_json::to_string(&tool_args).unwrap_or_else(|_| "{}".to_string()); + + serde_json::json!({ + "id": tc.id, + "type": "function", + "function": { + "name": tool_name, + "arguments": args_str + } + }) + }) + .collect(); + + serde_json::json!({ + "content": "", + "tool_calls": formatted_calls + }) + .to_string() + } + + /// Extract the actual tool name and arguments from potentially nested structures + fn extract_tool_name_and_args(&self, tc: &OllamaToolCall) -> (String, serde_json::Value) { + let name = &tc.function.name; + let args = &tc.function.arguments; + + // Pattern 1: Nested tool_call wrapper (various malformed versions) + // {"name": "tool_call", "arguments": {"name": "shell", "arguments": {"command": "date"}}} + // {"name": "tool_call>") + || name.starts_with("tool_call<") + { + if let Some(nested_name) = args.get("name").and_then(|v| v.as_str()) { + let nested_args = args + .get("arguments") + .cloned() + .unwrap_or(serde_json::json!({})); + tracing::debug!( + "Unwrapped nested tool call: {} -> {} with args {:?}", + name, + nested_name, + nested_args + ); + return (nested_name.to_string(), nested_args); + } + } + + // Pattern 2: Prefixed tool name (tool.shell, tool.file_read, etc.) + if let Some(stripped) = name.strip_prefix("tool.") { + return (stripped.to_string(), args.clone()); + } + + // Pattern 3: Normal tool call + (name.clone(), args.clone()) + } } #[async_trait] @@ -62,6 +266,8 @@ impl Provider for OllamaProvider { model: &str, temperature: f64, ) -> anyhow::Result { + let (normalized_model, should_auth) = self.resolve_request_details(model)?; + let mut messages = Vec::new(); if let Some(sys) = system_prompt { @@ -76,115 +282,281 @@ impl Provider for OllamaProvider { content: message.to_string(), }); - let request = ChatRequest { - model: model.to_string(), - messages, - stream: false, - options: Options { temperature }, - }; + let response = self + .send_request(messages, &normalized_model, temperature, should_auth) + .await?; - let url = format!("{}/api/chat", self.base_url); - - let response = self.client.post(&url).json(&request).send().await?; - - if !response.status().is_success() { - let error = response.text().await?; - anyhow::bail!( - "Ollama error: {error}. Is Ollama running? (brew install ollama && ollama serve)" + // If model returned tool calls, format them for loop_.rs's parse_tool_calls + if !response.message.tool_calls.is_empty() { + tracing::debug!( + "Ollama returned {} tool call(s), formatting for loop parser", + response.message.tool_calls.len() ); + return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls)); } - let chat_response: ChatResponse = response.json().await?; - Ok(chat_response.message.content) + // Plain text response + let content = response.message.content; + + // Handle edge case: model returned only "thinking" with no content or tool calls + if content.is_empty() { + if let Some(thinking) = &response.message.thinking { + tracing::warn!( + "Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.", + if thinking.len() > 100 { &thinking[..100] } else { thinking } + ); + return Ok(format!( + "I was thinking about this: {}... but I didn't complete my response. Could you try asking again?", + if thinking.len() > 200 { &thinking[..200] } else { thinking } + )); + } + tracing::warn!("Ollama returned empty content with no tool calls"); + } + + Ok(content) + } + + async fn chat_with_history( + &self, + messages: &[crate::providers::ChatMessage], + model: &str, + temperature: f64, + ) -> anyhow::Result { + let (normalized_model, should_auth) = self.resolve_request_details(model)?; + + let api_messages: Vec = messages + .iter() + .map(|m| Message { + role: m.role.clone(), + content: m.content.clone(), + }) + .collect(); + + let response = self + .send_request(api_messages, &normalized_model, temperature, should_auth) + .await?; + + // If model returned tool calls, format them for loop_.rs's parse_tool_calls + if !response.message.tool_calls.is_empty() { + tracing::debug!( + "Ollama returned {} tool call(s), formatting for loop parser", + response.message.tool_calls.len() + ); + return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls)); + } + + // Plain text response + let content = response.message.content; + + // Handle edge case: model returned only "thinking" with no content or tool calls + // This is a model quirk - it stopped after reasoning without producing output + if content.is_empty() { + if let Some(thinking) = &response.message.thinking { + tracing::warn!( + "Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.", + if thinking.len() > 100 { &thinking[..100] } else { thinking } + ); + // Return a message indicating the model's thought process but no action + return Ok(format!( + "I was thinking about this: {}... but I didn't complete my response. Could you try asking again?", + if thinking.len() > 200 { &thinking[..200] } else { thinking } + )); + } + tracing::warn!("Ollama returned empty content with no tool calls"); + } + + Ok(content) + } + + fn supports_native_tools(&self) -> bool { + // Return false since loop_.rs uses XML-style tool parsing via system prompt + // The model may return native tool_calls but we convert them to JSON format + // that parse_tool_calls() understands + false } } +// ─── Tests ──────────────────────────────────────────────────────────────────── + #[cfg(test)] mod tests { use super::*; #[test] fn default_url() { - let p = OllamaProvider::new(None); + let p = OllamaProvider::new(None, None); assert_eq!(p.base_url, "http://localhost:11434"); } #[test] fn custom_url_trailing_slash() { - let p = OllamaProvider::new(Some("http://192.168.1.100:11434/")); + let p = OllamaProvider::new(Some("http://192.168.1.100:11434/"), None); assert_eq!(p.base_url, "http://192.168.1.100:11434"); } #[test] fn custom_url_no_trailing_slash() { - let p = OllamaProvider::new(Some("http://myserver:11434")); + let p = OllamaProvider::new(Some("http://myserver:11434"), None); assert_eq!(p.base_url, "http://myserver:11434"); } #[test] fn empty_url_uses_empty() { - let p = OllamaProvider::new(Some("")); + let p = OllamaProvider::new(Some(""), None); assert_eq!(p.base_url, ""); } #[test] - fn request_serializes_with_system() { - let req = ChatRequest { - model: "llama3".to_string(), - messages: vec![ - Message { - role: "system".to_string(), - content: "You are ZeroClaw".to_string(), - }, - Message { - role: "user".to_string(), - content: "hello".to_string(), - }, - ], - stream: false, - options: Options { temperature: 0.7 }, - }; - let json = serde_json::to_string(&req).unwrap(); - assert!(json.contains("\"stream\":false")); - assert!(json.contains("llama3")); - assert!(json.contains("system")); - assert!(json.contains("\"temperature\":0.7")); + fn cloud_suffix_strips_model_name() { + let p = OllamaProvider::new(Some("https://ollama.com"), Some("ollama-key")); + let (model, should_auth) = p.resolve_request_details("qwen3:cloud").unwrap(); + assert_eq!(model, "qwen3"); + assert!(should_auth); } #[test] - fn request_serializes_without_system() { - let req = ChatRequest { - model: "mistral".to_string(), - messages: vec![Message { - role: "user".to_string(), - content: "test".to_string(), - }], - stream: false, - options: Options { temperature: 0.0 }, - }; - let json = serde_json::to_string(&req).unwrap(); - assert!(!json.contains("\"role\":\"system\"")); - assert!(json.contains("mistral")); + fn cloud_suffix_with_local_endpoint_errors() { + let p = OllamaProvider::new(None, Some("ollama-key")); + let error = p + .resolve_request_details("qwen3:cloud") + .expect_err("cloud suffix should fail on local endpoint"); + assert!(error + .to_string() + .contains("requested cloud routing, but Ollama endpoint is local")); + } + + #[test] + fn cloud_suffix_without_api_key_errors() { + let p = OllamaProvider::new(Some("https://ollama.com"), None); + let error = p + .resolve_request_details("qwen3:cloud") + .expect_err("cloud suffix should require API key"); + assert!(error + .to_string() + .contains("requested cloud routing, but no API key is configured")); + } + + #[test] + fn remote_endpoint_auth_enabled_when_key_present() { + let p = OllamaProvider::new(Some("https://ollama.com"), Some("ollama-key")); + let (_model, should_auth) = p.resolve_request_details("qwen3").unwrap(); + assert!(should_auth); + } + + #[test] + fn local_endpoint_auth_disabled_even_with_key() { + let p = OllamaProvider::new(None, Some("ollama-key")); + let (_model, should_auth) = p.resolve_request_details("llama3").unwrap(); + assert!(!should_auth); } #[test] fn response_deserializes() { let json = r#"{"message":{"role":"assistant","content":"Hello from Ollama!"}}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); assert_eq!(resp.message.content, "Hello from Ollama!"); } #[test] fn response_with_empty_content() { let json = r#"{"message":{"role":"assistant","content":""}}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); assert!(resp.message.content.is_empty()); } #[test] - fn response_with_multiline() { - let json = r#"{"message":{"role":"assistant","content":"line1\nline2\nline3"}}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); - assert!(resp.message.content.contains("line1")); + fn response_with_missing_content_defaults_to_empty() { + let json = r#"{"message":{"role":"assistant"}}"#; + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); + assert!(resp.message.content.is_empty()); + } + + #[test] + fn response_with_thinking_field_extracts_content() { + let json = + r#"{"message":{"role":"assistant","content":"hello","thinking":"internal reasoning"}}"#; + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); + assert_eq!(resp.message.content, "hello"); + } + + #[test] + fn response_with_tool_calls_parses_correctly() { + let json = r#"{"message":{"role":"assistant","content":"","tool_calls":[{"id":"call_123","function":{"name":"shell","arguments":{"command":"date"}}}]}}"#; + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); + assert!(resp.message.content.is_empty()); + assert_eq!(resp.message.tool_calls.len(), 1); + assert_eq!(resp.message.tool_calls[0].function.name, "shell"); + } + + #[test] + fn extract_tool_name_handles_nested_tool_call() { + let provider = OllamaProvider::new(None, None); + let tc = OllamaToolCall { + id: Some("call_123".into()), + function: OllamaFunction { + name: "tool_call".into(), + arguments: serde_json::json!({ + "name": "shell", + "arguments": {"command": "date"} + }), + }, + }; + let (name, args) = provider.extract_tool_name_and_args(&tc); + assert_eq!(name, "shell"); + assert_eq!(args.get("command").unwrap(), "date"); + } + + #[test] + fn extract_tool_name_handles_prefixed_name() { + let provider = OllamaProvider::new(None, None); + let tc = OllamaToolCall { + id: Some("call_123".into()), + function: OllamaFunction { + name: "tool.shell".into(), + arguments: serde_json::json!({"command": "ls"}), + }, + }; + let (name, args) = provider.extract_tool_name_and_args(&tc); + assert_eq!(name, "shell"); + assert_eq!(args.get("command").unwrap(), "ls"); + } + + #[test] + fn extract_tool_name_handles_normal_call() { + let provider = OllamaProvider::new(None, None); + let tc = OllamaToolCall { + id: Some("call_123".into()), + function: OllamaFunction { + name: "file_read".into(), + arguments: serde_json::json!({"path": "/tmp/test"}), + }, + }; + let (name, args) = provider.extract_tool_name_and_args(&tc); + assert_eq!(name, "file_read"); + assert_eq!(args.get("path").unwrap(), "/tmp/test"); + } + + #[test] + fn format_tool_calls_produces_valid_json() { + let provider = OllamaProvider::new(None, None); + let tool_calls = vec![OllamaToolCall { + id: Some("call_abc".into()), + function: OllamaFunction { + name: "shell".into(), + arguments: serde_json::json!({"command": "date"}), + }, + }]; + + let formatted = provider.format_tool_calls_for_loop(&tool_calls); + let parsed: serde_json::Value = serde_json::from_str(&formatted).unwrap(); + + assert!(parsed.get("tool_calls").is_some()); + let calls = parsed.get("tool_calls").unwrap().as_array().unwrap(); + assert_eq!(calls.len(), 1); + + let func = calls[0].get("function").unwrap(); + assert_eq!(func.get("name").unwrap(), "shell"); + // arguments should be a string (JSON-encoded) + assert!(func.get("arguments").unwrap().is_string()); } } diff --git a/src/providers/openai.rs b/src/providers/openai.rs index 3481ce4..22b53ca 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -1,10 +1,14 @@ -use crate::providers::traits::Provider; +use crate::providers::traits::{ + ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse, + Provider, ToolCall as ProviderToolCall, +}; +use crate::tools::ToolSpec; use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; pub struct OpenAiProvider { - api_key: Option, + credential: Option, client: Client, } @@ -36,10 +40,79 @@ struct ResponseMessage { content: String, } +#[derive(Debug, Serialize)] +struct NativeChatRequest { + model: String, + messages: Vec, + temperature: f64, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + tool_choice: Option, +} + +#[derive(Debug, Serialize)] +struct NativeMessage { + role: String, + #[serde(skip_serializing_if = "Option::is_none")] + content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_call_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_calls: Option>, +} + +#[derive(Debug, Serialize)] +struct NativeToolSpec { + #[serde(rename = "type")] + kind: String, + function: NativeToolFunctionSpec, +} + +#[derive(Debug, Serialize)] +struct NativeToolFunctionSpec { + name: String, + description: String, + parameters: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize)] +struct NativeToolCall { + #[serde(skip_serializing_if = "Option::is_none")] + id: Option, + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + kind: Option, + function: NativeFunctionCall, +} + +#[derive(Debug, Serialize, Deserialize)] +struct NativeFunctionCall { + name: String, + arguments: String, +} + +#[derive(Debug, Deserialize)] +struct NativeChatResponse { + choices: Vec, +} + +#[derive(Debug, Deserialize)] +struct NativeChoice { + message: NativeResponseMessage, +} + +#[derive(Debug, Deserialize)] +struct NativeResponseMessage { + #[serde(default)] + content: Option, + #[serde(default)] + tool_calls: Option>, +} + impl OpenAiProvider { - pub fn new(api_key: Option<&str>) -> Self { + pub fn new(credential: Option<&str>) -> Self { Self { - api_key: api_key.map(ToString::to_string), + credential: credential.map(ToString::to_string), client: Client::builder() .timeout(std::time::Duration::from_secs(120)) .connect_timeout(std::time::Duration::from_secs(10)) @@ -47,6 +120,107 @@ impl OpenAiProvider { .unwrap_or_else(|_| Client::new()), } } + + fn convert_tools(tools: Option<&[ToolSpec]>) -> Option> { + tools.map(|items| { + items + .iter() + .map(|tool| NativeToolSpec { + kind: "function".to_string(), + function: NativeToolFunctionSpec { + name: tool.name.clone(), + description: tool.description.clone(), + parameters: tool.parameters.clone(), + }, + }) + .collect() + }) + } + + fn convert_messages(messages: &[ChatMessage]) -> Vec { + messages + .iter() + .map(|m| { + if m.role == "assistant" { + if let Ok(value) = serde_json::from_str::(&m.content) { + if let Some(tool_calls_value) = value.get("tool_calls") { + if let Ok(parsed_calls) = + serde_json::from_value::>( + tool_calls_value.clone(), + ) + { + let tool_calls = parsed_calls + .into_iter() + .map(|tc| NativeToolCall { + id: Some(tc.id), + kind: Some("function".to_string()), + function: NativeFunctionCall { + name: tc.name, + arguments: tc.arguments, + }, + }) + .collect::>(); + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + return NativeMessage { + role: "assistant".to_string(), + content, + tool_call_id: None, + tool_calls: Some(tool_calls), + }; + } + } + } + } + + if m.role == "tool" { + if let Ok(value) = serde_json::from_str::(&m.content) { + let tool_call_id = value + .get("tool_call_id") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + return NativeMessage { + role: "tool".to_string(), + content, + tool_call_id, + tool_calls: None, + }; + } + } + + NativeMessage { + role: m.role.clone(), + content: Some(m.content.clone()), + tool_call_id: None, + tool_calls: None, + } + }) + .collect() + } + + fn parse_native_response(message: NativeResponseMessage) -> ProviderChatResponse { + let tool_calls = message + .tool_calls + .unwrap_or_default() + .into_iter() + .map(|tc| ProviderToolCall { + id: tc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()), + name: tc.function.name, + arguments: tc.function.arguments, + }) + .collect::>(); + + ProviderChatResponse { + text: message.content, + tool_calls, + } + } } #[async_trait] @@ -58,7 +232,7 @@ impl Provider for OpenAiProvider { model: &str, temperature: f64, ) -> anyhow::Result { - let api_key = self.api_key.as_ref().ok_or_else(|| { + let credential = self.credential.as_ref().ok_or_else(|| { anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.") })?; @@ -85,14 +259,13 @@ impl Provider for OpenAiProvider { let response = self .client .post("https://api.openai.com/v1/chat/completions") - .header("Authorization", format!("Bearer {api_key}")) + .header("Authorization", format!("Bearer {credential}")) .json(&request) .send() .await?; if !response.status().is_success() { - let error = response.text().await?; - anyhow::bail!("OpenAI API error: {error}"); + return Err(super::api_error("OpenAI", response).await); } let chat_response: ChatResponse = response.json().await?; @@ -104,6 +277,51 @@ impl Provider for OpenAiProvider { .map(|c| c.message.content) .ok_or_else(|| anyhow::anyhow!("No response from OpenAI")) } + + async fn chat( + &self, + request: ProviderChatRequest<'_>, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let credential = self.credential.as_ref().ok_or_else(|| { + anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.") + })?; + + let tools = Self::convert_tools(request.tools); + let native_request = NativeChatRequest { + model: model.to_string(), + messages: Self::convert_messages(request.messages), + temperature, + tool_choice: tools.as_ref().map(|_| "auto".to_string()), + tools, + }; + + let response = self + .client + .post("https://api.openai.com/v1/chat/completions") + .header("Authorization", format!("Bearer {credential}")) + .json(&native_request) + .send() + .await?; + + if !response.status().is_success() { + return Err(super::api_error("OpenAI", response).await); + } + + let native_response: NativeChatResponse = response.json().await?; + let message = native_response + .choices + .into_iter() + .next() + .map(|c| c.message) + .ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?; + Ok(Self::parse_native_response(message)) + } + + fn supports_native_tools(&self) -> bool { + true + } } #[cfg(test)] @@ -112,20 +330,20 @@ mod tests { #[test] fn creates_with_key() { - let p = OpenAiProvider::new(Some("sk-proj-abc123")); - assert_eq!(p.api_key.as_deref(), Some("sk-proj-abc123")); + let p = OpenAiProvider::new(Some("openai-test-credential")); + assert_eq!(p.credential.as_deref(), Some("openai-test-credential")); } #[test] fn creates_without_key() { let p = OpenAiProvider::new(None); - assert!(p.api_key.is_none()); + assert!(p.credential.is_none()); } #[test] fn creates_with_empty_key() { let p = OpenAiProvider::new(Some("")); - assert_eq!(p.api_key.as_deref(), Some("")); + assert_eq!(p.credential.as_deref(), Some("")); } #[tokio::test] diff --git a/src/providers/openrouter.rs b/src/providers/openrouter.rs index 3d99481..b27bff4 100644 --- a/src/providers/openrouter.rs +++ b/src/providers/openrouter.rs @@ -1,10 +1,14 @@ -use crate::providers::traits::Provider; +use crate::providers::traits::{ + ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse, + Provider, ToolCall as ProviderToolCall, +}; +use crate::tools::ToolSpec; use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; pub struct OpenRouterProvider { - api_key: Option, + credential: Option, client: Client, } @@ -22,7 +26,7 @@ struct Message { } #[derive(Debug, Deserialize)] -struct ChatResponse { +struct ApiChatResponse { choices: Vec, } @@ -36,10 +40,79 @@ struct ResponseMessage { content: String, } +#[derive(Debug, Serialize)] +struct NativeChatRequest { + model: String, + messages: Vec, + temperature: f64, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + tool_choice: Option, +} + +#[derive(Debug, Serialize)] +struct NativeMessage { + role: String, + #[serde(skip_serializing_if = "Option::is_none")] + content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_call_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_calls: Option>, +} + +#[derive(Debug, Serialize)] +struct NativeToolSpec { + #[serde(rename = "type")] + kind: String, + function: NativeToolFunctionSpec, +} + +#[derive(Debug, Serialize)] +struct NativeToolFunctionSpec { + name: String, + description: String, + parameters: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize)] +struct NativeToolCall { + #[serde(skip_serializing_if = "Option::is_none")] + id: Option, + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + kind: Option, + function: NativeFunctionCall, +} + +#[derive(Debug, Serialize, Deserialize)] +struct NativeFunctionCall { + name: String, + arguments: String, +} + +#[derive(Debug, Deserialize)] +struct NativeChatResponse { + choices: Vec, +} + +#[derive(Debug, Deserialize)] +struct NativeChoice { + message: NativeResponseMessage, +} + +#[derive(Debug, Deserialize)] +struct NativeResponseMessage { + #[serde(default)] + content: Option, + #[serde(default)] + tool_calls: Option>, +} + impl OpenRouterProvider { - pub fn new(api_key: Option<&str>) -> Self { + pub fn new(credential: Option<&str>) -> Self { Self { - api_key: api_key.map(ToString::to_string), + credential: credential.map(ToString::to_string), client: Client::builder() .timeout(std::time::Duration::from_secs(120)) .connect_timeout(std::time::Duration::from_secs(10)) @@ -47,10 +120,129 @@ impl OpenRouterProvider { .unwrap_or_else(|_| Client::new()), } } + + fn convert_tools(tools: Option<&[ToolSpec]>) -> Option> { + let items = tools?; + if items.is_empty() { + return None; + } + Some( + items + .iter() + .map(|tool| NativeToolSpec { + kind: "function".to_string(), + function: NativeToolFunctionSpec { + name: tool.name.clone(), + description: tool.description.clone(), + parameters: tool.parameters.clone(), + }, + }) + .collect(), + ) + } + + fn convert_messages(messages: &[ChatMessage]) -> Vec { + messages + .iter() + .map(|m| { + if m.role == "assistant" { + if let Ok(value) = serde_json::from_str::(&m.content) { + if let Some(tool_calls_value) = value.get("tool_calls") { + if let Ok(parsed_calls) = + serde_json::from_value::>( + tool_calls_value.clone(), + ) + { + let tool_calls = parsed_calls + .into_iter() + .map(|tc| NativeToolCall { + id: Some(tc.id), + kind: Some("function".to_string()), + function: NativeFunctionCall { + name: tc.name, + arguments: tc.arguments, + }, + }) + .collect::>(); + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + return NativeMessage { + role: "assistant".to_string(), + content, + tool_call_id: None, + tool_calls: Some(tool_calls), + }; + } + } + } + } + + if m.role == "tool" { + if let Ok(value) = serde_json::from_str::(&m.content) { + let tool_call_id = value + .get("tool_call_id") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + return NativeMessage { + role: "tool".to_string(), + content, + tool_call_id, + tool_calls: None, + }; + } + } + + NativeMessage { + role: m.role.clone(), + content: Some(m.content.clone()), + tool_call_id: None, + tool_calls: None, + } + }) + .collect() + } + + fn parse_native_response(message: NativeResponseMessage) -> ProviderChatResponse { + let tool_calls = message + .tool_calls + .unwrap_or_default() + .into_iter() + .map(|tc| ProviderToolCall { + id: tc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()), + name: tc.function.name, + arguments: tc.function.arguments, + }) + .collect::>(); + + ProviderChatResponse { + text: message.content, + tool_calls, + } + } } #[async_trait] impl Provider for OpenRouterProvider { + async fn warmup(&self) -> anyhow::Result<()> { + // Hit a lightweight endpoint to establish TLS + HTTP/2 connection pool. + // This prevents the first real chat request from timing out on cold start. + if let Some(credential) = self.credential.as_ref() { + self.client + .get("https://openrouter.ai/api/v1/auth/key") + .header("Authorization", format!("Bearer {credential}")) + .send() + .await? + .error_for_status()?; + } + Ok(()) + } + async fn chat_with_system( &self, system_prompt: Option<&str>, @@ -58,7 +250,7 @@ impl Provider for OpenRouterProvider { model: &str, temperature: f64, ) -> anyhow::Result { - let api_key = self.api_key.as_ref() + let credential = self.credential.as_ref() .ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?; let mut messages = Vec::new(); @@ -84,7 +276,7 @@ impl Provider for OpenRouterProvider { let response = self .client .post("https://openrouter.ai/api/v1/chat/completions") - .header("Authorization", format!("Bearer {api_key}")) + .header("Authorization", format!("Bearer {credential}")) .header( "HTTP-Referer", "https://github.com/theonlyhennygod/zeroclaw", @@ -95,11 +287,10 @@ impl Provider for OpenRouterProvider { .await?; if !response.status().is_success() { - let error = response.text().await?; - anyhow::bail!("OpenRouter API error: {error}"); + return Err(super::api_error("OpenRouter", response).await); } - let chat_response: ChatResponse = response.json().await?; + let chat_response: ApiChatResponse = response.json().await?; chat_response .choices @@ -108,4 +299,455 @@ impl Provider for OpenRouterProvider { .map(|c| c.message.content) .ok_or_else(|| anyhow::anyhow!("No response from OpenRouter")) } + + async fn chat_with_history( + &self, + messages: &[ChatMessage], + model: &str, + temperature: f64, + ) -> anyhow::Result { + let credential = self.credential.as_ref() + .ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?; + + let api_messages: Vec = messages + .iter() + .map(|m| Message { + role: m.role.clone(), + content: m.content.clone(), + }) + .collect(); + + let request = ChatRequest { + model: model.to_string(), + messages: api_messages, + temperature, + }; + + let response = self + .client + .post("https://openrouter.ai/api/v1/chat/completions") + .header("Authorization", format!("Bearer {credential}")) + .header( + "HTTP-Referer", + "https://github.com/theonlyhennygod/zeroclaw", + ) + .header("X-Title", "ZeroClaw") + .json(&request) + .send() + .await?; + + if !response.status().is_success() { + return Err(super::api_error("OpenRouter", response).await); + } + + let chat_response: ApiChatResponse = response.json().await?; + + chat_response + .choices + .into_iter() + .next() + .map(|c| c.message.content) + .ok_or_else(|| anyhow::anyhow!("No response from OpenRouter")) + } + + async fn chat( + &self, + request: ProviderChatRequest<'_>, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let credential = self.credential.as_ref().ok_or_else(|| { + anyhow::anyhow!( + "OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var." + ) + })?; + + let tools = Self::convert_tools(request.tools); + let native_request = NativeChatRequest { + model: model.to_string(), + messages: Self::convert_messages(request.messages), + temperature, + tool_choice: tools.as_ref().map(|_| "auto".to_string()), + tools, + }; + + let response = self + .client + .post("https://openrouter.ai/api/v1/chat/completions") + .header("Authorization", format!("Bearer {credential}")) + .header( + "HTTP-Referer", + "https://github.com/theonlyhennygod/zeroclaw", + ) + .header("X-Title", "ZeroClaw") + .json(&native_request) + .send() + .await?; + + if !response.status().is_success() { + return Err(super::api_error("OpenRouter", response).await); + } + + let native_response: NativeChatResponse = response.json().await?; + let message = native_response + .choices + .into_iter() + .next() + .map(|c| c.message) + .ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))?; + Ok(Self::parse_native_response(message)) + } + + fn supports_native_tools(&self) -> bool { + true + } + + async fn chat_with_tools( + &self, + messages: &[ChatMessage], + tools: &[serde_json::Value], + model: &str, + temperature: f64, + ) -> anyhow::Result { + let credential = self.credential.as_ref().ok_or_else(|| { + anyhow::anyhow!( + "OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var." + ) + })?; + + // Convert tool JSON values to NativeToolSpec + let native_tools: Option> = if tools.is_empty() { + None + } else { + let specs: Vec = tools + .iter() + .filter_map(|t| { + let func = t.get("function")?; + Some(NativeToolSpec { + kind: "function".to_string(), + function: NativeToolFunctionSpec { + name: func.get("name")?.as_str()?.to_string(), + description: func + .get("description") + .and_then(|d| d.as_str()) + .unwrap_or("") + .to_string(), + parameters: func + .get("parameters") + .cloned() + .unwrap_or(serde_json::json!({})), + }, + }) + }) + .collect(); + if specs.is_empty() { + None + } else { + Some(specs) + } + }; + + // Convert ChatMessage to NativeMessage, preserving structured assistant/tool entries + // when history contains native tool-call metadata. + let native_messages = Self::convert_messages(messages); + + let native_request = NativeChatRequest { + model: model.to_string(), + messages: native_messages, + temperature, + tool_choice: native_tools.as_ref().map(|_| "auto".to_string()), + tools: native_tools, + }; + + let response = self + .client + .post("https://openrouter.ai/api/v1/chat/completions") + .header("Authorization", format!("Bearer {credential}")) + .header( + "HTTP-Referer", + "https://github.com/theonlyhennygod/zeroclaw", + ) + .header("X-Title", "ZeroClaw") + .json(&native_request) + .send() + .await?; + + if !response.status().is_success() { + return Err(super::api_error("OpenRouter", response).await); + } + + let native_response: NativeChatResponse = response.json().await?; + let message = native_response + .choices + .into_iter() + .next() + .map(|c| c.message) + .ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))?; + Ok(Self::parse_native_response(message)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::providers::traits::{ChatMessage, Provider}; + + #[test] + fn creates_with_key() { + let provider = OpenRouterProvider::new(Some("openrouter-test-credential")); + assert_eq!( + provider.credential.as_deref(), + Some("openrouter-test-credential") + ); + } + + #[test] + fn creates_without_key() { + let provider = OpenRouterProvider::new(None); + assert!(provider.credential.is_none()); + } + + #[tokio::test] + async fn warmup_without_key_is_noop() { + let provider = OpenRouterProvider::new(None); + let result = provider.warmup().await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn chat_with_system_fails_without_key() { + let provider = OpenRouterProvider::new(None); + let result = provider + .chat_with_system(Some("system"), "hello", "openai/gpt-4o", 0.2) + .await; + + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("API key not set")); + } + + #[tokio::test] + async fn chat_with_history_fails_without_key() { + let provider = OpenRouterProvider::new(None); + let messages = vec![ + ChatMessage { + role: "system".into(), + content: "be concise".into(), + }, + ChatMessage { + role: "user".into(), + content: "hello".into(), + }, + ]; + + let result = provider + .chat_with_history(&messages, "anthropic/claude-sonnet-4", 0.7) + .await; + + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("API key not set")); + } + + #[test] + fn chat_request_serializes_with_system_and_user() { + let request = ChatRequest { + model: "anthropic/claude-sonnet-4".into(), + messages: vec![ + Message { + role: "system".into(), + content: "You are helpful".into(), + }, + Message { + role: "user".into(), + content: "Summarize this".into(), + }, + ], + temperature: 0.5, + }; + + let json = serde_json::to_string(&request).unwrap(); + + assert!(json.contains("anthropic/claude-sonnet-4")); + assert!(json.contains("\"role\":\"system\"")); + assert!(json.contains("\"role\":\"user\"")); + assert!(json.contains("\"temperature\":0.5")); + } + + #[test] + fn chat_request_serializes_history_messages() { + let messages = [ + ChatMessage { + role: "assistant".into(), + content: "Previous answer".into(), + }, + ChatMessage { + role: "user".into(), + content: "Follow-up".into(), + }, + ]; + + let request = ChatRequest { + model: "google/gemini-2.5-pro".into(), + messages: messages + .iter() + .map(|msg| Message { + role: msg.role.clone(), + content: msg.content.clone(), + }) + .collect(), + temperature: 0.0, + }; + + let json = serde_json::to_string(&request).unwrap(); + assert!(json.contains("\"role\":\"assistant\"")); + assert!(json.contains("\"role\":\"user\"")); + assert!(json.contains("google/gemini-2.5-pro")); + } + + #[test] + fn response_deserializes_single_choice() { + let json = r#"{"choices":[{"message":{"content":"Hi from OpenRouter"}}]}"#; + + let response: ApiChatResponse = serde_json::from_str(json).unwrap(); + + assert_eq!(response.choices.len(), 1); + assert_eq!(response.choices[0].message.content, "Hi from OpenRouter"); + } + + #[test] + fn response_deserializes_empty_choices() { + let json = r#"{"choices":[]}"#; + + let response: ApiChatResponse = serde_json::from_str(json).unwrap(); + + assert!(response.choices.is_empty()); + } + + #[tokio::test] + async fn chat_with_tools_fails_without_key() { + let provider = OpenRouterProvider::new(None); + let messages = vec![ChatMessage { + role: "user".into(), + content: "What is the date?".into(), + }]; + let tools = vec![serde_json::json!({ + "type": "function", + "function": { + "name": "shell", + "description": "Run a shell command", + "parameters": {"type": "object", "properties": {"command": {"type": "string"}}} + } + })]; + + let result = provider + .chat_with_tools(&messages, &tools, "deepseek/deepseek-chat", 0.5) + .await; + + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("API key not set")); + } + + #[test] + fn native_response_deserializes_with_tool_calls() { + let json = r#"{ + "choices":[{ + "message":{ + "content":null, + "tool_calls":[ + {"id":"call_123","type":"function","function":{"name":"get_price","arguments":"{\"symbol\":\"BTC\"}"}} + ] + } + }] + }"#; + + let response: NativeChatResponse = serde_json::from_str(json).unwrap(); + + assert_eq!(response.choices.len(), 1); + let message = &response.choices[0].message; + assert!(message.content.is_none()); + let tool_calls = message.tool_calls.as_ref().unwrap(); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].id.as_deref(), Some("call_123")); + assert_eq!(tool_calls[0].function.name, "get_price"); + assert_eq!(tool_calls[0].function.arguments, "{\"symbol\":\"BTC\"}"); + } + + #[test] + fn native_response_deserializes_with_text_and_tool_calls() { + let json = r#"{ + "choices":[{ + "message":{ + "content":"I'll get that for you.", + "tool_calls":[ + {"id":"call_456","type":"function","function":{"name":"shell","arguments":"{\"command\":\"date\"}"}} + ] + } + }] + }"#; + + let response: NativeChatResponse = serde_json::from_str(json).unwrap(); + + assert_eq!(response.choices.len(), 1); + let message = &response.choices[0].message; + assert_eq!(message.content.as_deref(), Some("I'll get that for you.")); + let tool_calls = message.tool_calls.as_ref().unwrap(); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].function.name, "shell"); + } + + #[test] + fn parse_native_response_converts_to_chat_response() { + let message = NativeResponseMessage { + content: Some("Here you go.".into()), + tool_calls: Some(vec![NativeToolCall { + id: Some("call_789".into()), + kind: Some("function".into()), + function: NativeFunctionCall { + name: "file_read".into(), + arguments: r#"{"path":"test.txt"}"#.into(), + }, + }]), + }; + + let response = OpenRouterProvider::parse_native_response(message); + + assert_eq!(response.text.as_deref(), Some("Here you go.")); + assert_eq!(response.tool_calls.len(), 1); + assert_eq!(response.tool_calls[0].id, "call_789"); + assert_eq!(response.tool_calls[0].name, "file_read"); + } + + #[test] + fn convert_messages_parses_assistant_tool_call_payload() { + let messages = vec![ChatMessage { + role: "assistant".into(), + content: r#"{"content":"Using tool","tool_calls":[{"id":"call_abc","name":"shell","arguments":"{\"command\":\"pwd\"}"}]}"# + .into(), + }]; + + let converted = OpenRouterProvider::convert_messages(&messages); + assert_eq!(converted.len(), 1); + assert_eq!(converted[0].role, "assistant"); + assert_eq!(converted[0].content.as_deref(), Some("Using tool")); + + let tool_calls = converted[0].tool_calls.as_ref().unwrap(); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].id.as_deref(), Some("call_abc")); + assert_eq!(tool_calls[0].function.name, "shell"); + assert_eq!(tool_calls[0].function.arguments, r#"{"command":"pwd"}"#); + } + + #[test] + fn convert_messages_parses_tool_result_payload() { + let messages = vec![ChatMessage { + role: "tool".into(), + content: r#"{"tool_call_id":"call_xyz","content":"done"}"#.into(), + }]; + + let converted = OpenRouterProvider::convert_messages(&messages); + assert_eq!(converted.len(), 1); + assert_eq!(converted[0].role, "tool"); + assert_eq!(converted[0].tool_call_id.as_deref(), Some("call_xyz")); + assert_eq!(converted[0].content.as_deref(), Some("done")); + assert!(converted[0].tool_calls.is_none()); + } } diff --git a/src/providers/reliable.rs b/src/providers/reliable.rs index c324f21..fe49d35 100644 --- a/src/providers/reliable.rs +++ b/src/providers/reliable.rs @@ -1,12 +1,85 @@ +use super::traits::{ChatMessage, StreamChunk, StreamOptions, StreamResult}; use super::Provider; use async_trait::async_trait; +use futures_util::{stream, StreamExt}; +use std::collections::HashMap; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; -/// Provider wrapper with retry + fallback behavior. +/// Check if an error is non-retryable (client errors that won't resolve with retries). +fn is_non_retryable(err: &anyhow::Error) -> bool { + if let Some(reqwest_err) = err.downcast_ref::() { + if let Some(status) = reqwest_err.status() { + let code = status.as_u16(); + return status.is_client_error() && code != 429 && code != 408; + } + } + let msg = err.to_string(); + for word in msg.split(|c: char| !c.is_ascii_digit()) { + if let Ok(code) = word.parse::() { + if (400..500).contains(&code) { + return code != 429 && code != 408; + } + } + } + false +} + +/// Check if an error is a rate-limit (429) error. +fn is_rate_limited(err: &anyhow::Error) -> bool { + if let Some(reqwest_err) = err.downcast_ref::() { + if let Some(status) = reqwest_err.status() { + return status.as_u16() == 429; + } + } + let msg = err.to_string(); + msg.contains("429") + && (msg.contains("Too Many") || msg.contains("rate") || msg.contains("limit")) +} + +/// Try to extract a Retry-After value (in milliseconds) from an error message. +/// Looks for patterns like `Retry-After: 5` or `retry_after: 2.5` in the error string. +fn parse_retry_after_ms(err: &anyhow::Error) -> Option { + let msg = err.to_string(); + let lower = msg.to_lowercase(); + + // Look for "retry-after: " or "retry_after: " + for prefix in &[ + "retry-after:", + "retry_after:", + "retry-after ", + "retry_after ", + ] { + if let Some(pos) = lower.find(prefix) { + let after = &msg[pos + prefix.len()..]; + let num_str: String = after + .trim() + .chars() + .take_while(|c| c.is_ascii_digit() || *c == '.') + .collect(); + if let Ok(secs) = num_str.parse::() { + if secs.is_finite() && secs >= 0.0 { + let millis = Duration::from_secs_f64(secs).as_millis(); + if let Ok(value) = u64::try_from(millis) { + return Some(value); + } + } + } + } + } + None +} + +/// Provider wrapper with retry, fallback, auth rotation, and model failover. pub struct ReliableProvider { providers: Vec<(String, Box)>, max_retries: u32, base_backoff_ms: u64, + /// Extra API keys for rotation (index tracks round-robin position). + api_keys: Vec, + key_index: AtomicUsize, + /// Per-model fallback chains: model_name → [fallback_model_1, fallback_model_2, ...] + model_fallbacks: HashMap>, } impl ReliableProvider { @@ -19,12 +92,65 @@ impl ReliableProvider { providers, max_retries, base_backoff_ms: base_backoff_ms.max(50), + api_keys: Vec::new(), + key_index: AtomicUsize::new(0), + model_fallbacks: HashMap::new(), + } + } + + /// Set additional API keys for round-robin rotation on rate-limit errors. + pub fn with_api_keys(mut self, keys: Vec) -> Self { + self.api_keys = keys; + self + } + + /// Set per-model fallback chains. + pub fn with_model_fallbacks(mut self, fallbacks: HashMap>) -> Self { + self.model_fallbacks = fallbacks; + self + } + + /// Build the list of models to try: [original, fallback1, fallback2, ...] + fn model_chain<'a>(&'a self, model: &'a str) -> Vec<&'a str> { + let mut chain = vec![model]; + if let Some(fallbacks) = self.model_fallbacks.get(model) { + chain.extend(fallbacks.iter().map(|s| s.as_str())); + } + chain + } + + /// Advance to the next API key and return it, or None if no extra keys configured. + fn rotate_key(&self) -> Option<&str> { + if self.api_keys.is_empty() { + return None; + } + let idx = self.key_index.fetch_add(1, Ordering::Relaxed) % self.api_keys.len(); + Some(&self.api_keys[idx]) + } + + /// Compute backoff duration, respecting Retry-After if present. + fn compute_backoff(&self, base: u64, err: &anyhow::Error) -> u64 { + if let Some(retry_after) = parse_retry_after_ms(err) { + // Use Retry-After but cap at 30s to avoid indefinite waits + retry_after.min(30_000).max(base) + } else { + base } } } #[async_trait] impl Provider for ReliableProvider { + async fn warmup(&self) -> anyhow::Result<()> { + for (name, provider) in &self.providers { + tracing::info!(provider = name, "Warming up provider connection pool"); + if provider.warmup().await.is_err() { + tracing::warn!(provider = name, "Warmup failed (non-fatal)"); + } + } + Ok(()) + } + async fn chat_with_system( &self, system_prompt: Option<&str>, @@ -32,58 +158,278 @@ impl Provider for ReliableProvider { model: &str, temperature: f64, ) -> anyhow::Result { + let models = self.model_chain(model); let mut failures = Vec::new(); - for (provider_name, provider) in &self.providers { - let mut backoff_ms = self.base_backoff_ms; + for current_model in &models { + for (provider_name, provider) in &self.providers { + let mut backoff_ms = self.base_backoff_ms; - for attempt in 0..=self.max_retries { - match provider - .chat_with_system(system_prompt, message, model, temperature) - .await - { - Ok(resp) => { - if attempt > 0 { - tracing::info!( - provider = provider_name, - attempt, - "Provider recovered after retries" - ); + for attempt in 0..=self.max_retries { + match provider + .chat_with_system(system_prompt, message, current_model, temperature) + .await + { + Ok(resp) => { + if attempt > 0 || *current_model != model { + tracing::info!( + provider = provider_name, + model = *current_model, + attempt, + original_model = model, + "Provider recovered (failover/retry)" + ); + } + return Ok(resp); } - return Ok(resp); - } - Err(e) => { - failures.push(format!( - "{provider_name} attempt {}/{}: {e}", - attempt + 1, - self.max_retries + 1 - )); + Err(e) => { + let non_retryable = is_non_retryable(&e); + let rate_limited = is_rate_limited(&e); - if attempt < self.max_retries { - tracing::warn!( - provider = provider_name, - attempt = attempt + 1, - max_retries = self.max_retries, - "Provider call failed, retrying" - ); - tokio::time::sleep(Duration::from_millis(backoff_ms)).await; - backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000); + let failure_reason = if rate_limited { + "rate_limited" + } else if non_retryable { + "non_retryable" + } else { + "retryable" + }; + failures.push(format!( + "{provider_name}/{current_model} attempt {}/{}: {failure_reason}", + attempt + 1, + self.max_retries + 1 + )); + + // On rate-limit, try rotating API key + if rate_limited { + if let Some(new_key) = self.rotate_key() { + tracing::info!( + provider = provider_name, + "Rate limited, rotated API key (key ending ...{})", + &new_key[new_key.len().saturating_sub(4)..] + ); + } + } + + if non_retryable { + tracing::warn!( + provider = provider_name, + model = *current_model, + "Non-retryable error, moving on" + ); + break; + } + + if attempt < self.max_retries { + let wait = self.compute_backoff(backoff_ms, &e); + tracing::warn!( + provider = provider_name, + model = *current_model, + attempt = attempt + 1, + backoff_ms = wait, + "Provider call failed, retrying" + ); + tokio::time::sleep(Duration::from_millis(wait)).await; + backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000); + } } } } + + tracing::warn!( + provider = provider_name, + model = *current_model, + "Exhausted retries, trying next provider/model" + ); } - tracing::warn!(provider = provider_name, "Switching to fallback provider"); + if *current_model != model { + tracing::warn!( + original_model = model, + fallback_model = *current_model, + "Model fallback exhausted all providers, trying next fallback model" + ); + } } - anyhow::bail!("All providers failed. Attempts:\n{}", failures.join("\n")) + anyhow::bail!( + "All providers/models failed. Attempts:\n{}", + failures.join("\n") + ) + } + + async fn chat_with_history( + &self, + messages: &[ChatMessage], + model: &str, + temperature: f64, + ) -> anyhow::Result { + let models = self.model_chain(model); + let mut failures = Vec::new(); + + for current_model in &models { + for (provider_name, provider) in &self.providers { + let mut backoff_ms = self.base_backoff_ms; + + for attempt in 0..=self.max_retries { + match provider + .chat_with_history(messages, current_model, temperature) + .await + { + Ok(resp) => { + if attempt > 0 || *current_model != model { + tracing::info!( + provider = provider_name, + model = *current_model, + attempt, + original_model = model, + "Provider recovered (failover/retry)" + ); + } + return Ok(resp); + } + Err(e) => { + let non_retryable = is_non_retryable(&e); + let rate_limited = is_rate_limited(&e); + + let failure_reason = if rate_limited { + "rate_limited" + } else if non_retryable { + "non_retryable" + } else { + "retryable" + }; + failures.push(format!( + "{provider_name}/{current_model} attempt {}/{}: {failure_reason}", + attempt + 1, + self.max_retries + 1 + )); + + if rate_limited { + if let Some(new_key) = self.rotate_key() { + tracing::info!( + provider = provider_name, + "Rate limited, rotated API key (key ending ...{})", + &new_key[new_key.len().saturating_sub(4)..] + ); + } + } + + if non_retryable { + tracing::warn!( + provider = provider_name, + model = *current_model, + "Non-retryable error, moving on" + ); + break; + } + + if attempt < self.max_retries { + let wait = self.compute_backoff(backoff_ms, &e); + tracing::warn!( + provider = provider_name, + model = *current_model, + attempt = attempt + 1, + backoff_ms = wait, + "Provider call failed, retrying" + ); + tokio::time::sleep(Duration::from_millis(wait)).await; + backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000); + } + } + } + } + + tracing::warn!( + provider = provider_name, + model = *current_model, + "Exhausted retries, trying next provider/model" + ); + } + } + + anyhow::bail!( + "All providers/models failed. Attempts:\n{}", + failures.join("\n") + ) + } + + fn supports_streaming(&self) -> bool { + self.providers.iter().any(|(_, p)| p.supports_streaming()) + } + + fn stream_chat_with_system( + &self, + system_prompt: Option<&str>, + message: &str, + model: &str, + temperature: f64, + options: StreamOptions, + ) -> stream::BoxStream<'static, StreamResult> { + // Try each provider/model combination for streaming + // For streaming, we use the first provider that supports it and has streaming enabled + for (provider_name, provider) in &self.providers { + if !provider.supports_streaming() || !options.enabled { + continue; + } + + // Clone provider data for the stream + let provider_clone = provider_name.clone(); + + // Try the first model in the chain for streaming + let current_model = match self.model_chain(model).first() { + Some(m) => m.to_string(), + None => model.to_string(), + }; + + // For streaming, we attempt once and propagate errors + // The caller can retry the entire request if needed + let stream = provider.stream_chat_with_system( + system_prompt, + message, + ¤t_model, + temperature, + options, + ); + + // Use a channel to bridge the stream with logging + let (tx, rx) = tokio::sync::mpsc::channel::>(100); + + tokio::spawn(async move { + let mut stream = stream; + while let Some(chunk) = stream.next().await { + if let Err(ref e) = chunk { + tracing::warn!( + provider = provider_clone, + model = current_model, + "Streaming error: {e}" + ); + } + if tx.send(chunk).await.is_err() { + break; // Receiver dropped + } + } + }); + + // Convert channel receiver to stream + return stream::unfold(rx, |mut rx| async move { + rx.recv().await.map(|chunk| (chunk, rx)) + }) + .boxed(); + } + + // No streaming support available + stream::once(async move { + Err(super::traits::StreamError::Provider( + "No provider supports streaming".to_string(), + )) + }) + .boxed() } } #[cfg(test)] mod tests { use super::*; - use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; struct MockProvider { @@ -108,8 +454,49 @@ mod tests { } Ok(self.response.to_string()) } + + async fn chat_with_history( + &self, + _messages: &[ChatMessage], + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1; + if attempt <= self.fail_until_attempt { + anyhow::bail!(self.error); + } + Ok(self.response.to_string()) + } } + /// Mock that records which model was used for each call. + struct ModelAwareMock { + calls: Arc, + models_seen: parking_lot::Mutex>, + fail_models: Vec<&'static str>, + response: &'static str, + } + + #[async_trait] + impl Provider for ModelAwareMock { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + model: &str, + _temperature: f64, + ) -> anyhow::Result { + self.calls.fetch_add(1, Ordering::SeqCst); + self.models_seen.lock().push(model.to_string()); + if self.fail_models.contains(&model) { + anyhow::bail!("500 model {} unavailable", model); + } + Ok(self.response.to_string()) + } + } + + // ── Existing tests (preserved) ── + #[tokio::test] async fn succeeds_without_retry() { let calls = Arc::new(AtomicUsize::new(0)); @@ -127,7 +514,7 @@ mod tests { 1, ); - let result = provider.chat("hello", "test", 0.0).await.unwrap(); + let result = provider.simple_chat("hello", "test", 0.0).await.unwrap(); assert_eq!(result, "ok"); assert_eq!(calls.load(Ordering::SeqCst), 1); } @@ -149,7 +536,7 @@ mod tests { 1, ); - let result = provider.chat("hello", "test", 0.0).await.unwrap(); + let result = provider.simple_chat("hello", "test", 0.0).await.unwrap(); assert_eq!(result, "recovered"); assert_eq!(calls.load(Ordering::SeqCst), 2); } @@ -184,7 +571,7 @@ mod tests { 1, ); - let result = provider.chat("hello", "test", 0.0).await.unwrap(); + let result = provider.simple_chat("hello", "test", 0.0).await.unwrap(); assert_eq!(result, "from fallback"); assert_eq!(primary_calls.load(Ordering::SeqCst), 2); assert_eq!(fallback_calls.load(Ordering::SeqCst), 1); @@ -218,12 +605,323 @@ mod tests { ); let err = provider - .chat("hello", "test", 0.0) + .simple_chat("hello", "test", 0.0) .await .expect_err("all providers should fail"); let msg = err.to_string(); - assert!(msg.contains("All providers failed")); - assert!(msg.contains("p1 attempt 1/1")); - assert!(msg.contains("p2 attempt 1/1")); + assert!(msg.contains("All providers/models failed")); + assert!(msg.contains("p1")); + assert!(msg.contains("p2")); + } + + #[test] + fn non_retryable_detects_common_patterns() { + assert!(is_non_retryable(&anyhow::anyhow!("400 Bad Request"))); + assert!(is_non_retryable(&anyhow::anyhow!("401 Unauthorized"))); + assert!(is_non_retryable(&anyhow::anyhow!("403 Forbidden"))); + assert!(is_non_retryable(&anyhow::anyhow!("404 Not Found"))); + assert!(!is_non_retryable(&anyhow::anyhow!("429 Too Many Requests"))); + assert!(!is_non_retryable(&anyhow::anyhow!("408 Request Timeout"))); + assert!(!is_non_retryable(&anyhow::anyhow!( + "500 Internal Server Error" + ))); + assert!(!is_non_retryable(&anyhow::anyhow!("502 Bad Gateway"))); + assert!(!is_non_retryable(&anyhow::anyhow!("timeout"))); + assert!(!is_non_retryable(&anyhow::anyhow!("connection reset"))); + } + + #[tokio::test] + async fn skips_retries_on_non_retryable_error() { + let primary_calls = Arc::new(AtomicUsize::new(0)); + let fallback_calls = Arc::new(AtomicUsize::new(0)); + + let provider = ReliableProvider::new( + vec![ + ( + "primary".into(), + Box::new(MockProvider { + calls: Arc::clone(&primary_calls), + fail_until_attempt: usize::MAX, + response: "never", + error: "401 Unauthorized", + }), + ), + ( + "fallback".into(), + Box::new(MockProvider { + calls: Arc::clone(&fallback_calls), + fail_until_attempt: 0, + response: "from fallback", + error: "fallback err", + }), + ), + ], + 3, + 1, + ); + + let result = provider.simple_chat("hello", "test", 0.0).await.unwrap(); + assert_eq!(result, "from fallback"); + // Primary should have been called only once (no retries) + assert_eq!(primary_calls.load(Ordering::SeqCst), 1); + assert_eq!(fallback_calls.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn chat_with_history_retries_then_recovers() { + let calls = Arc::new(AtomicUsize::new(0)); + let provider = ReliableProvider::new( + vec![( + "primary".into(), + Box::new(MockProvider { + calls: Arc::clone(&calls), + fail_until_attempt: 1, + response: "history ok", + error: "temporary", + }), + )], + 2, + 1, + ); + + let messages = vec![ChatMessage::system("system"), ChatMessage::user("hello")]; + let result = provider + .chat_with_history(&messages, "test", 0.0) + .await + .unwrap(); + assert_eq!(result, "history ok"); + assert_eq!(calls.load(Ordering::SeqCst), 2); + } + + #[tokio::test] + async fn chat_with_history_falls_back() { + let primary_calls = Arc::new(AtomicUsize::new(0)); + let fallback_calls = Arc::new(AtomicUsize::new(0)); + + let provider = ReliableProvider::new( + vec![ + ( + "primary".into(), + Box::new(MockProvider { + calls: Arc::clone(&primary_calls), + fail_until_attempt: usize::MAX, + response: "never", + error: "primary down", + }), + ), + ( + "fallback".into(), + Box::new(MockProvider { + calls: Arc::clone(&fallback_calls), + fail_until_attempt: 0, + response: "fallback ok", + error: "fallback err", + }), + ), + ], + 1, + 1, + ); + + let messages = vec![ChatMessage::user("hello")]; + let result = provider + .chat_with_history(&messages, "test", 0.0) + .await + .unwrap(); + assert_eq!(result, "fallback ok"); + assert_eq!(primary_calls.load(Ordering::SeqCst), 2); + assert_eq!(fallback_calls.load(Ordering::SeqCst), 1); + } + + // ── New tests: model failover ── + + #[tokio::test] + async fn model_failover_tries_fallback_model() { + let calls = Arc::new(AtomicUsize::new(0)); + let mock = Arc::new(ModelAwareMock { + calls: Arc::clone(&calls), + models_seen: parking_lot::Mutex::new(Vec::new()), + fail_models: vec!["claude-opus"], + response: "ok from sonnet", + }); + + let mut fallbacks = HashMap::new(); + fallbacks.insert("claude-opus".to_string(), vec!["claude-sonnet".to_string()]); + + let provider = ReliableProvider::new( + vec![( + "anthropic".into(), + Box::new(mock.clone()) as Box, + )], + 0, // no retries — force immediate model failover + 1, + ) + .with_model_fallbacks(fallbacks); + + let result = provider + .simple_chat("hello", "claude-opus", 0.0) + .await + .unwrap(); + assert_eq!(result, "ok from sonnet"); + + let seen = mock.models_seen.lock(); + assert_eq!(seen.len(), 2); + assert_eq!(seen[0], "claude-opus"); + assert_eq!(seen[1], "claude-sonnet"); + } + + #[tokio::test] + async fn model_failover_all_models_fail() { + let calls = Arc::new(AtomicUsize::new(0)); + let mock = Arc::new(ModelAwareMock { + calls: Arc::clone(&calls), + models_seen: parking_lot::Mutex::new(Vec::new()), + fail_models: vec!["model-a", "model-b", "model-c"], + response: "never", + }); + + let mut fallbacks = HashMap::new(); + fallbacks.insert( + "model-a".to_string(), + vec!["model-b".to_string(), "model-c".to_string()], + ); + + let provider = ReliableProvider::new( + vec![("p1".into(), Box::new(mock.clone()) as Box)], + 0, + 1, + ) + .with_model_fallbacks(fallbacks); + + let err = provider + .simple_chat("hello", "model-a", 0.0) + .await + .expect_err("all models should fail"); + assert!(err.to_string().contains("All providers/models failed")); + + let seen = mock.models_seen.lock(); + assert_eq!(seen.len(), 3); + } + + #[tokio::test] + async fn no_model_fallbacks_behaves_like_before() { + let calls = Arc::new(AtomicUsize::new(0)); + let provider = ReliableProvider::new( + vec![( + "primary".into(), + Box::new(MockProvider { + calls: Arc::clone(&calls), + fail_until_attempt: 0, + response: "ok", + error: "boom", + }), + )], + 2, + 1, + ); + // No model_fallbacks set — should work exactly as before + let result = provider.simple_chat("hello", "test", 0.0).await.unwrap(); + assert_eq!(result, "ok"); + assert_eq!(calls.load(Ordering::SeqCst), 1); + } + + // ── New tests: auth rotation ── + + #[tokio::test] + async fn auth_rotation_cycles_keys() { + let provider = ReliableProvider::new( + vec![( + "p".into(), + Box::new(MockProvider { + calls: Arc::new(AtomicUsize::new(0)), + fail_until_attempt: 0, + response: "ok", + error: "", + }), + )], + 0, + 1, + ) + .with_api_keys(vec!["key-a".into(), "key-b".into(), "key-c".into()]); + + // Rotate 5 times, verify round-robin + let keys: Vec<&str> = (0..5).map(|_| provider.rotate_key().unwrap()).collect(); + assert_eq!(keys, vec!["key-a", "key-b", "key-c", "key-a", "key-b"]); + } + + #[tokio::test] + async fn auth_rotation_returns_none_when_empty() { + let provider = ReliableProvider::new(vec![], 0, 1); + assert!(provider.rotate_key().is_none()); + } + + // ── New tests: Retry-After parsing ── + + #[test] + fn parse_retry_after_integer() { + let err = anyhow::anyhow!("429 Too Many Requests, Retry-After: 5"); + assert_eq!(parse_retry_after_ms(&err), Some(5000)); + } + + #[test] + fn parse_retry_after_float() { + let err = anyhow::anyhow!("Rate limited. retry_after: 2.5 seconds"); + assert_eq!(parse_retry_after_ms(&err), Some(2500)); + } + + #[test] + fn parse_retry_after_missing() { + let err = anyhow::anyhow!("500 Internal Server Error"); + assert_eq!(parse_retry_after_ms(&err), None); + } + + #[test] + fn rate_limited_detection() { + assert!(is_rate_limited(&anyhow::anyhow!("429 Too Many Requests"))); + assert!(is_rate_limited(&anyhow::anyhow!( + "HTTP 429 rate limit exceeded" + ))); + assert!(!is_rate_limited(&anyhow::anyhow!("401 Unauthorized"))); + assert!(!is_rate_limited(&anyhow::anyhow!( + "500 Internal Server Error" + ))); + } + + #[test] + fn compute_backoff_uses_retry_after() { + let provider = ReliableProvider::new(vec![], 0, 500); + let err = anyhow::anyhow!("429 Retry-After: 3"); + assert_eq!(provider.compute_backoff(500, &err), 3000); + } + + #[test] + fn compute_backoff_caps_at_30s() { + let provider = ReliableProvider::new(vec![], 0, 500); + let err = anyhow::anyhow!("429 Retry-After: 120"); + assert_eq!(provider.compute_backoff(500, &err), 30_000); + } + + #[test] + fn compute_backoff_falls_back_to_base() { + let provider = ReliableProvider::new(vec![], 0, 500); + let err = anyhow::anyhow!("500 Server Error"); + assert_eq!(provider.compute_backoff(500, &err), 500); + } + + // ── Arc Provider impl for test ── + + #[async_trait] + impl Provider for Arc { + async fn chat_with_system( + &self, + system_prompt: Option<&str>, + message: &str, + model: &str, + temperature: f64, + ) -> anyhow::Result { + self.as_ref() + .chat_with_system(system_prompt, message, model, temperature) + .await + } } } diff --git a/src/providers/router.rs b/src/providers/router.rs new file mode 100644 index 0000000..78edde0 --- /dev/null +++ b/src/providers/router.rs @@ -0,0 +1,385 @@ +use super::traits::{ChatMessage, ChatRequest, ChatResponse}; +use super::Provider; +use async_trait::async_trait; +use std::collections::HashMap; + +/// A single route: maps a task hint to a provider + model combo. +#[derive(Debug, Clone)] +pub struct Route { + pub provider_name: String, + pub model: String, +} + +/// Multi-model router — routes requests to different provider+model combos +/// based on a task hint encoded in the model parameter. +/// +/// The model parameter can be: +/// - A regular model name (e.g. "anthropic/claude-sonnet-4") → uses default provider +/// - A hint-prefixed string (e.g. "hint:reasoning") → resolves via route table +/// +/// This wraps multiple pre-created providers and selects the right one per request. +pub struct RouterProvider { + routes: HashMap, // hint → (provider_index, model) + providers: Vec<(String, Box)>, + default_index: usize, + default_model: String, +} + +impl RouterProvider { + /// Create a new router with a default provider and optional routes. + /// + /// `providers` is a list of (name, provider) pairs. The first one is the default. + /// `routes` maps hint names to Route structs containing provider_name and model. + pub fn new( + providers: Vec<(String, Box)>, + routes: Vec<(String, Route)>, + default_model: String, + ) -> Self { + // Build provider name → index lookup + let name_to_index: HashMap<&str, usize> = providers + .iter() + .enumerate() + .map(|(i, (name, _))| (name.as_str(), i)) + .collect(); + + // Resolve routes to provider indices + let resolved_routes: HashMap = routes + .into_iter() + .filter_map(|(hint, route)| { + let index = name_to_index.get(route.provider_name.as_str()).copied(); + match index { + Some(i) => Some((hint, (i, route.model))), + None => { + tracing::warn!( + hint = hint, + provider = route.provider_name, + "Route references unknown provider, skipping" + ); + None + } + } + }) + .collect(); + + Self { + routes: resolved_routes, + providers, + default_index: 0, + default_model, + } + } + + /// Resolve a model parameter to a (provider, actual_model) pair. + /// + /// If the model starts with "hint:", look up the hint in the route table. + /// Otherwise, use the default provider with the given model name. + /// Resolve a model parameter to a (provider_index, actual_model) pair. + fn resolve(&self, model: &str) -> (usize, String) { + if let Some(hint) = model.strip_prefix("hint:") { + if let Some((idx, resolved_model)) = self.routes.get(hint) { + return (*idx, resolved_model.clone()); + } + tracing::warn!( + hint = hint, + "Unknown route hint, falling back to default provider" + ); + } + + // Not a hint or hint not found — use default provider with the model as-is + (self.default_index, model.to_string()) + } +} + +#[async_trait] +impl Provider for RouterProvider { + async fn chat_with_system( + &self, + system_prompt: Option<&str>, + message: &str, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let (provider_idx, resolved_model) = self.resolve(model); + + let (provider_name, provider) = &self.providers[provider_idx]; + tracing::info!( + provider = provider_name.as_str(), + model = resolved_model.as_str(), + "Router dispatching request" + ); + + provider + .chat_with_system(system_prompt, message, &resolved_model, temperature) + .await + } + + async fn chat_with_history( + &self, + messages: &[ChatMessage], + model: &str, + temperature: f64, + ) -> anyhow::Result { + let (provider_idx, resolved_model) = self.resolve(model); + let (_, provider) = &self.providers[provider_idx]; + provider + .chat_with_history(messages, &resolved_model, temperature) + .await + } + + async fn chat( + &self, + request: ChatRequest<'_>, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let (provider_idx, resolved_model) = self.resolve(model); + let (_, provider) = &self.providers[provider_idx]; + provider.chat(request, &resolved_model, temperature).await + } + + fn supports_native_tools(&self) -> bool { + self.providers + .get(self.default_index) + .map(|(_, p)| p.supports_native_tools()) + .unwrap_or(false) + } + + async fn warmup(&self) -> anyhow::Result<()> { + for (name, provider) in &self.providers { + tracing::info!(provider = name, "Warming up routed provider"); + if let Err(e) = provider.warmup().await { + tracing::warn!(provider = name, "Warmup failed (non-fatal): {e}"); + } + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + + struct MockProvider { + calls: Arc, + response: &'static str, + last_model: parking_lot::Mutex, + } + + impl MockProvider { + fn new(response: &'static str) -> Self { + Self { + calls: Arc::new(AtomicUsize::new(0)), + response, + last_model: parking_lot::Mutex::new(String::new()), + } + } + + fn call_count(&self) -> usize { + self.calls.load(Ordering::SeqCst) + } + + fn last_model(&self) -> String { + self.last_model.lock().clone() + } + } + + #[async_trait] + impl Provider for MockProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + model: &str, + _temperature: f64, + ) -> anyhow::Result { + self.calls.fetch_add(1, Ordering::SeqCst); + *self.last_model.lock() = model.to_string(); + Ok(self.response.to_string()) + } + } + + fn make_router( + providers: Vec<(&'static str, &'static str)>, + routes: Vec<(&str, &str, &str)>, + ) -> (RouterProvider, Vec>) { + let mocks: Vec> = providers + .iter() + .map(|(_, response)| Arc::new(MockProvider::new(response))) + .collect(); + + let provider_list: Vec<(String, Box)> = providers + .iter() + .zip(mocks.iter()) + .map(|((name, _), mock)| { + ( + name.to_string(), + Box::new(Arc::clone(mock)) as Box, + ) + }) + .collect(); + + let route_list: Vec<(String, Route)> = routes + .iter() + .map(|(hint, provider_name, model)| { + ( + hint.to_string(), + Route { + provider_name: provider_name.to_string(), + model: model.to_string(), + }, + ) + }) + .collect(); + + let router = RouterProvider::new(provider_list, route_list, "default-model".to_string()); + + (router, mocks) + } + + // Arc should also be a Provider + #[async_trait] + impl Provider for Arc { + async fn chat_with_system( + &self, + system_prompt: Option<&str>, + message: &str, + model: &str, + temperature: f64, + ) -> anyhow::Result { + self.as_ref() + .chat_with_system(system_prompt, message, model, temperature) + .await + } + } + + #[tokio::test] + async fn routes_hint_to_correct_provider() { + let (router, mocks) = make_router( + vec![("fast", "fast-response"), ("smart", "smart-response")], + vec![ + ("fast", "fast", "llama-3-70b"), + ("reasoning", "smart", "claude-opus"), + ], + ); + + let result = router + .simple_chat("hello", "hint:reasoning", 0.5) + .await + .unwrap(); + assert_eq!(result, "smart-response"); + assert_eq!(mocks[1].call_count(), 1); + assert_eq!(mocks[1].last_model(), "claude-opus"); + assert_eq!(mocks[0].call_count(), 0); + } + + #[tokio::test] + async fn routes_fast_hint() { + let (router, mocks) = make_router( + vec![("fast", "fast-response"), ("smart", "smart-response")], + vec![("fast", "fast", "llama-3-70b")], + ); + + let result = router.simple_chat("hello", "hint:fast", 0.5).await.unwrap(); + assert_eq!(result, "fast-response"); + assert_eq!(mocks[0].call_count(), 1); + assert_eq!(mocks[0].last_model(), "llama-3-70b"); + } + + #[tokio::test] + async fn unknown_hint_falls_back_to_default() { + let (router, mocks) = make_router( + vec![("default", "default-response"), ("other", "other-response")], + vec![], + ); + + let result = router + .simple_chat("hello", "hint:nonexistent", 0.5) + .await + .unwrap(); + assert_eq!(result, "default-response"); + assert_eq!(mocks[0].call_count(), 1); + // Falls back to default with the hint as model name + assert_eq!(mocks[0].last_model(), "hint:nonexistent"); + } + + #[tokio::test] + async fn non_hint_model_uses_default_provider() { + let (router, mocks) = make_router( + vec![ + ("primary", "primary-response"), + ("secondary", "secondary-response"), + ], + vec![("code", "secondary", "codellama")], + ); + + let result = router + .simple_chat("hello", "anthropic/claude-sonnet-4-20250514", 0.5) + .await + .unwrap(); + assert_eq!(result, "primary-response"); + assert_eq!(mocks[0].call_count(), 1); + assert_eq!(mocks[0].last_model(), "anthropic/claude-sonnet-4-20250514"); + } + + #[test] + fn resolve_preserves_model_for_non_hints() { + let (router, _) = make_router(vec![("default", "ok")], vec![]); + + let (idx, model) = router.resolve("gpt-4o"); + assert_eq!(idx, 0); + assert_eq!(model, "gpt-4o"); + } + + #[test] + fn resolve_strips_hint_prefix() { + let (router, _) = make_router( + vec![("fast", "ok"), ("smart", "ok")], + vec![("reasoning", "smart", "claude-opus")], + ); + + let (idx, model) = router.resolve("hint:reasoning"); + assert_eq!(idx, 1); + assert_eq!(model, "claude-opus"); + } + + #[test] + fn skips_routes_with_unknown_provider() { + let (router, _) = make_router( + vec![("default", "ok")], + vec![("broken", "nonexistent", "model")], + ); + + // Route should not exist + assert!(!router.routes.contains_key("broken")); + } + + #[tokio::test] + async fn warmup_calls_all_providers() { + let (router, _) = make_router(vec![("a", "ok"), ("b", "ok")], vec![]); + + // Warmup should not error + assert!(router.warmup().await.is_ok()); + } + + #[tokio::test] + async fn chat_with_system_passes_system_prompt() { + let mock = Arc::new(MockProvider::new("response")); + let router = RouterProvider::new( + vec![( + "default".into(), + Box::new(Arc::clone(&mock)) as Box, + )], + vec![], + "model".into(), + ); + + let result = router + .chat_with_system(Some("system"), "hello", "model", 0.5) + .await + .unwrap(); + assert_eq!(result, "response"); + assert_eq!(mock.call_count(), 1); + } +} diff --git a/src/providers/traits.rs b/src/providers/traits.rs index 8a24714..fe830ef 100644 --- a/src/providers/traits.rs +++ b/src/providers/traits.rs @@ -1,12 +1,269 @@ +use crate::tools::ToolSpec; use async_trait::async_trait; +use futures_util::{stream, StreamExt}; +use serde::{Deserialize, Serialize}; +use std::fmt::Write; + +/// A single message in a conversation. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatMessage { + pub role: String, + pub content: String, +} + +impl ChatMessage { + pub fn system(content: impl Into) -> Self { + Self { + role: "system".into(), + content: content.into(), + } + } + + pub fn user(content: impl Into) -> Self { + Self { + role: "user".into(), + content: content.into(), + } + } + + pub fn assistant(content: impl Into) -> Self { + Self { + role: "assistant".into(), + content: content.into(), + } + } + + pub fn tool(content: impl Into) -> Self { + Self { + role: "tool".into(), + content: content.into(), + } + } +} + +/// A tool call requested by the LLM. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCall { + pub id: String, + pub name: String, + pub arguments: String, +} + +/// An LLM response that may contain text, tool calls, or both. +#[derive(Debug, Clone)] +pub struct ChatResponse { + /// Text content of the response (may be empty if only tool calls). + pub text: Option, + /// Tool calls requested by the LLM. + pub tool_calls: Vec, +} + +impl ChatResponse { + /// True when the LLM wants to invoke at least one tool. + pub fn has_tool_calls(&self) -> bool { + !self.tool_calls.is_empty() + } + + /// Convenience: return text content or empty string. + pub fn text_or_empty(&self) -> &str { + self.text.as_deref().unwrap_or("") + } +} + +/// Request payload for provider chat calls. +#[derive(Debug, Clone, Copy)] +pub struct ChatRequest<'a> { + pub messages: &'a [ChatMessage], + pub tools: Option<&'a [ToolSpec]>, +} + +/// A tool result to feed back to the LLM. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolResultMessage { + pub tool_call_id: String, + pub content: String, +} + +/// A message in a multi-turn conversation, including tool interactions. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", content = "data")] +pub enum ConversationMessage { + /// Regular chat message (system, user, assistant). + Chat(ChatMessage), + /// Tool calls from the assistant (stored for history fidelity). + AssistantToolCalls { + text: Option, + tool_calls: Vec, + }, + /// Results of tool executions, fed back to the LLM. + ToolResults(Vec), +} + +/// A chunk of content from a streaming response. +#[derive(Debug, Clone)] +pub struct StreamChunk { + /// Text delta for this chunk. + pub delta: String, + /// Whether this is the final chunk. + pub is_final: bool, + /// Approximate token count for this chunk (estimated). + pub token_count: usize, +} + +impl StreamChunk { + /// Create a new non-final chunk. + pub fn delta(text: impl Into) -> Self { + Self { + delta: text.into(), + is_final: false, + token_count: 0, + } + } + + /// Create a final chunk. + pub fn final_chunk() -> Self { + Self { + delta: String::new(), + is_final: true, + token_count: 0, + } + } + + /// Create an error chunk. + pub fn error(message: impl Into) -> Self { + Self { + delta: message.into(), + is_final: true, + token_count: 0, + } + } + + /// Estimate tokens (rough approximation: ~4 chars per token). + pub fn with_token_estimate(mut self) -> Self { + self.token_count = self.delta.len().div_ceil(4); + self + } +} + +/// Options for streaming chat requests. +#[derive(Debug, Clone, Copy, Default)] +pub struct StreamOptions { + /// Whether to enable streaming (default: true). + pub enabled: bool, + /// Whether to include token counts in chunks. + pub count_tokens: bool, +} + +impl StreamOptions { + /// Create new streaming options with enabled flag. + pub fn new(enabled: bool) -> Self { + Self { + enabled, + count_tokens: false, + } + } + + /// Enable token counting. + pub fn with_token_count(mut self) -> Self { + self.count_tokens = true; + self + } +} + +/// Result type for streaming operations. +pub type StreamResult = std::result::Result; + +/// Errors that can occur during streaming. +#[derive(Debug, thiserror::Error)] +pub enum StreamError { + #[error("HTTP error: {0}")] + Http(reqwest::Error), + + #[error("JSON parse error: {0}")] + Json(serde_json::Error), + + #[error("Invalid SSE format: {0}")] + InvalidSse(String), + + #[error("Provider error: {0}")] + Provider(String), + + #[error("IO error: {0}")] + Io(#[from] std::io::Error), +} + +/// Provider capabilities declaration. +/// +/// Describes what features a provider supports, enabling intelligent +/// adaptation of tool calling modes and request formatting. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct ProviderCapabilities { + /// Whether the provider supports native tool calling via API primitives. + /// + /// When `true`, the provider can convert tool definitions to API-native + /// formats (e.g., Gemini's functionDeclarations, Anthropic's input_schema). + /// + /// When `false`, tools must be injected via system prompt as text. + pub native_tool_calling: bool, +} + +/// Provider-specific tool payload formats. +/// +/// Different LLM providers require different formats for tool definitions. +/// This enum encapsulates those variations, enabling providers to convert +/// from the unified `ToolSpec` format to their native API requirements. +#[derive(Debug, Clone)] +pub enum ToolsPayload { + /// Gemini API format (functionDeclarations). + Gemini { + function_declarations: Vec, + }, + /// Anthropic Messages API format (tools with input_schema). + Anthropic { tools: Vec }, + /// OpenAI Chat Completions API format (tools with function). + OpenAI { tools: Vec }, + /// Prompt-guided fallback (tools injected as text in system prompt). + PromptGuided { instructions: String }, +} #[async_trait] pub trait Provider: Send + Sync { - async fn chat(&self, message: &str, model: &str, temperature: f64) -> anyhow::Result { + /// Query provider capabilities. + /// + /// Default implementation returns minimal capabilities (no native tool calling). + /// Providers should override this to declare their actual capabilities. + fn capabilities(&self) -> ProviderCapabilities { + ProviderCapabilities::default() + } + + /// Convert tool specifications to provider-native format. + /// + /// Default implementation returns `PromptGuided` payload, which injects + /// tool documentation into the system prompt as text. Providers with + /// native tool calling support should override this to return their + /// specific format (Gemini, Anthropic, OpenAI). + fn convert_tools(&self, tools: &[ToolSpec]) -> ToolsPayload { + ToolsPayload::PromptGuided { + instructions: build_tool_instructions_text(tools), + } + } + + /// Simple one-shot chat (single user message, no explicit system prompt). + /// + /// This is the preferred API for non-agentic direct interactions. + async fn simple_chat( + &self, + message: &str, + model: &str, + temperature: f64, + ) -> anyhow::Result { self.chat_with_system(None, message, model, temperature) .await } + /// One-shot chat with optional system prompt. + /// + /// Kept for compatibility and advanced one-shot prompting. async fn chat_with_system( &self, system_prompt: Option<&str>, @@ -14,4 +271,605 @@ pub trait Provider: Send + Sync { model: &str, temperature: f64, ) -> anyhow::Result; + + /// Multi-turn conversation. Default implementation extracts the last user + /// message and delegates to `chat_with_system`. + async fn chat_with_history( + &self, + messages: &[ChatMessage], + model: &str, + temperature: f64, + ) -> anyhow::Result { + let system = messages + .iter() + .find(|m| m.role == "system") + .map(|m| m.content.as_str()); + let last_user = messages + .iter() + .rfind(|m| m.role == "user") + .map(|m| m.content.as_str()) + .unwrap_or(""); + self.chat_with_system(system, last_user, model, temperature) + .await + } + + /// Structured chat API for agent loop callers. + async fn chat( + &self, + request: ChatRequest<'_>, + model: &str, + temperature: f64, + ) -> anyhow::Result { + // If tools are provided but provider doesn't support native tools, + // inject tool instructions into system prompt as fallback. + if let Some(tools) = request.tools { + if !tools.is_empty() && !self.supports_native_tools() { + let tool_instructions = match self.convert_tools(tools) { + ToolsPayload::PromptGuided { instructions } => instructions, + payload => { + anyhow::bail!( + "Provider returned non-prompt-guided tools payload ({payload:?}) while supports_native_tools() is false" + ) + } + }; + let mut modified_messages = request.messages.to_vec(); + + // Inject tool instructions into an existing system message. + // If none exists, prepend one to the conversation. + if let Some(system_message) = + modified_messages.iter_mut().find(|m| m.role == "system") + { + if !system_message.content.is_empty() { + system_message.content.push_str("\n\n"); + } + system_message.content.push_str(&tool_instructions); + } else { + modified_messages.insert(0, ChatMessage::system(tool_instructions)); + } + + let text = self + .chat_with_history(&modified_messages, model, temperature) + .await?; + return Ok(ChatResponse { + text: Some(text), + tool_calls: Vec::new(), + }); + } + } + + let text = self + .chat_with_history(request.messages, model, temperature) + .await?; + Ok(ChatResponse { + text: Some(text), + tool_calls: Vec::new(), + }) + } + + /// Whether provider supports native tool calls over API. + fn supports_native_tools(&self) -> bool { + self.capabilities().native_tool_calling + } + + /// Warm up the HTTP connection pool (TLS handshake, DNS, HTTP/2 setup). + /// Default implementation is a no-op; providers with HTTP clients should override. + async fn warmup(&self) -> anyhow::Result<()> { + Ok(()) + } + + /// Chat with tool definitions for native function calling support. + /// The default implementation falls back to chat_with_history and returns + /// an empty tool_calls vector (prompt-based tool use only). + async fn chat_with_tools( + &self, + messages: &[ChatMessage], + _tools: &[serde_json::Value], + model: &str, + temperature: f64, + ) -> anyhow::Result { + let text = self.chat_with_history(messages, model, temperature).await?; + Ok(ChatResponse { + text: Some(text), + tool_calls: Vec::new(), + }) + } + + /// Whether provider supports streaming responses. + /// Default implementation returns false. + fn supports_streaming(&self) -> bool { + false + } + + /// Streaming chat with optional system prompt. + /// Returns an async stream of text chunks. + /// Default implementation falls back to non-streaming chat. + fn stream_chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + _options: StreamOptions, + ) -> stream::BoxStream<'static, StreamResult> { + // Default: return an empty stream (not supported) + stream::empty().boxed() + } + + /// Streaming chat with history. + /// Default implementation falls back to stream_chat_with_system with last user message. + fn stream_chat_with_history( + &self, + _messages: &[ChatMessage], + _model: &str, + _temperature: f64, + _options: StreamOptions, + ) -> stream::BoxStream<'static, StreamResult> { + // For default implementation, we need to convert to owned strings + // This is a limitation of the default implementation + let provider_name = "unknown".to_string(); + + // Create a single empty chunk to indicate not supported + let chunk = StreamChunk::error(format!("{} does not support streaming", provider_name)); + stream::once(async move { Ok(chunk) }).boxed() + } +} + +/// Build tool instructions text for prompt-guided tool calling. +/// +/// Generates a formatted text block describing available tools and how to +/// invoke them using XML-style tags. This is used as a fallback when the +/// provider doesn't support native tool calling. +pub fn build_tool_instructions_text(tools: &[ToolSpec]) -> String { + let mut instructions = String::new(); + + instructions.push_str("## Tool Use Protocol\n\n"); + instructions.push_str("To use a tool, wrap a JSON object in tags:\n\n"); + instructions.push_str("\n"); + instructions.push_str(r#"{"name": "tool_name", "arguments": {"param": "value"}}"#); + instructions.push_str("\n\n\n"); + instructions.push_str("You may use multiple tool calls in a single response. "); + instructions.push_str("After tool execution, results appear in tags. "); + instructions + .push_str("Continue reasoning with the results until you can give a final answer.\n\n"); + instructions.push_str("### Available Tools\n\n"); + + for tool in tools { + writeln!(&mut instructions, "**{}**: {}", tool.name, tool.description) + .expect("writing to String cannot fail"); + + let parameters = + serde_json::to_string(&tool.parameters).unwrap_or_else(|_| "{}".to_string()); + writeln!(&mut instructions, "Parameters: `{parameters}`") + .expect("writing to String cannot fail"); + instructions.push('\n'); + } + + instructions +} + +#[cfg(test)] +mod tests { + use super::*; + + struct CapabilityMockProvider; + + #[async_trait] + impl Provider for CapabilityMockProvider { + fn capabilities(&self) -> ProviderCapabilities { + ProviderCapabilities { + native_tool_calling: true, + } + } + + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + Ok("ok".into()) + } + } + + #[test] + fn chat_message_constructors() { + let sys = ChatMessage::system("Be helpful"); + assert_eq!(sys.role, "system"); + assert_eq!(sys.content, "Be helpful"); + + let user = ChatMessage::user("Hello"); + assert_eq!(user.role, "user"); + + let asst = ChatMessage::assistant("Hi there"); + assert_eq!(asst.role, "assistant"); + + let tool = ChatMessage::tool("{}"); + assert_eq!(tool.role, "tool"); + } + + #[test] + fn chat_response_helpers() { + let empty = ChatResponse { + text: None, + tool_calls: vec![], + }; + assert!(!empty.has_tool_calls()); + assert_eq!(empty.text_or_empty(), ""); + + let with_tools = ChatResponse { + text: Some("Let me check".into()), + tool_calls: vec![ToolCall { + id: "1".into(), + name: "shell".into(), + arguments: "{}".into(), + }], + }; + assert!(with_tools.has_tool_calls()); + assert_eq!(with_tools.text_or_empty(), "Let me check"); + } + + #[test] + fn tool_call_serialization() { + let tc = ToolCall { + id: "call_123".into(), + name: "file_read".into(), + arguments: r#"{"path":"test.txt"}"#.into(), + }; + let json = serde_json::to_string(&tc).unwrap(); + assert!(json.contains("call_123")); + assert!(json.contains("file_read")); + } + + #[test] + fn conversation_message_variants() { + let chat = ConversationMessage::Chat(ChatMessage::user("hi")); + let json = serde_json::to_string(&chat).unwrap(); + assert!(json.contains("\"type\":\"Chat\"")); + + let tool_result = ConversationMessage::ToolResults(vec![ToolResultMessage { + tool_call_id: "1".into(), + content: "done".into(), + }]); + let json = serde_json::to_string(&tool_result).unwrap(); + assert!(json.contains("\"type\":\"ToolResults\"")); + } + + #[test] + fn provider_capabilities_default() { + let caps = ProviderCapabilities::default(); + assert!(!caps.native_tool_calling); + } + + #[test] + fn provider_capabilities_equality() { + let caps1 = ProviderCapabilities { + native_tool_calling: true, + }; + let caps2 = ProviderCapabilities { + native_tool_calling: true, + }; + let caps3 = ProviderCapabilities { + native_tool_calling: false, + }; + + assert_eq!(caps1, caps2); + assert_ne!(caps1, caps3); + } + + #[test] + fn supports_native_tools_reflects_capabilities_default_mapping() { + let provider = CapabilityMockProvider; + assert!(provider.supports_native_tools()); + } + + #[test] + fn tools_payload_variants() { + // Test Gemini variant + let gemini = ToolsPayload::Gemini { + function_declarations: vec![serde_json::json!({"name": "test"})], + }; + assert!(matches!(gemini, ToolsPayload::Gemini { .. })); + + // Test Anthropic variant + let anthropic = ToolsPayload::Anthropic { + tools: vec![serde_json::json!({"name": "test"})], + }; + assert!(matches!(anthropic, ToolsPayload::Anthropic { .. })); + + // Test OpenAI variant + let openai = ToolsPayload::OpenAI { + tools: vec![serde_json::json!({"type": "function"})], + }; + assert!(matches!(openai, ToolsPayload::OpenAI { .. })); + + // Test PromptGuided variant + let prompt_guided = ToolsPayload::PromptGuided { + instructions: "Use tools...".to_string(), + }; + assert!(matches!(prompt_guided, ToolsPayload::PromptGuided { .. })); + } + + #[test] + fn build_tool_instructions_text_format() { + let tools = vec![ + ToolSpec { + name: "shell".to_string(), + description: "Execute commands".to_string(), + parameters: serde_json::json!({ + "type": "object", + "properties": { + "command": {"type": "string"} + } + }), + }, + ToolSpec { + name: "file_read".to_string(), + description: "Read files".to_string(), + parameters: serde_json::json!({ + "type": "object", + "properties": { + "path": {"type": "string"} + } + }), + }, + ]; + + let instructions = build_tool_instructions_text(&tools); + + // Check for protocol description + assert!(instructions.contains("Tool Use Protocol")); + assert!(instructions.contains("")); + assert!(instructions.contains("")); + + // Check for tool listings + assert!(instructions.contains("**shell**")); + assert!(instructions.contains("Execute commands")); + assert!(instructions.contains("**file_read**")); + assert!(instructions.contains("Read files")); + + // Check for parameters + assert!(instructions.contains("Parameters:")); + assert!(instructions.contains(r#""type":"object""#)); + } + + #[test] + fn build_tool_instructions_text_empty() { + let instructions = build_tool_instructions_text(&[]); + + // Should still have protocol description + assert!(instructions.contains("Tool Use Protocol")); + + // Should have empty tools section + assert!(instructions.contains("Available Tools")); + } + + // Mock provider for testing. + struct MockProvider { + supports_native: bool, + } + + #[async_trait] + impl Provider for MockProvider { + fn supports_native_tools(&self) -> bool { + self.supports_native + } + + async fn chat_with_system( + &self, + _system: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + Ok("response".to_string()) + } + } + + #[test] + fn provider_convert_tools_default() { + let provider = MockProvider { + supports_native: false, + }; + + let tools = vec![ToolSpec { + name: "test_tool".to_string(), + description: "A test tool".to_string(), + parameters: serde_json::json!({"type": "object"}), + }]; + + let payload = provider.convert_tools(&tools); + + // Default implementation should return PromptGuided. + assert!(matches!(payload, ToolsPayload::PromptGuided { .. })); + + if let ToolsPayload::PromptGuided { instructions } = payload { + assert!(instructions.contains("test_tool")); + assert!(instructions.contains("A test tool")); + } + } + + #[tokio::test] + async fn provider_chat_prompt_guided_fallback() { + let provider = MockProvider { + supports_native: false, + }; + + let tools = vec![ToolSpec { + name: "shell".to_string(), + description: "Run commands".to_string(), + parameters: serde_json::json!({"type": "object"}), + }]; + + let request = ChatRequest { + messages: &[ChatMessage::user("Hello")], + tools: Some(&tools), + }; + + let response = provider.chat(request, "model", 0.7).await.unwrap(); + + // Should return a response (default impl calls chat_with_history). + assert!(response.text.is_some()); + } + + #[tokio::test] + async fn provider_chat_without_tools() { + let provider = MockProvider { + supports_native: true, + }; + + let request = ChatRequest { + messages: &[ChatMessage::user("Hello")], + tools: None, + }; + + let response = provider.chat(request, "model", 0.7).await.unwrap(); + + // Should work normally without tools. + assert!(response.text.is_some()); + } + + // Provider that echoes the system prompt for assertions. + struct EchoSystemProvider { + supports_native: bool, + } + + #[async_trait] + impl Provider for EchoSystemProvider { + fn supports_native_tools(&self) -> bool { + self.supports_native + } + + async fn chat_with_system( + &self, + system: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + Ok(system.unwrap_or_default().to_string()) + } + } + + // Provider with custom prompt-guided conversion. + struct CustomConvertProvider; + + #[async_trait] + impl Provider for CustomConvertProvider { + fn supports_native_tools(&self) -> bool { + false + } + + fn convert_tools(&self, _tools: &[ToolSpec]) -> ToolsPayload { + ToolsPayload::PromptGuided { + instructions: "CUSTOM_TOOL_INSTRUCTIONS".to_string(), + } + } + + async fn chat_with_system( + &self, + system: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + Ok(system.unwrap_or_default().to_string()) + } + } + + // Provider returning an invalid payload for non-native mode. + struct InvalidConvertProvider; + + #[async_trait] + impl Provider for InvalidConvertProvider { + fn supports_native_tools(&self) -> bool { + false + } + + fn convert_tools(&self, _tools: &[ToolSpec]) -> ToolsPayload { + ToolsPayload::OpenAI { + tools: vec![serde_json::json!({"type": "function"})], + } + } + + async fn chat_with_system( + &self, + _system: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + Ok("should_not_reach".to_string()) + } + } + + #[tokio::test] + async fn provider_chat_prompt_guided_preserves_existing_system_not_first() { + let provider = EchoSystemProvider { + supports_native: false, + }; + + let tools = vec![ToolSpec { + name: "shell".to_string(), + description: "Run commands".to_string(), + parameters: serde_json::json!({"type": "object"}), + }]; + + let request = ChatRequest { + messages: &[ + ChatMessage::user("Hello"), + ChatMessage::system("BASE_SYSTEM_PROMPT"), + ], + tools: Some(&tools), + }; + + let response = provider.chat(request, "model", 0.7).await.unwrap(); + let text = response.text.unwrap_or_default(); + + assert!(text.contains("BASE_SYSTEM_PROMPT")); + assert!(text.contains("Tool Use Protocol")); + } + + #[tokio::test] + async fn provider_chat_prompt_guided_uses_convert_tools_override() { + let provider = CustomConvertProvider; + + let tools = vec![ToolSpec { + name: "shell".to_string(), + description: "Run commands".to_string(), + parameters: serde_json::json!({"type": "object"}), + }]; + + let request = ChatRequest { + messages: &[ChatMessage::system("BASE"), ChatMessage::user("Hello")], + tools: Some(&tools), + }; + + let response = provider.chat(request, "model", 0.7).await.unwrap(); + let text = response.text.unwrap_or_default(); + + assert!(text.contains("BASE")); + assert!(text.contains("CUSTOM_TOOL_INSTRUCTIONS")); + } + + #[tokio::test] + async fn provider_chat_prompt_guided_rejects_non_prompt_payload() { + let provider = InvalidConvertProvider; + + let tools = vec![ToolSpec { + name: "shell".to_string(), + description: "Run commands".to_string(), + parameters: serde_json::json!({"type": "object"}), + }]; + + let request = ChatRequest { + messages: &[ChatMessage::user("Hello")], + tools: Some(&tools), + }; + + let err = provider.chat(request, "model", 0.7).await.unwrap_err(); + let message = err.to_string(); + + assert!(message.contains("non-prompt-guided")); + } } diff --git a/src/rag/mod.rs b/src/rag/mod.rs new file mode 100644 index 0000000..19254f8 --- /dev/null +++ b/src/rag/mod.rs @@ -0,0 +1,395 @@ +//! RAG pipeline for hardware datasheet retrieval. +//! +//! Supports: +//! - Markdown and text datasheets (always) +//! - PDF ingestion (with `rag-pdf` feature) +//! - Pin/alias tables (e.g. `red_led: 13`) for explicit lookup +//! - Keyword retrieval (default) or semantic search via embeddings (optional) + +use crate::memory::chunker; +use std::collections::HashMap; +use std::path::Path; + +/// A chunk of datasheet content with board metadata. +#[derive(Debug, Clone)] +pub struct DatasheetChunk { + /// Board this chunk applies to (e.g. "nucleo-f401re", "rpi-gpio"), or None for generic. + pub board: Option, + /// Source file path (for debugging). + pub source: String, + /// Chunk content. + pub content: String, +} + +/// Pin alias: human-readable name → pin number (e.g. "red_led" → 13). +pub type PinAliases = HashMap; + +/// Parse pin aliases from markdown. Looks for: +/// - `## Pin Aliases` section with `alias: pin` lines +/// - Markdown table `| alias | pin |` +fn parse_pin_aliases(content: &str) -> PinAliases { + let mut aliases = PinAliases::new(); + let content_lower = content.to_lowercase(); + + // Find ## Pin Aliases section + let section_markers = ["## pin aliases", "## pin alias", "## pins"]; + let mut in_section = false; + let mut section_start = 0; + + for marker in section_markers { + if let Some(pos) = content_lower.find(marker) { + in_section = true; + section_start = pos + marker.len(); + break; + } + } + + if !in_section { + return aliases; + } + + let rest = &content[section_start..]; + let section_end = rest + .find("\n## ") + .map(|i| section_start + i) + .unwrap_or(content.len()); + let section = &content[section_start..section_end]; + + // Parse "alias: pin" or "alias = pin" lines + for line in section.lines() { + let line = line.trim(); + if line.is_empty() { + continue; + } + // Table row: | red_led | 13 | (skip header | alias | pin | and separator |---|) + if line.starts_with('|') { + let parts: Vec<&str> = line.split('|').map(|s| s.trim()).collect(); + if parts.len() >= 3 { + let alias = parts[1].trim().to_lowercase().replace(' ', "_"); + let pin_str = parts[2].trim(); + // Skip header row and separator (|---|) + if alias.eq("alias") + || alias.eq("pin") + || pin_str.eq("pin") + || alias.contains("---") + || pin_str.contains("---") + { + continue; + } + if let Ok(pin) = pin_str.parse::() { + if !alias.is_empty() { + aliases.insert(alias, pin); + } + } + } + continue; + } + // Key: value + if let Some((k, v)) = line.split_once(':').or_else(|| line.split_once('=')) { + let alias = k.trim().to_lowercase().replace(' ', "_"); + if let Ok(pin) = v.trim().parse::() { + if !alias.is_empty() { + aliases.insert(alias, pin); + } + } + } + } + + aliases +} + +fn collect_md_txt_paths(dir: &Path, out: &mut Vec) { + let Ok(entries) = std::fs::read_dir(dir) else { + return; + }; + for entry in entries.flatten() { + let path = entry.path(); + if path.is_dir() { + collect_md_txt_paths(&path, out); + } else if path.is_file() { + let ext = path.extension().and_then(|e| e.to_str()); + if ext == Some("md") || ext == Some("txt") { + out.push(path); + } + } + } +} + +#[cfg(feature = "rag-pdf")] +fn collect_pdf_paths(dir: &Path, out: &mut Vec) { + let Ok(entries) = std::fs::read_dir(dir) else { + return; + }; + for entry in entries.flatten() { + let path = entry.path(); + if path.is_dir() { + collect_pdf_paths(&path, out); + } else if path.is_file() { + if path.extension().and_then(|e| e.to_str()) == Some("pdf") { + out.push(path); + } + } + } +} + +#[cfg(feature = "rag-pdf")] +fn extract_pdf_text(path: &Path) -> Option { + let bytes = std::fs::read(path).ok()?; + pdf_extract::extract_text_from_mem(&bytes).ok() +} + +/// Hardware RAG index — loads and retrieves datasheet chunks. +pub struct HardwareRag { + chunks: Vec, + /// Per-board pin aliases (board -> alias -> pin). + pin_aliases: HashMap, +} + +impl HardwareRag { + /// Load datasheets from a directory. Expects .md, .txt, and optionally .pdf (with rag-pdf). + /// Filename (without extension) is used as board tag. + /// Supports `## Pin Aliases` section for explicit alias→pin mapping. + pub fn load(workspace_dir: &Path, datasheet_dir: &str) -> anyhow::Result { + let base = workspace_dir.join(datasheet_dir); + if !base.exists() || !base.is_dir() { + return Ok(Self { + chunks: Vec::new(), + pin_aliases: HashMap::new(), + }); + } + + let mut paths: Vec = Vec::new(); + collect_md_txt_paths(&base, &mut paths); + #[cfg(feature = "rag-pdf")] + collect_pdf_paths(&base, &mut paths); + + let mut chunks = Vec::new(); + let mut pin_aliases: HashMap = HashMap::new(); + let max_tokens = 512; + + for path in paths { + let content = if path.extension().and_then(|e| e.to_str()) == Some("pdf") { + #[cfg(feature = "rag-pdf")] + { + extract_pdf_text(&path).unwrap_or_default() + } + #[cfg(not(feature = "rag-pdf"))] + { + String::new() + } + } else { + std::fs::read_to_string(&path).unwrap_or_default() + }; + + if content.trim().is_empty() { + continue; + } + + let board = infer_board_from_path(&path, &base); + let source = path + .strip_prefix(workspace_dir) + .unwrap_or(&path) + .display() + .to_string(); + + // Parse pin aliases from full content + let aliases = parse_pin_aliases(&content); + if let Some(ref b) = board { + if !aliases.is_empty() { + pin_aliases.insert(b.clone(), aliases); + } + } + + for chunk in chunker::chunk_markdown(&content, max_tokens) { + chunks.push(DatasheetChunk { + board: board.clone(), + source: source.clone(), + content: chunk.content, + }); + } + } + + Ok(Self { + chunks, + pin_aliases, + }) + } + + /// Get pin aliases for a board (e.g. "red_led" -> 13). + pub fn pin_aliases_for_board(&self, board: &str) -> Option<&PinAliases> { + self.pin_aliases.get(board) + } + + /// Build pin-alias context for query. When user says "red led", inject "red_led: 13" for matching boards. + pub fn pin_alias_context(&self, query: &str, boards: &[String]) -> String { + let query_lower = query.to_lowercase(); + let query_words: Vec<&str> = query_lower + .split_whitespace() + .filter(|w| w.len() > 1) + .collect(); + + let mut lines = Vec::new(); + for board in boards { + if let Some(aliases) = self.pin_aliases.get(board) { + for (alias, pin) in aliases { + let alias_words: Vec<&str> = alias.split('_').collect(); + let matches = query_words.iter().any(|qw| alias_words.contains(qw)) + || query_lower.contains(&alias.replace('_', " ")); + if matches { + lines.push(format!("{board}: {alias} = pin {pin}")); + } + } + } + } + if lines.is_empty() { + return String::new(); + } + format!("[Pin aliases for query]\n{}\n\n", lines.join("\n")) + } + + /// Retrieve chunks relevant to the query and boards. + /// Uses keyword matching and board filter. Pin-alias context is built separately via `pin_alias_context`. + pub fn retrieve(&self, query: &str, boards: &[String], limit: usize) -> Vec<&DatasheetChunk> { + if self.chunks.is_empty() || limit == 0 { + return Vec::new(); + } + + let query_lower = query.to_lowercase(); + let query_terms: Vec<&str> = query_lower + .split_whitespace() + .filter(|w| w.len() > 2) + .collect(); + + let mut scored: Vec<(&DatasheetChunk, f32)> = Vec::new(); + for chunk in &self.chunks { + let content_lower = chunk.content.to_lowercase(); + let mut score = 0.0f32; + + for term in &query_terms { + if content_lower.contains(term) { + score += 1.0; + } + } + + if score > 0.0 { + let board_match = chunk.board.as_ref().map_or(false, |b| boards.contains(b)); + if board_match { + score += 2.0; + } + scored.push((chunk, score)); + } + } + + scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + scored.truncate(limit); + scored.into_iter().map(|(c, _)| c).collect() + } + + /// Number of indexed chunks. + pub fn len(&self) -> usize { + self.chunks.len() + } + + /// True if no chunks are indexed. + pub fn is_empty(&self) -> bool { + self.chunks.is_empty() + } +} + +/// Infer board tag from file path. `nucleo-f401re.md` → Some("nucleo-f401re"). +fn infer_board_from_path(path: &Path, base: &Path) -> Option { + let rel = path.strip_prefix(base).ok()?; + let stem = path.file_stem()?.to_str()?; + + if stem == "generic" || stem.starts_with("generic_") { + return None; + } + if rel.parent().and_then(|p| p.to_str()) == Some("_generic") { + return None; + } + + Some(stem.to_string()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_pin_aliases_key_value() { + let md = r#"## Pin Aliases +red_led: 13 +builtin_led: 13 +user_led: 5"#; + let a = parse_pin_aliases(md); + assert_eq!(a.get("red_led"), Some(&13)); + assert_eq!(a.get("builtin_led"), Some(&13)); + assert_eq!(a.get("user_led"), Some(&5)); + } + + #[test] + fn parse_pin_aliases_table() { + let md = r#"## Pin Aliases +| alias | pin | +|-------|-----| +| red_led | 13 | +| builtin_led | 13 |"#; + let a = parse_pin_aliases(md); + assert_eq!(a.get("red_led"), Some(&13)); + assert_eq!(a.get("builtin_led"), Some(&13)); + } + + #[test] + fn parse_pin_aliases_empty() { + let a = parse_pin_aliases("No aliases here"); + assert!(a.is_empty()); + } + + #[test] + fn infer_board_from_path_nucleo() { + let base = std::path::Path::new("/base"); + let path = std::path::Path::new("/base/nucleo-f401re.md"); + assert_eq!( + infer_board_from_path(path, base), + Some("nucleo-f401re".into()) + ); + } + + #[test] + fn infer_board_generic_none() { + let base = std::path::Path::new("/base"); + let path = std::path::Path::new("/base/generic.md"); + assert_eq!(infer_board_from_path(path, base), None); + } + + #[test] + fn hardware_rag_load_and_retrieve() { + let tmp = tempfile::tempdir().unwrap(); + let base = tmp.path().join("datasheets"); + std::fs::create_dir_all(&base).unwrap(); + let content = r#"# Test Board +## Pin Aliases +red_led: 13 +## GPIO +Pin 13: LED +"#; + std::fs::write(base.join("test-board.md"), content).unwrap(); + + let rag = HardwareRag::load(tmp.path(), "datasheets").unwrap(); + assert!(!rag.is_empty()); + let boards = vec!["test-board".to_string()]; + let chunks = rag.retrieve("led", &boards, 5); + assert!(!chunks.is_empty()); + let ctx = rag.pin_alias_context("red led", &boards); + assert!(ctx.contains("13")); + } + + #[test] + fn hardware_rag_load_empty_dir() { + let tmp = tempfile::tempdir().unwrap(); + let base = tmp.path().join("empty_ds"); + std::fs::create_dir_all(&base).unwrap(); + let rag = HardwareRag::load(tmp.path(), "empty_ds").unwrap(); + assert!(rag.is_empty()); + } +} diff --git a/src/runtime/docker.rs b/src/runtime/docker.rs new file mode 100644 index 0000000..eaa3d09 --- /dev/null +++ b/src/runtime/docker.rs @@ -0,0 +1,199 @@ +use super::traits::RuntimeAdapter; +use crate::config::DockerRuntimeConfig; +use anyhow::{Context, Result}; +use std::path::{Path, PathBuf}; + +/// Docker runtime with lightweight container isolation. +#[derive(Debug, Clone)] +pub struct DockerRuntime { + config: DockerRuntimeConfig, +} + +impl DockerRuntime { + pub fn new(config: DockerRuntimeConfig) -> Self { + Self { config } + } + + fn workspace_mount_path(&self, workspace_dir: &Path) -> Result { + let resolved = workspace_dir + .canonicalize() + .unwrap_or_else(|_| workspace_dir.to_path_buf()); + + if !resolved.is_absolute() { + anyhow::bail!( + "Docker runtime requires an absolute workspace path, got: {}", + resolved.display() + ); + } + + if resolved == Path::new("/") { + anyhow::bail!("Refusing to mount filesystem root (/) into docker runtime"); + } + + if self.config.allowed_workspace_roots.is_empty() { + return Ok(resolved); + } + + let allowed = self.config.allowed_workspace_roots.iter().any(|root| { + let root_path = Path::new(root) + .canonicalize() + .unwrap_or_else(|_| PathBuf::from(root)); + resolved.starts_with(root_path) + }); + + if !allowed { + anyhow::bail!( + "Workspace path {} is not in runtime.docker.allowed_workspace_roots", + resolved.display() + ); + } + + Ok(resolved) + } +} + +impl RuntimeAdapter for DockerRuntime { + fn name(&self) -> &str { + "docker" + } + + fn has_shell_access(&self) -> bool { + true + } + + fn has_filesystem_access(&self) -> bool { + self.config.mount_workspace + } + + fn storage_path(&self) -> PathBuf { + if self.config.mount_workspace { + PathBuf::from("/workspace/.zeroclaw") + } else { + PathBuf::from("/tmp/.zeroclaw") + } + } + + fn supports_long_running(&self) -> bool { + false + } + + fn memory_budget(&self) -> u64 { + self.config + .memory_limit_mb + .map_or(0, |mb| mb.saturating_mul(1024 * 1024)) + } + + fn build_shell_command( + &self, + command: &str, + workspace_dir: &Path, + ) -> anyhow::Result { + let mut process = tokio::process::Command::new("docker"); + process + .arg("run") + .arg("--rm") + .arg("--init") + .arg("--interactive"); + + let network = self.config.network.trim(); + if !network.is_empty() { + process.arg("--network").arg(network); + } + + if let Some(memory_limit_mb) = self.config.memory_limit_mb.filter(|mb| *mb > 0) { + process.arg("--memory").arg(format!("{memory_limit_mb}m")); + } + + if let Some(cpu_limit) = self.config.cpu_limit.filter(|cpus| *cpus > 0.0) { + process.arg("--cpus").arg(cpu_limit.to_string()); + } + + if self.config.read_only_rootfs { + process.arg("--read-only"); + } + + if self.config.mount_workspace { + let host_workspace = self.workspace_mount_path(workspace_dir).with_context(|| { + format!( + "Failed to validate workspace mount path {}", + workspace_dir.display() + ) + })?; + + process + .arg("--volume") + .arg(format!("{}:/workspace:rw", host_workspace.display())) + .arg("--workdir") + .arg("/workspace"); + } + + process + .arg(self.config.image.trim()) + .arg("sh") + .arg("-c") + .arg(command); + + Ok(process) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn docker_runtime_name() { + let runtime = DockerRuntime::new(DockerRuntimeConfig::default()); + assert_eq!(runtime.name(), "docker"); + } + + #[test] + fn docker_runtime_memory_budget() { + let mut cfg = DockerRuntimeConfig::default(); + cfg.memory_limit_mb = Some(256); + let runtime = DockerRuntime::new(cfg); + assert_eq!(runtime.memory_budget(), 256 * 1024 * 1024); + } + + #[test] + fn docker_build_shell_command_includes_runtime_flags() { + let cfg = DockerRuntimeConfig { + image: "alpine:3.20".into(), + network: "none".into(), + memory_limit_mb: Some(128), + cpu_limit: Some(1.5), + read_only_rootfs: true, + mount_workspace: true, + allowed_workspace_roots: Vec::new(), + }; + let runtime = DockerRuntime::new(cfg); + + let workspace = std::env::temp_dir(); + let command = runtime + .build_shell_command("echo hello", &workspace) + .unwrap(); + let debug = format!("{command:?}"); + + assert!(debug.contains("docker")); + assert!(debug.contains("--memory")); + assert!(debug.contains("128m")); + assert!(debug.contains("--cpus")); + assert!(debug.contains("1.5")); + assert!(debug.contains("--workdir")); + assert!(debug.contains("echo hello")); + } + + #[test] + fn docker_workspace_allowlist_blocks_outside_paths() { + let cfg = DockerRuntimeConfig { + allowed_workspace_roots: vec!["/tmp/allowed".into()], + ..DockerRuntimeConfig::default() + }; + let runtime = DockerRuntime::new(cfg); + + let outside = PathBuf::from("/tmp/blocked_workspace"); + let result = runtime.build_shell_command("echo test", &outside); + + assert!(result.is_err()); + } +} diff --git a/src/runtime/mod.rs b/src/runtime/mod.rs index 9ed0ee0..cea7aa3 100644 --- a/src/runtime/mod.rs +++ b/src/runtime/mod.rs @@ -1,6 +1,8 @@ +pub mod docker; pub mod native; pub mod traits; +pub use docker::DockerRuntime; pub use native::NativeRuntime; pub use traits::RuntimeAdapter; @@ -10,18 +12,14 @@ use crate::config::RuntimeConfig; pub fn create_runtime(config: &RuntimeConfig) -> anyhow::Result> { match config.kind.as_str() { "native" => Ok(Box::new(NativeRuntime::new())), - "docker" => anyhow::bail!( - "runtime.kind='docker' is not implemented yet. Use runtime.kind='native' until container runtime support lands." - ), + "docker" => Ok(Box::new(DockerRuntime::new(config.docker.clone()))), "cloudflare" => anyhow::bail!( "runtime.kind='cloudflare' is not implemented yet. Use runtime.kind='native' for now." ), - other if other.trim().is_empty() => anyhow::bail!( - "runtime.kind cannot be empty. Supported values: native" - ), - other => anyhow::bail!( - "Unknown runtime kind '{other}'. Supported values: native" - ), + other if other.trim().is_empty() => { + anyhow::bail!("runtime.kind cannot be empty. Supported values: native, docker") + } + other => anyhow::bail!("Unknown runtime kind '{other}'. Supported values: native, docker"), } } @@ -33,6 +31,7 @@ mod tests { fn factory_native() { let cfg = RuntimeConfig { kind: "native".into(), + ..RuntimeConfig::default() }; let rt = create_runtime(&cfg).unwrap(); assert_eq!(rt.name(), "native"); @@ -40,20 +39,21 @@ mod tests { } #[test] - fn factory_docker_errors() { + fn factory_docker() { let cfg = RuntimeConfig { kind: "docker".into(), + ..RuntimeConfig::default() }; - match create_runtime(&cfg) { - Err(err) => assert!(err.to_string().contains("not implemented")), - Ok(_) => panic!("docker runtime should error"), - } + let rt = create_runtime(&cfg).unwrap(); + assert_eq!(rt.name(), "docker"); + assert!(rt.has_shell_access()); } #[test] fn factory_cloudflare_errors() { let cfg = RuntimeConfig { kind: "cloudflare".into(), + ..RuntimeConfig::default() }; match create_runtime(&cfg) { Err(err) => assert!(err.to_string().contains("not implemented")), @@ -65,6 +65,7 @@ mod tests { fn factory_unknown_errors() { let cfg = RuntimeConfig { kind: "wasm-edge-unknown".into(), + ..RuntimeConfig::default() }; match create_runtime(&cfg) { Err(err) => assert!(err.to_string().contains("Unknown runtime kind")), @@ -76,6 +77,7 @@ mod tests { fn factory_empty_errors() { let cfg = RuntimeConfig { kind: String::new(), + ..RuntimeConfig::default() }; match create_runtime(&cfg) { Err(err) => assert!(err.to_string().contains("cannot be empty")), diff --git a/src/runtime/native.rs b/src/runtime/native.rs index 4b0ef3c..927c895 100644 --- a/src/runtime/native.rs +++ b/src/runtime/native.rs @@ -1,5 +1,5 @@ use super::traits::RuntimeAdapter; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; /// Native runtime — full access, runs on Mac/Linux/Docker/Raspberry Pi pub struct NativeRuntime; @@ -33,6 +33,16 @@ impl RuntimeAdapter for NativeRuntime { fn supports_long_running(&self) -> bool { true } + + fn build_shell_command( + &self, + command: &str, + workspace_dir: &Path, + ) -> anyhow::Result { + let mut process = tokio::process::Command::new("sh"); + process.arg("-c").arg(command).current_dir(workspace_dir); + Ok(process) + } } #[cfg(test)] @@ -69,4 +79,14 @@ mod tests { let path = NativeRuntime::new().storage_path(); assert!(path.to_string_lossy().contains("zeroclaw")); } + + #[test] + fn native_builds_shell_command() { + let cwd = std::env::temp_dir(); + let command = NativeRuntime::new() + .build_shell_command("echo hello", &cwd) + .unwrap(); + let debug = format!("{command:?}"); + assert!(debug.contains("echo hello")); + } } diff --git a/src/runtime/traits.rs b/src/runtime/traits.rs index cbff5b1..153c06f 100644 --- a/src/runtime/traits.rs +++ b/src/runtime/traits.rs @@ -1,4 +1,4 @@ -use std::path::PathBuf; +use std::path::{Path, PathBuf}; /// Runtime adapter — abstracts platform differences so the same agent /// code runs on native, Docker, Cloudflare Workers, Raspberry Pi, etc. @@ -22,4 +22,82 @@ pub trait RuntimeAdapter: Send + Sync { fn memory_budget(&self) -> u64 { 0 } + + /// Build a shell command process for this runtime. + fn build_shell_command( + &self, + command: &str, + workspace_dir: &Path, + ) -> anyhow::Result; +} + +#[cfg(test)] +mod tests { + use super::*; + + struct DummyRuntime; + + impl RuntimeAdapter for DummyRuntime { + fn name(&self) -> &str { + "dummy-runtime" + } + + fn has_shell_access(&self) -> bool { + true + } + + fn has_filesystem_access(&self) -> bool { + true + } + + fn storage_path(&self) -> PathBuf { + PathBuf::from("/tmp/dummy-runtime") + } + + fn supports_long_running(&self) -> bool { + true + } + + fn build_shell_command( + &self, + command: &str, + workspace_dir: &Path, + ) -> anyhow::Result { + let mut cmd = tokio::process::Command::new("echo"); + cmd.arg(command); + cmd.current_dir(workspace_dir); + Ok(cmd) + } + } + + #[test] + fn default_memory_budget_is_zero() { + let runtime = DummyRuntime; + assert_eq!(runtime.memory_budget(), 0); + } + + #[test] + fn runtime_reports_capabilities() { + let runtime = DummyRuntime; + + assert_eq!(runtime.name(), "dummy-runtime"); + assert!(runtime.has_shell_access()); + assert!(runtime.has_filesystem_access()); + assert!(runtime.supports_long_running()); + assert_eq!(runtime.storage_path(), PathBuf::from("/tmp/dummy-runtime")); + } + + #[tokio::test] + async fn build_shell_command_executes() { + let runtime = DummyRuntime; + let mut cmd = runtime + .build_shell_command("hello-runtime", Path::new(".")) + .unwrap(); + + let output = cmd.output().await.unwrap(); + let stdout = String::from_utf8_lossy(&output.stdout); + + assert!(output.status.success()); + assert!(stdout.contains("hello-runtime")); + } } diff --git a/src/runtime/wasm.rs b/src/runtime/wasm.rs new file mode 100644 index 0000000..6b4c6f3 --- /dev/null +++ b/src/runtime/wasm.rs @@ -0,0 +1,620 @@ +//! WASM sandbox runtime — in-process tool isolation via `wasmi`. +//! +//! Provides capability-based sandboxing without Docker or external runtimes. +//! Each WASM module runs with: +//! - **Fuel limits**: prevents infinite loops (each instruction costs 1 fuel) +//! - **Memory caps**: configurable per-module memory ceiling +//! - **No filesystem access**: by default, tools are pure computation +//! - **No network access**: unless explicitly allowlisted hosts are configured +//! +//! # Feature gate +//! This module is only compiled when `--features runtime-wasm` is enabled. +//! The default ZeroClaw binary excludes it to maintain the 4.6 MB size target. + +use super::traits::RuntimeAdapter; +use crate::config::WasmRuntimeConfig; +use anyhow::{bail, Context, Result}; +use std::path::{Path, PathBuf}; + +/// WASM sandbox runtime — executes tool modules in an isolated interpreter. +#[derive(Debug, Clone)] +pub struct WasmRuntime { + config: WasmRuntimeConfig, + workspace_dir: Option, +} + +/// Result of executing a WASM module. +#[derive(Debug, Clone)] +pub struct WasmExecutionResult { + /// Standard output captured from the module (if WASI is used) + pub stdout: String, + /// Standard error captured from the module + pub stderr: String, + /// Exit code (0 = success) + pub exit_code: i32, + /// Fuel consumed during execution + pub fuel_consumed: u64, +} + +/// Capabilities granted to a WASM tool module. +#[derive(Debug, Clone, Default)] +pub struct WasmCapabilities { + /// Allow reading files from workspace + pub read_workspace: bool, + /// Allow writing files to workspace + pub write_workspace: bool, + /// Allowed HTTP hosts (empty = no network) + pub allowed_hosts: Vec, + /// Custom fuel override (0 = use config default) + pub fuel_override: u64, + /// Custom memory override in MB (0 = use config default) + pub memory_override_mb: u64, +} + +impl WasmRuntime { + /// Create a new WASM runtime with the given configuration. + pub fn new(config: WasmRuntimeConfig) -> Self { + Self { + config, + workspace_dir: None, + } + } + + /// Create a WASM runtime bound to a specific workspace directory. + pub fn with_workspace(config: WasmRuntimeConfig, workspace_dir: PathBuf) -> Self { + Self { + config, + workspace_dir: Some(workspace_dir), + } + } + + /// Check if the WASM runtime feature is available in this build. + pub fn is_available() -> bool { + cfg!(feature = "runtime-wasm") + } + + /// Validate the WASM config for common misconfigurations. + pub fn validate_config(&self) -> Result<()> { + if self.config.memory_limit_mb == 0 { + bail!("runtime.wasm.memory_limit_mb must be > 0"); + } + if self.config.memory_limit_mb > 4096 { + bail!( + "runtime.wasm.memory_limit_mb of {} exceeds the 4 GB safety limit for 32-bit WASM", + self.config.memory_limit_mb + ); + } + if self.config.tools_dir.is_empty() { + bail!("runtime.wasm.tools_dir cannot be empty"); + } + // Verify tools directory doesn't escape workspace + if self.config.tools_dir.contains("..") { + bail!("runtime.wasm.tools_dir must not contain '..' path traversal"); + } + Ok(()) + } + + /// Resolve the absolute path to the WASM tools directory. + pub fn tools_dir(&self, workspace_dir: &Path) -> PathBuf { + workspace_dir.join(&self.config.tools_dir) + } + + /// Build capabilities from config defaults. + pub fn default_capabilities(&self) -> WasmCapabilities { + WasmCapabilities { + read_workspace: self.config.allow_workspace_read, + write_workspace: self.config.allow_workspace_write, + allowed_hosts: self.config.allowed_hosts.clone(), + fuel_override: 0, + memory_override_mb: 0, + } + } + + /// Get the effective fuel limit for an invocation. + pub fn effective_fuel(&self, caps: &WasmCapabilities) -> u64 { + if caps.fuel_override > 0 { + caps.fuel_override + } else { + self.config.fuel_limit + } + } + + /// Get the effective memory limit in bytes. + pub fn effective_memory_bytes(&self, caps: &WasmCapabilities) -> u64 { + let mb = if caps.memory_override_mb > 0 { + caps.memory_override_mb + } else { + self.config.memory_limit_mb + }; + mb.saturating_mul(1024 * 1024) + } + + /// Execute a WASM module from the tools directory. + /// + /// This is the primary entry point for running sandboxed tool code. + /// The module must export a `_start` function (WASI convention) or + /// a custom `run` function that takes no arguments and returns i32. + #[cfg(feature = "runtime-wasm")] + pub fn execute_module( + &self, + module_name: &str, + workspace_dir: &Path, + caps: &WasmCapabilities, + ) -> Result { + use wasmi::{Engine, Linker, Module, Store}; + + // Resolve module path + let tools_path = self.tools_dir(workspace_dir); + let module_path = tools_path.join(format!("{module_name}.wasm")); + + if !module_path.exists() { + bail!( + "WASM module not found: {} (looked in {})", + module_name, + tools_path.display() + ); + } + + // Read module bytes + let wasm_bytes = std::fs::read(&module_path) + .with_context(|| format!("Failed to read WASM module: {}", module_path.display()))?; + + // Validate module size (sanity check) + if wasm_bytes.len() > 50 * 1024 * 1024 { + bail!( + "WASM module {} is {} MB — exceeds 50 MB safety limit", + module_name, + wasm_bytes.len() / (1024 * 1024) + ); + } + + // Configure engine with fuel metering + let mut engine_config = wasmi::Config::default(); + engine_config.consume_fuel(true); + let engine = Engine::new(&engine_config); + + // Parse and validate module + let module = Module::new(&engine, &wasm_bytes[..]) + .with_context(|| format!("Failed to parse WASM module: {module_name}"))?; + + // Create store with fuel budget + let mut store = Store::new(&engine, ()); + let fuel = self.effective_fuel(caps); + if fuel > 0 { + store.set_fuel(fuel).with_context(|| { + format!("Failed to set fuel budget ({fuel}) for module: {module_name}") + })?; + } + + // Link host functions (minimal — pure sandboxing) + let linker = Linker::new(&engine); + + // Instantiate module + let instance = linker + .instantiate(&mut store, &module) + .and_then(|pre| pre.start(&mut store)) + .with_context(|| format!("Failed to instantiate WASM module: {module_name}"))?; + + // Look for exported entry point + let run_fn = instance + .get_typed_func::<(), i32>(&store, "run") + .or_else(|_| instance.get_typed_func::<(), i32>(&store, "_start")) + .with_context(|| { + format!( + "WASM module '{module_name}' must export a 'run() -> i32' or '_start() -> i32' function" + ) + })?; + + // Execute with fuel accounting + let fuel_before = store.get_fuel().unwrap_or(0); + let exit_code = match run_fn.call(&mut store, ()) { + Ok(code) => code, + Err(e) => { + // Check if we ran out of fuel (infinite loop protection) + let fuel_after = store.get_fuel().unwrap_or(0); + if fuel_after == 0 && fuel > 0 { + return Ok(WasmExecutionResult { + stdout: String::new(), + stderr: format!( + "WASM module '{module_name}' exceeded fuel limit ({fuel} ticks) — likely an infinite loop" + ), + exit_code: -1, + fuel_consumed: fuel, + }); + } + bail!("WASM execution error in '{module_name}': {e}"); + } + }; + let fuel_after = store.get_fuel().unwrap_or(0); + let fuel_consumed = fuel_before.saturating_sub(fuel_after); + + Ok(WasmExecutionResult { + stdout: String::new(), // No WASI stdout yet — pure computation + stderr: String::new(), + exit_code, + fuel_consumed, + }) + } + + /// Stub for when the `runtime-wasm` feature is not enabled. + #[cfg(not(feature = "runtime-wasm"))] + pub fn execute_module( + &self, + module_name: &str, + _workspace_dir: &Path, + _caps: &WasmCapabilities, + ) -> Result { + bail!( + "WASM runtime is not available in this build. \ + Rebuild with `cargo build --features runtime-wasm` to enable WASM sandbox support. \ + Module requested: {module_name}" + ) + } + + /// List available WASM tool modules in the tools directory. + pub fn list_modules(&self, workspace_dir: &Path) -> Result> { + let tools_path = self.tools_dir(workspace_dir); + if !tools_path.exists() { + return Ok(Vec::new()); + } + + let mut modules = Vec::new(); + for entry in std::fs::read_dir(&tools_path) + .with_context(|| format!("Failed to read tools dir: {}", tools_path.display()))? + { + let entry = entry?; + let path = entry.path(); + if path.extension().is_some_and(|ext| ext == "wasm") { + if let Some(stem) = path.file_stem() { + modules.push(stem.to_string_lossy().to_string()); + } + } + } + modules.sort(); + Ok(modules) + } +} + +impl RuntimeAdapter for WasmRuntime { + fn name(&self) -> &str { + "wasm" + } + + fn has_shell_access(&self) -> bool { + // WASM sandbox does NOT provide shell access — that's the point + false + } + + fn has_filesystem_access(&self) -> bool { + self.config.allow_workspace_read || self.config.allow_workspace_write + } + + fn storage_path(&self) -> PathBuf { + self.workspace_dir + .as_ref() + .map_or_else(|| PathBuf::from(".zeroclaw"), |w| w.join(".zeroclaw")) + } + + fn supports_long_running(&self) -> bool { + // WASM modules are short-lived invocations, not daemons + false + } + + fn memory_budget(&self) -> u64 { + self.config.memory_limit_mb.saturating_mul(1024 * 1024) + } + + fn build_shell_command( + &self, + _command: &str, + _workspace_dir: &Path, + ) -> anyhow::Result { + bail!( + "WASM runtime does not support shell commands. \ + Use `execute_module()` to run WASM tools, or switch to runtime.kind = \"native\" for shell access." + ) + } +} + +// ── Tests ─────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + fn default_config() -> WasmRuntimeConfig { + WasmRuntimeConfig::default() + } + + // ── Basic trait compliance ────────────────────────────────── + + #[test] + fn wasm_runtime_name() { + let rt = WasmRuntime::new(default_config()); + assert_eq!(rt.name(), "wasm"); + } + + #[test] + fn wasm_no_shell_access() { + let rt = WasmRuntime::new(default_config()); + assert!(!rt.has_shell_access()); + } + + #[test] + fn wasm_no_filesystem_by_default() { + let rt = WasmRuntime::new(default_config()); + assert!(!rt.has_filesystem_access()); + } + + #[test] + fn wasm_filesystem_when_read_enabled() { + let mut cfg = default_config(); + cfg.allow_workspace_read = true; + let rt = WasmRuntime::new(cfg); + assert!(rt.has_filesystem_access()); + } + + #[test] + fn wasm_filesystem_when_write_enabled() { + let mut cfg = default_config(); + cfg.allow_workspace_write = true; + let rt = WasmRuntime::new(cfg); + assert!(rt.has_filesystem_access()); + } + + #[test] + fn wasm_no_long_running() { + let rt = WasmRuntime::new(default_config()); + assert!(!rt.supports_long_running()); + } + + #[test] + fn wasm_memory_budget() { + let rt = WasmRuntime::new(default_config()); + assert_eq!(rt.memory_budget(), 64 * 1024 * 1024); + } + + #[test] + fn wasm_shell_command_errors() { + let rt = WasmRuntime::new(default_config()); + let result = rt.build_shell_command("echo hello", Path::new("/tmp")); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("does not support shell")); + } + + #[test] + fn wasm_storage_path_default() { + let rt = WasmRuntime::new(default_config()); + assert!(rt.storage_path().to_string_lossy().contains("zeroclaw")); + } + + #[test] + fn wasm_storage_path_with_workspace() { + let rt = WasmRuntime::with_workspace(default_config(), PathBuf::from("/home/user/project")); + assert_eq!(rt.storage_path(), PathBuf::from("/home/user/project/.zeroclaw")); + } + + // ── Config validation ────────────────────────────────────── + + #[test] + fn validate_rejects_zero_memory() { + let mut cfg = default_config(); + cfg.memory_limit_mb = 0; + let rt = WasmRuntime::new(cfg); + let err = rt.validate_config().unwrap_err(); + assert!(err.to_string().contains("must be > 0")); + } + + #[test] + fn validate_rejects_excessive_memory() { + let mut cfg = default_config(); + cfg.memory_limit_mb = 8192; + let rt = WasmRuntime::new(cfg); + let err = rt.validate_config().unwrap_err(); + assert!(err.to_string().contains("4 GB safety limit")); + } + + #[test] + fn validate_rejects_empty_tools_dir() { + let mut cfg = default_config(); + cfg.tools_dir = String::new(); + let rt = WasmRuntime::new(cfg); + let err = rt.validate_config().unwrap_err(); + assert!(err.to_string().contains("cannot be empty")); + } + + #[test] + fn validate_rejects_path_traversal() { + let mut cfg = default_config(); + cfg.tools_dir = "../../../etc/passwd".into(); + let rt = WasmRuntime::new(cfg); + let err = rt.validate_config().unwrap_err(); + assert!(err.to_string().contains("path traversal")); + } + + #[test] + fn validate_accepts_valid_config() { + let rt = WasmRuntime::new(default_config()); + assert!(rt.validate_config().is_ok()); + } + + #[test] + fn validate_accepts_max_memory() { + let mut cfg = default_config(); + cfg.memory_limit_mb = 4096; + let rt = WasmRuntime::new(cfg); + assert!(rt.validate_config().is_ok()); + } + + // ── Capabilities & fuel ──────────────────────────────────── + + #[test] + fn effective_fuel_uses_config_default() { + let rt = WasmRuntime::new(default_config()); + let caps = WasmCapabilities::default(); + assert_eq!(rt.effective_fuel(&caps), 1_000_000); + } + + #[test] + fn effective_fuel_respects_override() { + let rt = WasmRuntime::new(default_config()); + let caps = WasmCapabilities { + fuel_override: 500, + ..Default::default() + }; + assert_eq!(rt.effective_fuel(&caps), 500); + } + + #[test] + fn effective_memory_uses_config_default() { + let rt = WasmRuntime::new(default_config()); + let caps = WasmCapabilities::default(); + assert_eq!(rt.effective_memory_bytes(&caps), 64 * 1024 * 1024); + } + + #[test] + fn effective_memory_respects_override() { + let rt = WasmRuntime::new(default_config()); + let caps = WasmCapabilities { + memory_override_mb: 128, + ..Default::default() + }; + assert_eq!(rt.effective_memory_bytes(&caps), 128 * 1024 * 1024); + } + + #[test] + fn default_capabilities_match_config() { + let mut cfg = default_config(); + cfg.allow_workspace_read = true; + cfg.allowed_hosts = vec!["api.example.com".into()]; + let rt = WasmRuntime::new(cfg); + let caps = rt.default_capabilities(); + assert!(caps.read_workspace); + assert!(!caps.write_workspace); + assert_eq!(caps.allowed_hosts, vec!["api.example.com"]); + } + + // ── Tools directory ──────────────────────────────────────── + + #[test] + fn tools_dir_resolves_relative_to_workspace() { + let rt = WasmRuntime::new(default_config()); + let dir = rt.tools_dir(Path::new("/home/user/project")); + assert_eq!(dir, PathBuf::from("/home/user/project/tools/wasm")); + } + + #[test] + fn list_modules_empty_when_dir_missing() { + let rt = WasmRuntime::new(default_config()); + let modules = rt.list_modules(Path::new("/nonexistent/path")).unwrap(); + assert!(modules.is_empty()); + } + + #[test] + fn list_modules_finds_wasm_files() { + let dir = tempfile::tempdir().unwrap(); + let tools_dir = dir.path().join("tools/wasm"); + std::fs::create_dir_all(&tools_dir).unwrap(); + + // Create dummy .wasm files + std::fs::write(tools_dir.join("calculator.wasm"), b"\0asm").unwrap(); + std::fs::write(tools_dir.join("formatter.wasm"), b"\0asm").unwrap(); + std::fs::write(tools_dir.join("readme.txt"), b"not a wasm").unwrap(); + + let rt = WasmRuntime::new(default_config()); + let modules = rt.list_modules(dir.path()).unwrap(); + assert_eq!(modules, vec!["calculator", "formatter"]); + } + + // ── Module execution edge cases ──────────────────────────── + + #[test] + fn execute_module_missing_file() { + let dir = tempfile::tempdir().unwrap(); + let tools_dir = dir.path().join("tools/wasm"); + std::fs::create_dir_all(&tools_dir).unwrap(); + + let rt = WasmRuntime::new(default_config()); + let caps = WasmCapabilities::default(); + let result = rt.execute_module("nonexistent", dir.path(), &caps); + assert!(result.is_err()); + + let err_msg = result.unwrap_err().to_string(); + // Should mention the module name + assert!(err_msg.contains("nonexistent")); + } + + #[test] + fn execute_module_invalid_wasm() { + let dir = tempfile::tempdir().unwrap(); + let tools_dir = dir.path().join("tools/wasm"); + std::fs::create_dir_all(&tools_dir).unwrap(); + + // Write invalid WASM bytes + std::fs::write(tools_dir.join("bad.wasm"), b"not valid wasm bytes at all").unwrap(); + + let rt = WasmRuntime::new(default_config()); + let caps = WasmCapabilities::default(); + let result = rt.execute_module("bad", dir.path(), &caps); + assert!(result.is_err()); + } + + #[test] + fn execute_module_oversized_file() { + let dir = tempfile::tempdir().unwrap(); + let tools_dir = dir.path().join("tools/wasm"); + std::fs::create_dir_all(&tools_dir).unwrap(); + + // Write a file > 50 MB (we just check the size, don't actually allocate) + // This test verifies the check without consuming 50 MB of disk + let rt = WasmRuntime::new(default_config()); + let caps = WasmCapabilities::default(); + + // File doesn't exist for oversized test — the missing file check catches first + // But if it did exist and was 51 MB, the size check would catch it + let result = rt.execute_module("oversized", dir.path(), &caps); + assert!(result.is_err()); + } + + // ── Feature gate check ───────────────────────────────────── + + #[test] + fn is_available_matches_feature_flag() { + // This test verifies the compile-time feature detection works + let available = WasmRuntime::is_available(); + assert_eq!(available, cfg!(feature = "runtime-wasm")); + } + + // ── Memory overflow edge cases ───────────────────────────── + + #[test] + fn memory_budget_no_overflow() { + let mut cfg = default_config(); + cfg.memory_limit_mb = 4096; // Max valid + let rt = WasmRuntime::new(cfg); + assert_eq!(rt.memory_budget(), 4096 * 1024 * 1024); + } + + #[test] + fn effective_memory_saturating() { + let rt = WasmRuntime::new(default_config()); + let caps = WasmCapabilities { + memory_override_mb: u64::MAX, + ..Default::default() + }; + // Should not panic — saturating_mul prevents overflow + let _bytes = rt.effective_memory_bytes(&caps); + } + + // ── WasmCapabilities default ─────────────────────────────── + + #[test] + fn capabilities_default_is_locked_down() { + let caps = WasmCapabilities::default(); + assert!(!caps.read_workspace); + assert!(!caps.write_workspace); + assert!(caps.allowed_hosts.is_empty()); + assert_eq!(caps.fuel_override, 0); + assert_eq!(caps.memory_override_mb, 0); + } +} diff --git a/src/security/audit.rs b/src/security/audit.rs new file mode 100644 index 0000000..5eb2b42 --- /dev/null +++ b/src/security/audit.rs @@ -0,0 +1,335 @@ +//! Audit logging for security events + +use crate::config::AuditConfig; +use anyhow::Result; +use chrono::{DateTime, Utc}; +use parking_lot::Mutex; +use serde::{Deserialize, Serialize}; +use std::fs::OpenOptions; +use std::io::Write; +use std::path::PathBuf; +use uuid::Uuid; + +/// Audit event types +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum AuditEventType { + CommandExecution, + FileAccess, + ConfigChange, + AuthSuccess, + AuthFailure, + PolicyViolation, + SecurityEvent, +} + +/// Actor information (who performed the action) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Actor { + pub channel: String, + pub user_id: Option, + pub username: Option, +} + +/// Action information (what was done) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Action { + pub command: Option, + pub risk_level: Option, + pub approved: bool, + pub allowed: bool, +} + +/// Execution result +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ExecutionResult { + pub success: bool, + pub exit_code: Option, + pub duration_ms: Option, + pub error: Option, +} + +/// Security context +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SecurityContext { + pub policy_violation: bool, + pub rate_limit_remaining: Option, + pub sandbox_backend: Option, +} + +/// Complete audit event +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AuditEvent { + pub timestamp: DateTime, + pub event_id: String, + pub event_type: AuditEventType, + pub actor: Option, + pub action: Option, + pub result: Option, + pub security: SecurityContext, +} + +impl AuditEvent { + /// Create a new audit event + pub fn new(event_type: AuditEventType) -> Self { + Self { + timestamp: Utc::now(), + event_id: Uuid::new_v4().to_string(), + event_type, + actor: None, + action: None, + result: None, + security: SecurityContext { + policy_violation: false, + rate_limit_remaining: None, + sandbox_backend: None, + }, + } + } + + /// Set the actor + pub fn with_actor( + mut self, + channel: String, + user_id: Option, + username: Option, + ) -> Self { + self.actor = Some(Actor { + channel, + user_id, + username, + }); + self + } + + /// Set the action + pub fn with_action( + mut self, + command: String, + risk_level: String, + approved: bool, + allowed: bool, + ) -> Self { + self.action = Some(Action { + command: Some(command), + risk_level: Some(risk_level), + approved, + allowed, + }); + self + } + + /// Set the result + pub fn with_result( + mut self, + success: bool, + exit_code: Option, + duration_ms: u64, + error: Option, + ) -> Self { + self.result = Some(ExecutionResult { + success, + exit_code, + duration_ms: Some(duration_ms), + error, + }); + self + } + + /// Set security context + pub fn with_security(mut self, sandbox_backend: Option) -> Self { + self.security.sandbox_backend = sandbox_backend; + self + } +} + +/// Audit logger +pub struct AuditLogger { + log_path: PathBuf, + config: AuditConfig, + buffer: Mutex>, +} + +/// Structured command execution details for audit logging. +#[derive(Debug, Clone)] +pub struct CommandExecutionLog<'a> { + pub channel: &'a str, + pub command: &'a str, + pub risk_level: &'a str, + pub approved: bool, + pub allowed: bool, + pub success: bool, + pub duration_ms: u64, +} + +impl AuditLogger { + /// Create a new audit logger + pub fn new(config: AuditConfig, zeroclaw_dir: PathBuf) -> Result { + let log_path = zeroclaw_dir.join(&config.log_path); + Ok(Self { + log_path, + config, + buffer: Mutex::new(Vec::new()), + }) + } + + /// Log an event + pub fn log(&self, event: &AuditEvent) -> Result<()> { + if !self.config.enabled { + return Ok(()); + } + + // Check log size and rotate if needed + self.rotate_if_needed()?; + + // Serialize and write + let line = serde_json::to_string(event)?; + let mut file = OpenOptions::new() + .create(true) + .append(true) + .open(&self.log_path)?; + + writeln!(file, "{}", line)?; + file.sync_all()?; + + Ok(()) + } + + /// Log a command execution event. + pub fn log_command_event(&self, entry: CommandExecutionLog<'_>) -> Result<()> { + let event = AuditEvent::new(AuditEventType::CommandExecution) + .with_actor(entry.channel.to_string(), None, None) + .with_action( + entry.command.to_string(), + entry.risk_level.to_string(), + entry.approved, + entry.allowed, + ) + .with_result(entry.success, None, entry.duration_ms, None); + + self.log(&event) + } + + /// Backward-compatible helper to log a command execution event. + #[allow(clippy::too_many_arguments)] + pub fn log_command( + &self, + channel: &str, + command: &str, + risk_level: &str, + approved: bool, + allowed: bool, + success: bool, + duration_ms: u64, + ) -> Result<()> { + self.log_command_event(CommandExecutionLog { + channel, + command, + risk_level, + approved, + allowed, + success, + duration_ms, + }) + } + + /// Rotate log if it exceeds max size + fn rotate_if_needed(&self) -> Result<()> { + if let Ok(metadata) = std::fs::metadata(&self.log_path) { + let current_size_mb = metadata.len() / (1024 * 1024); + if current_size_mb >= u64::from(self.config.max_size_mb) { + self.rotate()?; + } + } + Ok(()) + } + + /// Rotate the log file + fn rotate(&self) -> Result<()> { + for i in (1..10).rev() { + let old_name = format!("{}.{}.log", self.log_path.display(), i); + let new_name = format!("{}.{}.log", self.log_path.display(), i + 1); + let _ = std::fs::rename(&old_name, &new_name); + } + + let rotated = format!("{}.1.log", self.log_path.display()); + std::fs::rename(&self.log_path, &rotated)?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + #[test] + fn audit_event_new_creates_unique_id() { + let event1 = AuditEvent::new(AuditEventType::CommandExecution); + let event2 = AuditEvent::new(AuditEventType::CommandExecution); + assert_ne!(event1.event_id, event2.event_id); + } + + #[test] + fn audit_event_with_actor() { + let event = AuditEvent::new(AuditEventType::CommandExecution).with_actor( + "telegram".to_string(), + Some("123".to_string()), + Some("@alice".to_string()), + ); + + assert!(event.actor.is_some()); + let actor = event.actor.as_ref().unwrap(); + assert_eq!(actor.channel, "telegram"); + assert_eq!(actor.user_id, Some("123".to_string())); + assert_eq!(actor.username, Some("@alice".to_string())); + } + + #[test] + fn audit_event_with_action() { + let event = AuditEvent::new(AuditEventType::CommandExecution).with_action( + "ls -la".to_string(), + "low".to_string(), + false, + true, + ); + + assert!(event.action.is_some()); + let action = event.action.as_ref().unwrap(); + assert_eq!(action.command, Some("ls -la".to_string())); + assert_eq!(action.risk_level, Some("low".to_string())); + } + + #[test] + fn audit_event_serializes_to_json() { + let event = AuditEvent::new(AuditEventType::CommandExecution) + .with_actor("telegram".to_string(), None, None) + .with_action("ls".to_string(), "low".to_string(), false, true) + .with_result(true, Some(0), 15, None); + + let json = serde_json::to_string(&event); + assert!(json.is_ok()); + let json = json.expect("serialize"); + let parsed: AuditEvent = serde_json::from_str(json.as_str()).expect("parse"); + assert!(parsed.actor.is_some()); + assert!(parsed.action.is_some()); + assert!(parsed.result.is_some()); + } + + #[test] + fn audit_logger_disabled_does_not_create_file() -> Result<()> { + let tmp = TempDir::new()?; + let config = AuditConfig { + enabled: false, + ..Default::default() + }; + let logger = AuditLogger::new(config, tmp.path().to_path_buf())?; + let event = AuditEvent::new(AuditEventType::CommandExecution); + + logger.log(&event)?; + + // File should not exist since logging is disabled + assert!(!tmp.path().join("audit.log").exists()); + Ok(()) + } +} diff --git a/src/security/bubblewrap.rs b/src/security/bubblewrap.rs new file mode 100644 index 0000000..fca76e6 --- /dev/null +++ b/src/security/bubblewrap.rs @@ -0,0 +1,97 @@ +//! Bubblewrap sandbox (user namespaces for Linux/macOS) + +use crate::security::traits::Sandbox; +use std::process::Command; + +/// Bubblewrap sandbox backend +#[derive(Debug, Clone, Default)] +pub struct BubblewrapSandbox; + +impl BubblewrapSandbox { + pub fn new() -> std::io::Result { + if Self::is_installed() { + Ok(Self) + } else { + Err(std::io::Error::new( + std::io::ErrorKind::NotFound, + "Bubblewrap not found", + )) + } + } + + pub fn probe() -> std::io::Result { + Self::new() + } + + fn is_installed() -> bool { + Command::new("bwrap") + .arg("--version") + .output() + .map(|o| o.status.success()) + .unwrap_or(false) + } +} + +impl Sandbox for BubblewrapSandbox { + fn wrap_command(&self, cmd: &mut Command) -> std::io::Result<()> { + let program = cmd.get_program().to_string_lossy().to_string(); + let args: Vec = cmd + .get_args() + .map(|s| s.to_string_lossy().to_string()) + .collect(); + + let mut bwrap_cmd = Command::new("bwrap"); + bwrap_cmd.args([ + "--ro-bind", + "/usr", + "/usr", + "--dev", + "/dev", + "--proc", + "/proc", + "--bind", + "/tmp", + "/tmp", + "--unshare-all", + "--die-with-parent", + ]); + bwrap_cmd.arg(&program); + bwrap_cmd.args(&args); + + *cmd = bwrap_cmd; + Ok(()) + } + + fn is_available(&self) -> bool { + Self::is_installed() + } + + fn name(&self) -> &str { + "bubblewrap" + } + + fn description(&self) -> &str { + "User namespace sandbox (requires bwrap)" + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn bubblewrap_sandbox_name() { + let sandbox = BubblewrapSandbox; + assert_eq!(sandbox.name(), "bubblewrap"); + } + + #[test] + fn bubblewrap_is_available_only_if_installed() { + // Result depends on whether bwrap is installed + let sandbox = BubblewrapSandbox; + let _available = sandbox.is_available(); + + // Either way, the name should still work + assert_eq!(sandbox.name(), "bubblewrap"); + } +} diff --git a/src/security/detect.rs b/src/security/detect.rs new file mode 100644 index 0000000..751d8d0 --- /dev/null +++ b/src/security/detect.rs @@ -0,0 +1,157 @@ +//! Auto-detection of available security features + +use crate::config::{SandboxBackend, SecurityConfig}; +use crate::security::traits::Sandbox; +use std::sync::Arc; + +/// Create a sandbox based on auto-detection or explicit config +pub fn create_sandbox(config: &SecurityConfig) -> Arc { + let backend = &config.sandbox.backend; + + // If explicitly disabled, return noop + if matches!(backend, SandboxBackend::None) || config.sandbox.enabled == Some(false) { + return Arc::new(super::traits::NoopSandbox); + } + + // If specific backend requested, try that + match backend { + SandboxBackend::Landlock => { + #[cfg(feature = "sandbox-landlock")] + { + #[cfg(target_os = "linux")] + { + if let Ok(sandbox) = super::landlock::LandlockSandbox::new() { + return Arc::new(sandbox); + } + } + } + tracing::warn!( + "Landlock requested but not available, falling back to application-layer" + ); + Arc::new(super::traits::NoopSandbox) + } + SandboxBackend::Firejail => { + #[cfg(target_os = "linux")] + { + if let Ok(sandbox) = super::firejail::FirejailSandbox::new() { + return Arc::new(sandbox); + } + } + tracing::warn!( + "Firejail requested but not available, falling back to application-layer" + ); + Arc::new(super::traits::NoopSandbox) + } + SandboxBackend::Bubblewrap => { + #[cfg(feature = "sandbox-bubblewrap")] + { + #[cfg(any(target_os = "linux", target_os = "macos"))] + { + if let Ok(sandbox) = super::bubblewrap::BubblewrapSandbox::new() { + return Arc::new(sandbox); + } + } + } + tracing::warn!( + "Bubblewrap requested but not available, falling back to application-layer" + ); + Arc::new(super::traits::NoopSandbox) + } + SandboxBackend::Docker => { + if let Ok(sandbox) = super::docker::DockerSandbox::new() { + return Arc::new(sandbox); + } + tracing::warn!("Docker requested but not available, falling back to application-layer"); + Arc::new(super::traits::NoopSandbox) + } + SandboxBackend::Auto | SandboxBackend::None => { + // Auto-detect best available + detect_best_sandbox() + } + } +} + +/// Auto-detect the best available sandbox +fn detect_best_sandbox() -> Arc { + #[cfg(target_os = "linux")] + { + // Try Landlock first (native, no dependencies) + #[cfg(feature = "sandbox-landlock")] + { + if let Ok(sandbox) = super::landlock::LandlockSandbox::probe() { + tracing::info!("Landlock sandbox enabled (Linux kernel 5.13+)"); + return Arc::new(sandbox); + } + } + + // Try Firejail second (user-space tool) + if let Ok(sandbox) = super::firejail::FirejailSandbox::probe() { + tracing::info!("Firejail sandbox enabled"); + return Arc::new(sandbox); + } + } + + #[cfg(target_os = "macos")] + { + // Try Bubblewrap on macOS + #[cfg(feature = "sandbox-bubblewrap")] + { + if let Ok(sandbox) = super::bubblewrap::BubblewrapSandbox::probe() { + tracing::info!("Bubblewrap sandbox enabled"); + return Arc::new(sandbox); + } + } + } + + // Docker is heavy but works everywhere if docker is installed + if let Ok(sandbox) = super::docker::DockerSandbox::probe() { + tracing::info!("Docker sandbox enabled"); + return Arc::new(sandbox); + } + + // Fallback: application-layer security only + tracing::info!("No sandbox backend available, using application-layer security"); + Arc::new(super::traits::NoopSandbox) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::{SandboxConfig, SecurityConfig}; + + #[test] + fn detect_best_sandbox_returns_something() { + let sandbox = detect_best_sandbox(); + // Should always return at least NoopSandbox + assert!(sandbox.is_available()); + } + + #[test] + fn explicit_none_returns_noop() { + let config = SecurityConfig { + sandbox: SandboxConfig { + enabled: Some(false), + backend: SandboxBackend::None, + firejail_args: Vec::new(), + }, + ..Default::default() + }; + let sandbox = create_sandbox(&config); + assert_eq!(sandbox.name(), "none"); + } + + #[test] + fn auto_mode_detects_something() { + let config = SecurityConfig { + sandbox: SandboxConfig { + enabled: None, // Auto-detect + backend: SandboxBackend::Auto, + firejail_args: Vec::new(), + }, + ..Default::default() + }; + let sandbox = create_sandbox(&config); + // Should return some sandbox (at least NoopSandbox) + assert!(sandbox.is_available()); + } +} diff --git a/src/security/docker.rs b/src/security/docker.rs new file mode 100644 index 0000000..2c32e20 --- /dev/null +++ b/src/security/docker.rs @@ -0,0 +1,120 @@ +//! Docker sandbox (container isolation) + +use crate::security::traits::Sandbox; +use std::process::Command; + +/// Docker sandbox backend +#[derive(Debug, Clone)] +pub struct DockerSandbox { + image: String, +} + +impl Default for DockerSandbox { + fn default() -> Self { + Self { + image: "alpine:latest".to_string(), + } + } +} + +impl DockerSandbox { + pub fn new() -> std::io::Result { + if Self::is_installed() { + Ok(Self::default()) + } else { + Err(std::io::Error::new( + std::io::ErrorKind::NotFound, + "Docker not found", + )) + } + } + + pub fn with_image(image: String) -> std::io::Result { + if Self::is_installed() { + Ok(Self { image }) + } else { + Err(std::io::Error::new( + std::io::ErrorKind::NotFound, + "Docker not found", + )) + } + } + + pub fn probe() -> std::io::Result { + Self::new() + } + + fn is_installed() -> bool { + Command::new("docker") + .arg("--version") + .output() + .map(|o| o.status.success()) + .unwrap_or(false) + } +} + +impl Sandbox for DockerSandbox { + fn wrap_command(&self, cmd: &mut Command) -> std::io::Result<()> { + let program = cmd.get_program().to_string_lossy().to_string(); + let args: Vec = cmd + .get_args() + .map(|s| s.to_string_lossy().to_string()) + .collect(); + + let mut docker_cmd = Command::new("docker"); + docker_cmd.args([ + "run", + "--rm", + "--memory", + "512m", + "--cpus", + "1.0", + "--network", + "none", + ]); + docker_cmd.arg(&self.image); + docker_cmd.arg(&program); + docker_cmd.args(&args); + + *cmd = docker_cmd; + Ok(()) + } + + fn is_available(&self) -> bool { + Self::is_installed() + } + + fn name(&self) -> &str { + "docker" + } + + fn description(&self) -> &str { + "Docker container isolation (requires docker)" + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn docker_sandbox_name() { + let sandbox = DockerSandbox::default(); + assert_eq!(sandbox.name(), "docker"); + } + + #[test] + fn docker_sandbox_default_image() { + let sandbox = DockerSandbox::default(); + assert_eq!(sandbox.image, "alpine:latest"); + } + + #[test] + fn docker_with_custom_image() { + let result = DockerSandbox::with_image("ubuntu:latest".to_string()); + match result { + Ok(sandbox) => assert_eq!(sandbox.image, "ubuntu:latest"), + Err(_) => assert!(!DockerSandbox::is_installed()), + } + } +} diff --git a/src/security/firejail.rs b/src/security/firejail.rs new file mode 100644 index 0000000..9eeb6c7 --- /dev/null +++ b/src/security/firejail.rs @@ -0,0 +1,128 @@ +//! Firejail sandbox (Linux user-space sandboxing) +//! +//! Firejail is a SUID sandbox program that Linux applications use to sandbox themselves. + +use crate::security::traits::Sandbox; +use std::process::Command; + +/// Firejail sandbox backend for Linux +#[derive(Debug, Clone, Default)] +pub struct FirejailSandbox; + +impl FirejailSandbox { + /// Create a new Firejail sandbox + pub fn new() -> std::io::Result { + if Self::is_installed() { + Ok(Self) + } else { + Err(std::io::Error::new( + std::io::ErrorKind::NotFound, + "Firejail not found. Install with: sudo apt install firejail", + )) + } + } + + /// Probe if Firejail is available (for auto-detection) + pub fn probe() -> std::io::Result { + Self::new() + } + + /// Check if firejail is installed + fn is_installed() -> bool { + Command::new("firejail") + .arg("--version") + .output() + .map(|o| o.status.success()) + .unwrap_or(false) + } +} + +impl Sandbox for FirejailSandbox { + fn wrap_command(&self, cmd: &mut Command) -> std::io::Result<()> { + // Prepend firejail to the command + let program = cmd.get_program().to_string_lossy().to_string(); + let args: Vec = cmd + .get_args() + .map(|s| s.to_string_lossy().to_string()) + .collect(); + + // Build firejail wrapper with security flags + let mut firejail_cmd = Command::new("firejail"); + firejail_cmd.args([ + "--private=home", // New home directory + "--private-dev", // Minimal /dev + "--nosound", // No audio + "--no3d", // No 3D acceleration + "--novideo", // No video devices + "--nowheel", // No input devices + "--notv", // No TV devices + "--noprofile", // Skip profile loading + "--quiet", // Suppress warnings + ]); + + // Add the original command + firejail_cmd.arg(&program); + firejail_cmd.args(&args); + + // Replace the command + *cmd = firejail_cmd; + Ok(()) + } + + fn is_available(&self) -> bool { + Self::is_installed() + } + + fn name(&self) -> &str { + "firejail" + } + + fn description(&self) -> &str { + "Linux user-space sandbox (requires firejail to be installed)" + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn firejail_sandbox_name() { + assert_eq!(FirejailSandbox.name(), "firejail"); + } + + #[test] + fn firejail_description_mentions_dependency() { + let desc = FirejailSandbox.description(); + assert!(desc.contains("firejail")); + } + + #[test] + fn firejail_new_fails_if_not_installed() { + // This will fail unless firejail is actually installed + let result = FirejailSandbox::new(); + match result { + Ok(_) => println!("Firejail is installed"), + Err(e) => assert!( + e.kind() == std::io::ErrorKind::NotFound + || e.kind() == std::io::ErrorKind::Unsupported + ), + } + } + + #[test] + fn firejail_wrap_command_prepends_firejail() { + let sandbox = FirejailSandbox; + let mut cmd = Command::new("echo"); + cmd.arg("test"); + + // Note: wrap_command will fail if firejail isn't installed, + // but we can still test the logic structure + let _ = sandbox.wrap_command(&mut cmd); + + // After wrapping, the program should be firejail + if sandbox.is_available() { + assert_eq!(cmd.get_program().to_string_lossy(), "firejail"); + } + } +} diff --git a/src/security/landlock.rs b/src/security/landlock.rs new file mode 100644 index 0000000..afb990f --- /dev/null +++ b/src/security/landlock.rs @@ -0,0 +1,206 @@ +//! Landlock sandbox (Linux kernel 5.13+ LSM) +//! +//! Landlock provides unprivileged sandboxing through the Linux kernel. +//! This module uses the pure-Rust `landlock` crate for filesystem access control. + +#[cfg(all(feature = "sandbox-landlock", target_os = "linux"))] +use landlock::{AccessFS, Ruleset, RulesetCreated}; + +use crate::security::traits::Sandbox; +use std::path::Path; + +/// Landlock sandbox backend for Linux +#[cfg(all(feature = "sandbox-landlock", target_os = "linux"))] +#[derive(Debug)] +pub struct LandlockSandbox { + workspace_dir: Option, +} + +#[cfg(all(feature = "sandbox-landlock", target_os = "linux"))] +impl LandlockSandbox { + /// Create a new Landlock sandbox with the given workspace directory + pub fn new() -> std::io::Result { + Self::with_workspace(None) + } + + /// Create a Landlock sandbox with a specific workspace directory + pub fn with_workspace(workspace_dir: Option) -> std::io::Result { + // Test if Landlock is available by trying to create a minimal ruleset + let test_ruleset = Ruleset::new().set_access_fs(AccessFS::read_file | AccessFS::write_file); + + match test_ruleset.create() { + Ok(_) => Ok(Self { workspace_dir }), + Err(e) => { + tracing::debug!("Landlock not available: {}", e); + Err(std::io::Error::new( + std::io::ErrorKind::Unsupported, + "Landlock not available", + )) + } + } + } + + /// Probe if Landlock is available (for auto-detection) + pub fn probe() -> std::io::Result { + Self::new() + } + + /// Apply Landlock restrictions to the current process + fn apply_restrictions(&self) -> std::io::Result<()> { + let mut ruleset = Ruleset::new().set_access_fs( + AccessFS::read_file + | AccessFS::write_file + | AccessFS::read_dir + | AccessFS::remove_dir + | AccessFS::remove_file + | AccessFS::make_char + | AccessFS::make_sock + | AccessFS::make_fifo + | AccessFS::make_block + | AccessFS::make_reg + | AccessFS::make_sym, + ); + + // Allow workspace directory (read/write) + if let Some(ref workspace) = self.workspace_dir { + if workspace.exists() { + ruleset = ruleset.add_path( + workspace, + AccessFS::read_file | AccessFS::write_file | AccessFS::read_dir, + )?; + } + } + + // Allow /tmp for general operations + ruleset = ruleset.add_path( + Path::new("/tmp"), + AccessFS::read_file | AccessFS::write_file, + )?; + + // Allow /usr and /bin for executing commands + ruleset = ruleset.add_path(Path::new("/usr"), AccessFS::read_file | AccessFS::read_dir)?; + ruleset = ruleset.add_path(Path::new("/bin"), AccessFS::read_file | AccessFS::read_dir)?; + + // Apply the ruleset + match ruleset.create() { + Ok(_) => { + tracing::debug!("Landlock restrictions applied successfully"); + Ok(()) + } + Err(e) => { + tracing::warn!("Failed to apply Landlock restrictions: {}", e); + Err(std::io::Error::new(std::io::ErrorKind::Other, e)) + } + } + } +} + +#[cfg(all(feature = "sandbox-landlock", target_os = "linux"))] +impl Sandbox for LandlockSandbox { + fn wrap_command(&self, cmd: &mut std::process::Command) -> std::io::Result<()> { + // Apply Landlock restrictions before executing the command + // Note: This affects the current process, not the child process + // Child processes inherit the Landlock restrictions + self.apply_restrictions() + } + + fn is_available(&self) -> bool { + // Try to create a minimal ruleset to verify availability + Ruleset::new() + .set_access_fs(AccessFS::read_file) + .create() + .is_ok() + } + + fn name(&self) -> &str { + "landlock" + } + + fn description(&self) -> &str { + "Linux kernel LSM sandboxing (filesystem access control)" + } +} + +// Stub implementations for non-Linux or when feature is disabled +#[cfg(not(all(feature = "sandbox-landlock", target_os = "linux")))] +pub struct LandlockSandbox; + +#[cfg(not(all(feature = "sandbox-landlock", target_os = "linux")))] +impl LandlockSandbox { + pub fn new() -> std::io::Result { + Err(std::io::Error::new( + std::io::ErrorKind::Unsupported, + "Landlock is only supported on Linux with the sandbox-landlock feature", + )) + } + + pub fn with_workspace(_workspace_dir: Option) -> std::io::Result { + Err(std::io::Error::new( + std::io::ErrorKind::Unsupported, + "Landlock is only supported on Linux", + )) + } + + pub fn probe() -> std::io::Result { + Err(std::io::Error::new( + std::io::ErrorKind::Unsupported, + "Landlock is only supported on Linux", + )) + } +} + +#[cfg(not(all(feature = "sandbox-landlock", target_os = "linux")))] +impl Sandbox for LandlockSandbox { + fn wrap_command(&self, _cmd: &mut std::process::Command) -> std::io::Result<()> { + Err(std::io::Error::new( + std::io::ErrorKind::Unsupported, + "Landlock is only supported on Linux", + )) + } + + fn is_available(&self) -> bool { + false + } + + fn name(&self) -> &str { + "landlock" + } + + fn description(&self) -> &str { + "Linux kernel LSM sandboxing (not available on this platform)" + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg(all(feature = "sandbox-landlock", target_os = "linux"))] + #[test] + fn landlock_sandbox_name() { + if let Ok(sandbox) = LandlockSandbox::new() { + assert_eq!(sandbox.name(), "landlock"); + } + } + + #[cfg(not(all(feature = "sandbox-landlock", target_os = "linux")))] + #[test] + fn landlock_not_available_on_non_linux() { + assert!(!LandlockSandbox.is_available()); + assert_eq!(LandlockSandbox.name(), "landlock"); + } + + #[test] + fn landlock_with_none_workspace() { + // Should work even without a workspace directory + let result = LandlockSandbox::with_workspace(None); + // Result depends on platform and feature flag + match result { + Ok(sandbox) => assert!(sandbox.is_available()), + Err(_) => assert!(!cfg!(all( + feature = "sandbox-landlock", + target_os = "linux" + ))), + } + } +} diff --git a/src/security/mod.rs b/src/security/mod.rs index 5a85deb..4009b6f 100644 --- a/src/security/mod.rs +++ b/src/security/mod.rs @@ -1,9 +1,50 @@ +pub mod audit; +#[cfg(feature = "sandbox-bubblewrap")] +pub mod bubblewrap; +pub mod detect; +pub mod docker; +#[cfg(target_os = "linux")] +pub mod firejail; +#[cfg(feature = "sandbox-landlock")] +pub mod landlock; pub mod pairing; pub mod policy; pub mod secrets; +pub mod traits; +#[allow(unused_imports)] +pub use audit::{AuditEvent, AuditEventType, AuditLogger}; +#[allow(unused_imports)] +pub use detect::create_sandbox; #[allow(unused_imports)] pub use pairing::PairingGuard; pub use policy::{AutonomyLevel, SecurityPolicy}; #[allow(unused_imports)] pub use secrets::SecretStore; +#[allow(unused_imports)] +pub use traits::{NoopSandbox, Sandbox}; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn reexported_policy_and_pairing_types_are_usable() { + let policy = SecurityPolicy::default(); + assert_eq!(policy.autonomy, AutonomyLevel::Supervised); + + let guard = PairingGuard::new(false, &[]); + assert!(!guard.require_pairing()); + } + + #[test] + fn reexported_secret_store_encrypt_decrypt_roundtrip() { + let temp = tempfile::tempdir().unwrap(); + let store = SecretStore::new(temp.path(), false); + + let encrypted = store.encrypt("top-secret").unwrap(); + let decrypted = store.decrypt(&encrypted).unwrap(); + + assert_eq!(decrypted, "top-secret"); + } +} diff --git a/src/security/pairing.rs b/src/security/pairing.rs index e176d38..2a828e1 100644 --- a/src/security/pairing.rs +++ b/src/security/pairing.rs @@ -8,8 +8,9 @@ // Already-paired tokens are persisted in config so restarts don't require // re-pairing. +use parking_lot::Mutex; +use sha2::{Digest, Sha256}; use std::collections::HashSet; -use std::sync::Mutex; use std::time::Instant; /// Maximum failed pairing attempts before lockout. @@ -18,13 +19,17 @@ const MAX_PAIR_ATTEMPTS: u32 = 5; const PAIR_LOCKOUT_SECS: u64 = 300; // 5 minutes /// Manages pairing state for the gateway. +/// +/// Bearer tokens are stored as SHA-256 hashes to prevent plaintext exposure +/// in config files. When a new token is generated, the plaintext is returned +/// to the client once, and only the hash is retained. #[derive(Debug)] pub struct PairingGuard { /// Whether pairing is required at all. require_pairing: bool, /// One-time pairing code (generated on startup, consumed on first pair). - pairing_code: Option, - /// Set of valid bearer tokens (persisted across restarts). + pairing_code: Mutex>, + /// Set of SHA-256 hashed bearer tokens (persisted across restarts). paired_tokens: Mutex>, /// Brute-force protection: failed attempt counter + lockout time. failed_attempts: Mutex<(u32, Option)>, @@ -35,8 +40,21 @@ impl PairingGuard { /// /// If `require_pairing` is true and no tokens exist yet, a fresh /// pairing code is generated and returned via `pairing_code()`. + /// + /// Existing tokens are accepted in both forms: + /// - Plaintext (`zc_...`): hashed on load for backward compatibility + /// - Already hashed (64-char hex): stored as-is pub fn new(require_pairing: bool, existing_tokens: &[String]) -> Self { - let tokens: HashSet = existing_tokens.iter().cloned().collect(); + let tokens: HashSet = existing_tokens + .iter() + .map(|t| { + if is_token_hash(t) { + t.clone() + } else { + hash_token(t) + } + }) + .collect(); let code = if require_pairing && tokens.is_empty() { Some(generate_code()) } else { @@ -44,15 +62,15 @@ impl PairingGuard { }; Self { require_pairing, - pairing_code: code, + pairing_code: Mutex::new(code), paired_tokens: Mutex::new(tokens), failed_attempts: Mutex::new((0, None)), } } /// The one-time pairing code (only set when no tokens exist yet). - pub fn pairing_code(&self) -> Option<&str> { - self.pairing_code.as_deref() + pub fn pairing_code(&self) -> Option { + self.pairing_code.lock().clone() } /// Whether pairing is required at all. @@ -65,10 +83,7 @@ impl PairingGuard { pub fn try_pair(&self, code: &str) -> Result, u64> { // Check brute force lockout { - let attempts = self - .failed_attempts - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); + let attempts = self.failed_attempts.lock(); if let (count, Some(locked_at)) = &*attempts { if *count >= MAX_PAIR_ATTEMPTS { let elapsed = locked_at.elapsed().as_secs(); @@ -79,32 +94,30 @@ impl PairingGuard { } } - if let Some(ref expected) = self.pairing_code { - if constant_time_eq(code.trim(), expected.trim()) { - // Reset failed attempts on success - { - let mut attempts = self - .failed_attempts - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); - *attempts = (0, None); + { + let mut pairing_code = self.pairing_code.lock(); + if let Some(ref expected) = *pairing_code { + if constant_time_eq(code.trim(), expected.trim()) { + // Reset failed attempts on success + { + let mut attempts = self.failed_attempts.lock(); + *attempts = (0, None); + } + let token = generate_token(); + let mut tokens = self.paired_tokens.lock(); + tokens.insert(hash_token(&token)); + + // Consume the pairing code so it cannot be reused + *pairing_code = None; + + return Ok(Some(token)); } - let token = generate_token(); - let mut tokens = self - .paired_tokens - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); - tokens.insert(token.clone()); - return Ok(Some(token)); } } // Increment failed attempts { - let mut attempts = self - .failed_attempts - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); + let mut attempts = self.failed_attempts.lock(); attempts.0 += 1; if attempts.0 >= MAX_PAIR_ATTEMPTS { attempts.1 = Some(Instant::now()); @@ -114,33 +127,25 @@ impl PairingGuard { Ok(None) } - /// Check if a bearer token is valid. + /// Check if a bearer token is valid (compares against stored hashes). pub fn is_authenticated(&self, token: &str) -> bool { if !self.require_pairing { return true; } - let tokens = self - .paired_tokens - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); - tokens.contains(token) + let hashed = hash_token(token); + let tokens = self.paired_tokens.lock(); + tokens.contains(&hashed) } /// Returns true if the gateway is already paired (has at least one token). pub fn is_paired(&self) -> bool { - let tokens = self - .paired_tokens - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); + let tokens = self.paired_tokens.lock(); !tokens.is_empty() } - /// Get all paired tokens (for persisting to config). + /// Get all paired token hashes (for persisting to config). pub fn tokens(&self) -> Vec { - let tokens = self - .paired_tokens - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); + let tokens = self.paired_tokens.lock(); tokens.iter().cloned().collect() } } @@ -169,9 +174,28 @@ fn generate_code() -> String { } } -/// Generate a cryptographically-adequate bearer token (hex-encoded). +/// Generate a cryptographically-adequate bearer token with 256-bit entropy. +/// +/// Uses `rand::thread_rng()` which is backed by the OS CSPRNG +/// (/dev/urandom on Linux, BCryptGenRandom on Windows, SecRandomCopyBytes +/// on macOS). The 32 random bytes (256 bits) are hex-encoded for a +/// 64-character token, providing 256 bits of entropy. fn generate_token() -> String { - format!("zc_{}", uuid::Uuid::new_v4().as_simple()) + use rand::RngCore; + let mut bytes = [0u8; 32]; + rand::thread_rng().fill_bytes(&mut bytes); + format!("zc_{}", hex::encode(bytes)) +} + +/// SHA-256 hash a bearer token for storage. Returns lowercase hex. +fn hash_token(token: &str) -> String { + format!("{:x}", Sha256::digest(token.as_bytes())) +} + +/// Check if a stored value looks like a SHA-256 hash (64 hex chars) +/// rather than a plaintext token. +fn is_token_hash(value: &str) -> bool { + value.len() == 64 && value.chars().all(|c| c.is_ascii_hexdigit()) } /// Constant-time string comparison to prevent timing attacks. @@ -258,10 +282,19 @@ mod tests { #[test] fn is_authenticated_with_valid_token() { + // Pass plaintext token — PairingGuard hashes it on load let guard = PairingGuard::new(true, &["zc_valid".into()]); assert!(guard.is_authenticated("zc_valid")); } + #[test] + fn is_authenticated_with_prehashed_token() { + // Pass an already-hashed token (64 hex chars) + let hashed = hash_token("zc_valid"); + let guard = PairingGuard::new(true, &[hashed]); + assert!(guard.is_authenticated("zc_valid")); + } + #[test] fn is_authenticated_with_invalid_token() { let guard = PairingGuard::new(true, &["zc_valid".into()]); @@ -276,11 +309,16 @@ mod tests { } #[test] - fn tokens_returns_all_paired() { - let guard = PairingGuard::new(true, &["a".into(), "b".into()]); - let mut tokens = guard.tokens(); - tokens.sort(); - assert_eq!(tokens, vec!["a", "b"]); + fn tokens_returns_hashes() { + let guard = PairingGuard::new(true, &["zc_a".into(), "zc_b".into()]); + let tokens = guard.tokens(); + assert_eq!(tokens.len(), 2); + // Tokens should be stored as 64-char hex hashes, not plaintext + for t in &tokens { + assert_eq!(t.len(), 64, "Token should be a SHA-256 hash"); + assert!(t.chars().all(|c| c.is_ascii_hexdigit())); + assert!(!t.starts_with("zc_"), "Token should not be plaintext"); + } } #[test] @@ -292,6 +330,33 @@ mod tests { assert!(!guard.is_authenticated("wrong")); } + // ── Token hashing ──────────────────────────────────────── + + #[test] + fn hash_token_produces_64_hex_chars() { + let hash = hash_token("zc_test_token"); + assert_eq!(hash.len(), 64); + assert!(hash.chars().all(|c| c.is_ascii_hexdigit())); + } + + #[test] + fn hash_token_is_deterministic() { + assert_eq!(hash_token("zc_abc"), hash_token("zc_abc")); + } + + #[test] + fn hash_token_differs_for_different_inputs() { + assert_ne!(hash_token("zc_a"), hash_token("zc_b")); + } + + #[test] + fn is_token_hash_detects_hash_vs_plaintext() { + assert!(is_token_hash(&hash_token("zc_test"))); + assert!(!is_token_hash("zc_test_token")); + assert!(!is_token_hash("too_short")); + assert!(!is_token_hash("")); + } + // ── is_public_bind ─────────────────────────────────────── #[test] diff --git a/src/security/policy.rs b/src/security/policy.rs index 49d58df..7db3ef8 100644 --- a/src/security/policy.rs +++ b/src/security/policy.rs @@ -1,6 +1,6 @@ +use parking_lot::Mutex; use serde::{Deserialize, Serialize}; use std::path::{Path, PathBuf}; -use std::sync::Mutex; use std::time::Instant; /// How much autonomy the agent has @@ -16,6 +16,14 @@ pub enum AutonomyLevel { Full, } +/// Risk score for shell command execution. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CommandRiskLevel { + Low, + Medium, + High, +} + /// Sliding-window action tracker for rate limiting. #[derive(Debug)] pub struct ActionTracker { @@ -32,10 +40,7 @@ impl ActionTracker { /// Record an action and return the current count within the window. pub fn record(&self) -> usize { - let mut actions = self - .actions - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); + let mut actions = self.actions.lock(); let cutoff = Instant::now() .checked_sub(std::time::Duration::from_secs(3600)) .unwrap_or_else(Instant::now); @@ -46,10 +51,7 @@ impl ActionTracker { /// Count of actions in the current window without recording. pub fn count(&self) -> usize { - let mut actions = self - .actions - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); + let mut actions = self.actions.lock(); let cutoff = Instant::now() .checked_sub(std::time::Duration::from_secs(3600)) .unwrap_or_else(Instant::now); @@ -60,10 +62,7 @@ impl ActionTracker { impl Clone for ActionTracker { fn clone(&self) -> Self { - let actions = self - .actions - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); + let actions = self.actions.lock(); Self { actions: Mutex::new(actions.clone()), } @@ -80,6 +79,8 @@ pub struct SecurityPolicy { pub forbidden_paths: Vec, pub max_actions_per_hour: u32, pub max_cost_per_day_cents: u32, + pub require_approval_for_medium_risk: bool, + pub block_high_risk_commands: bool, pub tracker: ActionTracker, } @@ -127,6 +128,8 @@ impl Default for SecurityPolicy { ], max_actions_per_hour: 20, max_cost_per_day_cents: 500, + require_approval_for_medium_risk: true, + block_high_risk_commands: true, tracker: ActionTracker::new(), } } @@ -155,14 +158,192 @@ fn skip_env_assignments(s: &str) -> &str { } } +/// Detect a single `&` operator (background/chain). `&&` is allowed. +/// +/// We treat any standalone `&` as unsafe in policy validation because it can +/// chain hidden sub-commands and escape foreground timeout expectations. +fn contains_single_ampersand(s: &str) -> bool { + let bytes = s.as_bytes(); + for (i, b) in bytes.iter().enumerate() { + if *b != b'&' { + continue; + } + let prev_is_amp = i > 0 && bytes[i - 1] == b'&'; + let next_is_amp = i + 1 < bytes.len() && bytes[i + 1] == b'&'; + if !prev_is_amp && !next_is_amp { + return true; + } + } + false +} + impl SecurityPolicy { + /// Classify command risk. Any high-risk segment marks the whole command high. + pub fn command_risk_level(&self, command: &str) -> CommandRiskLevel { + let mut normalized = command.to_string(); + for sep in ["&&", "||"] { + normalized = normalized.replace(sep, "\x00"); + } + for sep in ['\n', ';', '|', '&'] { + normalized = normalized.replace(sep, "\x00"); + } + + let mut saw_medium = false; + + for segment in normalized.split('\x00') { + let segment = segment.trim(); + if segment.is_empty() { + continue; + } + + let cmd_part = skip_env_assignments(segment); + let mut words = cmd_part.split_whitespace(); + let Some(base_raw) = words.next() else { + continue; + }; + + let base = base_raw + .rsplit('/') + .next() + .unwrap_or("") + .to_ascii_lowercase(); + + let args: Vec = words.map(|w| w.to_ascii_lowercase()).collect(); + let joined_segment = cmd_part.to_ascii_lowercase(); + + // High-risk commands + if matches!( + base.as_str(), + "rm" | "mkfs" + | "dd" + | "shutdown" + | "reboot" + | "halt" + | "poweroff" + | "sudo" + | "su" + | "chown" + | "chmod" + | "useradd" + | "userdel" + | "usermod" + | "passwd" + | "mount" + | "umount" + | "iptables" + | "ufw" + | "firewall-cmd" + | "curl" + | "wget" + | "nc" + | "ncat" + | "netcat" + | "scp" + | "ssh" + | "ftp" + | "telnet" + ) { + return CommandRiskLevel::High; + } + + if joined_segment.contains("rm -rf /") + || joined_segment.contains("rm -fr /") + || joined_segment.contains(":(){:|:&};:") + { + return CommandRiskLevel::High; + } + + // Medium-risk commands (state-changing, but not inherently destructive) + let medium = match base.as_str() { + "git" => args.first().is_some_and(|verb| { + matches!( + verb.as_str(), + "commit" + | "push" + | "reset" + | "clean" + | "rebase" + | "merge" + | "cherry-pick" + | "revert" + | "branch" + | "checkout" + | "switch" + | "tag" + ) + }), + "npm" | "pnpm" | "yarn" => args.first().is_some_and(|verb| { + matches!( + verb.as_str(), + "install" | "add" | "remove" | "uninstall" | "update" | "publish" + ) + }), + "cargo" => args.first().is_some_and(|verb| { + matches!( + verb.as_str(), + "add" | "remove" | "install" | "clean" | "publish" + ) + }), + "touch" | "mkdir" | "mv" | "cp" | "ln" => true, + _ => false, + }; + + saw_medium |= medium; + } + + if saw_medium { + CommandRiskLevel::Medium + } else { + CommandRiskLevel::Low + } + } + + /// Validate full command execution policy (allowlist + risk gate). + pub fn validate_command_execution( + &self, + command: &str, + approved: bool, + ) -> Result { + if !self.is_command_allowed(command) { + return Err(format!("Command not allowed by security policy: {command}")); + } + + let risk = self.command_risk_level(command); + + if risk == CommandRiskLevel::High { + if self.block_high_risk_commands { + return Err("Command blocked: high-risk command is disallowed by policy".into()); + } + if self.autonomy == AutonomyLevel::Supervised && !approved { + return Err( + "Command requires explicit approval (approved=true): high-risk operation" + .into(), + ); + } + } + + if risk == CommandRiskLevel::Medium + && self.autonomy == AutonomyLevel::Supervised + && self.require_approval_for_medium_risk + && !approved + { + return Err( + "Command requires explicit approval (approved=true): medium-risk operation".into(), + ); + } + + Ok(risk) + } + /// Check if a shell command is allowed. /// /// Validates the **entire** command string, not just the first word: /// - Blocks subshell operators (`` ` ``, `$(`) that hide arbitrary execution /// - Splits on command separators (`|`, `&&`, `||`, `;`, newlines) and /// validates each sub-command against the allowlist + /// - Blocks single `&` background chaining (`&&` remains supported) /// - Blocks output redirections (`>`, `>>`) that could write outside workspace + /// - Blocks dangerous arguments (e.g. `find -exec`, `git config`) pub fn is_command_allowed(&self, command: &str) -> bool { if self.autonomy == AutonomyLevel::ReadOnly { return false; @@ -170,7 +351,12 @@ impl SecurityPolicy { // Block subshell/expansion operators — these allow hiding arbitrary // commands inside an allowed command (e.g. `echo $(rm -rf /)`) - if command.contains('`') || command.contains("$(") || command.contains("${") { + if command.contains('`') + || command.contains("$(") + || command.contains("${") + || command.contains("<(") + || command.contains(">(") + { return false; } @@ -179,6 +365,21 @@ impl SecurityPolicy { return false; } + // Block `tee` — it can write to arbitrary files, bypassing the + // redirect check above (e.g. `echo secret | tee /etc/crontab`) + if command + .split_whitespace() + .any(|w| w == "tee" || w.ends_with("/tee")) + { + return false; + } + + // Block background command chaining (`&`), which can hide extra + // sub-commands and outlive timeout expectations. Keep `&&` allowed. + if contains_single_ampersand(command) { + return false; + } + // Split on command separators and validate each sub-command. // We collect segments by scanning for separator characters. let mut normalized = command.to_string(); @@ -198,13 +399,9 @@ impl SecurityPolicy { // Strip leading env var assignments (e.g. FOO=bar cmd) let cmd_part = skip_env_assignments(segment); - let base_cmd = cmd_part - .split_whitespace() - .next() - .unwrap_or("") - .rsplit('/') - .next() - .unwrap_or(""); + let mut words = cmd_part.split_whitespace(); + let base_raw = words.next().unwrap_or(""); + let base_cmd = base_raw.rsplit('/').next().unwrap_or(""); if base_cmd.is_empty() { continue; @@ -217,6 +414,12 @@ impl SecurityPolicy { { return false; } + + // Validate arguments for the command + let args: Vec = words.map(|w| w.to_ascii_lowercase()).collect(); + if !self.is_args_safe(base_cmd, &args) { + return false; + } } // At least one command must be present @@ -228,6 +431,29 @@ impl SecurityPolicy { has_cmd } + /// Check for dangerous arguments that allow sub-command execution. + fn is_args_safe(&self, base: &str, args: &[String]) -> bool { + let base = base.to_ascii_lowercase(); + match base.as_str() { + "find" => { + // find -exec and find -ok allow arbitrary command execution + !args.iter().any(|arg| arg == "-exec" || arg == "-ok") + } + "git" => { + // git config, alias, and -c can be used to set dangerous options + // (e.g. git config core.editor "rm -rf /") + !args.iter().any(|arg| { + arg == "config" + || arg.starts_with("config.") + || arg == "alias" + || arg.starts_with("alias.") + || arg == "-c" + }) + } + _ => true, + } + } + /// Check if a file path is allowed (no path traversal, within workspace) pub fn is_path_allowed(&self, path: &str) -> bool { // Block null bytes (can truncate paths in C-backed syscalls) @@ -235,19 +461,50 @@ impl SecurityPolicy { return false; } - // Block obvious traversal attempts - if path.contains("..") { + // Block path traversal: check for ".." as a path component + if Path::new(path) + .components() + .any(|c| matches!(c, std::path::Component::ParentDir)) + { return false; } + // Block URL-encoded traversal attempts (e.g. ..%2f) + let lower = path.to_lowercase(); + if lower.contains("..%2f") || lower.contains("%2f..") { + return false; + } + + // Expand tilde for comparison + let expanded = if let Some(stripped) = path.strip_prefix("~/") { + if let Some(home) = std::env::var("HOME").ok().map(PathBuf::from) { + home.join(stripped).to_string_lossy().to_string() + } else { + path.to_string() + } + } else { + path.to_string() + }; + // Block absolute paths when workspace_only is set - if self.workspace_only && Path::new(path).is_absolute() { + if self.workspace_only && Path::new(&expanded).is_absolute() { return false; } - // Block forbidden paths + // Block forbidden paths using path-component-aware matching + let expanded_path = Path::new(&expanded); for forbidden in &self.forbidden_paths { - if path.starts_with(forbidden.as_str()) { + let forbidden_expanded = if let Some(stripped) = forbidden.strip_prefix("~/") { + if let Some(home) = std::env::var("HOME").ok().map(PathBuf::from) { + home.join(stripped).to_string_lossy().to_string() + } else { + forbidden.clone() + } + } else { + forbidden.clone() + }; + let forbidden_path = Path::new(&forbidden_expanded); + if expanded_path.starts_with(forbidden_path) { return false; } } @@ -298,6 +555,8 @@ impl SecurityPolicy { forbidden_paths: autonomy_config.forbidden_paths.clone(), max_actions_per_hour: autonomy_config.max_actions_per_hour, max_cost_per_day_cents: autonomy_config.max_cost_per_day_cents, + require_approval_for_medium_risk: autonomy_config.require_approval_for_medium_risk, + block_high_risk_commands: autonomy_config.block_high_risk_commands, tracker: ActionTracker::new(), } } @@ -442,6 +701,79 @@ mod tests { assert!(!p.is_command_allowed("echo hello")); } + #[test] + fn command_risk_low_for_read_commands() { + let p = default_policy(); + assert_eq!(p.command_risk_level("git status"), CommandRiskLevel::Low); + assert_eq!(p.command_risk_level("ls -la"), CommandRiskLevel::Low); + } + + #[test] + fn command_risk_medium_for_mutating_commands() { + let p = SecurityPolicy { + allowed_commands: vec!["git".into(), "touch".into()], + ..SecurityPolicy::default() + }; + assert_eq!( + p.command_risk_level("git reset --hard HEAD~1"), + CommandRiskLevel::Medium + ); + assert_eq!( + p.command_risk_level("touch file.txt"), + CommandRiskLevel::Medium + ); + } + + #[test] + fn command_risk_high_for_dangerous_commands() { + let p = SecurityPolicy { + allowed_commands: vec!["rm".into()], + ..SecurityPolicy::default() + }; + assert_eq!( + p.command_risk_level("rm -rf /tmp/test"), + CommandRiskLevel::High + ); + } + + #[test] + fn validate_command_requires_approval_for_medium_risk() { + let p = SecurityPolicy { + autonomy: AutonomyLevel::Supervised, + require_approval_for_medium_risk: true, + allowed_commands: vec!["touch".into()], + ..SecurityPolicy::default() + }; + + let denied = p.validate_command_execution("touch test.txt", false); + assert!(denied.is_err()); + assert!(denied.unwrap_err().contains("requires explicit approval"),); + + let allowed = p.validate_command_execution("touch test.txt", true); + assert_eq!(allowed.unwrap(), CommandRiskLevel::Medium); + } + + #[test] + fn validate_command_blocks_high_risk_by_default() { + let p = SecurityPolicy { + autonomy: AutonomyLevel::Supervised, + allowed_commands: vec!["rm".into()], + ..SecurityPolicy::default() + }; + + let result = p.validate_command_execution("rm -rf /tmp/test", true); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("high-risk")); + } + + #[test] + fn validate_command_rejects_background_chain_bypass() { + let p = default_policy(); + let result = p.validate_command_execution("ls & python3 -c 'print(1)'", false); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("not allowed")); + } + // ── is_path_allowed ───────────────────────────────────── #[test] @@ -515,6 +847,9 @@ mod tests { forbidden_paths: vec!["/secret".into()], max_actions_per_hour: 100, max_cost_per_day_cents: 1000, + require_approval_for_medium_risk: false, + block_high_risk_commands: false, + ..crate::config::AutonomyConfig::default() }; let workspace = PathBuf::from("/tmp/test-workspace"); let policy = SecurityPolicy::from_config(&autonomy_config, &workspace); @@ -525,6 +860,8 @@ mod tests { assert_eq!(policy.forbidden_paths, vec!["/secret"]); assert_eq!(policy.max_actions_per_hour, 100); assert_eq!(policy.max_cost_per_day_cents, 1000); + assert!(!policy.require_approval_for_medium_risk); + assert!(!policy.block_high_risk_commands); assert_eq!(policy.workspace_dir, PathBuf::from("/tmp/test-workspace")); } @@ -539,6 +876,8 @@ mod tests { assert!(!p.forbidden_paths.is_empty()); assert!(p.max_actions_per_hour > 0); assert!(p.max_cost_per_day_cents > 0); + assert!(p.require_approval_for_medium_risk); + assert!(p.block_high_risk_commands); } // ── ActionTracker / rate limiting ─────────────────────── @@ -669,6 +1008,14 @@ mod tests { assert!(p.is_command_allowed("ls || echo fallback")); } + #[test] + fn command_injection_background_chain_blocked() { + let p = default_policy(); + assert!(!p.is_command_allowed("ls & rm -rf /")); + assert!(!p.is_command_allowed("ls&rm -rf /")); + assert!(!p.is_command_allowed("echo ok & python3 -c 'print(1)'")); + } + #[test] fn command_injection_redirect_blocked() { let p = default_policy(); @@ -676,12 +1023,43 @@ mod tests { assert!(!p.is_command_allowed("ls >> /tmp/exfil.txt")); } + #[test] + fn command_argument_injection_blocked() { + let p = default_policy(); + // find -exec is a common bypass + assert!(!p.is_command_allowed("find . -exec rm -rf {} +")); + assert!(!p.is_command_allowed("find / -ok cat {} \\;")); + // git config/alias can execute commands + assert!(!p.is_command_allowed("git config core.editor \"rm -rf /\"")); + assert!(!p.is_command_allowed("git alias.st status")); + assert!(!p.is_command_allowed("git -c core.editor=calc.exe commit")); + // Legitimate commands should still work + assert!(p.is_command_allowed("find . -name '*.txt'")); + assert!(p.is_command_allowed("git status")); + assert!(p.is_command_allowed("git add .")); + } + #[test] fn command_injection_dollar_brace_blocked() { let p = default_policy(); assert!(!p.is_command_allowed("echo ${IFS}cat${IFS}/etc/passwd")); } + #[test] + fn command_injection_tee_blocked() { + let p = default_policy(); + assert!(!p.is_command_allowed("echo secret | tee /etc/crontab")); + assert!(!p.is_command_allowed("ls | /usr/bin/tee outfile")); + assert!(!p.is_command_allowed("tee file.txt")); + } + + #[test] + fn command_injection_process_substitution_blocked() { + let p = default_policy(); + assert!(!p.is_command_allowed("cat <(echo pwned)")); + assert!(!p.is_command_allowed("ls >(cat /etc/passwd)")); + } + #[test] fn command_env_var_prefix_with_allowed_cmd() { let p = default_policy(); @@ -704,8 +1082,11 @@ mod tests { #[test] fn path_traversal_double_dot_in_filename() { let p = default_policy(); - // ".." anywhere in the path is blocked (conservative) - assert!(!p.is_path_allowed("my..file.txt")); + // ".." in a filename (not a path component) is allowed + assert!(p.is_path_allowed("my..file.txt")); + // But actual traversal components are still blocked + assert!(!p.is_path_allowed("../etc/passwd")); + assert!(!p.is_path_allowed("foo/../etc/passwd")); } #[test] @@ -819,6 +1200,9 @@ mod tests { forbidden_paths: vec![], max_actions_per_hour: 10, max_cost_per_day_cents: 100, + require_approval_for_medium_risk: true, + block_high_risk_commands: true, + ..crate::config::AutonomyConfig::default() }; let workspace = PathBuf::from("/tmp/test"); let policy = SecurityPolicy::from_config(&autonomy_config, &workspace); diff --git a/src/security/traits.rs b/src/security/traits.rs new file mode 100644 index 0000000..06fc4ef --- /dev/null +++ b/src/security/traits.rs @@ -0,0 +1,81 @@ +//! Sandbox trait for pluggable OS-level isolation + +use async_trait::async_trait; +use std::process::Command; + +/// Sandbox backend for OS-level isolation +#[async_trait] +pub trait Sandbox: Send + Sync { + /// Wrap a command with sandbox protection + fn wrap_command(&self, cmd: &mut Command) -> std::io::Result<()>; + + /// Check if this sandbox backend is available on the current platform + fn is_available(&self) -> bool; + + /// Human-readable name of this sandbox backend + fn name(&self) -> &str; + + /// Description of what this sandbox provides + fn description(&self) -> &str; +} + +/// No-op sandbox (always available, provides no additional isolation) +#[derive(Debug, Clone, Default)] +pub struct NoopSandbox; + +impl Sandbox for NoopSandbox { + fn wrap_command(&self, _cmd: &mut Command) -> std::io::Result<()> { + // Pass through unchanged + Ok(()) + } + + fn is_available(&self) -> bool { + true + } + + fn name(&self) -> &str { + "none" + } + + fn description(&self) -> &str { + "No sandboxing (application-layer security only)" + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn noop_sandbox_name() { + assert_eq!(NoopSandbox.name(), "none"); + } + + #[test] + fn noop_sandbox_is_always_available() { + assert!(NoopSandbox.is_available()); + } + + #[test] + fn noop_sandbox_wrap_command_is_noop() { + let mut cmd = Command::new("echo"); + cmd.arg("test"); + let original_program = cmd.get_program().to_string_lossy().to_string(); + let original_args: Vec = cmd + .get_args() + .map(|s| s.to_string_lossy().to_string()) + .collect(); + + let sandbox = NoopSandbox; + assert!(sandbox.wrap_command(&mut cmd).is_ok()); + + // Command should be unchanged + assert_eq!(cmd.get_program().to_string_lossy(), original_program); + assert_eq!( + cmd.get_args() + .map(|s| s.to_string_lossy().to_string()) + .collect::>(), + original_args + ); + } +} diff --git a/src/service/mod.rs b/src/service/mod.rs index 4b3d2b3..287f446 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -11,13 +11,13 @@ fn windows_task_name() -> &'static str { WINDOWS_TASK_NAME } -pub fn handle_command(command: &super::ServiceCommands, config: &Config) -> Result<()> { +pub fn handle_command(command: &crate::ServiceCommands, config: &Config) -> Result<()> { match command { - super::ServiceCommands::Install => install(config), - super::ServiceCommands::Start => start(config), - super::ServiceCommands::Stop => stop(config), - super::ServiceCommands::Status => status(config), - super::ServiceCommands::Uninstall => uninstall(config), + crate::ServiceCommands::Install => install(config), + crate::ServiceCommands::Start => start(config), + crate::ServiceCommands::Stop => stop(config), + crate::ServiceCommands::Status => status(config), + crate::ServiceCommands::Uninstall => uninstall(config), } } diff --git a/src/skillforge/evaluate.rs b/src/skillforge/evaluate.rs new file mode 100644 index 0000000..bdefd59 --- /dev/null +++ b/src/skillforge/evaluate.rs @@ -0,0 +1,272 @@ +//! Evaluator — scores discovered skill candidates across multiple dimensions. + +use serde::{Deserialize, Serialize}; + +use super::scout::ScoutResult; + +// --------------------------------------------------------------------------- +// Scoring dimensions +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Scores { + /// OS / arch / runtime compatibility (0.0–1.0). + pub compatibility: f64, + /// Code quality signals: stars, tests, docs (0.0–1.0). + pub quality: f64, + /// Security posture: license, known-bad patterns (0.0–1.0). + pub security: f64, +} + +impl Scores { + /// Weighted total. Weights: compatibility 0.3, quality 0.35, security 0.35. + pub fn total(&self) -> f64 { + self.compatibility * 0.30 + self.quality * 0.35 + self.security * 0.35 + } +} + +// --------------------------------------------------------------------------- +// Recommendation +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum Recommendation { + /// Score >= threshold → safe to auto-integrate. + Auto, + /// Score in [0.4, threshold) → needs human review. + Manual, + /// Score < 0.4 → skip entirely. + Skip, +} + +// --------------------------------------------------------------------------- +// EvalResult +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EvalResult { + pub candidate: ScoutResult, + pub scores: Scores, + pub total_score: f64, + pub recommendation: Recommendation, +} + +// --------------------------------------------------------------------------- +// Evaluator +// --------------------------------------------------------------------------- + +pub struct Evaluator { + /// Minimum total score for auto-integration. + min_score: f64, +} + +/// Known-bad patterns in repo names / descriptions (matched as whole words). +const BAD_PATTERNS: &[&str] = &[ + "malware", + "exploit", + "hack", + "crack", + "keygen", + "ransomware", + "trojan", +]; + +/// Check if `haystack` contains `word` as a whole word (bounded by non-alphanumeric chars). +fn contains_word(haystack: &str, word: &str) -> bool { + for (i, _) in haystack.match_indices(word) { + let before_ok = i == 0 || !haystack.as_bytes()[i - 1].is_ascii_alphanumeric(); + let after = i + word.len(); + let after_ok = + after >= haystack.len() || !haystack.as_bytes()[after].is_ascii_alphanumeric(); + if before_ok && after_ok { + return true; + } + } + false +} + +impl Evaluator { + pub fn new(min_score: f64) -> Self { + Self { min_score } + } + + pub fn evaluate(&self, candidate: ScoutResult) -> EvalResult { + let compatibility = self.score_compatibility(&candidate); + let quality = self.score_quality(&candidate); + let security = self.score_security(&candidate); + + let scores = Scores { + compatibility, + quality, + security, + }; + let total_score = scores.total(); + + let recommendation = if total_score >= self.min_score { + Recommendation::Auto + } else if total_score >= 0.4 { + Recommendation::Manual + } else { + Recommendation::Skip + }; + + EvalResult { + candidate, + scores, + total_score, + recommendation, + } + } + + // -- Dimension scorers -------------------------------------------------- + + /// Compatibility: favour Rust repos; penalise unknown languages. + fn score_compatibility(&self, c: &ScoutResult) -> f64 { + match c.language.as_deref() { + Some("Rust") => 1.0, + Some("Python" | "TypeScript" | "JavaScript") => 0.6, + Some(_) => 0.3, + None => 0.2, + } + } + + /// Quality: based on star count (log scale, capped at 1.0). + fn score_quality(&self, c: &ScoutResult) -> f64 { + // log2(stars + 1) / 10, capped at 1.0 + let raw = ((c.stars as f64) + 1.0).log2() / 10.0; + raw.min(1.0) + } + + /// Security: license presence + bad-pattern check. + fn score_security(&self, c: &ScoutResult) -> f64 { + let mut score: f64 = 0.5; + + // License bonus + if c.has_license { + score += 0.3; + } + + // Bad-pattern penalty (whole-word match) + let lower_name = c.name.to_lowercase(); + let lower_desc = c.description.to_lowercase(); + for pat in BAD_PATTERNS { + if contains_word(&lower_name, pat) || contains_word(&lower_desc, pat) { + score -= 0.5; + break; + } + } + + // Recency bonus: updated within last 180 days (guard against future timestamps) + if let Some(updated) = c.updated_at { + let age_days = (chrono::Utc::now() - updated).num_days(); + if (0..180).contains(&age_days) { + score += 0.2; + } + } + + score.clamp(0.0, 1.0) + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use crate::skillforge::scout::{ScoutResult, ScoutSource}; + + fn make_candidate(stars: u64, lang: Option<&str>, has_license: bool) -> ScoutResult { + ScoutResult { + name: "test-skill".into(), + url: "https://github.com/test/test-skill".into(), + description: "A test skill".into(), + stars, + language: lang.map(String::from), + updated_at: Some(chrono::Utc::now()), + source: ScoutSource::GitHub, + owner: "test".into(), + has_license, + } + } + + #[test] + fn high_quality_rust_repo_gets_auto() { + let eval = Evaluator::new(0.7); + let c = make_candidate(500, Some("Rust"), true); + let res = eval.evaluate(c); + assert!(res.total_score >= 0.7, "score: {}", res.total_score); + assert_eq!(res.recommendation, Recommendation::Auto); + } + + #[test] + fn low_star_no_license_gets_manual_or_skip() { + let eval = Evaluator::new(0.7); + let c = make_candidate(1, None, false); + let res = eval.evaluate(c); + assert!(res.total_score < 0.7, "score: {}", res.total_score); + assert_ne!(res.recommendation, Recommendation::Auto); + } + + #[test] + fn bad_pattern_tanks_security() { + let eval = Evaluator::new(0.7); + let mut c = make_candidate(1000, Some("Rust"), true); + c.name = "malware-skill".into(); + let res = eval.evaluate(c); + // 0.5 base + 0.3 license - 0.5 bad_pattern + 0.2 recency = 0.5 + assert!( + res.scores.security <= 0.5, + "security: {}", + res.scores.security + ); + } + + #[test] + fn scores_total_weighted() { + let s = Scores { + compatibility: 1.0, + quality: 1.0, + security: 1.0, + }; + assert!((s.total() - 1.0).abs() < f64::EPSILON); + + let s2 = Scores { + compatibility: 0.0, + quality: 0.0, + security: 0.0, + }; + assert!((s2.total()).abs() < f64::EPSILON); + } + + #[test] + fn hackathon_not_flagged_as_bad() { + let eval = Evaluator::new(0.7); + let mut c = make_candidate(500, Some("Rust"), true); + c.name = "hackathon-tools".into(); + c.description = "Tools for hackathons and lifehacks".into(); + let res = eval.evaluate(c); + // "hack" should NOT match "hackathon" or "lifehacks" + assert!( + res.scores.security >= 0.5, + "security: {}", + res.scores.security + ); + } + + #[test] + fn exact_hack_is_flagged() { + let eval = Evaluator::new(0.7); + let mut c = make_candidate(500, Some("Rust"), false); + c.name = "hack-tool".into(); + c.updated_at = None; + let res = eval.evaluate(c); + // 0.5 base + 0.0 license - 0.5 bad_pattern + 0.0 recency = 0.0 + assert!( + res.scores.security < 0.5, + "security: {}", + res.scores.security + ); + } +} diff --git a/src/skillforge/integrate.rs b/src/skillforge/integrate.rs new file mode 100644 index 0000000..540dd8b --- /dev/null +++ b/src/skillforge/integrate.rs @@ -0,0 +1,248 @@ +//! Integrator — generates ZeroClaw-standard SKILL.toml + SKILL.md from scout results. + +use std::fs; +use std::path::PathBuf; + +use anyhow::{bail, Context, Result}; +use chrono::Utc; +use tracing::info; + +use super::scout::ScoutResult; + +// --------------------------------------------------------------------------- +// Integrator +// --------------------------------------------------------------------------- + +pub struct Integrator { + output_dir: PathBuf, +} + +impl Integrator { + pub fn new(output_dir: String) -> Self { + Self { + output_dir: PathBuf::from(output_dir), + } + } + + /// Write SKILL.toml and SKILL.md for the given candidate. + pub fn integrate(&self, candidate: &ScoutResult) -> Result { + let safe_name = sanitize_path_component(&candidate.name)?; + let skill_dir = self.output_dir.join(&safe_name); + fs::create_dir_all(&skill_dir) + .with_context(|| format!("Failed to create dir: {}", skill_dir.display()))?; + + let toml_path = skill_dir.join("SKILL.toml"); + let md_path = skill_dir.join("SKILL.md"); + + let toml_content = self.generate_toml(candidate); + let md_content = self.generate_md(candidate); + + fs::write(&toml_path, &toml_content) + .with_context(|| format!("Failed to write {}", toml_path.display()))?; + fs::write(&md_path, &md_content) + .with_context(|| format!("Failed to write {}", md_path.display()))?; + + info!( + skill = candidate.name.as_str(), + path = %skill_dir.display(), + "Integrated skill" + ); + + Ok(skill_dir) + } + + // -- Generators --------------------------------------------------------- + + fn generate_toml(&self, c: &ScoutResult) -> String { + let lang = c.language.as_deref().unwrap_or("unknown"); + let updated = c + .updated_at + .map(|d| d.format("%Y-%m-%d").to_string()) + .unwrap_or_else(|| "unknown".into()); + + format!( + r#"# Auto-generated by SkillForge on {now} + +[skill] +name = "{name}" +version = "0.1.0" +description = "{description}" +source = "{url}" +owner = "{owner}" +language = "{lang}" +license = {license} +stars = {stars} +updated_at = "{updated}" + +[skill.requirements] +runtime = "zeroclaw >= 0.1" + +[skill.metadata] +auto_integrated = true +forge_timestamp = "{now}" +"#, + now = Utc::now().format("%Y-%m-%dT%H:%M:%SZ"), + name = escape_toml(&c.name), + description = escape_toml(&c.description), + url = escape_toml(&c.url), + owner = escape_toml(&c.owner), + lang = lang, + license = if c.has_license { "true" } else { "false" }, + stars = c.stars, + updated = updated, + ) + } + + fn generate_md(&self, c: &ScoutResult) -> String { + let lang = c.language.as_deref().unwrap_or("unknown"); + format!( + r#"# {name} + +> Auto-generated by SkillForge + +## Overview + +- **Source**: [{url}]({url}) +- **Owner**: {owner} +- **Language**: {lang} +- **Stars**: {stars} +- **License**: {license} + +## Description + +{description} + +## Usage + +```toml +# Add to your ZeroClaw config: +[skills.{name}] +enabled = true +``` + +## Notes + +This manifest was auto-generated from repository metadata. +Review before enabling in production. +"#, + name = c.name, + url = c.url, + owner = c.owner, + lang = lang, + stars = c.stars, + license = if c.has_license { "yes" } else { "unknown" }, + description = c.description, + ) + } +} + +/// Escape special characters for TOML basic string values. +fn escape_toml(s: &str) -> String { + s.replace('\\', "\\\\") + .replace('"', "\\\"") + .replace('\n', "\\n") + .replace('\r', "\\r") + .replace('\t', "\\t") + .replace('\u{08}', "\\b") + .replace('\u{0C}', "\\f") +} + +/// Sanitize a string for use as a single path component. +/// Rejects empty names, "..", and names containing path separators or NUL. +fn sanitize_path_component(name: &str) -> Result { + let trimmed = name.trim().trim_matches('.'); + if trimmed.is_empty() { + bail!("Skill name is empty or only dots after sanitization"); + } + let sanitized: String = trimmed + .chars() + .map(|c| match c { + '/' | '\\' | '\0' => '_', + _ => c, + }) + .collect(); + if sanitized == ".." || sanitized.contains('/') || sanitized.contains('\\') { + bail!("Skill name '{}' is unsafe as a path component", name); + } + Ok(sanitized) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use crate::skillforge::scout::{ScoutResult, ScoutSource}; + use std::fs; + + fn sample_candidate() -> ScoutResult { + ScoutResult { + name: "test-skill".into(), + url: "https://github.com/user/test-skill".into(), + description: "A test skill for unit tests".into(), + stars: 42, + language: Some("Rust".into()), + updated_at: Some(Utc::now()), + source: ScoutSource::GitHub, + owner: "user".into(), + has_license: true, + } + } + + #[test] + fn integrate_creates_files() { + let tmp = std::env::temp_dir().join("zeroclaw-test-integrate"); + let _ = fs::remove_dir_all(&tmp); + + let integrator = Integrator::new(tmp.to_string_lossy().into_owned()); + let c = sample_candidate(); + let path = integrator.integrate(&c).unwrap(); + + assert!(path.join("SKILL.toml").exists()); + assert!(path.join("SKILL.md").exists()); + + let toml = fs::read_to_string(path.join("SKILL.toml")).unwrap(); + assert!(toml.contains("name = \"test-skill\"")); + assert!(toml.contains("stars = 42")); + + let md = fs::read_to_string(path.join("SKILL.md")).unwrap(); + assert!(md.contains("# test-skill")); + assert!(md.contains("A test skill for unit tests")); + + let _ = fs::remove_dir_all(&tmp); + } + + #[test] + fn escape_toml_handles_quotes_and_control_chars() { + assert_eq!(escape_toml(r#"say "hello""#), r#"say \"hello\""#); + assert_eq!(escape_toml(r"back\slash"), r"back\\slash"); + assert_eq!(escape_toml("line\nbreak"), "line\\nbreak"); + assert_eq!(escape_toml("tab\there"), "tab\\there"); + assert_eq!(escape_toml("cr\rhere"), "cr\\rhere"); + } + + #[test] + fn sanitize_rejects_traversal() { + assert!(sanitize_path_component("..").is_err()); + assert!(sanitize_path_component("...").is_err()); + assert!(sanitize_path_component("").is_err()); + assert!(sanitize_path_component(" ").is_err()); + } + + #[test] + fn sanitize_replaces_separators() { + let s = sanitize_path_component("foo/bar\\baz\0qux").unwrap(); + assert!(!s.contains('/')); + assert!(!s.contains('\\')); + assert!(!s.contains('\0')); + assert_eq!(s, "foo_bar_baz_qux"); + } + + #[test] + fn sanitize_trims_dots() { + let s = sanitize_path_component(".hidden.").unwrap(); + assert_eq!(s, "hidden"); + } +} diff --git a/src/skillforge/mod.rs b/src/skillforge/mod.rs new file mode 100644 index 0000000..17c2336 --- /dev/null +++ b/src/skillforge/mod.rs @@ -0,0 +1,255 @@ +//! SkillForge — Skill auto-discovery, evaluation, and integration engine. +//! +//! Pipeline: Scout → Evaluate → Integrate +//! Discovers skills from external sources, scores them, and generates +//! ZeroClaw-compatible manifests for qualified candidates. + +pub mod evaluate; +pub mod integrate; +pub mod scout; + +use anyhow::Result; +use serde::{Deserialize, Serialize}; +use tracing::{info, warn}; + +use self::evaluate::{EvalResult, Evaluator, Recommendation}; +use self::integrate::Integrator; +use self::scout::{GitHubScout, Scout, ScoutResult, ScoutSource}; + +// --------------------------------------------------------------------------- +// Configuration +// --------------------------------------------------------------------------- + +#[derive(Clone, Serialize, Deserialize)] +pub struct SkillForgeConfig { + #[serde(default)] + pub enabled: bool, + #[serde(default = "default_auto_integrate")] + pub auto_integrate: bool, + #[serde(default = "default_sources")] + pub sources: Vec, + #[serde(default = "default_scan_interval")] + pub scan_interval_hours: u64, + #[serde(default = "default_min_score")] + pub min_score: f64, + /// Optional GitHub personal-access token for higher rate limits. + #[serde(default)] + pub github_token: Option, + /// Directory where integrated skills are written. + #[serde(default = "default_output_dir")] + pub output_dir: String, +} + +fn default_auto_integrate() -> bool { + true +} +fn default_sources() -> Vec { + vec!["github".into(), "clawhub".into()] +} +fn default_scan_interval() -> u64 { + 24 +} +fn default_min_score() -> f64 { + 0.7 +} +fn default_output_dir() -> String { + "./skills".into() +} + +impl Default for SkillForgeConfig { + fn default() -> Self { + Self { + enabled: false, + auto_integrate: default_auto_integrate(), + sources: default_sources(), + scan_interval_hours: default_scan_interval(), + min_score: default_min_score(), + github_token: None, + output_dir: default_output_dir(), + } + } +} + +impl std::fmt::Debug for SkillForgeConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SkillForgeConfig") + .field("enabled", &self.enabled) + .field("auto_integrate", &self.auto_integrate) + .field("sources", &self.sources) + .field("scan_interval_hours", &self.scan_interval_hours) + .field("min_score", &self.min_score) + .field("github_token", &self.github_token.as_ref().map(|_| "***")) + .field("output_dir", &self.output_dir) + .finish() + } +} + +// --------------------------------------------------------------------------- +// ForgeReport — summary of a single pipeline run +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ForgeReport { + pub discovered: usize, + pub evaluated: usize, + pub auto_integrated: usize, + pub manual_review: usize, + pub skipped: usize, + pub results: Vec, +} + +// --------------------------------------------------------------------------- +// SkillForge +// --------------------------------------------------------------------------- + +pub struct SkillForge { + config: SkillForgeConfig, + evaluator: Evaluator, + integrator: Integrator, +} + +impl SkillForge { + pub fn new(config: SkillForgeConfig) -> Self { + let evaluator = Evaluator::new(config.min_score); + let integrator = Integrator::new(config.output_dir.clone()); + Self { + config, + evaluator, + integrator, + } + } + + /// Run the full pipeline: Scout → Evaluate → Integrate. + pub async fn forge(&self) -> Result { + if !self.config.enabled { + warn!("SkillForge is disabled — skipping"); + return Ok(ForgeReport { + discovered: 0, + evaluated: 0, + auto_integrated: 0, + manual_review: 0, + skipped: 0, + results: vec![], + }); + } + + // --- Scout ---------------------------------------------------------- + let mut candidates: Vec = Vec::new(); + + for src in &self.config.sources { + let source: ScoutSource = src.parse().unwrap(); // Infallible + match source { + ScoutSource::GitHub => { + let scout = GitHubScout::new(self.config.github_token.clone()); + match scout.discover().await { + Ok(mut found) => { + info!(count = found.len(), "GitHub scout returned candidates"); + candidates.append(&mut found); + } + Err(e) => { + warn!(error = %e, "GitHub scout failed, continuing with other sources"); + } + } + } + ScoutSource::ClawHub | ScoutSource::HuggingFace => { + info!( + source = src.as_str(), + "Source not yet implemented — skipping" + ); + } + } + } + + // Deduplicate by URL + scout::dedup(&mut candidates); + let discovered = candidates.len(); + info!(discovered, "Total unique candidates after dedup"); + + // --- Evaluate ------------------------------------------------------- + let results: Vec = candidates + .into_iter() + .map(|c| self.evaluator.evaluate(c)) + .collect(); + let evaluated = results.len(); + + // --- Integrate ------------------------------------------------------ + let mut auto_integrated = 0usize; + let mut manual_review = 0usize; + let mut skipped = 0usize; + + for res in &results { + match res.recommendation { + Recommendation::Auto => { + if self.config.auto_integrate { + match self.integrator.integrate(&res.candidate) { + Ok(_) => { + auto_integrated += 1; + } + Err(e) => { + warn!( + skill = res.candidate.name.as_str(), + error = %e, + "Integration failed for candidate, continuing" + ); + } + } + } else { + // Count as would-be auto but not actually integrated + manual_review += 1; + } + } + Recommendation::Manual => { + manual_review += 1; + } + Recommendation::Skip => { + skipped += 1; + } + } + } + + info!( + auto_integrated, + manual_review, skipped, "Forge pipeline complete" + ); + + Ok(ForgeReport { + discovered, + evaluated, + auto_integrated, + manual_review, + skipped, + results, + }) + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn disabled_forge_returns_empty_report() { + let cfg = SkillForgeConfig { + enabled: false, + ..Default::default() + }; + let forge = SkillForge::new(cfg); + let report = forge.forge().await.unwrap(); + assert_eq!(report.discovered, 0); + assert_eq!(report.auto_integrated, 0); + } + + #[test] + fn default_config_values() { + let cfg = SkillForgeConfig::default(); + assert!(!cfg.enabled); + assert!(cfg.auto_integrate); + assert_eq!(cfg.scan_interval_hours, 24); + assert!((cfg.min_score - 0.7).abs() < f64::EPSILON); + assert_eq!(cfg.sources, vec!["github", "clawhub"]); + } +} diff --git a/src/skillforge/scout.rs b/src/skillforge/scout.rs new file mode 100644 index 0000000..1ad8af4 --- /dev/null +++ b/src/skillforge/scout.rs @@ -0,0 +1,339 @@ +//! Scout — skill discovery from external sources. + +use anyhow::Result; +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use tracing::{debug, warn}; + +// --------------------------------------------------------------------------- +// ScoutSource +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum ScoutSource { + GitHub, + ClawHub, + HuggingFace, +} + +impl std::str::FromStr for ScoutSource { + type Err = std::convert::Infallible; + + fn from_str(s: &str) -> std::result::Result { + Ok(match s.to_lowercase().as_str() { + "github" => Self::GitHub, + "clawhub" => Self::ClawHub, + "huggingface" | "hf" => Self::HuggingFace, + _ => { + warn!(source = s, "Unknown scout source, defaulting to GitHub"); + Self::GitHub + } + }) + } +} + +// --------------------------------------------------------------------------- +// ScoutResult +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ScoutResult { + pub name: String, + pub url: String, + pub description: String, + pub stars: u64, + pub language: Option, + pub updated_at: Option>, + pub source: ScoutSource, + /// Owner / org extracted from the URL or API response. + pub owner: String, + /// Whether the repo has a license file. + pub has_license: bool, +} + +// --------------------------------------------------------------------------- +// Scout trait +// --------------------------------------------------------------------------- + +#[async_trait] +pub trait Scout: Send + Sync { + /// Discover candidate skills from the source. + async fn discover(&self) -> Result>; +} + +// --------------------------------------------------------------------------- +// GitHubScout +// --------------------------------------------------------------------------- + +/// Searches GitHub for repos matching skill-related queries. +pub struct GitHubScout { + client: reqwest::Client, + queries: Vec, +} + +impl GitHubScout { + pub fn new(token: Option) -> Self { + use std::time::Duration; + + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + reqwest::header::ACCEPT, + "application/vnd.github+json".parse().expect("valid header"), + ); + headers.insert( + reqwest::header::USER_AGENT, + "ZeroClaw-SkillForge/0.1".parse().expect("valid header"), + ); + if let Some(ref t) = token { + if let Ok(val) = format!("Bearer {t}").parse() { + headers.insert(reqwest::header::AUTHORIZATION, val); + } + } + + let client = reqwest::Client::builder() + .default_headers(headers) + .timeout(Duration::from_secs(30)) + .build() + .expect("failed to build reqwest client"); + + Self { + client, + queries: vec!["zeroclaw skill".into(), "ai agent skill".into()], + } + } + + /// Parse the GitHub search/repositories JSON response. + fn parse_items(body: &serde_json::Value) -> Vec { + let items = match body.get("items").and_then(|v| v.as_array()) { + Some(arr) => arr, + None => return vec![], + }; + + items + .iter() + .filter_map(|item| { + let name = item.get("name")?.as_str()?.to_string(); + let url = item.get("html_url")?.as_str()?.to_string(); + let description = item + .get("description") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let stars = item + .get("stargazers_count") + .and_then(|v| v.as_u64()) + .unwrap_or(0); + let language = item + .get("language") + .and_then(|v| v.as_str()) + .map(String::from); + let updated_at = item + .get("updated_at") + .and_then(|v| v.as_str()) + .and_then(|s| s.parse::>().ok()); + let owner = item + .get("owner") + .and_then(|o| o.get("login")) + .and_then(|v| v.as_str()) + .unwrap_or("unknown") + .to_string(); + let has_license = item.get("license").map(|v| !v.is_null()).unwrap_or(false); + + Some(ScoutResult { + name, + url, + description, + stars, + language, + updated_at, + source: ScoutSource::GitHub, + owner, + has_license, + }) + }) + .collect() + } +} + +#[async_trait] +impl Scout for GitHubScout { + async fn discover(&self) -> Result> { + let mut all: Vec = Vec::new(); + + for query in &self.queries { + let url = format!( + "https://api.github.com/search/repositories?q={}&sort=stars&order=desc&per_page=30", + urlencoding(query) + ); + debug!(query = query.as_str(), "Searching GitHub"); + + let resp = match self.client.get(&url).send().await { + Ok(r) => r, + Err(e) => { + warn!( + query = query.as_str(), + error = %e, + "GitHub API request failed, skipping query" + ); + continue; + } + }; + + if !resp.status().is_success() { + warn!( + status = %resp.status(), + query = query.as_str(), + "GitHub search returned non-200" + ); + continue; + } + + let body: serde_json::Value = match resp.json().await { + Ok(v) => v, + Err(e) => { + warn!( + query = query.as_str(), + error = %e, + "Failed to parse GitHub response, skipping query" + ); + continue; + } + }; + + let mut items = Self::parse_items(&body); + debug!(count = items.len(), query = query.as_str(), "Parsed items"); + all.append(&mut items); + } + + dedup(&mut all); + Ok(all) + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/// Minimal percent-encoding for query strings (space → +). +fn urlencoding(s: &str) -> String { + s.replace(' ', "+").replace('&', "%26").replace('#', "%23") +} + +/// Deduplicate scout results by URL (keeps first occurrence). +pub fn dedup(results: &mut Vec) { + let mut seen = std::collections::HashSet::new(); + results.retain(|r| seen.insert(r.url.clone())); +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn scout_source_from_str() { + assert_eq!( + "github".parse::().unwrap(), + ScoutSource::GitHub + ); + assert_eq!( + "GitHub".parse::().unwrap(), + ScoutSource::GitHub + ); + assert_eq!( + "clawhub".parse::().unwrap(), + ScoutSource::ClawHub + ); + assert_eq!( + "huggingface".parse::().unwrap(), + ScoutSource::HuggingFace + ); + assert_eq!( + "hf".parse::().unwrap(), + ScoutSource::HuggingFace + ); + // unknown falls back to GitHub + assert_eq!( + "unknown".parse::().unwrap(), + ScoutSource::GitHub + ); + } + + #[test] + fn dedup_removes_duplicates() { + let mut results = vec![ + ScoutResult { + name: "a".into(), + url: "https://github.com/x/a".into(), + description: String::new(), + stars: 10, + language: None, + updated_at: None, + source: ScoutSource::GitHub, + owner: "x".into(), + has_license: true, + }, + ScoutResult { + name: "a-dup".into(), + url: "https://github.com/x/a".into(), + description: String::new(), + stars: 10, + language: None, + updated_at: None, + source: ScoutSource::GitHub, + owner: "x".into(), + has_license: true, + }, + ScoutResult { + name: "b".into(), + url: "https://github.com/x/b".into(), + description: String::new(), + stars: 5, + language: None, + updated_at: None, + source: ScoutSource::GitHub, + owner: "x".into(), + has_license: false, + }, + ]; + dedup(&mut results); + assert_eq!(results.len(), 2); + assert_eq!(results[0].name, "a"); + assert_eq!(results[1].name, "b"); + } + + #[test] + fn parse_github_items() { + let json = serde_json::json!({ + "total_count": 1, + "items": [ + { + "name": "cool-skill", + "html_url": "https://github.com/user/cool-skill", + "description": "A cool skill", + "stargazers_count": 42, + "language": "Rust", + "updated_at": "2026-01-15T10:00:00Z", + "owner": { "login": "user" }, + "license": { "spdx_id": "MIT" } + } + ] + }); + let items = GitHubScout::parse_items(&json); + assert_eq!(items.len(), 1); + assert_eq!(items[0].name, "cool-skill"); + assert_eq!(items[0].stars, 42); + assert!(items[0].has_license); + assert_eq!(items[0].owner, "user"); + } + + #[test] + fn urlencoding_works() { + assert_eq!(urlencoding("hello world"), "hello+world"); + assert_eq!(urlencoding("a&b#c"), "a%26b%23c"); + } +} diff --git a/src/skills/mod.rs b/src/skills/mod.rs index 6bf43f0..4db6cbb 100644 --- a/src/skills/mod.rs +++ b/src/skills/mod.rs @@ -453,9 +453,9 @@ fn copy_dir_recursive(src: &Path, dest: &Path) -> Result<()> { /// Handle the `skills` CLI command #[allow(clippy::too_many_lines)] -pub fn handle_command(command: super::SkillCommands, workspace_dir: &Path) -> Result<()> { +pub fn handle_command(command: crate::SkillCommands, workspace_dir: &Path) -> Result<()> { match command { - super::SkillCommands::List => { + crate::SkillCommands::List => { let skills = load_skills(workspace_dir); if skills.is_empty() { println!("No skills installed."); @@ -493,13 +493,13 @@ pub fn handle_command(command: super::SkillCommands, workspace_dir: &Path) -> Re println!(); Ok(()) } - super::SkillCommands::Install { source } => { + crate::SkillCommands::Install { source } => { println!("Installing skill from: {source}"); let skills_path = skills_dir(workspace_dir); std::fs::create_dir_all(&skills_path)?; - if source.starts_with("http") || source.contains("github.com") { + if source.starts_with("https://") || source.starts_with("http://") { // Git clone let output = std::process::Command::new("git") .args(["clone", "--depth", "1", &source]) @@ -584,8 +584,24 @@ pub fn handle_command(command: super::SkillCommands, workspace_dir: &Path) -> Re Ok(()) } - super::SkillCommands::Remove { name } => { + crate::SkillCommands::Remove { name } => { + // Reject path traversal attempts + if name.contains("..") || name.contains('/') || name.contains('\\') { + anyhow::bail!("Invalid skill name: {name}"); + } + let skill_path = skills_dir(workspace_dir).join(&name); + + // Verify the resolved path is actually inside the skills directory + let canonical_skills = skills_dir(workspace_dir) + .canonicalize() + .unwrap_or_else(|_| skills_dir(workspace_dir)); + if let Ok(canonical_skill) = skill_path.canonicalize() { + if !canonical_skill.starts_with(&canonical_skills) { + anyhow::bail!("Skill path escapes skills directory: {name}"); + } + } + if !skill_path.exists() { anyhow::bail!("Skill not found: {name}"); } diff --git a/src/tools/browser.rs b/src/tools/browser.rs index 2dbec77..4e3d59e 100644 --- a/src/tools/browser.rs +++ b/src/tools/browser.rs @@ -1,24 +1,100 @@ -//! Browser automation tool using Vercel's agent-browser CLI +//! Browser automation tool with pluggable backends. //! -//! This tool provides AI-optimized web browsing capabilities via the agent-browser CLI. -//! It supports semantic element selection, accessibility snapshots, and JSON output -//! for efficient LLM integration. +//! By default this uses Vercel's `agent-browser` CLI for automation. +//! Optionally, a Rust-native backend can be enabled at build time via +//! `--features browser-native` and selected through config. +//! Computer-use (OS-level) actions are supported via an optional sidecar endpoint. use super::traits::{Tool, ToolResult}; use crate::security::SecurityPolicy; +use anyhow::Context; use async_trait::async_trait; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; +use std::net::ToSocketAddrs; use std::process::Stdio; use std::sync::Arc; +use std::time::Duration; use tokio::process::Command; use tracing::debug; -/// Browser automation tool using agent-browser CLI +/// Computer-use sidecar settings. +#[derive(Debug, Clone)] +pub struct ComputerUseConfig { + pub endpoint: String, + pub api_key: Option, + pub timeout_ms: u64, + pub allow_remote_endpoint: bool, + pub window_allowlist: Vec, + pub max_coordinate_x: Option, + pub max_coordinate_y: Option, +} + +impl Default for ComputerUseConfig { + fn default() -> Self { + Self { + endpoint: "http://127.0.0.1:8787/v1/actions".into(), + api_key: None, + timeout_ms: 15_000, + allow_remote_endpoint: false, + window_allowlist: Vec::new(), + max_coordinate_x: None, + max_coordinate_y: None, + } + } +} + +/// Browser automation tool using pluggable backends. pub struct BrowserTool { security: Arc, allowed_domains: Vec, session_name: Option, + backend: String, + native_headless: bool, + native_webdriver_url: String, + native_chrome_path: Option, + computer_use: ComputerUseConfig, + #[cfg(feature = "browser-native")] + native_state: tokio::sync::Mutex, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum BrowserBackendKind { + AgentBrowser, + RustNative, + ComputerUse, + Auto, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ResolvedBackend { + AgentBrowser, + RustNative, + ComputerUse, +} + +impl BrowserBackendKind { + fn parse(raw: &str) -> anyhow::Result { + let key = raw.trim().to_ascii_lowercase().replace('-', "_"); + match key.as_str() { + "agent_browser" | "agentbrowser" => Ok(Self::AgentBrowser), + "rust_native" | "native" => Ok(Self::RustNative), + "computer_use" | "computeruse" => Ok(Self::ComputerUse), + "auto" => Ok(Self::Auto), + _ => anyhow::bail!( + "Unsupported browser backend '{raw}'. Use 'agent_browser', 'rust_native', 'computer_use', or 'auto'" + ), + } + } + + fn as_str(self) -> &'static str { + match self { + Self::AgentBrowser => "agent_browser", + Self::RustNative => "rust_native", + Self::ComputerUse => "computer_use", + Self::Auto => "auto", + } + } } /// Response from agent-browser --json commands @@ -29,6 +105,17 @@ struct AgentBrowserResponse { error: Option, } +/// Response format from computer-use sidecar. +#[derive(Debug, Deserialize)] +struct ComputerUseResponse { + #[serde(default)] + success: Option, + #[serde(default)] + data: Option, + #[serde(default)] + error: Option, +} + /// Supported browser actions #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] @@ -101,16 +188,46 @@ impl BrowserTool { security: Arc, allowed_domains: Vec, session_name: Option, + ) -> Self { + Self::new_with_backend( + security, + allowed_domains, + session_name, + "agent_browser".into(), + true, + "http://127.0.0.1:9515".into(), + None, + ComputerUseConfig::default(), + ) + } + + #[allow(clippy::too_many_arguments)] + pub fn new_with_backend( + security: Arc, + allowed_domains: Vec, + session_name: Option, + backend: String, + native_headless: bool, + native_webdriver_url: String, + native_chrome_path: Option, + computer_use: ComputerUseConfig, ) -> Self { Self { security, allowed_domains: normalize_domains(allowed_domains), session_name, + backend, + native_headless, + native_webdriver_url, + native_chrome_path, + computer_use, + #[cfg(feature = "browser-native")] + native_state: tokio::sync::Mutex::new(native_backend::NativeBrowserState::default()), } } /// Check if agent-browser CLI is available - pub async fn is_available() -> bool { + pub async fn is_agent_browser_available() -> bool { Command::new("agent-browser") .arg("--version") .stdout(Stdio::null()) @@ -121,6 +238,153 @@ impl BrowserTool { .unwrap_or(false) } + /// Backward-compatible alias. + pub async fn is_available() -> bool { + Self::is_agent_browser_available().await + } + + fn configured_backend(&self) -> anyhow::Result { + BrowserBackendKind::parse(&self.backend) + } + + fn rust_native_compiled() -> bool { + cfg!(feature = "browser-native") + } + + fn rust_native_available(&self) -> bool { + #[cfg(feature = "browser-native")] + { + native_backend::NativeBrowserState::is_available( + self.native_headless, + &self.native_webdriver_url, + self.native_chrome_path.as_deref(), + ) + } + #[cfg(not(feature = "browser-native"))] + { + false + } + } + + fn computer_use_endpoint_url(&self) -> anyhow::Result { + if self.computer_use.timeout_ms == 0 { + anyhow::bail!("browser.computer_use.timeout_ms must be > 0"); + } + + let endpoint = self.computer_use.endpoint.trim(); + if endpoint.is_empty() { + anyhow::bail!("browser.computer_use.endpoint cannot be empty"); + } + + let parsed = reqwest::Url::parse(endpoint).map_err(|_| { + anyhow::anyhow!( + "Invalid browser.computer_use.endpoint: '{endpoint}'. Expected http(s) URL" + ) + })?; + + let scheme = parsed.scheme(); + if scheme != "http" && scheme != "https" { + anyhow::bail!("browser.computer_use.endpoint must use http:// or https://"); + } + + let host = parsed + .host_str() + .ok_or_else(|| anyhow::anyhow!("browser.computer_use.endpoint must include host"))?; + + let host_is_private = is_private_host(host); + if !self.computer_use.allow_remote_endpoint && !host_is_private { + anyhow::bail!( + "browser.computer_use.endpoint host '{host}' is public. Set browser.computer_use.allow_remote_endpoint=true to allow it" + ); + } + + if self.computer_use.allow_remote_endpoint && !host_is_private && scheme != "https" { + anyhow::bail!( + "browser.computer_use.endpoint must use https:// when allow_remote_endpoint=true and host is public" + ); + } + + Ok(parsed) + } + + fn computer_use_available(&self) -> anyhow::Result { + let endpoint = self.computer_use_endpoint_url()?; + Ok(endpoint_reachable(&endpoint, Duration::from_millis(500))) + } + + async fn resolve_backend(&self) -> anyhow::Result { + let configured = self.configured_backend()?; + + match configured { + BrowserBackendKind::AgentBrowser => { + if Self::is_agent_browser_available().await { + Ok(ResolvedBackend::AgentBrowser) + } else { + anyhow::bail!( + "browser.backend='{}' but agent-browser CLI is unavailable. Install with: npm install -g agent-browser", + configured.as_str() + ) + } + } + BrowserBackendKind::RustNative => { + if !Self::rust_native_compiled() { + anyhow::bail!( + "browser.backend='rust_native' requires build feature 'browser-native'" + ); + } + if !self.rust_native_available() { + anyhow::bail!( + "Rust-native browser backend is enabled but WebDriver endpoint is unreachable. Set browser.native_webdriver_url and start a compatible driver" + ); + } + Ok(ResolvedBackend::RustNative) + } + BrowserBackendKind::ComputerUse => { + if !self.computer_use_available()? { + anyhow::bail!( + "browser.backend='computer_use' but sidecar endpoint is unreachable. Check browser.computer_use.endpoint and sidecar status" + ); + } + Ok(ResolvedBackend::ComputerUse) + } + BrowserBackendKind::Auto => { + if Self::rust_native_compiled() && self.rust_native_available() { + return Ok(ResolvedBackend::RustNative); + } + if Self::is_agent_browser_available().await { + return Ok(ResolvedBackend::AgentBrowser); + } + + let computer_use_err = match self.computer_use_available() { + Ok(true) => return Ok(ResolvedBackend::ComputerUse), + Ok(false) => None, + Err(err) => Some(err.to_string()), + }; + + if Self::rust_native_compiled() { + if let Some(err) = computer_use_err { + anyhow::bail!( + "browser.backend='auto' found no usable backend (agent-browser missing, rust-native unavailable, computer-use invalid: {err})" + ); + } + anyhow::bail!( + "browser.backend='auto' found no usable backend (agent-browser missing, rust-native unavailable, computer-use sidecar unreachable)" + ) + } + + if let Some(err) = computer_use_err { + anyhow::bail!( + "browser.backend='auto' needs agent-browser CLI, browser-native, or valid computer-use sidecar (error: {err})" + ); + } + + anyhow::bail!( + "browser.backend='auto' needs agent-browser CLI, browser-native, or computer-use sidecar" + ) + } + } + } + /// Validate URL against allowlist fn validate_url(&self, url: &str) -> anyhow::Result<()> { let url = url.trim(); @@ -129,9 +393,10 @@ impl BrowserTool { anyhow::bail!("URL cannot be empty"); } - // Allow file:// URLs for local testing + // Block file:// URLs — browser file access bypasses all SSRF and + // domain-allowlist controls and can exfiltrate arbitrary local files. if url.starts_with("file://") { - return Ok(()); + anyhow::bail!("file:// URLs are not allowed in browser automation"); } if !url.starts_with("https://") && !url.starts_with("http://") { @@ -206,9 +471,12 @@ impl BrowserTool { } } - /// Execute a browser action + /// Execute a browser action via agent-browser CLI #[allow(clippy::too_many_lines)] - async fn execute_action(&self, action: BrowserAction) -> anyhow::Result { + async fn execute_agent_browser_action( + &self, + action: BrowserAction, + ) -> anyhow::Result { match action { BrowserAction::Open { url } => { self.validate_url(&url)?; @@ -343,6 +611,227 @@ impl BrowserTool { } } + #[allow(clippy::unused_async)] + async fn execute_rust_native_action( + &self, + action: BrowserAction, + ) -> anyhow::Result { + #[cfg(feature = "browser-native")] + { + let mut state = self.native_state.lock().await; + + let output = state + .execute_action( + action, + self.native_headless, + &self.native_webdriver_url, + self.native_chrome_path.as_deref(), + ) + .await?; + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&output).unwrap_or_default(), + error: None, + }) + } + + #[cfg(not(feature = "browser-native"))] + { + let _ = action; + anyhow::bail!( + "Rust-native browser backend is not compiled. Rebuild with --features browser-native" + ) + } + } + + fn validate_coordinate(&self, key: &str, value: i64, max: Option) -> anyhow::Result<()> { + if value < 0 { + anyhow::bail!("'{key}' must be >= 0") + } + if let Some(limit) = max { + if limit < 0 { + anyhow::bail!("Configured coordinate limit for '{key}' must be >= 0") + } + if value > limit { + anyhow::bail!("'{key}'={value} exceeds configured limit {limit}") + } + } + Ok(()) + } + + fn read_required_i64( + &self, + params: &serde_json::Map, + key: &str, + ) -> anyhow::Result { + params + .get(key) + .and_then(Value::as_i64) + .ok_or_else(|| anyhow::anyhow!("Missing or invalid '{key}' parameter")) + } + + fn validate_computer_use_action( + &self, + action: &str, + params: &serde_json::Map, + ) -> anyhow::Result<()> { + match action { + "open" => { + let url = params + .get("url") + .and_then(Value::as_str) + .ok_or_else(|| anyhow::anyhow!("Missing 'url' for open action"))?; + self.validate_url(url)?; + } + "mouse_move" | "mouse_click" => { + let x = self.read_required_i64(params, "x")?; + let y = self.read_required_i64(params, "y")?; + self.validate_coordinate("x", x, self.computer_use.max_coordinate_x)?; + self.validate_coordinate("y", y, self.computer_use.max_coordinate_y)?; + } + "mouse_drag" => { + let from_x = self.read_required_i64(params, "from_x")?; + let from_y = self.read_required_i64(params, "from_y")?; + let to_x = self.read_required_i64(params, "to_x")?; + let to_y = self.read_required_i64(params, "to_y")?; + self.validate_coordinate("from_x", from_x, self.computer_use.max_coordinate_x)?; + self.validate_coordinate("to_x", to_x, self.computer_use.max_coordinate_x)?; + self.validate_coordinate("from_y", from_y, self.computer_use.max_coordinate_y)?; + self.validate_coordinate("to_y", to_y, self.computer_use.max_coordinate_y)?; + } + _ => {} + } + Ok(()) + } + + async fn execute_computer_use_action( + &self, + action: &str, + args: &Value, + ) -> anyhow::Result { + let endpoint = self.computer_use_endpoint_url()?; + + let mut params = args + .as_object() + .cloned() + .ok_or_else(|| anyhow::anyhow!("browser args must be a JSON object"))?; + params.remove("action"); + + self.validate_computer_use_action(action, ¶ms)?; + + let payload = json!({ + "action": action, + "params": params, + "policy": { + "allowed_domains": self.allowed_domains, + "window_allowlist": self.computer_use.window_allowlist, + "max_coordinate_x": self.computer_use.max_coordinate_x, + "max_coordinate_y": self.computer_use.max_coordinate_y, + }, + "metadata": { + "session_name": self.session_name, + "source": "zeroclaw.browser", + "version": env!("CARGO_PKG_VERSION"), + } + }); + + let client = reqwest::Client::new(); + let mut request = client + .post(endpoint) + .timeout(Duration::from_millis(self.computer_use.timeout_ms)) + .json(&payload); + + if let Some(api_key) = self.computer_use.api_key.as_deref() { + let token = api_key.trim(); + if !token.is_empty() { + request = request.bearer_auth(token); + } + } + + let response = request.send().await.with_context(|| { + format!( + "Failed to call computer-use sidecar at {}", + self.computer_use.endpoint + ) + })?; + + let status = response.status(); + let body = response + .text() + .await + .context("Failed to read computer-use sidecar response body")?; + + if let Ok(parsed) = serde_json::from_str::(&body) { + if status.is_success() && parsed.success.unwrap_or(true) { + let output = parsed + .data + .map(|data| serde_json::to_string_pretty(&data).unwrap_or_default()) + .unwrap_or_else(|| { + serde_json::to_string_pretty(&json!({ + "backend": "computer_use", + "action": action, + "ok": true, + })) + .unwrap_or_default() + }); + + return Ok(ToolResult { + success: true, + output, + error: None, + }); + } + + let error = parsed.error.or_else(|| { + if status.is_success() && parsed.success == Some(false) { + Some("computer-use sidecar returned success=false".to_string()) + } else { + Some(format!( + "computer-use sidecar request failed with status {status}" + )) + } + }); + + return Ok(ToolResult { + success: false, + output: String::new(), + error, + }); + } + + if status.is_success() { + return Ok(ToolResult { + success: true, + output: body, + error: None, + }); + } + + Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "computer-use sidecar request failed with status {status}: {}", + body.trim() + )), + }) + } + + async fn execute_action( + &self, + action: BrowserAction, + backend: ResolvedBackend, + ) -> anyhow::Result { + match backend { + ResolvedBackend::AgentBrowser => self.execute_agent_browser_action(action).await, + ResolvedBackend::RustNative => self.execute_rust_native_action(action).await, + ResolvedBackend::ComputerUse => anyhow::bail!( + "Internal error: computer_use backend must be handled before BrowserAction parsing" + ), + } + } + #[allow(clippy::unnecessary_wraps, clippy::unused_self)] fn to_result(&self, resp: AgentBrowserResponse) -> anyhow::Result { if resp.success { @@ -365,7 +854,6 @@ impl BrowserTool { } } -#[allow(clippy::too_many_lines)] #[async_trait] impl Tool for BrowserTool { fn name(&self) -> &str { @@ -373,10 +861,12 @@ impl Tool for BrowserTool { } fn description(&self) -> &str { - "Web browser automation using agent-browser. Supports navigation, clicking, \ - filling forms, taking screenshots, and getting accessibility snapshots with refs. \ - Use 'snapshot' to get interactive elements with refs (@e1, @e2), then use refs \ - for precise element interaction. Allowed domains only." + concat!( + "Web/browser automation with pluggable backends (agent-browser, rust-native, computer_use). ", + "Supports DOM actions plus optional OS-level actions (mouse_move, mouse_click, mouse_drag, ", + "key_type, key_press, screen_capture) through a computer-use sidecar. Use 'snapshot' to map ", + "interactive elements to refs (@e1, @e2). Enforces browser.allowed_domains for open actions." + ) } fn parameters_schema(&self) -> Value { @@ -387,8 +877,10 @@ impl Tool for BrowserTool { "type": "string", "enum": ["open", "snapshot", "click", "fill", "type", "get_text", "get_title", "get_url", "screenshot", "wait", "press", - "hover", "scroll", "is_visible", "close", "find"], - "description": "Browser action to perform" + "hover", "scroll", "is_visible", "close", "find", + "mouse_move", "mouse_click", "mouse_drag", "key_type", + "key_press", "screen_capture"], + "description": "Browser action to perform (OS-level actions require backend=computer_use)" }, "url": { "type": "string", @@ -410,6 +902,35 @@ impl Tool for BrowserTool { "type": "string", "description": "Key to press (Enter, Tab, Escape, etc.)" }, + "x": { + "type": "integer", + "description": "Screen X coordinate (computer_use: mouse_move/mouse_click)" + }, + "y": { + "type": "integer", + "description": "Screen Y coordinate (computer_use: mouse_move/mouse_click)" + }, + "from_x": { + "type": "integer", + "description": "Drag source X coordinate (computer_use: mouse_drag)" + }, + "from_y": { + "type": "integer", + "description": "Drag source Y coordinate (computer_use: mouse_drag)" + }, + "to_x": { + "type": "integer", + "description": "Drag target X coordinate (computer_use: mouse_drag)" + }, + "to_y": { + "type": "integer", + "description": "Drag target Y coordinate (computer_use: mouse_drag)" + }, + "button": { + "type": "string", + "enum": ["left", "right", "middle"], + "description": "Mouse button for computer_use mouse_click" + }, "direction": { "type": "string", "enum": ["up", "down", "left", "right"], @@ -480,17 +1001,16 @@ impl Tool for BrowserTool { }); } - // Check if agent-browser is available - if !Self::is_available().await { - return Ok(ToolResult { - success: false, - output: String::new(), - error: Some( - "agent-browser CLI not found. Install with: npm install -g agent-browser" - .into(), - ), - }); - } + let backend = match self.resolve_backend().await { + Ok(selected) => selected, + Err(error) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(error.to_string()), + }); + } + }; // Parse action from args let action_str = args @@ -498,168 +1018,921 @@ impl Tool for BrowserTool { .and_then(|v| v.as_str()) .ok_or_else(|| anyhow::anyhow!("Missing 'action' parameter"))?; - let action = match action_str { - "open" => { - let url = args - .get("url") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'url' for open action"))?; - BrowserAction::Open { url: url.into() } - } - "snapshot" => BrowserAction::Snapshot { - interactive_only: args - .get("interactive_only") - .and_then(serde_json::Value::as_bool) - .unwrap_or(true), // Default to interactive for AI - compact: args - .get("compact") - .and_then(serde_json::Value::as_bool) - .unwrap_or(true), - depth: args - .get("depth") - .and_then(serde_json::Value::as_u64) - .map(|d| u32::try_from(d).unwrap_or(u32::MAX)), - }, - "click" => { - let selector = args - .get("selector") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for click"))?; - BrowserAction::Click { - selector: selector.into(), - } - } - "fill" => { - let selector = args - .get("selector") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for fill"))?; - let value = args - .get("value") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'value' for fill"))?; - BrowserAction::Fill { - selector: selector.into(), - value: value.into(), - } - } - "type" => { - let selector = args - .get("selector") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for type"))?; - let text = args - .get("text") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'text' for type"))?; - BrowserAction::Type { - selector: selector.into(), - text: text.into(), - } - } - "get_text" => { - let selector = args - .get("selector") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for get_text"))?; - BrowserAction::GetText { - selector: selector.into(), - } - } - "get_title" => BrowserAction::GetTitle, - "get_url" => BrowserAction::GetUrl, - "screenshot" => BrowserAction::Screenshot { - path: args.get("path").and_then(|v| v.as_str()).map(String::from), - full_page: args - .get("full_page") - .and_then(serde_json::Value::as_bool) - .unwrap_or(false), - }, - "wait" => BrowserAction::Wait { - selector: args - .get("selector") - .and_then(|v| v.as_str()) - .map(String::from), - ms: args.get("ms").and_then(serde_json::Value::as_u64), - text: args.get("text").and_then(|v| v.as_str()).map(String::from), - }, - "press" => { - let key = args - .get("key") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'key' for press"))?; - BrowserAction::Press { key: key.into() } - } - "hover" => { - let selector = args - .get("selector") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for hover"))?; - BrowserAction::Hover { - selector: selector.into(), - } - } - "scroll" => { - let direction = args - .get("direction") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'direction' for scroll"))?; - BrowserAction::Scroll { - direction: direction.into(), - pixels: args - .get("pixels") - .and_then(serde_json::Value::as_u64) - .map(|p| u32::try_from(p).unwrap_or(u32::MAX)), - } - } - "is_visible" => { - let selector = args - .get("selector") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for is_visible"))?; - BrowserAction::IsVisible { - selector: selector.into(), - } - } - "close" => BrowserAction::Close, - "find" => { - let by = args - .get("by") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'by' for find"))?; - let value = args - .get("value") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'value' for find"))?; - let action = args - .get("find_action") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'find_action' for find"))?; - BrowserAction::Find { - by: by.into(), - value: value.into(), - action: action.into(), - fill_value: args - .get("fill_value") - .and_then(|v| v.as_str()) - .map(String::from), - } - } - _ => { + if !is_supported_browser_action(action_str) { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Unknown action: {action_str}")), + }); + } + + if backend == ResolvedBackend::ComputerUse { + return self.execute_computer_use_action(action_str, &args).await; + } + + if is_computer_use_only_action(action_str) { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(unavailable_action_for_backend_error(action_str, backend)), + }); + } + + let action = match parse_browser_action(action_str, &args) { + Ok(a) => a, + Err(e) => { return Ok(ToolResult { success: false, output: String::new(), - error: Some(format!("Unknown action: {action_str}")), + error: Some(e.to_string()), }); } }; - self.execute_action(action).await + self.execute_action(action, backend).await + } +} + +#[cfg(feature = "browser-native")] +mod native_backend { + use super::BrowserAction; + use anyhow::{Context, Result}; + use base64::Engine; + use fantoccini::actions::{InputSource, MouseActions, PointerAction}; + use fantoccini::key::Key; + use fantoccini::{Client, ClientBuilder, Locator}; + use serde_json::{json, Map, Value}; + use std::net::{TcpStream, ToSocketAddrs}; + use std::time::Duration; + + #[derive(Default)] + pub struct NativeBrowserState { + client: Option, + } + + impl NativeBrowserState { + pub fn is_available( + _headless: bool, + webdriver_url: &str, + _chrome_path: Option<&str>, + ) -> bool { + webdriver_endpoint_reachable(webdriver_url, Duration::from_millis(500)) + } + + #[allow(clippy::too_many_lines)] + pub async fn execute_action( + &mut self, + action: BrowserAction, + headless: bool, + webdriver_url: &str, + chrome_path: Option<&str>, + ) -> Result { + match action { + BrowserAction::Open { url } => { + self.ensure_session(headless, webdriver_url, chrome_path) + .await?; + let client = self.active_client()?; + client + .goto(&url) + .await + .with_context(|| format!("Failed to open URL: {url}"))?; + let current_url = client + .current_url() + .await + .context("Failed to read current URL after navigation")?; + + Ok(json!({ + "backend": "rust_native", + "action": "open", + "url": current_url.as_str(), + })) + } + BrowserAction::Snapshot { + interactive_only, + compact, + depth, + } => { + let client = self.active_client()?; + let snapshot = client + .execute( + &snapshot_script(interactive_only, compact, depth.map(i64::from)), + vec![], + ) + .await + .context("Failed to evaluate snapshot script")?; + + Ok(json!({ + "backend": "rust_native", + "action": "snapshot", + "data": snapshot, + })) + } + BrowserAction::Click { selector } => { + let client = self.active_client()?; + find_element(client, &selector).await?.click().await?; + + Ok(json!({ + "backend": "rust_native", + "action": "click", + "selector": selector, + })) + } + BrowserAction::Fill { selector, value } => { + let client = self.active_client()?; + let element = find_element(client, &selector).await?; + let _ = element.clear().await; + element.send_keys(&value).await?; + + Ok(json!({ + "backend": "rust_native", + "action": "fill", + "selector": selector, + })) + } + BrowserAction::Type { selector, text } => { + let client = self.active_client()?; + find_element(client, &selector) + .await? + .send_keys(&text) + .await?; + + Ok(json!({ + "backend": "rust_native", + "action": "type", + "selector": selector, + "typed": text.len(), + })) + } + BrowserAction::GetText { selector } => { + let client = self.active_client()?; + let text = find_element(client, &selector).await?.text().await?; + + Ok(json!({ + "backend": "rust_native", + "action": "get_text", + "selector": selector, + "text": text, + })) + } + BrowserAction::GetTitle => { + let client = self.active_client()?; + let title = client.title().await.context("Failed to read page title")?; + + Ok(json!({ + "backend": "rust_native", + "action": "get_title", + "title": title, + })) + } + BrowserAction::GetUrl => { + let client = self.active_client()?; + let url = client + .current_url() + .await + .context("Failed to read current URL")?; + + Ok(json!({ + "backend": "rust_native", + "action": "get_url", + "url": url.as_str(), + })) + } + BrowserAction::Screenshot { path, full_page } => { + let client = self.active_client()?; + let png = client + .screenshot() + .await + .context("Failed to capture screenshot")?; + let mut payload = json!({ + "backend": "rust_native", + "action": "screenshot", + "full_page": full_page, + "bytes": png.len(), + }); + + if let Some(path_str) = path { + std::fs::write(&path_str, &png) + .with_context(|| format!("Failed to write screenshot to {path_str}"))?; + payload["path"] = Value::String(path_str); + } else { + payload["png_base64"] = + Value::String(base64::engine::general_purpose::STANDARD.encode(&png)); + } + + Ok(payload) + } + BrowserAction::Wait { selector, ms, text } => { + let client = self.active_client()?; + if let Some(sel) = selector.as_ref() { + wait_for_selector(client, sel).await?; + Ok(json!({ + "backend": "rust_native", + "action": "wait", + "selector": sel, + })) + } else if let Some(duration_ms) = ms { + tokio::time::sleep(Duration::from_millis(duration_ms)).await; + Ok(json!({ + "backend": "rust_native", + "action": "wait", + "ms": duration_ms, + })) + } else if let Some(needle) = text.as_ref() { + let xpath = xpath_contains_text(needle); + client + .wait() + .for_element(Locator::XPath(&xpath)) + .await + .with_context(|| { + format!("Timed out waiting for text to appear: {needle}") + })?; + Ok(json!({ + "backend": "rust_native", + "action": "wait", + "text": needle, + })) + } else { + tokio::time::sleep(Duration::from_millis(250)).await; + Ok(json!({ + "backend": "rust_native", + "action": "wait", + "ms": 250, + })) + } + } + BrowserAction::Press { key } => { + let client = self.active_client()?; + let key_input = webdriver_key(&key); + match client.active_element().await { + Ok(element) => { + element.send_keys(&key_input).await?; + } + Err(_) => { + find_element(client, "body") + .await? + .send_keys(&key_input) + .await?; + } + } + + Ok(json!({ + "backend": "rust_native", + "action": "press", + "key": key, + })) + } + BrowserAction::Hover { selector } => { + let client = self.active_client()?; + let element = find_element(client, &selector).await?; + hover_element(client, &element).await?; + + Ok(json!({ + "backend": "rust_native", + "action": "hover", + "selector": selector, + })) + } + BrowserAction::Scroll { direction, pixels } => { + let client = self.active_client()?; + let amount = i64::from(pixels.unwrap_or(600)); + let (dx, dy) = match direction.as_str() { + "up" => (0, -amount), + "down" => (0, amount), + "left" => (-amount, 0), + "right" => (amount, 0), + _ => anyhow::bail!( + "Unsupported scroll direction '{direction}'. Use up/down/left/right" + ), + }; + + let position = client + .execute( + "window.scrollBy(arguments[0], arguments[1]); return { x: window.scrollX, y: window.scrollY };", + vec![json!(dx), json!(dy)], + ) + .await + .context("Failed to execute scroll script")?; + + Ok(json!({ + "backend": "rust_native", + "action": "scroll", + "position": position, + })) + } + BrowserAction::IsVisible { selector } => { + let client = self.active_client()?; + let visible = find_element(client, &selector) + .await? + .is_displayed() + .await?; + + Ok(json!({ + "backend": "rust_native", + "action": "is_visible", + "selector": selector, + "visible": visible, + })) + } + BrowserAction::Close => { + if let Some(client) = self.client.take() { + let _ = client.close().await; + } + + Ok(json!({ + "backend": "rust_native", + "action": "close", + "closed": true, + })) + } + BrowserAction::Find { + by, + value, + action, + fill_value, + } => { + let client = self.active_client()?; + let selector = selector_for_find(&by, &value); + let element = find_element(client, &selector).await?; + + let payload = match action.as_str() { + "click" => { + element.click().await?; + json!({"result": "clicked"}) + } + "fill" => { + let fill = fill_value.ok_or_else(|| { + anyhow::anyhow!("find_action='fill' requires fill_value") + })?; + let _ = element.clear().await; + element.send_keys(&fill).await?; + json!({"result": "filled", "typed": fill.len()}) + } + "text" => { + let text = element.text().await?; + json!({"result": "text", "text": text}) + } + "hover" => { + hover_element(client, &element).await?; + json!({"result": "hovered"}) + } + "check" => { + let checked_before = element_checked(&element).await?; + if !checked_before { + element.click().await?; + } + let checked_after = element_checked(&element).await?; + json!({ + "result": "checked", + "checked_before": checked_before, + "checked_after": checked_after, + }) + } + _ => anyhow::bail!( + "Unsupported find_action '{action}'. Use click/fill/text/hover/check" + ), + }; + + Ok(json!({ + "backend": "rust_native", + "action": "find", + "by": by, + "value": value, + "selector": selector, + "data": payload, + })) + } + } + } + + async fn ensure_session( + &mut self, + headless: bool, + webdriver_url: &str, + chrome_path: Option<&str>, + ) -> Result<()> { + if self.client.is_some() { + return Ok(()); + } + + let mut capabilities: Map = Map::new(); + let mut chrome_options: Map = Map::new(); + let mut args: Vec = Vec::new(); + + if headless { + args.push(Value::String("--headless=new".to_string())); + args.push(Value::String("--disable-gpu".to_string())); + } + + if !args.is_empty() { + chrome_options.insert("args".to_string(), Value::Array(args)); + } + + if let Some(path) = chrome_path { + let trimmed = path.trim(); + if !trimmed.is_empty() { + chrome_options.insert("binary".to_string(), Value::String(trimmed.to_string())); + } + } + + if !chrome_options.is_empty() { + capabilities.insert( + "goog:chromeOptions".to_string(), + Value::Object(chrome_options), + ); + } + + let mut builder = + ClientBuilder::rustls().context("Failed to initialize rustls connector")?; + if !capabilities.is_empty() { + builder.capabilities(capabilities); + } + + let client = builder + .connect(webdriver_url) + .await + .with_context(|| { + format!( + "Failed to connect to WebDriver at {webdriver_url}. Start chromedriver/geckodriver first" + ) + })?; + + self.client = Some(client); + Ok(()) + } + + fn active_client(&self) -> Result<&Client> { + self.client.as_ref().ok_or_else(|| { + anyhow::anyhow!("No active native browser session. Run browser action='open' first") + }) + } + } + + fn webdriver_endpoint_reachable(webdriver_url: &str, timeout: Duration) -> bool { + let parsed = match reqwest::Url::parse(webdriver_url) { + Ok(url) => url, + Err(_) => return false, + }; + + if parsed.scheme() != "http" && parsed.scheme() != "https" { + return false; + } + + let host = match parsed.host_str() { + Some(h) if !h.is_empty() => h, + _ => return false, + }; + + let port = parsed.port_or_known_default().unwrap_or(4444); + let mut addrs = match (host, port).to_socket_addrs() { + Ok(iter) => iter, + Err(_) => return false, + }; + + let addr = match addrs.next() { + Some(a) => a, + None => return false, + }; + + TcpStream::connect_timeout(&addr, timeout).is_ok() + } + + fn selector_for_find(by: &str, value: &str) -> String { + let escaped = css_attr_escape(value); + match by { + "role" => format!(r#"[role=\"{escaped}\"]"#), + "label" => format!("label={value}"), + "placeholder" => format!(r#"[placeholder=\"{escaped}\"]"#), + "testid" => format!(r#"[data-testid=\"{escaped}\"]"#), + _ => format!("text={value}"), + } + } + + async fn wait_for_selector(client: &Client, selector: &str) -> Result<()> { + match parse_selector(selector) { + SelectorKind::Css(css) => { + client + .wait() + .for_element(Locator::Css(&css)) + .await + .with_context(|| format!("Timed out waiting for selector '{selector}'"))?; + } + SelectorKind::XPath(xpath) => { + client + .wait() + .for_element(Locator::XPath(&xpath)) + .await + .with_context(|| format!("Timed out waiting for selector '{selector}'"))?; + } + } + Ok(()) + } + + async fn find_element( + client: &Client, + selector: &str, + ) -> Result { + let element = match parse_selector(selector) { + SelectorKind::Css(css) => client + .find(Locator::Css(&css)) + .await + .with_context(|| format!("Failed to find element by CSS '{css}'"))?, + SelectorKind::XPath(xpath) => client + .find(Locator::XPath(&xpath)) + .await + .with_context(|| format!("Failed to find element by XPath '{xpath}'"))?, + }; + Ok(element) + } + + async fn hover_element(client: &Client, element: &fantoccini::elements::Element) -> Result<()> { + let actions = MouseActions::new("mouse".to_string()).then(PointerAction::MoveToElement { + element: element.clone(), + duration: Some(Duration::from_millis(150)), + x: 0.0, + y: 0.0, + }); + + client + .perform_actions(actions) + .await + .context("Failed to perform hover action")?; + let _ = client.release_actions().await; + Ok(()) + } + + async fn element_checked(element: &fantoccini::elements::Element) -> Result { + let checked = element + .prop("checked") + .await + .context("Failed to read checkbox checked property")? + .unwrap_or_default() + .to_ascii_lowercase(); + Ok(matches!(checked.as_str(), "true" | "checked" | "1")) + } + + enum SelectorKind { + Css(String), + XPath(String), + } + + fn parse_selector(selector: &str) -> SelectorKind { + let trimmed = selector.trim(); + if let Some(text_query) = trimmed.strip_prefix("text=") { + return SelectorKind::XPath(xpath_contains_text(text_query)); + } + + if let Some(label_query) = trimmed.strip_prefix("label=") { + let literal = xpath_literal(label_query); + return SelectorKind::XPath(format!( + "(//label[contains(normalize-space(.), {literal})]/following::*[self::input or self::textarea or self::select][1] | //*[@aria-label and contains(normalize-space(@aria-label), {literal})] | //label[contains(normalize-space(.), {literal})])" + )); + } + + if trimmed.starts_with('@') { + let escaped = css_attr_escape(trimmed); + return SelectorKind::Css(format!(r#"[data-zc-ref=\"{escaped}\"]"#)); + } + + SelectorKind::Css(trimmed.to_string()) + } + + fn css_attr_escape(input: &str) -> String { + input + .replace('\\', "\\\\") + .replace('"', "\\\"") + .replace('\n', " ") + } + + fn xpath_contains_text(text: &str) -> String { + format!("//*[contains(normalize-space(.), {})]", xpath_literal(text)) + } + + fn xpath_literal(input: &str) -> String { + if !input.contains('"') { + return format!("\"{input}\""); + } + if !input.contains('\'') { + return format!("'{input}'"); + } + + let segments: Vec<&str> = input.split('"').collect(); + let mut parts: Vec = Vec::new(); + for (index, part) in segments.iter().enumerate() { + if !part.is_empty() { + parts.push(format!("\"{part}\"")); + } + if index + 1 < segments.len() { + parts.push("'\"'".to_string()); + } + } + + if parts.is_empty() { + "\"\"".to_string() + } else { + format!("concat({})", parts.join(",")) + } + } + + fn webdriver_key(key: &str) -> String { + match key.trim().to_ascii_lowercase().as_str() { + "enter" => Key::Enter.to_string(), + "return" => Key::Return.to_string(), + "tab" => Key::Tab.to_string(), + "escape" | "esc" => Key::Escape.to_string(), + "backspace" => Key::Backspace.to_string(), + "delete" => Key::Delete.to_string(), + "space" => Key::Space.to_string(), + "arrowup" | "up" => Key::Up.to_string(), + "arrowdown" | "down" => Key::Down.to_string(), + "arrowleft" | "left" => Key::Left.to_string(), + "arrowright" | "right" => Key::Right.to_string(), + "home" => Key::Home.to_string(), + "end" => Key::End.to_string(), + "pageup" => Key::PageUp.to_string(), + "pagedown" => Key::PageDown.to_string(), + other => other.to_string(), + } + } + + fn snapshot_script(interactive_only: bool, compact: bool, depth: Option) -> String { + let depth_literal = depth + .map(|level| level.to_string()) + .unwrap_or_else(|| "null".to_string()); + + format!( + r#"(() => {{ + const interactiveOnly = {interactive_only}; + const compact = {compact}; + const maxDepth = {depth_literal}; + const nodes = []; + const root = document.body || document.documentElement; + let counter = 0; + + const isVisible = (el) => {{ + const style = window.getComputedStyle(el); + if (style.display === 'none' || style.visibility === 'hidden' || Number(style.opacity || 1) === 0) {{ + return false; + }} + const rect = el.getBoundingClientRect(); + return rect.width > 0 && rect.height > 0; + }}; + + const isInteractive = (el) => {{ + if (el.matches('a,button,input,select,textarea,summary,[role],*[tabindex]')) return true; + return typeof el.onclick === 'function'; + }}; + + const describe = (el, depth) => {{ + const interactive = isInteractive(el); + const text = (el.innerText || el.textContent || '').trim().replace(/\s+/g, ' ').slice(0, 140); + if (interactiveOnly && !interactive) return; + if (compact && !interactive && !text) return; + + const ref = '@e' + (++counter); + el.setAttribute('data-zc-ref', ref); + nodes.push({{ + ref, + depth, + tag: el.tagName.toLowerCase(), + id: el.id || null, + role: el.getAttribute('role'), + text, + interactive, + }}); + }}; + + const walk = (el, depth) => {{ + if (!(el instanceof Element)) return; + if (maxDepth !== null && depth > maxDepth) return; + if (isVisible(el)) {{ + describe(el, depth); + }} + for (const child of el.children) {{ + walk(child, depth + 1); + if (nodes.length >= 400) return; + }} + }}; + + if (root) walk(root, 0); + + return {{ + title: document.title, + url: window.location.href, + count: nodes.length, + nodes, + }}; +}})();"# + ) + } +} + +// ── Action parsing ────────────────────────────────────────────── + +/// Parse a JSON `args` object into a typed `BrowserAction`. +fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result { + match action_str { + "open" => { + let url = args + .get("url") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'url' for open action"))?; + Ok(BrowserAction::Open { url: url.into() }) + } + "snapshot" => Ok(BrowserAction::Snapshot { + interactive_only: args + .get("interactive_only") + .and_then(serde_json::Value::as_bool) + .unwrap_or(true), + compact: args + .get("compact") + .and_then(serde_json::Value::as_bool) + .unwrap_or(true), + depth: args + .get("depth") + .and_then(serde_json::Value::as_u64) + .map(|d| u32::try_from(d).unwrap_or(u32::MAX)), + }), + "click" => { + let selector = args + .get("selector") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for click"))?; + Ok(BrowserAction::Click { + selector: selector.into(), + }) + } + "fill" => { + let selector = args + .get("selector") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for fill"))?; + let value = args + .get("value") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'value' for fill"))?; + Ok(BrowserAction::Fill { + selector: selector.into(), + value: value.into(), + }) + } + "type" => { + let selector = args + .get("selector") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for type"))?; + let text = args + .get("text") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'text' for type"))?; + Ok(BrowserAction::Type { + selector: selector.into(), + text: text.into(), + }) + } + "get_text" => { + let selector = args + .get("selector") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for get_text"))?; + Ok(BrowserAction::GetText { + selector: selector.into(), + }) + } + "get_title" => Ok(BrowserAction::GetTitle), + "get_url" => Ok(BrowserAction::GetUrl), + "screenshot" => Ok(BrowserAction::Screenshot { + path: args.get("path").and_then(|v| v.as_str()).map(String::from), + full_page: args + .get("full_page") + .and_then(serde_json::Value::as_bool) + .unwrap_or(false), + }), + "wait" => Ok(BrowserAction::Wait { + selector: args + .get("selector") + .and_then(|v| v.as_str()) + .map(String::from), + ms: args.get("ms").and_then(serde_json::Value::as_u64), + text: args.get("text").and_then(|v| v.as_str()).map(String::from), + }), + "press" => { + let key = args + .get("key") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'key' for press"))?; + Ok(BrowserAction::Press { key: key.into() }) + } + "hover" => { + let selector = args + .get("selector") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for hover"))?; + Ok(BrowserAction::Hover { + selector: selector.into(), + }) + } + "scroll" => { + let direction = args + .get("direction") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'direction' for scroll"))?; + Ok(BrowserAction::Scroll { + direction: direction.into(), + pixels: args + .get("pixels") + .and_then(serde_json::Value::as_u64) + .map(|p| u32::try_from(p).unwrap_or(u32::MAX)), + }) + } + "is_visible" => { + let selector = args + .get("selector") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for is_visible"))?; + Ok(BrowserAction::IsVisible { + selector: selector.into(), + }) + } + "close" => Ok(BrowserAction::Close), + "find" => { + let by = args + .get("by") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'by' for find"))?; + let value = args + .get("value") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'value' for find"))?; + let action = args + .get("find_action") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'find_action' for find"))?; + Ok(BrowserAction::Find { + by: by.into(), + value: value.into(), + action: action.into(), + fill_value: args + .get("fill_value") + .and_then(|v| v.as_str()) + .map(String::from), + }) + } + other => anyhow::bail!("Unsupported browser action: {other}"), } } // ── Helper functions ───────────────────────────────────────────── +fn is_supported_browser_action(action: &str) -> bool { + matches!( + action, + "open" + | "snapshot" + | "click" + | "fill" + | "type" + | "get_text" + | "get_title" + | "get_url" + | "screenshot" + | "wait" + | "press" + | "hover" + | "scroll" + | "is_visible" + | "close" + | "find" + | "mouse_move" + | "mouse_click" + | "mouse_drag" + | "key_type" + | "key_press" + | "screen_capture" + ) +} + +fn is_computer_use_only_action(action: &str) -> bool { + matches!( + action, + "mouse_move" | "mouse_click" | "mouse_drag" | "key_type" | "key_press" | "screen_capture" + ) +} + +fn backend_name(backend: ResolvedBackend) -> &'static str { + match backend { + ResolvedBackend::AgentBrowser => "agent_browser", + ResolvedBackend::RustNative => "rust_native", + ResolvedBackend::ComputerUse => "computer_use", + } +} + +fn unavailable_action_for_backend_error(action: &str, backend: ResolvedBackend) -> String { + format!( + "Action '{action}' is unavailable for backend '{}'", + backend_name(backend) + ) +} + fn normalize_domains(domains: Vec) -> Vec { domains .into_iter() @@ -668,6 +1941,30 @@ fn normalize_domains(domains: Vec) -> Vec { .collect() } +fn endpoint_reachable(endpoint: &reqwest::Url, timeout: Duration) -> bool { + let host = match endpoint.host_str() { + Some(host) if !host.is_empty() => host, + _ => return false, + }; + + let port = match endpoint.port_or_known_default() { + Some(port) => port, + None => return false, + }; + + let mut addrs = match (host, port).to_socket_addrs() { + Ok(addrs) => addrs, + Err(_) => return false, + }; + + let addr = match addrs.next() { + Some(addr) => addr, + None => return false, + }; + + std::net::TcpStream::connect_timeout(&addr, timeout).is_ok() +} + fn extract_host(url_str: &str) -> anyhow::Result { // Simple host extraction without url crate let url = url_str.trim(); @@ -677,14 +1974,16 @@ fn extract_host(url_str: &str) -> anyhow::Result { .or_else(|| url.strip_prefix("file://")) .unwrap_or(url); - // Extract host (before first / or :) - let host = without_scheme - .split('/') - .next() - .unwrap_or(without_scheme) - .split(':') - .next() - .unwrap_or(without_scheme); + // Extract host — handle bracketed IPv6 addresses like [::1]:8080 + let authority = without_scheme.split('/').next().unwrap_or(without_scheme); + + let host = if authority.starts_with('[') { + // IPv6: take everything up to and including the closing ']' + authority.find(']').map_or(authority, |i| &authority[..=i]) + } else { + // IPv4 or hostname: take everything before the port separator + authority.split(':').next().unwrap_or(authority) + }; if host.is_empty() { anyhow::bail!("Invalid URL: no host"); @@ -694,35 +1993,69 @@ fn extract_host(url_str: &str) -> anyhow::Result { } fn is_private_host(host: &str) -> bool { - let private_patterns = [ - "localhost", - "127.", - "10.", - "192.168.", - "172.16.", - "172.17.", - "172.18.", - "172.19.", - "172.20.", - "172.21.", - "172.22.", - "172.23.", - "172.24.", - "172.25.", - "172.26.", - "172.27.", - "172.28.", - "172.29.", - "172.30.", - "172.31.", - "0.0.0.0", - "::1", - "[::1]", - ]; + // Strip brackets from IPv6 addresses like [::1] + let bare = host + .strip_prefix('[') + .and_then(|h| h.strip_suffix(']')) + .unwrap_or(host); - private_patterns - .iter() - .any(|p| host.starts_with(p) || host == *p) + if bare == "localhost" || bare.ends_with(".localhost") { + return true; + } + + // .local TLD (mDNS) + if bare + .rsplit('.') + .next() + .is_some_and(|label| label == "local") + { + return true; + } + + // Parse as IP address to catch all representations (decimal, hex, octal, mapped) + if let Ok(ip) = bare.parse::() { + return match ip { + std::net::IpAddr::V4(v4) => is_non_global_v4(v4), + std::net::IpAddr::V6(v6) => is_non_global_v6(v6), + }; + } + + false +} + +/// Returns `true` for any IPv4 address that is not globally routable. +fn is_non_global_v4(v4: std::net::Ipv4Addr) -> bool { + let [a, b, _, _] = v4.octets(); + v4.is_loopback() + || v4.is_private() + || v4.is_link_local() + || v4.is_unspecified() + || v4.is_broadcast() + || v4.is_multicast() + // Shared address space (100.64/10) + || (a == 100 && (64..=127).contains(&b)) + // Reserved (240.0.0.0/4) + || a >= 240 + // Documentation (192.0.2.0/24, 198.51.100.0/24, 203.0.113.0/24) + || (a == 192 && b == 0) + || (a == 198 && b == 51) + || (a == 203 && b == 0) + // Benchmarking (198.18.0.0/15) + || (a == 198 && (18..=19).contains(&b)) +} + +/// Returns `true` for any IPv6 address that is not globally routable. +fn is_non_global_v6(v6: std::net::Ipv6Addr) -> bool { + let segs = v6.segments(); + v6.is_loopback() + || v6.is_unspecified() + || v6.is_multicast() + // Unique-local (fc00::/7) — IPv6 equivalent of RFC 1918 + || (segs[0] & 0xfe00) == 0xfc00 + // Link-local (fe80::/10) + || (segs[0] & 0xffc0) == 0xfe80 + // IPv4-mapped addresses + || v6.to_ipv4_mapped().is_some_and(is_non_global_v4) } fn host_matches_allowlist(host: &str, allowed: &[String]) -> bool { @@ -768,9 +2101,24 @@ mod tests { ); } + #[test] + fn extract_host_handles_ipv6() { + // IPv6 with brackets (required for URLs with ports) + assert_eq!(extract_host("https://[::1]/path").unwrap(), "[::1]"); + // IPv6 with brackets and port + assert_eq!( + extract_host("https://[2001:db8::1]:8080/path").unwrap(), + "[2001:db8::1]" + ); + // IPv6 with brackets, trailing slash + assert_eq!(extract_host("https://[fe80::1]/").unwrap(), "[fe80::1]"); + } + #[test] fn is_private_host_detects_local() { assert!(is_private_host("localhost")); + assert!(is_private_host("app.localhost")); + assert!(is_private_host("printer.local")); assert!(is_private_host("127.0.0.1")); assert!(is_private_host("192.168.1.1")); assert!(is_private_host("10.0.0.1")); @@ -778,6 +2126,55 @@ mod tests { assert!(!is_private_host("google.com")); } + #[test] + fn is_private_host_blocks_multicast_and_reserved() { + assert!(is_private_host("224.0.0.1")); // multicast + assert!(is_private_host("255.255.255.255")); // broadcast + assert!(is_private_host("100.64.0.1")); // shared address space + assert!(is_private_host("240.0.0.1")); // reserved + assert!(is_private_host("192.0.2.1")); // documentation + assert!(is_private_host("198.51.100.1")); // documentation + assert!(is_private_host("203.0.113.1")); // documentation + assert!(is_private_host("198.18.0.1")); // benchmarking + } + + #[test] + fn is_private_host_catches_ipv6() { + assert!(is_private_host("::1")); + assert!(is_private_host("[::1]")); + assert!(is_private_host("0.0.0.0")); + } + + #[test] + fn is_private_host_catches_mapped_ipv4() { + // IPv4-mapped IPv6 addresses + assert!(is_private_host("::ffff:127.0.0.1")); + assert!(is_private_host("::ffff:10.0.0.1")); + assert!(is_private_host("::ffff:192.168.1.1")); + } + + #[test] + fn is_private_host_catches_ipv6_private_ranges() { + // Unique-local (fc00::/7) + assert!(is_private_host("fd00::1")); + assert!(is_private_host("fc00::1")); + // Link-local (fe80::/10) + assert!(is_private_host("fe80::1")); + // Public IPv6 should pass + assert!(!is_private_host("2001:db8::1")); + } + + #[test] + fn validate_url_blocks_ipv6_ssrf() { + let security = Arc::new(SecurityPolicy::default()); + let tool = BrowserTool::new(security, vec!["*".into()], None); + assert!(tool.validate_url("https://[::1]/").is_err()); + assert!(tool.validate_url("https://[::ffff:127.0.0.1]/").is_err()); + assert!(tool + .validate_url("https://[::ffff:10.0.0.1]:8080/") + .is_err()); + } + #[test] fn host_matches_allowlist_exact() { let allowed = vec!["example.com".into()]; @@ -801,6 +2198,146 @@ mod tests { assert!(host_matches_allowlist("example.org", &allowed)); } + #[test] + fn browser_backend_parser_accepts_supported_values() { + assert_eq!( + BrowserBackendKind::parse("agent_browser").unwrap(), + BrowserBackendKind::AgentBrowser + ); + assert_eq!( + BrowserBackendKind::parse("rust-native").unwrap(), + BrowserBackendKind::RustNative + ); + assert_eq!( + BrowserBackendKind::parse("computer_use").unwrap(), + BrowserBackendKind::ComputerUse + ); + assert_eq!( + BrowserBackendKind::parse("auto").unwrap(), + BrowserBackendKind::Auto + ); + } + + #[test] + fn browser_backend_parser_rejects_unknown_values() { + assert!(BrowserBackendKind::parse("playwright").is_err()); + } + + #[test] + fn browser_tool_default_backend_is_agent_browser() { + let security = Arc::new(SecurityPolicy::default()); + let tool = BrowserTool::new(security, vec!["example.com".into()], None); + assert_eq!( + tool.configured_backend().unwrap(), + BrowserBackendKind::AgentBrowser + ); + } + + #[test] + fn browser_tool_accepts_auto_backend_config() { + let security = Arc::new(SecurityPolicy::default()); + let tool = BrowserTool::new_with_backend( + security, + vec!["example.com".into()], + None, + "auto".into(), + true, + "http://127.0.0.1:9515".into(), + None, + ComputerUseConfig::default(), + ); + assert_eq!(tool.configured_backend().unwrap(), BrowserBackendKind::Auto); + } + + #[test] + fn browser_tool_accepts_computer_use_backend_config() { + let security = Arc::new(SecurityPolicy::default()); + let tool = BrowserTool::new_with_backend( + security, + vec!["example.com".into()], + None, + "computer_use".into(), + true, + "http://127.0.0.1:9515".into(), + None, + ComputerUseConfig::default(), + ); + assert_eq!( + tool.configured_backend().unwrap(), + BrowserBackendKind::ComputerUse + ); + } + + #[test] + fn computer_use_endpoint_rejects_public_http_by_default() { + let security = Arc::new(SecurityPolicy::default()); + let tool = BrowserTool::new_with_backend( + security, + vec!["example.com".into()], + None, + "computer_use".into(), + true, + "http://127.0.0.1:9515".into(), + None, + ComputerUseConfig { + endpoint: "http://computer-use.example.com/v1/actions".into(), + ..ComputerUseConfig::default() + }, + ); + + assert!(tool.computer_use_endpoint_url().is_err()); + } + + #[test] + fn computer_use_endpoint_requires_https_for_public_remote() { + let security = Arc::new(SecurityPolicy::default()); + let tool = BrowserTool::new_with_backend( + security, + vec!["example.com".into()], + None, + "computer_use".into(), + true, + "http://127.0.0.1:9515".into(), + None, + ComputerUseConfig { + endpoint: "https://computer-use.example.com/v1/actions".into(), + allow_remote_endpoint: true, + ..ComputerUseConfig::default() + }, + ); + + assert!(tool.computer_use_endpoint_url().is_ok()); + } + + #[test] + fn computer_use_coordinate_validation_applies_limits() { + let security = Arc::new(SecurityPolicy::default()); + let tool = BrowserTool::new_with_backend( + security, + vec!["example.com".into()], + None, + "computer_use".into(), + true, + "http://127.0.0.1:9515".into(), + None, + ComputerUseConfig { + max_coordinate_x: Some(100), + max_coordinate_y: Some(100), + ..ComputerUseConfig::default() + }, + ); + + assert!(tool + .validate_coordinate("x", 50, tool.computer_use.max_coordinate_x) + .is_ok()); + assert!(tool + .validate_coordinate("x", 101, tool.computer_use.max_coordinate_x) + .is_err()); + assert!(tool + .validate_coordinate("y", -1, tool.computer_use.max_coordinate_y) + .is_err()); + } + #[test] fn browser_tool_name() { let security = Arc::new(SecurityPolicy::default()); @@ -827,8 +2364,8 @@ mod tests { // Invalid - not https assert!(tool.validate_url("ftp://example.com").is_err()); - // File URLs allowed - assert!(tool.validate_url("file:///tmp/test.html").is_ok()); + // file:// URLs blocked (local file exfiltration risk) + assert!(tool.validate_url("file:///tmp/test.html").is_err()); } #[test] @@ -837,4 +2374,28 @@ mod tests { let tool = BrowserTool::new(security, vec![], None); assert!(tool.validate_url("https://example.com").is_err()); } + + #[test] + fn computer_use_only_action_detection_is_correct() { + assert!(is_computer_use_only_action("mouse_move")); + assert!(is_computer_use_only_action("mouse_click")); + assert!(is_computer_use_only_action("mouse_drag")); + assert!(is_computer_use_only_action("key_type")); + assert!(is_computer_use_only_action("key_press")); + assert!(is_computer_use_only_action("screen_capture")); + assert!(!is_computer_use_only_action("open")); + assert!(!is_computer_use_only_action("snapshot")); + } + + #[test] + fn unavailable_action_error_preserves_backend_context() { + assert_eq!( + unavailable_action_for_backend_error("mouse_move", ResolvedBackend::AgentBrowser), + "Action 'mouse_move' is unavailable for backend 'agent_browser'" + ); + assert_eq!( + unavailable_action_for_backend_error("mouse_move", ResolvedBackend::RustNative), + "Action 'mouse_move' is unavailable for backend 'rust_native'" + ); + } } diff --git a/src/tools/composio.rs b/src/tools/composio.rs index 4602d5d..65f128e 100644 --- a/src/tools/composio.rs +++ b/src/tools/composio.rs @@ -7,23 +7,27 @@ // The Composio API key is stored in the encrypted secret store. use super::traits::{Tool, ToolResult}; +use anyhow::Context; use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json::json; -const COMPOSIO_API_BASE: &str = "https://backend.composio.dev/api/v2"; +const COMPOSIO_API_BASE_V2: &str = "https://backend.composio.dev/api/v2"; +const COMPOSIO_API_BASE_V3: &str = "https://backend.composio.dev/api/v3"; /// A tool that proxies actions to the Composio managed tool platform. pub struct ComposioTool { api_key: String, + default_entity_id: String, client: Client, } impl ComposioTool { - pub fn new(api_key: &str) -> Self { + pub fn new(api_key: &str, default_entity_id: Option<&str>) -> Self { Self { api_key: api_key.to_string(), + default_entity_id: normalize_entity_id(default_entity_id.unwrap_or("default")), client: Client::builder() .timeout(std::time::Duration::from_secs(60)) .connect_timeout(std::time::Duration::from_secs(10)) @@ -33,11 +37,50 @@ impl ComposioTool { } /// List available Composio apps/actions for the authenticated user. + /// + /// Uses v3 endpoint first and falls back to v2 for compatibility. pub async fn list_actions( &self, app_name: Option<&str>, ) -> anyhow::Result> { - let mut url = format!("{COMPOSIO_API_BASE}/actions"); + match self.list_actions_v3(app_name).await { + Ok(items) => Ok(items), + Err(v3_err) => { + let v2 = self.list_actions_v2(app_name).await; + match v2 { + Ok(items) => Ok(items), + Err(v2_err) => anyhow::bail!( + "Composio action listing failed on v3 ({v3_err}) and v2 fallback ({v2_err})" + ), + } + } + } + } + + async fn list_actions_v3(&self, app_name: Option<&str>) -> anyhow::Result> { + let url = format!("{COMPOSIO_API_BASE_V3}/tools"); + let mut req = self.client.get(&url).header("x-api-key", &self.api_key); + + req = req.query(&[("limit", "200")]); + if let Some(app) = app_name.map(str::trim).filter(|app| !app.is_empty()) { + req = req.query(&[("toolkits", app), ("toolkit_slug", app)]); + } + + let resp = req.send().await?; + if !resp.status().is_success() { + let err = response_error(resp).await; + anyhow::bail!("Composio v3 API error: {err}"); + } + + let body: ComposioToolsResponse = resp + .json() + .await + .context("Failed to decode Composio v3 tools response")?; + Ok(map_v3_tools_to_actions(body.items)) + } + + async fn list_actions_v2(&self, app_name: Option<&str>) -> anyhow::Result> { + let mut url = format!("{COMPOSIO_API_BASE_V2}/actions"); if let Some(app) = app_name { url = format!("{url}?appNames={app}"); } @@ -50,22 +93,110 @@ impl ComposioTool { .await?; if !resp.status().is_success() { - let err = resp.text().await.unwrap_or_default(); - anyhow::bail!("Composio API error: {err}"); + let err = response_error(resp).await; + anyhow::bail!("Composio v2 API error: {err}"); } - let body: ComposioActionsResponse = resp.json().await?; + let body: ComposioActionsResponse = resp + .json() + .await + .context("Failed to decode Composio v2 actions response")?; Ok(body.items) } - /// Execute a Composio action by name with given parameters. + /// Execute a Composio action/tool with given parameters. + /// + /// Uses v3 endpoint first and falls back to v2 for compatibility. pub async fn execute_action( &self, action_name: &str, params: serde_json::Value, entity_id: Option<&str>, + connected_account_ref: Option<&str>, ) -> anyhow::Result { - let url = format!("{COMPOSIO_API_BASE}/actions/{action_name}/execute"); + let tool_slug = normalize_tool_slug(action_name); + + match self + .execute_action_v3(&tool_slug, params.clone(), entity_id, connected_account_ref) + .await + { + Ok(result) => Ok(result), + Err(v3_err) => match self.execute_action_v2(action_name, params, entity_id).await { + Ok(result) => Ok(result), + Err(v2_err) => anyhow::bail!( + "Composio execute failed on v3 ({v3_err}) and v2 fallback ({v2_err})" + ), + }, + } + } + + fn build_execute_action_v3_request( + tool_slug: &str, + params: serde_json::Value, + entity_id: Option<&str>, + connected_account_ref: Option<&str>, + ) -> (String, serde_json::Value) { + let url = format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute"); + let account_ref = connected_account_ref.and_then(|candidate| { + let trimmed_candidate = candidate.trim(); + (!trimmed_candidate.is_empty()).then_some(trimmed_candidate) + }); + + let mut body = json!({ + "arguments": params, + }); + + if let Some(entity) = entity_id { + body["user_id"] = json!(entity); + } + if let Some(account_ref) = account_ref { + body["connected_account_id"] = json!(account_ref); + } + + (url, body) + } + + async fn execute_action_v3( + &self, + tool_slug: &str, + params: serde_json::Value, + entity_id: Option<&str>, + connected_account_ref: Option<&str>, + ) -> anyhow::Result { + let (url, body) = Self::build_execute_action_v3_request( + tool_slug, + params, + entity_id, + connected_account_ref, + ); + + let resp = self + .client + .post(&url) + .header("x-api-key", &self.api_key) + .json(&body) + .send() + .await?; + + if !resp.status().is_success() { + let err = response_error(resp).await; + anyhow::bail!("Composio v3 action execution failed: {err}"); + } + + let result: serde_json::Value = resp + .json() + .await + .context("Failed to decode Composio v3 execute response")?; + Ok(result) + } + + async fn execute_action_v2( + &self, + action_name: &str, + params: serde_json::Value, + entity_id: Option<&str>, + ) -> anyhow::Result { + let url = format!("{COMPOSIO_API_BASE_V2}/actions/{action_name}/execute"); let mut body = json!({ "input": params, @@ -84,21 +215,96 @@ impl ComposioTool { .await?; if !resp.status().is_success() { - let err = resp.text().await.unwrap_or_default(); - anyhow::bail!("Composio action execution failed: {err}"); + let err = response_error(resp).await; + anyhow::bail!("Composio v2 action execution failed: {err}"); } - let result: serde_json::Value = resp.json().await?; + let result: serde_json::Value = resp + .json() + .await + .context("Failed to decode Composio v2 execute response")?; Ok(result) } - /// Get the OAuth connection URL for a specific app. + /// Get the OAuth connection URL for a specific app/toolkit or auth config. + /// + /// Uses v3 endpoint first and falls back to v2 for compatibility. pub async fn get_connection_url( + &self, + app_name: Option<&str>, + auth_config_id: Option<&str>, + entity_id: &str, + ) -> anyhow::Result { + let v3 = self + .get_connection_url_v3(app_name, auth_config_id, entity_id) + .await; + match v3 { + Ok(url) => Ok(url), + Err(v3_err) => { + let app = app_name.ok_or_else(|| { + anyhow::anyhow!( + "Composio v3 connect failed ({v3_err}) and v2 fallback requires 'app'" + ) + })?; + match self.get_connection_url_v2(app, entity_id).await { + Ok(url) => Ok(url), + Err(v2_err) => anyhow::bail!( + "Composio connect failed on v3 ({v3_err}) and v2 fallback ({v2_err})" + ), + } + } + } + } + + async fn get_connection_url_v3( + &self, + app_name: Option<&str>, + auth_config_id: Option<&str>, + entity_id: &str, + ) -> anyhow::Result { + let auth_config_id = match auth_config_id { + Some(id) => id.to_string(), + None => { + let app = app_name.ok_or_else(|| { + anyhow::anyhow!("Missing 'app' or 'auth_config_id' for v3 connect") + })?; + self.resolve_auth_config_id(app).await? + } + }; + + let url = format!("{COMPOSIO_API_BASE_V3}/connected_accounts/link"); + let body = json!({ + "auth_config_id": auth_config_id, + "user_id": entity_id, + }); + + let resp = self + .client + .post(&url) + .header("x-api-key", &self.api_key) + .json(&body) + .send() + .await?; + + if !resp.status().is_success() { + let err = response_error(resp).await; + anyhow::bail!("Composio v3 connect failed: {err}"); + } + + let result: serde_json::Value = resp + .json() + .await + .context("Failed to decode Composio v3 connect response")?; + extract_redirect_url(&result) + .ok_or_else(|| anyhow::anyhow!("No redirect URL in Composio v3 response")) + } + + async fn get_connection_url_v2( &self, app_name: &str, entity_id: &str, ) -> anyhow::Result { - let url = format!("{COMPOSIO_API_BASE}/connectedAccounts"); + let url = format!("{COMPOSIO_API_BASE_V2}/connectedAccounts"); let body = json!({ "integrationId": app_name, @@ -114,16 +320,57 @@ impl ComposioTool { .await?; if !resp.status().is_success() { - let err = resp.text().await.unwrap_or_default(); - anyhow::bail!("Failed to get connection URL: {err}"); + let err = response_error(resp).await; + anyhow::bail!("Composio v2 connect failed: {err}"); } - let result: serde_json::Value = resp.json().await?; - result - .get("redirectUrl") - .and_then(|v| v.as_str()) - .map(String::from) - .ok_or_else(|| anyhow::anyhow!("No redirect URL in response")) + let result: serde_json::Value = resp + .json() + .await + .context("Failed to decode Composio v2 connect response")?; + extract_redirect_url(&result) + .ok_or_else(|| anyhow::anyhow!("No redirect URL in Composio v2 response")) + } + + async fn resolve_auth_config_id(&self, app_name: &str) -> anyhow::Result { + let url = format!("{COMPOSIO_API_BASE_V3}/auth_configs"); + + let resp = self + .client + .get(&url) + .header("x-api-key", &self.api_key) + .query(&[ + ("toolkit_slug", app_name), + ("show_disabled", "true"), + ("limit", "25"), + ]) + .send() + .await?; + + if !resp.status().is_success() { + let err = response_error(resp).await; + anyhow::bail!("Composio v3 auth config lookup failed: {err}"); + } + + let body: ComposioAuthConfigsResponse = resp + .json() + .await + .context("Failed to decode Composio v3 auth configs response")?; + + if body.items.is_empty() { + anyhow::bail!( + "No auth config found for toolkit '{app_name}'. Create one in Composio first." + ); + } + + let preferred = body + .items + .iter() + .find(|cfg| cfg.is_enabled()) + .or_else(|| body.items.first()) + .context("No usable auth config returned by Composio")?; + + Ok(preferred.id.clone()) } } @@ -135,7 +382,8 @@ impl Tool for ComposioTool { fn description(&self) -> &str { "Execute actions on 1000+ apps via Composio (Gmail, Notion, GitHub, Slack, etc.). \ - Use action='list' to see available actions, or action='execute' with action_name and params." + Use action='list' to see available actions, action='execute' with action_name/tool_slug, params, and optional connected_account_id, \ + or action='connect' with app/auth_config_id to get OAuth URL." } fn parameters_schema(&self) -> serde_json::Value { @@ -149,11 +397,15 @@ impl Tool for ComposioTool { }, "app": { "type": "string", - "description": "App name filter for 'list', or app name for 'connect' (e.g. 'gmail', 'notion', 'github')" + "description": "Toolkit slug filter for 'list', or toolkit/app for 'connect' (e.g. 'gmail', 'notion', 'github')" }, "action_name": { "type": "string", - "description": "The Composio action name to execute (e.g. 'GMAIL_FETCH_EMAILS')" + "description": "Action/tool identifier to execute (legacy aliases supported)" + }, + "tool_slug": { + "type": "string", + "description": "Preferred v3 tool slug to execute (alias of action_name)" }, "params": { "type": "object", @@ -161,7 +413,15 @@ impl Tool for ComposioTool { }, "entity_id": { "type": "string", - "description": "Entity ID for multi-user setups (defaults to 'default')" + "description": "Entity/user ID for multi-user setups (defaults to composio.entity_id from config)" + }, + "auth_config_id": { + "type": "string", + "description": "Optional Composio v3 auth config id for connect flow" + }, + "connected_account_id": { + "type": "string", + "description": "Optional connected account ID for execute flow when a specific account is required" } }, "required": ["action"] @@ -177,7 +437,7 @@ impl Tool for ComposioTool { let entity_id = args .get("entity_id") .and_then(|v| v.as_str()) - .unwrap_or("default"); + .unwrap_or(self.default_entity_id.as_str()); match action { "list" => { @@ -222,14 +482,19 @@ impl Tool for ComposioTool { "execute" => { let action_name = args - .get("action_name") + .get("tool_slug") + .or_else(|| args.get("action_name")) .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'action_name' for execute"))?; + .ok_or_else(|| { + anyhow::anyhow!("Missing 'action_name' (or 'tool_slug') for execute") + })?; let params = args.get("params").cloned().unwrap_or(json!({})); + let connected_account_ref = + args.get("connected_account_id").and_then(|v| v.as_str()); match self - .execute_action(action_name, params, Some(entity_id)) + .execute_action(action_name, params, Some(entity_id), connected_account_ref) .await { Ok(result) => { @@ -250,17 +515,26 @@ impl Tool for ComposioTool { } "connect" => { - let app = args - .get("app") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'app' for connect"))?; + let app = args.get("app").and_then(|v| v.as_str()); + let auth_config_id = args.get("auth_config_id").and_then(|v| v.as_str()); - match self.get_connection_url(app, entity_id).await { - Ok(url) => Ok(ToolResult { - success: true, - output: format!("Open this URL to connect {app}:\n{url}"), - error: None, - }), + if app.is_none() && auth_config_id.is_none() { + anyhow::bail!("Missing 'app' or 'auth_config_id' for connect"); + } + + match self + .get_connection_url(app, auth_config_id, entity_id) + .await + { + Ok(url) => { + let target = + app.unwrap_or(auth_config_id.unwrap_or("provided auth config")); + Ok(ToolResult { + success: true, + output: format!("Open this URL to connect {target}:\n{url}"), + error: None, + }) + } Err(e) => Ok(ToolResult { success: false, output: String::new(), @@ -280,6 +554,112 @@ impl Tool for ComposioTool { } } +fn normalize_entity_id(entity_id: &str) -> String { + let trimmed = entity_id.trim(); + if trimmed.is_empty() { + "default".to_string() + } else { + trimmed.to_string() + } +} + +fn normalize_tool_slug(action_name: &str) -> String { + action_name.trim().replace('_', "-").to_ascii_lowercase() +} + +fn map_v3_tools_to_actions(items: Vec) -> Vec { + items + .into_iter() + .filter_map(|item| { + let name = item.slug.or(item.name.clone())?; + let app_name = item + .toolkit + .as_ref() + .and_then(|toolkit| toolkit.slug.clone().or(toolkit.name.clone())) + .or(item.app_name); + let description = item.description.or(item.name); + Some(ComposioAction { + name, + app_name, + description, + enabled: true, + }) + }) + .collect() +} + +fn extract_redirect_url(result: &serde_json::Value) -> Option { + result + .get("redirect_url") + .and_then(|v| v.as_str()) + .or_else(|| result.get("redirectUrl").and_then(|v| v.as_str())) + .or_else(|| { + result + .get("data") + .and_then(|v| v.get("redirect_url")) + .and_then(|v| v.as_str()) + }) + .map(ToString::to_string) +} + +async fn response_error(resp: reqwest::Response) -> String { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + if body.trim().is_empty() { + return format!("HTTP {}", status.as_u16()); + } + + if let Some(api_error) = extract_api_error_message(&body) { + return format!( + "HTTP {}: {}", + status.as_u16(), + sanitize_error_message(&api_error) + ); + } + + format!("HTTP {}", status.as_u16()) +} + +fn sanitize_error_message(message: &str) -> String { + let mut sanitized = message.replace('\n', " "); + for marker in [ + "connected_account_id", + "connectedAccountId", + "entity_id", + "entityId", + "user_id", + "userId", + ] { + sanitized = sanitized.replace(marker, "[redacted]"); + } + + let max_chars = 240; + if sanitized.chars().count() <= max_chars { + sanitized + } else { + let mut end = max_chars; + while end > 0 && !sanitized.is_char_boundary(end) { + end -= 1; + } + format!("{}...", &sanitized[..end]) + } +} + +fn extract_api_error_message(body: &str) -> Option { + let parsed: serde_json::Value = serde_json::from_str(body).ok()?; + parsed + .get("error") + .and_then(|v| v.get("message")) + .and_then(|v| v.as_str()) + .map(ToString::to_string) + .or_else(|| { + parsed + .get("message") + .and_then(|v| v.as_str()) + .map(ToString::to_string) + }) +} + // ── API response types ────────────────────────────────────────── #[derive(Debug, Deserialize)] @@ -288,6 +668,59 @@ struct ComposioActionsResponse { items: Vec, } +#[derive(Debug, Deserialize)] +struct ComposioToolsResponse { + #[serde(default)] + items: Vec, +} + +#[derive(Debug, Clone, Deserialize)] +struct ComposioV3Tool { + #[serde(default)] + slug: Option, + #[serde(default)] + name: Option, + #[serde(default)] + description: Option, + #[serde(rename = "appName", default)] + app_name: Option, + #[serde(default)] + toolkit: Option, +} + +#[derive(Debug, Clone, Deserialize)] +struct ComposioToolkitRef { + #[serde(default)] + slug: Option, + #[serde(default)] + name: Option, +} + +#[derive(Debug, Deserialize)] +struct ComposioAuthConfigsResponse { + #[serde(default)] + items: Vec, +} + +#[derive(Debug, Clone, Deserialize)] +struct ComposioAuthConfig { + id: String, + #[serde(default)] + status: Option, + #[serde(default)] + enabled: Option, +} + +impl ComposioAuthConfig { + fn is_enabled(&self) -> bool { + self.enabled.unwrap_or(false) + || self + .status + .as_deref() + .is_some_and(|v| v.eq_ignore_ascii_case("enabled")) + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ComposioAction { pub name: String, @@ -306,32 +739,35 @@ mod tests { #[test] fn composio_tool_has_correct_name() { - let tool = ComposioTool::new("test-key"); + let tool = ComposioTool::new("test-key", None); assert_eq!(tool.name(), "composio"); } #[test] fn composio_tool_has_description() { - let tool = ComposioTool::new("test-key"); + let tool = ComposioTool::new("test-key", None); assert!(!tool.description().is_empty()); assert!(tool.description().contains("1000+")); } #[test] fn composio_tool_schema_has_required_fields() { - let tool = ComposioTool::new("test-key"); + let tool = ComposioTool::new("test-key", None); let schema = tool.parameters_schema(); assert!(schema["properties"]["action"].is_object()); assert!(schema["properties"]["action_name"].is_object()); + assert!(schema["properties"]["tool_slug"].is_object()); assert!(schema["properties"]["params"].is_object()); assert!(schema["properties"]["app"].is_object()); + assert!(schema["properties"]["auth_config_id"].is_object()); + assert!(schema["properties"]["connected_account_id"].is_object()); let required = schema["required"].as_array().unwrap(); assert!(required.contains(&json!("action"))); } #[test] fn composio_tool_spec_roundtrip() { - let tool = ComposioTool::new("test-key"); + let tool = ComposioTool::new("test-key", None); let spec = tool.spec(); assert_eq!(spec.name, "composio"); assert!(spec.parameters.is_object()); @@ -341,14 +777,14 @@ mod tests { #[tokio::test] async fn execute_missing_action_returns_error() { - let tool = ComposioTool::new("test-key"); + let tool = ComposioTool::new("test-key", None); let result = tool.execute(json!({})).await; assert!(result.is_err()); } #[tokio::test] async fn execute_unknown_action_returns_error() { - let tool = ComposioTool::new("test-key"); + let tool = ComposioTool::new("test-key", None); let result = tool.execute(json!({"action": "unknown"})).await.unwrap(); assert!(!result.success); assert!(result.error.as_ref().unwrap().contains("Unknown action")); @@ -356,14 +792,14 @@ mod tests { #[tokio::test] async fn execute_without_action_name_returns_error() { - let tool = ComposioTool::new("test-key"); + let tool = ComposioTool::new("test-key", None); let result = tool.execute(json!({"action": "execute"})).await; assert!(result.is_err()); } #[tokio::test] - async fn connect_without_app_returns_error() { - let tool = ComposioTool::new("test-key"); + async fn connect_without_target_returns_error() { + let tool = ComposioTool::new("test-key", None); let result = tool.execute(json!({"action": "connect"})).await; assert!(result.is_err()); } @@ -400,4 +836,197 @@ mod tests { let resp: ComposioActionsResponse = serde_json::from_str(json_str).unwrap(); assert!(resp.items.is_empty()); } + + #[test] + fn composio_v3_tools_response_maps_to_actions() { + let json_str = r#"{ + "items": [ + { + "slug": "gmail-fetch-emails", + "name": "Gmail Fetch Emails", + "description": "Fetch inbox emails", + "toolkit": { "slug": "gmail", "name": "Gmail" } + } + ] + }"#; + let resp: ComposioToolsResponse = serde_json::from_str(json_str).unwrap(); + let actions = map_v3_tools_to_actions(resp.items); + assert_eq!(actions.len(), 1); + assert_eq!(actions[0].name, "gmail-fetch-emails"); + assert_eq!(actions[0].app_name.as_deref(), Some("gmail")); + assert_eq!( + actions[0].description.as_deref(), + Some("Fetch inbox emails") + ); + } + + #[test] + fn normalize_entity_id_falls_back_to_default_when_blank() { + assert_eq!(normalize_entity_id(" "), "default"); + assert_eq!(normalize_entity_id("workspace-user"), "workspace-user"); + } + + #[test] + fn normalize_tool_slug_supports_legacy_action_name() { + assert_eq!( + normalize_tool_slug("GMAIL_FETCH_EMAILS"), + "gmail-fetch-emails" + ); + assert_eq!( + normalize_tool_slug(" github-list-repos "), + "github-list-repos" + ); + } + + #[test] + fn extract_redirect_url_supports_v2_and_v3_shapes() { + let v2 = json!({"redirectUrl": "https://app.composio.dev/connect-v2"}); + let v3 = json!({"redirect_url": "https://app.composio.dev/connect-v3"}); + let nested = json!({"data": {"redirect_url": "https://app.composio.dev/connect-nested"}}); + + assert_eq!( + extract_redirect_url(&v2).as_deref(), + Some("https://app.composio.dev/connect-v2") + ); + assert_eq!( + extract_redirect_url(&v3).as_deref(), + Some("https://app.composio.dev/connect-v3") + ); + assert_eq!( + extract_redirect_url(&nested).as_deref(), + Some("https://app.composio.dev/connect-nested") + ); + } + + #[test] + fn auth_config_prefers_enabled_status() { + let enabled = ComposioAuthConfig { + id: "cfg_1".into(), + status: Some("ENABLED".into()), + enabled: None, + }; + let disabled = ComposioAuthConfig { + id: "cfg_2".into(), + status: Some("DISABLED".into()), + enabled: Some(false), + }; + + assert!(enabled.is_enabled()); + assert!(!disabled.is_enabled()); + } + + #[test] + fn extract_api_error_message_from_common_shapes() { + let nested = r#"{"error":{"message":"tool not found"}}"#; + let flat = r#"{"message":"invalid api key"}"#; + + assert_eq!( + extract_api_error_message(nested).as_deref(), + Some("tool not found") + ); + assert_eq!( + extract_api_error_message(flat).as_deref(), + Some("invalid api key") + ); + assert_eq!(extract_api_error_message("not-json"), None); + } + + #[test] + fn composio_action_with_null_fields() { + let json_str = + r#"{"name": "TEST_ACTION", "appName": null, "description": null, "enabled": false}"#; + let action: ComposioAction = serde_json::from_str(json_str).unwrap(); + assert_eq!(action.name, "TEST_ACTION"); + assert!(action.app_name.is_none()); + assert!(action.description.is_none()); + assert!(!action.enabled); + } + + #[test] + fn composio_action_with_special_characters() { + let json_str = r#"{"name": "GMAIL_SEND_EMAIL_WITH_ATTACHMENT", "appName": "gmail", "description": "Send email with attachment & special chars: <>'\"\"", "enabled": true}"#; + let action: ComposioAction = serde_json::from_str(json_str).unwrap(); + assert_eq!(action.name, "GMAIL_SEND_EMAIL_WITH_ATTACHMENT"); + assert!(action.description.as_ref().unwrap().contains('&')); + assert!(action.description.as_ref().unwrap().contains('<')); + } + + #[test] + fn composio_action_with_unicode() { + let json_str = r#"{"name": "SLACK_SEND_MESSAGE", "appName": "slack", "description": "Send message with emoji 🎉 and unicode 中文", "enabled": true}"#; + let action: ComposioAction = serde_json::from_str(json_str).unwrap(); + assert!(action.description.as_ref().unwrap().contains("🎉")); + assert!(action.description.as_ref().unwrap().contains("中文")); + } + + #[test] + fn composio_malformed_json_returns_error() { + let json_str = r#"{"name": "TEST_ACTION", "appName": "gmail", }"#; + let result: Result = serde_json::from_str(json_str); + assert!(result.is_err()); + } + + #[test] + fn composio_empty_json_string_returns_error() { + let json_str = r#" ""#; + let result: Result = serde_json::from_str(json_str); + assert!(result.is_err()); + } + + #[test] + fn composio_large_actions_list() { + let mut items = Vec::new(); + for i in 0..100 { + items.push(json!({ + "name": format!("ACTION_{i}"), + "appName": "test", + "description": "Test action", + "enabled": true + })); + } + let json_str = json!({"items": items}).to_string(); + let resp: ComposioActionsResponse = serde_json::from_str(&json_str).unwrap(); + assert_eq!(resp.items.len(), 100); + } + + #[test] + fn composio_api_base_url_is_v3() { + assert_eq!(COMPOSIO_API_BASE_V3, "https://backend.composio.dev/api/v3"); + } + + #[test] + fn build_execute_action_v3_request_uses_fixed_endpoint_and_body_account_id() { + let (url, body) = ComposioTool::build_execute_action_v3_request( + "gmail-send-email", + json!({"to": "test@example.com"}), + Some("workspace-user"), + Some("account-42"), + ); + + assert_eq!( + url, + "https://backend.composio.dev/api/v3/tools/gmail-send-email/execute" + ); + assert_eq!(body["arguments"]["to"], json!("test@example.com")); + assert_eq!(body["user_id"], json!("workspace-user")); + assert_eq!(body["connected_account_id"], json!("account-42")); + } + + #[test] + fn build_execute_action_v3_request_drops_blank_optional_fields() { + let (url, body) = ComposioTool::build_execute_action_v3_request( + "github-list-repos", + json!({}), + None, + Some(" "), + ); + + assert_eq!( + url, + "https://backend.composio.dev/api/v3/tools/github-list-repos/execute" + ); + assert_eq!(body["arguments"], json!({})); + assert!(body.get("connected_account_id").is_none()); + assert!(body.get("user_id").is_none()); + } } diff --git a/src/tools/cron_add.rs b/src/tools/cron_add.rs new file mode 100644 index 0000000..bd3abea --- /dev/null +++ b/src/tools/cron_add.rs @@ -0,0 +1,326 @@ +use super::traits::{Tool, ToolResult}; +use crate::config::Config; +use crate::cron::{self, DeliveryConfig, JobType, Schedule, SessionTarget}; +use crate::security::SecurityPolicy; +use async_trait::async_trait; +use serde_json::json; +use std::sync::Arc; + +pub struct CronAddTool { + config: Arc, + security: Arc, +} + +impl CronAddTool { + pub fn new(config: Arc, security: Arc) -> Self { + Self { config, security } + } +} + +#[async_trait] +impl Tool for CronAddTool { + fn name(&self) -> &str { + "cron_add" + } + + fn description(&self) -> &str { + "Create a scheduled cron job (shell or agent) with cron/at/every schedules" + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "name": { "type": "string" }, + "schedule": { + "type": "object", + "description": "Schedule object: {kind:'cron',expr,tz?} | {kind:'at',at} | {kind:'every',every_ms}" + }, + "job_type": { "type": "string", "enum": ["shell", "agent"] }, + "command": { "type": "string" }, + "prompt": { "type": "string" }, + "session_target": { "type": "string", "enum": ["isolated", "main"] }, + "model": { "type": "string" }, + "delivery": { "type": "object" }, + "delete_after_run": { "type": "boolean" } + }, + "required": ["schedule"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + if !self.config.cron.enabled { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("cron is disabled by config (cron.enabled=false)".to_string()), + }); + } + + let schedule = match args.get("schedule") { + Some(v) => match serde_json::from_value::(v.clone()) { + Ok(schedule) => schedule, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Invalid schedule: {e}")), + }); + } + }, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing 'schedule' parameter".to_string()), + }); + } + }; + + let name = args + .get("name") + .and_then(serde_json::Value::as_str) + .map(str::to_string); + + let job_type = match args.get("job_type").and_then(serde_json::Value::as_str) { + Some("agent") => JobType::Agent, + Some("shell") => JobType::Shell, + Some(other) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Invalid job_type: {other}")), + }); + } + None => { + if args.get("prompt").is_some() { + JobType::Agent + } else { + JobType::Shell + } + } + }; + + let default_delete_after_run = matches!(schedule, Schedule::At { .. }); + let delete_after_run = args + .get("delete_after_run") + .and_then(serde_json::Value::as_bool) + .unwrap_or(default_delete_after_run); + + let result = match job_type { + JobType::Shell => { + let command = match args.get("command").and_then(serde_json::Value::as_str) { + Some(command) if !command.trim().is_empty() => command, + _ => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing 'command' for shell job".to_string()), + }); + } + }; + + if !self.security.is_command_allowed(command) { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Command blocked by security policy: {command}")), + }); + } + + cron::add_shell_job(&self.config, name, schedule, command) + } + JobType::Agent => { + let prompt = match args.get("prompt").and_then(serde_json::Value::as_str) { + Some(prompt) if !prompt.trim().is_empty() => prompt, + _ => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing 'prompt' for agent job".to_string()), + }); + } + }; + + let session_target = match args.get("session_target") { + Some(v) => match serde_json::from_value::(v.clone()) { + Ok(target) => target, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Invalid session_target: {e}")), + }); + } + }, + None => SessionTarget::Isolated, + }; + + let model = args + .get("model") + .and_then(serde_json::Value::as_str) + .map(str::to_string); + + let delivery = match args.get("delivery") { + Some(v) => match serde_json::from_value::(v.clone()) { + Ok(cfg) => Some(cfg), + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Invalid delivery config: {e}")), + }); + } + }, + None => None, + }; + + cron::add_agent_job( + &self.config, + name, + schedule, + prompt, + session_target, + model, + delivery, + delete_after_run, + ) + } + }; + + match result { + Ok(job) => Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&json!({ + "id": job.id, + "name": job.name, + "job_type": job.job_type, + "schedule": job.schedule, + "next_run": job.next_run, + "enabled": job.enabled + }))?, + error: None, + }), + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e.to_string()), + }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::Config; + use crate::security::AutonomyLevel; + use tempfile::TempDir; + + fn test_config(tmp: &TempDir) -> Arc { + let config = Config { + workspace_dir: tmp.path().join("workspace"), + config_path: tmp.path().join("config.toml"), + ..Config::default() + }; + std::fs::create_dir_all(&config.workspace_dir).unwrap(); + Arc::new(config) + } + + fn test_security(cfg: &Config) -> Arc { + Arc::new(SecurityPolicy::from_config( + &cfg.autonomy, + &cfg.workspace_dir, + )) + } + + #[tokio::test] + async fn adds_shell_job() { + let tmp = TempDir::new().unwrap(); + let cfg = test_config(&tmp); + let tool = CronAddTool::new(cfg.clone(), test_security(&cfg)); + let result = tool + .execute(json!({ + "schedule": { "kind": "cron", "expr": "*/5 * * * *" }, + "job_type": "shell", + "command": "echo ok" + })) + .await + .unwrap(); + + assert!(result.success, "{:?}", result.error); + assert!(result.output.contains("next_run")); + } + + #[tokio::test] + async fn blocks_disallowed_shell_command() { + let tmp = TempDir::new().unwrap(); + let mut config = Config { + workspace_dir: tmp.path().join("workspace"), + config_path: tmp.path().join("config.toml"), + ..Config::default() + }; + config.autonomy.allowed_commands = vec!["echo".into()]; + config.autonomy.level = AutonomyLevel::Supervised; + std::fs::create_dir_all(&config.workspace_dir).unwrap(); + let cfg = Arc::new(config); + let tool = CronAddTool::new(cfg.clone(), test_security(&cfg)); + + let result = tool + .execute(json!({ + "schedule": { "kind": "cron", "expr": "*/5 * * * *" }, + "job_type": "shell", + "command": "curl https://example.com" + })) + .await + .unwrap(); + + assert!(!result.success); + assert!(result + .error + .unwrap_or_default() + .contains("blocked by security policy")); + } + + #[tokio::test] + async fn rejects_invalid_schedule() { + let tmp = TempDir::new().unwrap(); + let cfg = test_config(&tmp); + let tool = CronAddTool::new(cfg.clone(), test_security(&cfg)); + + let result = tool + .execute(json!({ + "schedule": { "kind": "every", "every_ms": 0 }, + "job_type": "shell", + "command": "echo nope" + })) + .await + .unwrap(); + + assert!(!result.success); + assert!(result + .error + .unwrap_or_default() + .contains("every_ms must be > 0")); + } + + #[tokio::test] + async fn agent_job_requires_prompt() { + let tmp = TempDir::new().unwrap(); + let cfg = test_config(&tmp); + let tool = CronAddTool::new(cfg.clone(), test_security(&cfg)); + + let result = tool + .execute(json!({ + "schedule": { "kind": "cron", "expr": "*/5 * * * *" }, + "job_type": "agent" + })) + .await + .unwrap(); + assert!(!result.success); + assert!(result + .error + .unwrap_or_default() + .contains("Missing 'prompt'")); + } +} diff --git a/src/tools/cron_list.rs b/src/tools/cron_list.rs new file mode 100644 index 0000000..0392370 --- /dev/null +++ b/src/tools/cron_list.rs @@ -0,0 +1,101 @@ +use super::traits::{Tool, ToolResult}; +use crate::config::Config; +use crate::cron; +use async_trait::async_trait; +use serde_json::json; +use std::sync::Arc; + +pub struct CronListTool { + config: Arc, +} + +impl CronListTool { + pub fn new(config: Arc) -> Self { + Self { config } + } +} + +#[async_trait] +impl Tool for CronListTool { + fn name(&self) -> &str { + "cron_list" + } + + fn description(&self) -> &str { + "List all scheduled cron jobs" + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": {}, + "additionalProperties": false + }) + } + + async fn execute(&self, _args: serde_json::Value) -> anyhow::Result { + if !self.config.cron.enabled { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("cron is disabled by config (cron.enabled=false)".to_string()), + }); + } + + match cron::list_jobs(&self.config) { + Ok(jobs) => Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&jobs)?, + error: None, + }), + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e.to_string()), + }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::Config; + use tempfile::TempDir; + + fn test_config(tmp: &TempDir) -> Arc { + let config = Config { + workspace_dir: tmp.path().join("workspace"), + config_path: tmp.path().join("config.toml"), + ..Config::default() + }; + std::fs::create_dir_all(&config.workspace_dir).unwrap(); + Arc::new(config) + } + + #[tokio::test] + async fn returns_empty_list_when_no_jobs() { + let tmp = TempDir::new().unwrap(); + let cfg = test_config(&tmp); + let tool = CronListTool::new(cfg); + + let result = tool.execute(json!({})).await.unwrap(); + assert!(result.success); + assert_eq!(result.output.trim(), "[]"); + } + + #[tokio::test] + async fn errors_when_cron_disabled() { + let tmp = TempDir::new().unwrap(); + let mut cfg = (*test_config(&tmp)).clone(); + cfg.cron.enabled = false; + let tool = CronListTool::new(Arc::new(cfg)); + + let result = tool.execute(json!({})).await.unwrap(); + assert!(!result.success); + assert!(result + .error + .unwrap_or_default() + .contains("cron is disabled")); + } +} diff --git a/src/tools/cron_remove.rs b/src/tools/cron_remove.rs new file mode 100644 index 0000000..01a70dc --- /dev/null +++ b/src/tools/cron_remove.rs @@ -0,0 +1,114 @@ +use super::traits::{Tool, ToolResult}; +use crate::config::Config; +use crate::cron; +use async_trait::async_trait; +use serde_json::json; +use std::sync::Arc; + +pub struct CronRemoveTool { + config: Arc, +} + +impl CronRemoveTool { + pub fn new(config: Arc) -> Self { + Self { config } + } +} + +#[async_trait] +impl Tool for CronRemoveTool { + fn name(&self) -> &str { + "cron_remove" + } + + fn description(&self) -> &str { + "Remove a cron job by id" + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "job_id": { "type": "string" } + }, + "required": ["job_id"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + if !self.config.cron.enabled { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("cron is disabled by config (cron.enabled=false)".to_string()), + }); + } + + let job_id = match args.get("job_id").and_then(serde_json::Value::as_str) { + Some(v) if !v.trim().is_empty() => v, + _ => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing 'job_id' parameter".to_string()), + }); + } + }; + + match cron::remove_job(&self.config, job_id) { + Ok(()) => Ok(ToolResult { + success: true, + output: format!("Removed cron job {job_id}"), + error: None, + }), + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e.to_string()), + }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::Config; + use tempfile::TempDir; + + fn test_config(tmp: &TempDir) -> Arc { + let config = Config { + workspace_dir: tmp.path().join("workspace"), + config_path: tmp.path().join("config.toml"), + ..Config::default() + }; + std::fs::create_dir_all(&config.workspace_dir).unwrap(); + Arc::new(config) + } + + #[tokio::test] + async fn removes_existing_job() { + let tmp = TempDir::new().unwrap(); + let cfg = test_config(&tmp); + let job = cron::add_job(&cfg, "*/5 * * * *", "echo ok").unwrap(); + let tool = CronRemoveTool::new(cfg.clone()); + + let result = tool.execute(json!({"job_id": job.id})).await.unwrap(); + assert!(result.success); + assert!(cron::list_jobs(&cfg).unwrap().is_empty()); + } + + #[tokio::test] + async fn errors_when_job_id_missing() { + let tmp = TempDir::new().unwrap(); + let cfg = test_config(&tmp); + let tool = CronRemoveTool::new(cfg); + + let result = tool.execute(json!({})).await.unwrap(); + assert!(!result.success); + assert!(result + .error + .unwrap_or_default() + .contains("Missing 'job_id'")); + } +} diff --git a/src/tools/cron_run.rs b/src/tools/cron_run.rs new file mode 100644 index 0000000..a4e5f75 --- /dev/null +++ b/src/tools/cron_run.rs @@ -0,0 +1,147 @@ +use super::traits::{Tool, ToolResult}; +use crate::config::Config; +use crate::cron; +use async_trait::async_trait; +use chrono::Utc; +use serde_json::json; +use std::sync::Arc; + +pub struct CronRunTool { + config: Arc, +} + +impl CronRunTool { + pub fn new(config: Arc) -> Self { + Self { config } + } +} + +#[async_trait] +impl Tool for CronRunTool { + fn name(&self) -> &str { + "cron_run" + } + + fn description(&self) -> &str { + "Force-run a cron job immediately and record run history" + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "job_id": { "type": "string" } + }, + "required": ["job_id"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + if !self.config.cron.enabled { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("cron is disabled by config (cron.enabled=false)".to_string()), + }); + } + + let job_id = match args.get("job_id").and_then(serde_json::Value::as_str) { + Some(v) if !v.trim().is_empty() => v, + _ => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing 'job_id' parameter".to_string()), + }); + } + }; + + let job = match cron::get_job(&self.config, job_id) { + Ok(job) => job, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e.to_string()), + }); + } + }; + + let started_at = Utc::now(); + let (success, output) = cron::scheduler::execute_job_now(&self.config, &job).await; + let finished_at = Utc::now(); + let duration_ms = (finished_at - started_at).num_milliseconds(); + let status = if success { "ok" } else { "error" }; + + let _ = cron::record_run( + &self.config, + &job.id, + started_at, + finished_at, + status, + Some(&output), + duration_ms, + ); + let _ = cron::record_last_run(&self.config, &job.id, finished_at, success, &output); + + Ok(ToolResult { + success, + output: serde_json::to_string_pretty(&json!({ + "job_id": job.id, + "status": status, + "duration_ms": duration_ms, + "output": output + }))?, + error: if success { + None + } else { + Some("cron job execution failed".to_string()) + }, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::Config; + use tempfile::TempDir; + + fn test_config(tmp: &TempDir) -> Arc { + let config = Config { + workspace_dir: tmp.path().join("workspace"), + config_path: tmp.path().join("config.toml"), + ..Config::default() + }; + std::fs::create_dir_all(&config.workspace_dir).unwrap(); + Arc::new(config) + } + + #[tokio::test] + async fn force_runs_job_and_records_history() { + let tmp = TempDir::new().unwrap(); + let cfg = test_config(&tmp); + let job = cron::add_job(&cfg, "*/5 * * * *", "echo run-now").unwrap(); + let tool = CronRunTool::new(cfg.clone()); + + let result = tool.execute(json!({ "job_id": job.id })).await.unwrap(); + assert!(result.success, "{:?}", result.error); + + let runs = cron::list_runs(&cfg, &job.id, 10).unwrap(); + assert_eq!(runs.len(), 1); + } + + #[tokio::test] + async fn errors_for_missing_job() { + let tmp = TempDir::new().unwrap(); + let cfg = test_config(&tmp); + let tool = CronRunTool::new(cfg); + + let result = tool + .execute(json!({ "job_id": "missing-job-id" })) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.unwrap_or_default().contains("not found")); + } +} diff --git a/src/tools/cron_runs.rs b/src/tools/cron_runs.rs new file mode 100644 index 0000000..280baa1 --- /dev/null +++ b/src/tools/cron_runs.rs @@ -0,0 +1,175 @@ +use super::traits::{Tool, ToolResult}; +use crate::config::Config; +use crate::cron; +use async_trait::async_trait; +use serde::Serialize; +use serde_json::json; +use std::sync::Arc; + +const MAX_RUN_OUTPUT_CHARS: usize = 500; + +pub struct CronRunsTool { + config: Arc, +} + +impl CronRunsTool { + pub fn new(config: Arc) -> Self { + Self { config } + } +} + +#[derive(Serialize)] +struct RunView { + id: i64, + job_id: String, + started_at: chrono::DateTime, + finished_at: chrono::DateTime, + status: String, + output: Option, + duration_ms: Option, +} + +#[async_trait] +impl Tool for CronRunsTool { + fn name(&self) -> &str { + "cron_runs" + } + + fn description(&self) -> &str { + "List recent run history for a cron job" + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "job_id": { "type": "string" }, + "limit": { "type": "integer" } + }, + "required": ["job_id"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + if !self.config.cron.enabled { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("cron is disabled by config (cron.enabled=false)".to_string()), + }); + } + + let job_id = match args.get("job_id").and_then(serde_json::Value::as_str) { + Some(v) if !v.trim().is_empty() => v, + _ => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing 'job_id' parameter".to_string()), + }); + } + }; + + let limit = args + .get("limit") + .and_then(serde_json::Value::as_u64) + .map_or(10, |v| usize::try_from(v).unwrap_or(10)); + + match cron::list_runs(&self.config, job_id, limit) { + Ok(runs) => { + let runs: Vec = runs + .into_iter() + .map(|run| RunView { + id: run.id, + job_id: run.job_id, + started_at: run.started_at, + finished_at: run.finished_at, + status: run.status, + output: run.output.map(|out| truncate(&out, MAX_RUN_OUTPUT_CHARS)), + duration_ms: run.duration_ms, + }) + .collect(); + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&runs)?, + error: None, + }) + } + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e.to_string()), + }), + } + } +} + +fn truncate(input: &str, max_chars: usize) -> String { + if input.chars().count() <= max_chars { + return input.to_string(); + } + let mut out: String = input.chars().take(max_chars).collect(); + out.push_str("..."); + out +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::Config; + use chrono::{Duration as ChronoDuration, Utc}; + use tempfile::TempDir; + + fn test_config(tmp: &TempDir) -> Arc { + let config = Config { + workspace_dir: tmp.path().join("workspace"), + config_path: tmp.path().join("config.toml"), + ..Config::default() + }; + std::fs::create_dir_all(&config.workspace_dir).unwrap(); + Arc::new(config) + } + + #[tokio::test] + async fn lists_runs_with_truncation() { + let tmp = TempDir::new().unwrap(); + let cfg = test_config(&tmp); + let job = cron::add_job(&cfg, "*/5 * * * *", "echo ok").unwrap(); + + let long_output = "x".repeat(1000); + let now = Utc::now(); + cron::record_run( + &cfg, + &job.id, + now, + now + ChronoDuration::milliseconds(1), + "ok", + Some(&long_output), + 1, + ) + .unwrap(); + + let tool = CronRunsTool::new(cfg.clone()); + let result = tool + .execute(json!({ "job_id": job.id, "limit": 5 })) + .await + .unwrap(); + + assert!(result.success); + assert!(result.output.contains("...")); + } + + #[tokio::test] + async fn errors_when_job_id_missing() { + let tmp = TempDir::new().unwrap(); + let cfg = test_config(&tmp); + let tool = CronRunsTool::new(cfg); + let result = tool.execute(json!({})).await.unwrap(); + assert!(!result.success); + assert!(result + .error + .unwrap_or_default() + .contains("Missing 'job_id'")); + } +} diff --git a/src/tools/cron_update.rs b/src/tools/cron_update.rs new file mode 100644 index 0000000..c224b17 --- /dev/null +++ b/src/tools/cron_update.rs @@ -0,0 +1,177 @@ +use super::traits::{Tool, ToolResult}; +use crate::config::Config; +use crate::cron::{self, CronJobPatch}; +use crate::security::SecurityPolicy; +use async_trait::async_trait; +use serde_json::json; +use std::sync::Arc; + +pub struct CronUpdateTool { + config: Arc, + security: Arc, +} + +impl CronUpdateTool { + pub fn new(config: Arc, security: Arc) -> Self { + Self { config, security } + } +} + +#[async_trait] +impl Tool for CronUpdateTool { + fn name(&self) -> &str { + "cron_update" + } + + fn description(&self) -> &str { + "Patch an existing cron job (schedule, command, prompt, enabled, delivery, model, etc.)" + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "job_id": { "type": "string" }, + "patch": { "type": "object" } + }, + "required": ["job_id", "patch"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + if !self.config.cron.enabled { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("cron is disabled by config (cron.enabled=false)".to_string()), + }); + } + + let job_id = match args.get("job_id").and_then(serde_json::Value::as_str) { + Some(v) if !v.trim().is_empty() => v, + _ => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing 'job_id' parameter".to_string()), + }); + } + }; + + let patch_val = match args.get("patch") { + Some(v) => v.clone(), + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing 'patch' parameter".to_string()), + }); + } + }; + + let patch = match serde_json::from_value::(patch_val) { + Ok(patch) => patch, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Invalid patch payload: {e}")), + }); + } + }; + + if let Some(command) = &patch.command { + if !self.security.is_command_allowed(command) { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Command blocked by security policy: {command}")), + }); + } + } + + match cron::update_job(&self.config, job_id, patch) { + Ok(job) => Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&job)?, + error: None, + }), + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e.to_string()), + }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::Config; + use tempfile::TempDir; + + fn test_config(tmp: &TempDir) -> Arc { + let config = Config { + workspace_dir: tmp.path().join("workspace"), + config_path: tmp.path().join("config.toml"), + ..Config::default() + }; + std::fs::create_dir_all(&config.workspace_dir).unwrap(); + Arc::new(config) + } + + fn test_security(cfg: &Config) -> Arc { + Arc::new(SecurityPolicy::from_config( + &cfg.autonomy, + &cfg.workspace_dir, + )) + } + + #[tokio::test] + async fn updates_enabled_flag() { + let tmp = TempDir::new().unwrap(); + let cfg = test_config(&tmp); + let job = cron::add_job(&cfg, "*/5 * * * *", "echo ok").unwrap(); + let tool = CronUpdateTool::new(cfg.clone(), test_security(&cfg)); + + let result = tool + .execute(json!({ + "job_id": job.id, + "patch": { "enabled": false } + })) + .await + .unwrap(); + + assert!(result.success, "{:?}", result.error); + assert!(result.output.contains("\"enabled\": false")); + } + + #[tokio::test] + async fn blocks_disallowed_command_updates() { + let tmp = TempDir::new().unwrap(); + let mut config = Config { + workspace_dir: tmp.path().join("workspace"), + config_path: tmp.path().join("config.toml"), + ..Config::default() + }; + config.autonomy.allowed_commands = vec!["echo".into()]; + std::fs::create_dir_all(&config.workspace_dir).unwrap(); + let cfg = Arc::new(config); + let job = cron::add_job(&cfg, "*/5 * * * *", "echo ok").unwrap(); + let tool = CronUpdateTool::new(cfg.clone(), test_security(&cfg)); + + let result = tool + .execute(json!({ + "job_id": job.id, + "patch": { "command": "curl https://example.com" } + })) + .await + .unwrap(); + assert!(!result.success); + assert!(result + .error + .unwrap_or_default() + .contains("blocked by security policy")); + } +} diff --git a/src/tools/delegate.rs b/src/tools/delegate.rs new file mode 100644 index 0000000..3de7872 --- /dev/null +++ b/src/tools/delegate.rs @@ -0,0 +1,435 @@ +use super::traits::{Tool, ToolResult}; +use crate::config::DelegateAgentConfig; +use crate::providers::{self, Provider}; +use async_trait::async_trait; +use serde_json::json; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +/// Default timeout for sub-agent provider calls. +const DELEGATE_TIMEOUT_SECS: u64 = 120; + +/// Tool that delegates a subtask to a named agent with a different +/// provider/model configuration. Enables multi-agent workflows where +/// a primary agent can hand off specialized work (research, coding, +/// summarization) to purpose-built sub-agents. +pub struct DelegateTool { + agents: Arc>, + /// Global credential fallback (from config.api_key) + fallback_credential: Option, + /// Depth at which this tool instance lives in the delegation chain. + depth: u32, +} + +impl DelegateTool { + pub fn new( + agents: HashMap, + fallback_credential: Option, + ) -> Self { + Self { + agents: Arc::new(agents), + fallback_credential, + depth: 0, + } + } + + /// Create a DelegateTool for a sub-agent (with incremented depth). + /// When sub-agents eventually get their own tool registry, construct + /// their DelegateTool via this method with `depth: parent.depth + 1`. + pub fn with_depth( + agents: HashMap, + fallback_credential: Option, + depth: u32, + ) -> Self { + Self { + agents: Arc::new(agents), + fallback_credential, + depth, + } + } +} + +#[async_trait] +impl Tool for DelegateTool { + fn name(&self) -> &str { + "delegate" + } + + fn description(&self) -> &str { + "Delegate a subtask to a specialized agent. Use when: a task benefits from a different model \ + (e.g. fast summarization, deep reasoning, code generation). The sub-agent runs a single \ + prompt and returns its response." + } + + fn parameters_schema(&self) -> serde_json::Value { + let agent_names: Vec<&str> = self.agents.keys().map(|s: &String| s.as_str()).collect(); + json!({ + "type": "object", + "additionalProperties": false, + "properties": { + "agent": { + "type": "string", + "minLength": 1, + "description": format!( + "Name of the agent to delegate to. Available: {}", + if agent_names.is_empty() { + "(none configured)".to_string() + } else { + agent_names.join(", ") + } + ) + }, + "prompt": { + "type": "string", + "minLength": 1, + "description": "The task/prompt to send to the sub-agent" + }, + "context": { + "type": "string", + "description": "Optional context to prepend (e.g. relevant code, prior findings)" + } + }, + "required": ["agent", "prompt"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let agent_name = args + .get("agent") + .and_then(|v| v.as_str()) + .map(str::trim) + .ok_or_else(|| anyhow::anyhow!("Missing 'agent' parameter"))?; + + if agent_name.is_empty() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("'agent' parameter must not be empty".into()), + }); + } + + let prompt = args + .get("prompt") + .and_then(|v| v.as_str()) + .map(str::trim) + .ok_or_else(|| anyhow::anyhow!("Missing 'prompt' parameter"))?; + + if prompt.is_empty() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("'prompt' parameter must not be empty".into()), + }); + } + + let context = args + .get("context") + .and_then(|v| v.as_str()) + .map(str::trim) + .unwrap_or(""); + + // Look up agent config + let agent_config = match self.agents.get(agent_name) { + Some(cfg) => cfg, + None => { + let available: Vec<&str> = + self.agents.keys().map(|s: &String| s.as_str()).collect(); + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Unknown agent '{agent_name}'. Available agents: {}", + if available.is_empty() { + "(none configured)".to_string() + } else { + available.join(", ") + } + )), + }); + } + }; + + // Check recursion depth (immutable — set at construction, incremented for sub-agents) + if self.depth >= agent_config.max_depth { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Delegation depth limit reached ({depth}/{max}). \ + Cannot delegate further to prevent infinite loops.", + depth = self.depth, + max = agent_config.max_depth + )), + }); + } + + // Create provider for this agent + let provider_credential_owned = agent_config + .api_key + .clone() + .or_else(|| self.fallback_credential.clone()); + #[allow(clippy::option_as_ref_deref)] + let provider_credential = provider_credential_owned.as_ref().map(String::as_str); + + let provider: Box = + match providers::create_provider(&agent_config.provider, provider_credential) { + Ok(p) => p, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Failed to create provider '{}' for agent '{agent_name}': {e}", + agent_config.provider + )), + }); + } + }; + + // Build the message + let full_prompt = if context.is_empty() { + prompt.to_string() + } else { + format!("[Context]\n{context}\n\n[Task]\n{prompt}") + }; + + let temperature = agent_config.temperature.unwrap_or(0.7); + + // Wrap the provider call in a timeout to prevent indefinite blocking + let result = tokio::time::timeout( + Duration::from_secs(DELEGATE_TIMEOUT_SECS), + provider.chat_with_system( + agent_config.system_prompt.as_deref(), + &full_prompt, + &agent_config.model, + temperature, + ), + ) + .await; + + let result = match result { + Ok(inner) => inner, + Err(_elapsed) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Agent '{agent_name}' timed out after {DELEGATE_TIMEOUT_SECS}s" + )), + }); + } + }; + + match result { + Ok(response) => { + let mut rendered = response; + if rendered.trim().is_empty() { + rendered = "[Empty response]".to_string(); + } + + Ok(ToolResult { + success: true, + output: format!( + "[Agent '{agent_name}' ({provider}/{model})]\n{rendered}", + provider = agent_config.provider, + model = agent_config.model + ), + error: None, + }) + } + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Agent '{agent_name}' failed: {e}",)), + }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn sample_agents() -> HashMap { + let mut agents = HashMap::new(); + agents.insert( + "researcher".to_string(), + DelegateAgentConfig { + provider: "ollama".to_string(), + model: "llama3".to_string(), + system_prompt: Some("You are a research assistant.".to_string()), + api_key: None, + temperature: Some(0.3), + max_depth: 3, + }, + ); + agents.insert( + "coder".to_string(), + DelegateAgentConfig { + provider: "openrouter".to_string(), + model: "anthropic/claude-sonnet-4-20250514".to_string(), + system_prompt: None, + api_key: Some("delegate-test-credential".to_string()), + temperature: None, + max_depth: 2, + }, + ); + agents + } + + #[test] + fn name_and_schema() { + let tool = DelegateTool::new(sample_agents(), None); + assert_eq!(tool.name(), "delegate"); + let schema = tool.parameters_schema(); + assert!(schema["properties"]["agent"].is_object()); + assert!(schema["properties"]["prompt"].is_object()); + assert!(schema["properties"]["context"].is_object()); + let required = schema["required"].as_array().unwrap(); + assert!(required.contains(&json!("agent"))); + assert!(required.contains(&json!("prompt"))); + assert_eq!(schema["additionalProperties"], json!(false)); + assert_eq!(schema["properties"]["agent"]["minLength"], json!(1)); + assert_eq!(schema["properties"]["prompt"]["minLength"], json!(1)); + } + + #[test] + fn description_not_empty() { + let tool = DelegateTool::new(sample_agents(), None); + assert!(!tool.description().is_empty()); + } + + #[test] + fn schema_lists_agent_names() { + let tool = DelegateTool::new(sample_agents(), None); + let schema = tool.parameters_schema(); + let desc = schema["properties"]["agent"]["description"] + .as_str() + .unwrap(); + assert!(desc.contains("researcher") || desc.contains("coder")); + } + + #[tokio::test] + async fn missing_agent_param() { + let tool = DelegateTool::new(sample_agents(), None); + let result = tool.execute(json!({"prompt": "test"})).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn missing_prompt_param() { + let tool = DelegateTool::new(sample_agents(), None); + let result = tool.execute(json!({"agent": "researcher"})).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn unknown_agent_returns_error() { + let tool = DelegateTool::new(sample_agents(), None); + let result = tool + .execute(json!({"agent": "nonexistent", "prompt": "test"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("Unknown agent")); + } + + #[tokio::test] + async fn depth_limit_enforced() { + let tool = DelegateTool::with_depth(sample_agents(), None, 3); + let result = tool + .execute(json!({"agent": "researcher", "prompt": "test"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("depth limit")); + } + + #[tokio::test] + async fn depth_limit_per_agent() { + // coder has max_depth=2, so depth=2 should be blocked + let tool = DelegateTool::with_depth(sample_agents(), None, 2); + let result = tool + .execute(json!({"agent": "coder", "prompt": "test"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("depth limit")); + } + + #[test] + fn empty_agents_schema() { + let tool = DelegateTool::new(HashMap::new(), None); + let schema = tool.parameters_schema(); + let desc = schema["properties"]["agent"]["description"] + .as_str() + .unwrap(); + assert!(desc.contains("none configured")); + } + + #[tokio::test] + async fn invalid_provider_returns_error() { + let mut agents = HashMap::new(); + agents.insert( + "broken".to_string(), + DelegateAgentConfig { + provider: "totally-invalid-provider".to_string(), + model: "model".to_string(), + system_prompt: None, + api_key: None, + temperature: None, + max_depth: 3, + }, + ); + let tool = DelegateTool::new(agents, None); + let result = tool + .execute(json!({"agent": "broken", "prompt": "test"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("Failed to create provider")); + } + + #[tokio::test] + async fn blank_agent_rejected() { + let tool = DelegateTool::new(sample_agents(), None); + let result = tool + .execute(json!({"agent": " ", "prompt": "test"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("must not be empty")); + } + + #[tokio::test] + async fn blank_prompt_rejected() { + let tool = DelegateTool::new(sample_agents(), None); + let result = tool + .execute(json!({"agent": "researcher", "prompt": " \t "})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("must not be empty")); + } + + #[tokio::test] + async fn whitespace_agent_name_trimmed_and_found() { + let tool = DelegateTool::new(sample_agents(), None); + // " researcher " with surrounding whitespace — after trim becomes "researcher" + let result = tool + .execute(json!({"agent": " researcher ", "prompt": "test"})) + .await + .unwrap(); + // Should find "researcher" after trim — will fail at provider level + // since ollama isn't running, but must NOT get "Unknown agent". + assert!( + result.error.is_none() + || !result + .error + .as_deref() + .unwrap_or("") + .contains("Unknown agent") + ); + } +} diff --git a/src/tools/file_read.rs b/src/tools/file_read.rs index 97c46e0..c43bd2e 100644 --- a/src/tools/file_read.rs +++ b/src/tools/file_read.rs @@ -4,6 +4,8 @@ use async_trait::async_trait; use serde_json::json; use std::sync::Arc; +const MAX_FILE_SIZE_BYTES: u64 = 10 * 1024 * 1024; + /// Read file contents with path sandboxing pub struct FileReadTool { security: Arc, @@ -44,6 +46,14 @@ impl Tool for FileReadTool { .and_then(|v| v.as_str()) .ok_or_else(|| anyhow::anyhow!("Missing 'path' parameter"))?; + if self.security.is_rate_limited() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Rate limit exceeded: too many actions in the last hour".into()), + }); + } + // Security check: validate path is within workspace if !self.security.is_path_allowed(path) { return Ok(ToolResult { @@ -53,6 +63,17 @@ impl Tool for FileReadTool { }); } + // Record action BEFORE canonicalization so that every non-trivially-rejected + // request consumes rate limit budget. This prevents attackers from probing + // path existence (via canonicalize errors) without rate limit cost. + if !self.security.record_action() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Rate limit exceeded: action budget exhausted".into()), + }); + } + let full_path = self.security.workspace_dir.join(path); // Resolve path before reading to block symlink escapes. @@ -78,6 +99,29 @@ impl Tool for FileReadTool { }); } + // Check file size AFTER canonicalization to prevent TOCTOU symlink bypass + match tokio::fs::metadata(&resolved_path).await { + Ok(meta) => { + if meta.len() > MAX_FILE_SIZE_BYTES { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "File too large: {} bytes (limit: {MAX_FILE_SIZE_BYTES} bytes)", + meta.len() + )), + }); + } + } + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to read file metadata: {e}")), + }); + } + } + match tokio::fs::read_to_string(&resolved_path).await { Ok(contents) => Ok(ToolResult { success: true, @@ -106,6 +150,19 @@ mod tests { }) } + fn test_security_with( + workspace: std::path::PathBuf, + autonomy: AutonomyLevel, + max_actions_per_hour: u32, + ) -> Arc { + Arc::new(SecurityPolicy { + autonomy, + workspace_dir: workspace, + max_actions_per_hour, + ..SecurityPolicy::default() + }) + } + #[test] fn file_read_name() { let tool = FileReadTool::new(test_security(std::env::temp_dir())); @@ -180,6 +237,50 @@ mod tests { assert!(result.error.as_ref().unwrap().contains("not allowed")); } + #[tokio::test] + async fn file_read_blocks_when_rate_limited() { + let dir = std::env::temp_dir().join("zeroclaw_test_file_read_rate_limited"); + let _ = tokio::fs::remove_dir_all(&dir).await; + tokio::fs::create_dir_all(&dir).await.unwrap(); + tokio::fs::write(dir.join("test.txt"), "hello world") + .await + .unwrap(); + + let tool = FileReadTool::new(test_security_with( + dir.clone(), + AutonomyLevel::Supervised, + 0, + )); + let result = tool.execute(json!({"path": "test.txt"})).await.unwrap(); + + assert!(!result.success); + assert!(result + .error + .as_deref() + .unwrap_or("") + .contains("Rate limit exceeded")); + + let _ = tokio::fs::remove_dir_all(&dir).await; + } + + #[tokio::test] + async fn file_read_allows_readonly_mode() { + let dir = std::env::temp_dir().join("zeroclaw_test_file_read_readonly"); + let _ = tokio::fs::remove_dir_all(&dir).await; + tokio::fs::create_dir_all(&dir).await.unwrap(); + tokio::fs::write(dir.join("test.txt"), "readonly ok") + .await + .unwrap(); + + let tool = FileReadTool::new(test_security_with(dir.clone(), AutonomyLevel::ReadOnly, 20)); + let result = tool.execute(json!({"path": "test.txt"})).await.unwrap(); + + assert!(result.success); + assert_eq!(result.output, "readonly ok"); + + let _ = tokio::fs::remove_dir_all(&dir).await; + } + #[tokio::test] async fn file_read_missing_path_param() { let tool = FileReadTool::new(test_security(std::env::temp_dir())); @@ -255,4 +356,56 @@ mod tests { let _ = tokio::fs::remove_dir_all(&root).await; } + + #[tokio::test] + async fn file_read_nonexistent_consumes_rate_limit_budget() { + let dir = std::env::temp_dir().join("zeroclaw_test_file_read_probe"); + let _ = tokio::fs::remove_dir_all(&dir).await; + tokio::fs::create_dir_all(&dir).await.unwrap(); + + // Allow only 2 actions total + let tool = FileReadTool::new(test_security_with( + dir.clone(), + AutonomyLevel::Supervised, + 2, + )); + + // Both reads fail (file doesn't exist) but should consume budget + let r1 = tool.execute(json!({"path": "nope1.txt"})).await.unwrap(); + assert!(!r1.success); + assert!(r1.error.as_ref().unwrap().contains("Failed to resolve")); + + let r2 = tool.execute(json!({"path": "nope2.txt"})).await.unwrap(); + assert!(!r2.success); + assert!(r2.error.as_ref().unwrap().contains("Failed to resolve")); + + // Third attempt should be rate limited even though file doesn't exist + let r3 = tool.execute(json!({"path": "nope3.txt"})).await.unwrap(); + assert!(!r3.success); + assert!( + r3.error.as_ref().unwrap().contains("Rate limit"), + "Expected rate limit error, got: {:?}", + r3.error + ); + + let _ = tokio::fs::remove_dir_all(&dir).await; + } + + #[tokio::test] + async fn file_read_rejects_oversized_file() { + let dir = std::env::temp_dir().join("zeroclaw_test_file_read_large"); + let _ = tokio::fs::remove_dir_all(&dir).await; + tokio::fs::create_dir_all(&dir).await.unwrap(); + + // Create a file just over 10 MB + let big = vec![b'x'; 10 * 1024 * 1024 + 1]; + tokio::fs::write(dir.join("huge.bin"), &big).await.unwrap(); + + let tool = FileReadTool::new(test_security(dir.clone())); + let result = tool.execute(json!({"path": "huge.bin"})).await.unwrap(); + assert!(!result.success); + assert!(result.error.as_ref().unwrap().contains("File too large")); + + let _ = tokio::fs::remove_dir_all(&dir).await; + } } diff --git a/src/tools/file_write.rs b/src/tools/file_write.rs index 0760a29..620487f 100644 --- a/src/tools/file_write.rs +++ b/src/tools/file_write.rs @@ -53,6 +53,22 @@ impl Tool for FileWriteTool { .and_then(|v| v.as_str()) .ok_or_else(|| anyhow::anyhow!("Missing 'content' parameter"))?; + if !self.security.can_act() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Action blocked: autonomy is read-only".into()), + }); + } + + if self.security.is_rate_limited() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Rate limit exceeded: too many actions in the last hour".into()), + }); + } + // Security check: validate path is within workspace if !self.security.is_path_allowed(path) { return Ok(ToolResult { @@ -64,11 +80,6 @@ impl Tool for FileWriteTool { let full_path = self.security.workspace_dir.join(path); - // Ensure parent directory exists - if let Some(parent) = full_path.parent() { - tokio::fs::create_dir_all(parent).await?; - } - let Some(parent) = full_path.parent() else { return Ok(ToolResult { success: false, @@ -77,7 +88,10 @@ impl Tool for FileWriteTool { }); }; - // Resolve parent before writing to block symlink escapes. + // Ensure parent directory exists + tokio::fs::create_dir_all(parent).await?; + + // Resolve parent AFTER creation to block symlink escapes. let resolved_parent = match tokio::fs::canonicalize(parent).await { Ok(p) => p, Err(e) => { @@ -110,6 +124,28 @@ impl Tool for FileWriteTool { let resolved_target = resolved_parent.join(file_name); + // If the target already exists and is a symlink, refuse to follow it + if let Ok(meta) = tokio::fs::symlink_metadata(&resolved_target).await { + if meta.file_type().is_symlink() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Refusing to write through symlink: {}", + resolved_target.display() + )), + }); + } + } + + if !self.security.record_action() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Rate limit exceeded: action budget exhausted".into()), + }); + } + match tokio::fs::write(&resolved_target, content).await { Ok(()) => Ok(ToolResult { success: true, @@ -138,6 +174,19 @@ mod tests { }) } + fn test_security_with( + workspace: std::path::PathBuf, + autonomy: AutonomyLevel, + max_actions_per_hour: u32, + ) -> Arc { + Arc::new(SecurityPolicy { + autonomy, + workspace_dir: workspace, + max_actions_per_hour, + ..SecurityPolicy::default() + }) + } + #[test] fn file_write_name() { let tool = FileWriteTool::new(test_security(std::env::temp_dir())); @@ -312,4 +361,50 @@ mod tests { let _ = tokio::fs::remove_dir_all(&root).await; } + + #[tokio::test] + async fn file_write_blocks_readonly_mode() { + let dir = std::env::temp_dir().join("zeroclaw_test_file_write_readonly"); + let _ = tokio::fs::remove_dir_all(&dir).await; + tokio::fs::create_dir_all(&dir).await.unwrap(); + + let tool = FileWriteTool::new(test_security_with(dir.clone(), AutonomyLevel::ReadOnly, 20)); + let result = tool + .execute(json!({"path": "out.txt", "content": "should-block"})) + .await + .unwrap(); + + assert!(!result.success); + assert!(result.error.as_deref().unwrap_or("").contains("read-only")); + assert!(!dir.join("out.txt").exists()); + + let _ = tokio::fs::remove_dir_all(&dir).await; + } + + #[tokio::test] + async fn file_write_blocks_when_rate_limited() { + let dir = std::env::temp_dir().join("zeroclaw_test_file_write_rate_limited"); + let _ = tokio::fs::remove_dir_all(&dir).await; + tokio::fs::create_dir_all(&dir).await.unwrap(); + + let tool = FileWriteTool::new(test_security_with( + dir.clone(), + AutonomyLevel::Supervised, + 0, + )); + let result = tool + .execute(json!({"path": "out.txt", "content": "should-block"})) + .await + .unwrap(); + + assert!(!result.success); + assert!(result + .error + .as_deref() + .unwrap_or("") + .contains("Rate limit exceeded")); + assert!(!dir.join("out.txt").exists()); + + let _ = tokio::fs::remove_dir_all(&dir).await; + } } diff --git a/src/tools/git_operations.rs b/src/tools/git_operations.rs new file mode 100644 index 0000000..21440ba --- /dev/null +++ b/src/tools/git_operations.rs @@ -0,0 +1,769 @@ +use super::traits::{Tool, ToolResult}; +use crate::security::{AutonomyLevel, SecurityPolicy}; +use async_trait::async_trait; +use serde_json::json; +use std::sync::Arc; + +/// Git operations tool for structured repository management. +/// Provides safe, parsed git operations with JSON output. +pub struct GitOperationsTool { + security: Arc, + workspace_dir: std::path::PathBuf, +} + +impl GitOperationsTool { + pub fn new(security: Arc, workspace_dir: std::path::PathBuf) -> Self { + Self { + security, + workspace_dir, + } + } + + /// Sanitize git arguments to prevent injection attacks + fn sanitize_git_args(&self, args: &str) -> anyhow::Result> { + let mut result = Vec::new(); + for arg in args.split_whitespace() { + // Block dangerous git options that could lead to command injection + let arg_lower = arg.to_lowercase(); + if arg_lower.starts_with("--exec=") + || arg_lower.starts_with("--upload-pack=") + || arg_lower.starts_with("--receive-pack=") + || arg_lower.starts_with("--pager=") + || arg_lower.starts_with("--editor=") + || arg_lower == "--no-verify" + || arg_lower.contains("$(") + || arg_lower.contains('`') + || arg.contains('|') + || arg.contains(';') + || arg.contains('>') + { + anyhow::bail!("Blocked potentially dangerous git argument: {arg}"); + } + // Block `-c` config injection (exact match or `-c=...` prefix). + // This must not false-positive on `--cached` or `-cached`. + if arg_lower == "-c" || arg_lower.starts_with("-c=") { + anyhow::bail!("Blocked potentially dangerous git argument: {arg}"); + } + result.push(arg.to_string()); + } + Ok(result) + } + + /// Check if an operation requires write access + fn requires_write_access(&self, operation: &str) -> bool { + matches!( + operation, + "commit" | "add" | "checkout" | "branch" | "stash" | "reset" | "revert" + ) + } + + /// Check if an operation is read-only + fn is_read_only(&self, operation: &str) -> bool { + matches!( + operation, + "status" | "diff" | "log" | "show" | "branch" | "rev-parse" + ) + } + + async fn run_git_command(&self, args: &[&str]) -> anyhow::Result { + let output = tokio::process::Command::new("git") + .args(args) + .current_dir(&self.workspace_dir) + .output() + .await?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + anyhow::bail!("Git command failed: {stderr}"); + } + + Ok(String::from_utf8_lossy(&output.stdout).to_string()) + } + + async fn git_status(&self, _args: serde_json::Value) -> anyhow::Result { + let output = self + .run_git_command(&["status", "--porcelain=2", "--branch"]) + .await?; + + // Parse git status output into structured format + let mut result = serde_json::Map::new(); + let mut branch = String::new(); + let mut staged = Vec::new(); + let mut unstaged = Vec::new(); + let mut untracked = Vec::new(); + + for line in output.lines() { + if line.starts_with("# branch.head ") { + branch = line.trim_start_matches("# branch.head ").to_string(); + } else if let Some(rest) = line.strip_prefix("1 ") { + // Ordinary changed entry + let mut parts = rest.splitn(3, ' '); + if let (Some(staging), Some(path)) = (parts.next(), parts.next()) { + if !staging.is_empty() { + let status_char = staging.chars().next().unwrap_or(' '); + if status_char != '.' && status_char != ' ' { + staged.push(json!({"path": path, "status": status_char})); + } + let status_char = staging.chars().nth(1).unwrap_or(' '); + if status_char != '.' && status_char != ' ' { + unstaged.push(json!({"path": path, "status": status_char})); + } + } + } + } else if let Some(rest) = line.strip_prefix("? ") { + untracked.push(rest.to_string()); + } + } + + result.insert("branch".to_string(), json!(branch)); + result.insert("staged".to_string(), json!(staged)); + result.insert("unstaged".to_string(), json!(unstaged)); + result.insert("untracked".to_string(), json!(untracked)); + result.insert( + "clean".to_string(), + json!(staged.is_empty() && unstaged.is_empty() && untracked.is_empty()), + ); + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&result).unwrap_or_default(), + error: None, + }) + } + + async fn git_diff(&self, args: serde_json::Value) -> anyhow::Result { + let files = args.get("files").and_then(|v| v.as_str()).unwrap_or("."); + let cached = args + .get("cached") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + + // Validate files argument against injection patterns + self.sanitize_git_args(files)?; + + let mut git_args = vec!["diff", "--unified=3"]; + if cached { + git_args.push("--cached"); + } + git_args.push("--"); + git_args.push(files); + + let output = self.run_git_command(&git_args).await?; + + // Parse diff into structured hunks + let mut result = serde_json::Map::new(); + let mut hunks = Vec::new(); + let mut current_file = String::new(); + let mut current_hunk = serde_json::Map::new(); + let mut lines = Vec::new(); + + for line in output.lines() { + if line.starts_with("diff --git ") { + if !lines.is_empty() { + current_hunk.insert("lines".to_string(), json!(lines)); + if !current_hunk.is_empty() { + hunks.push(serde_json::Value::Object(current_hunk.clone())); + } + lines = Vec::new(); + current_hunk = serde_json::Map::new(); + } + let parts: Vec<&str> = line.split_whitespace().collect(); + if parts.len() >= 4 { + current_file = parts[3].trim_start_matches("b/").to_string(); + current_hunk.insert("file".to_string(), json!(current_file)); + } + } else if line.starts_with("@@ ") { + if !lines.is_empty() { + current_hunk.insert("lines".to_string(), json!(lines)); + if !current_hunk.is_empty() { + hunks.push(serde_json::Value::Object(current_hunk.clone())); + } + lines = Vec::new(); + current_hunk = serde_json::Map::new(); + current_hunk.insert("file".to_string(), json!(current_file)); + } + current_hunk.insert("header".to_string(), json!(line)); + } else if !line.is_empty() { + lines.push(json!({ + "text": line, + "type": if line.starts_with('+') { "add" } + else if line.starts_with('-') { "delete" } + else { "context" } + })); + } + } + + if !lines.is_empty() { + current_hunk.insert("lines".to_string(), json!(lines)); + if !current_hunk.is_empty() { + hunks.push(serde_json::Value::Object(current_hunk)); + } + } + + result.insert("hunks".to_string(), json!(hunks)); + result.insert("file_count".to_string(), json!(hunks.len())); + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&result).unwrap_or_default(), + error: None, + }) + } + + async fn git_log(&self, args: serde_json::Value) -> anyhow::Result { + let limit_raw = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(10); + let limit = usize::try_from(limit_raw).unwrap_or(usize::MAX).min(1000); + let limit_str = limit.to_string(); + + let output = self + .run_git_command(&[ + "log", + &format!("-{limit_str}"), + "--pretty=format:%H|%an|%ae|%ad|%s", + "--date=iso", + ]) + .await?; + + let mut commits = Vec::new(); + + for line in output.lines() { + let parts: Vec<&str> = line.split('|').collect(); + if parts.len() >= 5 { + commits.push(json!({ + "hash": parts[0], + "author": parts[1], + "email": parts[2], + "date": parts[3], + "message": parts[4] + })); + } + } + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&json!({ "commits": commits })) + .unwrap_or_default(), + error: None, + }) + } + + async fn git_branch(&self, _args: serde_json::Value) -> anyhow::Result { + let output = self + .run_git_command(&["branch", "--format=%(refname:short)|%(HEAD)"]) + .await?; + + let mut branches = Vec::new(); + let mut current = String::new(); + + for line in output.lines() { + if let Some((name, head)) = line.split_once('|') { + let is_current = head == "*"; + if is_current { + current = name.to_string(); + } + branches.push(json!({ + "name": name, + "current": is_current + })); + } + } + + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&json!({ + "current": current, + "branches": branches + })) + .unwrap_or_default(), + error: None, + }) + } + + fn truncate_commit_message(message: &str) -> String { + if message.chars().count() > 2000 { + format!("{}...", message.chars().take(1997).collect::()) + } else { + message.to_string() + } + } + + async fn git_commit(&self, args: serde_json::Value) -> anyhow::Result { + let message = args + .get("message") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'message' parameter"))?; + + // Sanitize commit message + let sanitized = message + .lines() + .map(|l| l.trim()) + .filter(|l| !l.is_empty()) + .collect::>() + .join("\n"); + + if sanitized.is_empty() { + anyhow::bail!("Commit message cannot be empty"); + } + + // Limit message length + let message = Self::truncate_commit_message(&sanitized); + + let output = self.run_git_command(&["commit", "-m", &message]).await; + + match output { + Ok(_) => Ok(ToolResult { + success: true, + output: format!("Committed: {message}"), + error: None, + }), + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Commit failed: {e}")), + }), + } + } + + async fn git_add(&self, args: serde_json::Value) -> anyhow::Result { + let paths = args + .get("paths") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'paths' parameter"))?; + + // Validate paths against injection patterns + self.sanitize_git_args(paths)?; + + let output = self.run_git_command(&["add", "--", paths]).await; + + match output { + Ok(_) => Ok(ToolResult { + success: true, + output: format!("Staged: {paths}"), + error: None, + }), + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Add failed: {e}")), + }), + } + } + + async fn git_checkout(&self, args: serde_json::Value) -> anyhow::Result { + let branch = args + .get("branch") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'branch' parameter"))?; + + // Sanitize branch name + let sanitized = self.sanitize_git_args(branch)?; + + if sanitized.is_empty() || sanitized.len() > 1 { + anyhow::bail!("Invalid branch specification"); + } + + let branch_name = &sanitized[0]; + + // Block dangerous branch names + if branch_name.contains('@') || branch_name.contains('^') || branch_name.contains('~') { + anyhow::bail!("Branch name contains invalid characters"); + } + + let output = self.run_git_command(&["checkout", branch_name]).await; + + match output { + Ok(_) => Ok(ToolResult { + success: true, + output: format!("Switched to branch: {branch_name}"), + error: None, + }), + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Checkout failed: {e}")), + }), + } + } + + async fn git_stash(&self, args: serde_json::Value) -> anyhow::Result { + let action = args + .get("action") + .and_then(|v| v.as_str()) + .unwrap_or("push"); + + let output = match action { + "push" | "save" => { + self.run_git_command(&["stash", "push", "-m", "auto-stash"]) + .await + } + "pop" => self.run_git_command(&["stash", "pop"]).await, + "list" => self.run_git_command(&["stash", "list"]).await, + "drop" => { + let index_raw = args.get("index").and_then(|v| v.as_u64()).unwrap_or(0); + let index = i32::try_from(index_raw) + .map_err(|_| anyhow::anyhow!("stash index too large: {index_raw}"))?; + self.run_git_command(&["stash", "drop", &format!("stash@{{{index}}}")]) + .await + } + _ => anyhow::bail!("Unknown stash action: {action}. Use: push, pop, list, drop"), + }; + + match output { + Ok(out) => Ok(ToolResult { + success: true, + output: out, + error: None, + }), + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Stash {action} failed: {e}")), + }), + } + } +} + +#[async_trait] +impl Tool for GitOperationsTool { + fn name(&self) -> &str { + "git_operations" + } + + fn description(&self) -> &str { + "Perform structured Git operations (status, diff, log, branch, commit, add, checkout, stash). Provides parsed JSON output and integrates with security policy for autonomy controls." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "operation": { + "type": "string", + "enum": ["status", "diff", "log", "branch", "commit", "add", "checkout", "stash"], + "description": "Git operation to perform" + }, + "message": { + "type": "string", + "description": "Commit message (for 'commit' operation)" + }, + "paths": { + "type": "string", + "description": "File paths to stage (for 'add' operation)" + }, + "branch": { + "type": "string", + "description": "Branch name (for 'checkout' operation)" + }, + "files": { + "type": "string", + "description": "File or path to diff (for 'diff' operation, default: '.')" + }, + "cached": { + "type": "boolean", + "description": "Show staged changes (for 'diff' operation)" + }, + "limit": { + "type": "integer", + "description": "Number of log entries (for 'log' operation, default: 10)" + }, + "action": { + "type": "string", + "enum": ["push", "pop", "list", "drop"], + "description": "Stash action (for 'stash' operation)" + }, + "index": { + "type": "integer", + "description": "Stash index (for 'stash' with 'drop' action)" + } + }, + "required": ["operation"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let operation = match args.get("operation").and_then(|v| v.as_str()) { + Some(op) => op, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Missing 'operation' parameter".into()), + }); + } + }; + + // Check if we're in a git repository + if !self.workspace_dir.join(".git").exists() { + // Try to find .git in parent directories + let mut current_dir = self.workspace_dir.as_path(); + let mut found_git = false; + while current_dir.parent().is_some() { + if current_dir.join(".git").exists() { + found_git = true; + break; + } + current_dir = current_dir.parent().unwrap(); + } + + if !found_git { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Not in a git repository".into()), + }); + } + } + + // Check autonomy level for write operations + if self.requires_write_access(operation) { + if !self.security.can_act() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some( + "Action blocked: git write operations require higher autonomy level".into(), + ), + }); + } + + match self.security.autonomy { + AutonomyLevel::ReadOnly => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Action blocked: read-only mode".into()), + }); + } + AutonomyLevel::Supervised | AutonomyLevel::Full => {} + } + } + + // Record action for rate limiting + if !self.security.record_action() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Action blocked: rate limit exceeded".into()), + }); + } + + // Execute the requested operation + match operation { + "status" => self.git_status(args).await, + "diff" => self.git_diff(args).await, + "log" => self.git_log(args).await, + "branch" => self.git_branch(args).await, + "commit" => self.git_commit(args).await, + "add" => self.git_add(args).await, + "checkout" => self.git_checkout(args).await, + "stash" => self.git_stash(args).await, + _ => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Unknown operation: {operation}")), + }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::security::SecurityPolicy; + use tempfile::TempDir; + + fn test_tool(dir: &std::path::Path) -> GitOperationsTool { + let security = Arc::new(SecurityPolicy { + autonomy: AutonomyLevel::Supervised, + ..SecurityPolicy::default() + }); + GitOperationsTool::new(security, dir.to_path_buf()) + } + + #[test] + fn sanitize_git_blocks_injection() { + let tmp = TempDir::new().unwrap(); + let tool = test_tool(tmp.path()); + + // Should block dangerous arguments + assert!(tool.sanitize_git_args("--exec=rm -rf /").is_err()); + assert!(tool.sanitize_git_args("$(echo pwned)").is_err()); + assert!(tool.sanitize_git_args("`malicious`").is_err()); + assert!(tool.sanitize_git_args("arg | cat").is_err()); + assert!(tool.sanitize_git_args("arg; rm file").is_err()); + } + + #[test] + fn sanitize_git_blocks_pager_editor_injection() { + let tmp = TempDir::new().unwrap(); + let tool = test_tool(tmp.path()); + + assert!(tool.sanitize_git_args("--pager=less").is_err()); + assert!(tool.sanitize_git_args("--editor=vim").is_err()); + } + + #[test] + fn sanitize_git_blocks_config_injection() { + let tmp = TempDir::new().unwrap(); + let tool = test_tool(tmp.path()); + + // Exact `-c` flag (config injection) + assert!(tool.sanitize_git_args("-c core.sshCommand=evil").is_err()); + assert!(tool.sanitize_git_args("-c=core.pager=less").is_err()); + } + + #[test] + fn sanitize_git_blocks_no_verify() { + let tmp = TempDir::new().unwrap(); + let tool = test_tool(tmp.path()); + + assert!(tool.sanitize_git_args("--no-verify").is_err()); + } + + #[test] + fn sanitize_git_blocks_redirect_in_args() { + let tmp = TempDir::new().unwrap(); + let tool = test_tool(tmp.path()); + + assert!(tool.sanitize_git_args("file.txt > /tmp/out").is_err()); + } + + #[test] + fn sanitize_git_cached_not_blocked() { + let tmp = TempDir::new().unwrap(); + let tool = test_tool(tmp.path()); + + // --cached must NOT be blocked by the `-c` check + assert!(tool.sanitize_git_args("--cached").is_ok()); + // Other safe flags starting with -c prefix + assert!(tool.sanitize_git_args("-cached").is_ok()); + } + + #[test] + fn sanitize_git_allows_safe() { + let tmp = TempDir::new().unwrap(); + let tool = test_tool(tmp.path()); + + // Should allow safe arguments + assert!(tool.sanitize_git_args("main").is_ok()); + assert!(tool.sanitize_git_args("feature/test-branch").is_ok()); + assert!(tool.sanitize_git_args("--cached").is_ok()); + assert!(tool.sanitize_git_args("src/main.rs").is_ok()); + assert!(tool.sanitize_git_args(".").is_ok()); + } + + #[test] + fn requires_write_detection() { + let tmp = TempDir::new().unwrap(); + let tool = test_tool(tmp.path()); + + assert!(tool.requires_write_access("commit")); + assert!(tool.requires_write_access("add")); + assert!(tool.requires_write_access("checkout")); + + assert!(!tool.requires_write_access("status")); + assert!(!tool.requires_write_access("diff")); + assert!(!tool.requires_write_access("log")); + } + + #[test] + fn is_read_only_detection() { + let tmp = TempDir::new().unwrap(); + let tool = test_tool(tmp.path()); + + assert!(tool.is_read_only("status")); + assert!(tool.is_read_only("diff")); + assert!(tool.is_read_only("log")); + + assert!(!tool.is_read_only("commit")); + assert!(!tool.is_read_only("add")); + } + + #[tokio::test] + async fn blocks_readonly_mode_for_write_ops() { + let tmp = TempDir::new().unwrap(); + // Initialize a git repository + std::process::Command::new("git") + .args(["init"]) + .current_dir(tmp.path()) + .output() + .unwrap(); + + let security = Arc::new(SecurityPolicy { + autonomy: AutonomyLevel::ReadOnly, + ..SecurityPolicy::default() + }); + let tool = GitOperationsTool::new(security, tmp.path().to_path_buf()); + + let result = tool + .execute(json!({"operation": "commit", "message": "test"})) + .await + .unwrap(); + assert!(!result.success); + // can_act() returns false for ReadOnly, so we get the "higher autonomy level" message + assert!(result + .error + .as_deref() + .unwrap_or("") + .contains("higher autonomy")); + } + + #[tokio::test] + async fn allows_readonly_ops_in_readonly_mode() { + let tmp = TempDir::new().unwrap(); + let security = Arc::new(SecurityPolicy { + autonomy: AutonomyLevel::ReadOnly, + ..SecurityPolicy::default() + }); + let tool = GitOperationsTool::new(security, tmp.path().to_path_buf()); + + // This will fail because there's no git repo, but it shouldn't be blocked by autonomy + let result = tool.execute(json!({"operation": "status"})).await.unwrap(); + // The error should be about not being in a git repo, not about read-only mode + let error_msg = result.error.as_deref().unwrap_or(""); + assert!(error_msg.contains("git repository") || error_msg.contains("Git command failed")); + } + + #[tokio::test] + async fn rejects_missing_operation() { + let tmp = TempDir::new().unwrap(); + let tool = test_tool(tmp.path()); + + let result = tool.execute(json!({})).await.unwrap(); + assert!(!result.success); + assert!(result + .error + .as_deref() + .unwrap_or("") + .contains("Missing 'operation'")); + } + + #[tokio::test] + async fn rejects_unknown_operation() { + let tmp = TempDir::new().unwrap(); + // Initialize a git repository + std::process::Command::new("git") + .args(["init"]) + .current_dir(tmp.path()) + .output() + .unwrap(); + + let tool = test_tool(tmp.path()); + + let result = tool.execute(json!({"operation": "push"})).await.unwrap(); + assert!(!result.success); + assert!(result + .error + .as_deref() + .unwrap_or("") + .contains("Unknown operation")); + } + + #[test] + fn truncates_multibyte_commit_message_without_panicking() { + let long = "🦀".repeat(2500); + let truncated = GitOperationsTool::truncate_commit_message(&long); + + assert_eq!(truncated.chars().count(), 2000); + } +} diff --git a/src/tools/hardware_board_info.rs b/src/tools/hardware_board_info.rs new file mode 100644 index 0000000..73b30fc --- /dev/null +++ b/src/tools/hardware_board_info.rs @@ -0,0 +1,208 @@ +//! Hardware board info tool — returns chip name, architecture, memory map for Telegram/agent. +//! +//! Use when user asks "what board do I have?", "board info", "connected hardware", etc. +//! Uses probe-rs for Nucleo when available; otherwise static datasheet info. + +use super::traits::{Tool, ToolResult}; +use async_trait::async_trait; +use serde_json::json; + +/// Static board info (datasheets). Used when probe-rs is unavailable. +const BOARD_INFO: &[(&str, &str, &str)] = &[ + ( + "nucleo-f401re", + "STM32F401RET6", + "ARM Cortex-M4, 84 MHz. Flash: 512 KB, RAM: 128 KB. User LED on PA5 (pin 13).", + ), + ( + "nucleo-f411re", + "STM32F411RET6", + "ARM Cortex-M4, 100 MHz. Flash: 512 KB, RAM: 128 KB. User LED on PA5 (pin 13).", + ), + ( + "arduino-uno", + "ATmega328P", + "8-bit AVR, 16 MHz. Flash: 16 KB, SRAM: 2 KB. Built-in LED on pin 13.", + ), + ( + "arduino-uno-q", + "STM32U585 + Qualcomm", + "Dual-core: STM32 (MCU) + Linux (aarch64). GPIO via Bridge app on port 9999.", + ), + ( + "esp32", + "ESP32", + "Dual-core Xtensa LX6, 240 MHz. Flash: 4 MB typical. Built-in LED on GPIO 2.", + ), + ( + "rpi-gpio", + "Raspberry Pi", + "ARM Linux. Native GPIO via sysfs/rppal. No fixed LED pin.", + ), +]; + +/// Tool: return full board info (chip, architecture, memory map) for agent/Telegram. +pub struct HardwareBoardInfoTool { + boards: Vec, +} + +impl HardwareBoardInfoTool { + pub fn new(boards: Vec) -> Self { + Self { boards } + } + + fn static_info_for_board(&self, board: &str) -> Option { + BOARD_INFO + .iter() + .find(|(b, _, _)| *b == board) + .map(|(_, chip, desc)| { + format!( + "**Board:** {}\n**Chip:** {}\n**Description:** {}", + board, chip, desc + ) + }) + } +} + +#[async_trait] +impl Tool for HardwareBoardInfoTool { + fn name(&self) -> &str { + "hardware_board_info" + } + + fn description(&self) -> &str { + "Return full board info (chip, architecture, memory map) for connected hardware. Use when: user asks for 'board info', 'what board do I have', 'connected hardware', 'chip info', 'what hardware', or 'memory map'." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "board": { + "type": "string", + "description": "Optional board name (e.g. nucleo-f401re). If omitted, returns info for first configured board." + } + } + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let board = args + .get("board") + .and_then(|v| v.as_str()) + .map(String::from) + .or_else(|| self.boards.first().cloned()); + + let board = board.as_deref().unwrap_or("unknown"); + + if self.boards.is_empty() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some( + "No peripherals configured. Add boards to config.toml [peripherals.boards]." + .into(), + ), + }); + } + + let mut output = String::new(); + + #[cfg(feature = "probe")] + if board == "nucleo-f401re" || board == "nucleo-f411re" { + let chip = if board == "nucleo-f411re" { + "STM32F411RETx" + } else { + "STM32F401RETx" + }; + match probe_board_info(chip) { + Ok(info) => { + return Ok(ToolResult { + success: true, + output: info, + error: None, + }); + } + Err(e) => { + use std::fmt::Write; + let _ = write!( + output, + "probe-rs attach failed: {e}. Using static info.\n\n" + ); + } + } + } + + if let Some(info) = self.static_info_for_board(board) { + output.push_str(&info); + if let Some(mem) = memory_map_static(board) { + use std::fmt::Write; + let _ = write!(output, "\n\n**Memory map:**\n{mem}"); + } + } else { + use std::fmt::Write; + let _ = write!( + output, + "Board '{board}' configured. No static info available." + ); + } + + Ok(ToolResult { + success: true, + output, + error: None, + }) + } +} + +#[cfg(feature = "probe")] +fn probe_board_info(chip: &str) -> anyhow::Result { + use probe_rs::config::MemoryRegion; + use probe_rs::{Session, SessionConfig}; + + let session = Session::auto_attach(chip, SessionConfig::default()) + .map_err(|e| anyhow::anyhow!("{}", e))?; + let target = session.target(); + let arch = session.architecture(); + + let mut out = format!( + "**Board:** {}\n**Chip:** {}\n**Architecture:** {:?}\n\n**Memory map:**\n", + chip, target.name, arch + ); + for region in target.memory_map.iter() { + match region { + MemoryRegion::Ram(ram) => { + let (start, end) = (ram.range.start, ram.range.end); + out.push_str(&format!( + "RAM: 0x{:08X} - 0x{:08X} ({} KB)\n", + start, + end, + (end - start) / 1024 + )); + } + MemoryRegion::Nvm(flash) => { + let (start, end) = (flash.range.start, flash.range.end); + out.push_str(&format!( + "Flash: 0x{:08X} - 0x{:08X} ({} KB)\n", + start, + end, + (end - start) / 1024 + )); + } + _ => {} + } + } + out.push_str("\n(Info read via USB/SWD — no firmware on target needed.)"); + Ok(out) +} + +fn memory_map_static(board: &str) -> Option<&'static str> { + match board { + "nucleo-f401re" | "nucleo-f411re" => Some( + "Flash: 0x0800_0000 - 0x0807_FFFF (512 KB)\nRAM: 0x2000_0000 - 0x2001_FFFF (128 KB)", + ), + "arduino-uno" => Some("Flash: 16 KB, SRAM: 2 KB, EEPROM: 1 KB"), + "esp32" => Some("Flash: 4 MB, IRAM/DRAM per ESP-IDF layout"), + _ => None, + } +} diff --git a/src/tools/hardware_memory_map.rs b/src/tools/hardware_memory_map.rs new file mode 100644 index 0000000..41fd07b --- /dev/null +++ b/src/tools/hardware_memory_map.rs @@ -0,0 +1,207 @@ +//! Hardware memory map tool — returns flash/RAM address ranges for connected boards. +//! +//! Phase B: When user asks "what are the upper and lower memory addresses?", this tool +//! returns the memory map. Uses probe-rs for Nucleo/STM32 when available; otherwise +//! returns static maps from datasheets. + +use super::traits::{Tool, ToolResult}; +use async_trait::async_trait; +use serde_json::json; + +/// Known memory maps (from datasheets). Used when probe-rs is unavailable. +const MEMORY_MAPS: &[(&str, &str)] = &[ + ( + "nucleo-f401re", + "Flash: 0x0800_0000 - 0x0807_FFFF (512 KB)\nRAM: 0x2000_0000 - 0x2001_FFFF (128 KB)\nSTM32F401RET6, ARM Cortex-M4", + ), + ( + "nucleo-f411re", + "Flash: 0x0800_0000 - 0x0807_FFFF (512 KB)\nRAM: 0x2000_0000 - 0x2001_FFFF (128 KB)\nSTM32F411RET6, ARM Cortex-M4", + ), + ( + "arduino-uno", + "Flash: 0x0000 - 0x3FFF (16 KB, ATmega328P)\nSRAM: 0x0100 - 0x08FF (2 KB)\nEEPROM: 0x0000 - 0x03FF (1 KB)", + ), + ( + "arduino-mega", + "Flash: 0x0000 - 0x3FFFF (256 KB, ATmega2560)\nSRAM: 0x0200 - 0x21FF (8 KB)\nEEPROM: 0x0000 - 0x0FFF (4 KB)", + ), + ( + "esp32", + "Flash: 0x3F40_0000 - 0x3F7F_FFFF (4 MB typical)\nIRAM: 0x4000_0000 - 0x4005_FFFF\nDRAM: 0x3FFB_0000 - 0x3FFF_FFFF", + ), +]; + +/// Tool: report hardware memory map for connected boards. +pub struct HardwareMemoryMapTool { + boards: Vec, +} + +impl HardwareMemoryMapTool { + pub fn new(boards: Vec) -> Self { + Self { boards } + } + + fn static_map_for_board(&self, board: &str) -> Option<&'static str> { + MEMORY_MAPS + .iter() + .find(|(b, _)| *b == board) + .map(|(_, m)| *m) + } +} + +#[async_trait] +impl Tool for HardwareMemoryMapTool { + fn name(&self) -> &str { + "hardware_memory_map" + } + + fn description(&self) -> &str { + "Return the memory map (flash and RAM address ranges) for connected hardware. Use when: user asks for 'upper and lower memory addresses', 'memory map', 'address space', or 'readable addresses'. Returns flash/RAM ranges from datasheets." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "board": { + "type": "string", + "description": "Optional board name (e.g. nucleo-f401re, arduino-uno). If omitted, returns map for first configured board." + } + } + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let board = args + .get("board") + .and_then(|v| v.as_str()) + .map(String::from) + .or_else(|| self.boards.first().cloned()); + + let board = board.as_deref().unwrap_or("unknown"); + + if self.boards.is_empty() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some( + "No peripherals configured. Add boards to config.toml [peripherals.boards]." + .into(), + ), + }); + } + + let mut output = String::new(); + + #[cfg(feature = "probe")] + let probe_ok = { + if board == "nucleo-f401re" || board == "nucleo-f411re" { + let chip = if board == "nucleo-f411re" { + "STM32F411RETx" + } else { + "STM32F401RETx" + }; + match probe_rs_memory_map(chip) { + Ok(probe_msg) => { + output.push_str(&format!("**{}** (via probe-rs):\n{}\n", board, probe_msg)); + true + } + Err(e) => { + output.push_str(&format!("Probe-rs failed: {}. ", e)); + false + } + } + } else { + false + } + }; + + #[cfg(not(feature = "probe"))] + let probe_ok = false; + + if !probe_ok { + if let Some(map) = self.static_map_for_board(board) { + use std::fmt::Write; + let _ = write!(output, "**{board}** (from datasheet):\n{map}"); + } else { + use std::fmt::Write; + let known: Vec<&str> = MEMORY_MAPS.iter().map(|(b, _)| *b).collect(); + let _ = write!( + output, + "No memory map for board '{board}'. Known boards: {}", + known.join(", ") + ); + } + } + + Ok(ToolResult { + success: true, + output, + error: None, + }) + } +} + +#[cfg(feature = "probe")] +fn probe_rs_memory_map(chip: &str) -> anyhow::Result { + use probe_rs::config::MemoryRegion; + use probe_rs::{Session, SessionConfig}; + + let session = Session::auto_attach(chip, SessionConfig::default()) + .map_err(|e| anyhow::anyhow!("probe-rs attach failed: {}", e))?; + + let target = session.target(); + let mut out = String::new(); + + for region in target.memory_map.iter() { + match region { + MemoryRegion::Ram(ram) => { + let start = ram.range.start; + let end = ram.range.end; + let size_kb = (end - start) / 1024; + out.push_str(&format!( + "RAM: 0x{:08X} - 0x{:08X} ({} KB)\n", + start, end, size_kb + )); + } + MemoryRegion::Nvm(flash) => { + let start = flash.range.start; + let end = flash.range.end; + let size_kb = (end - start) / 1024; + out.push_str(&format!( + "Flash: 0x{:08X} - 0x{:08X} ({} KB)\n", + start, end, size_kb + )); + } + _ => {} + } + } + + if out.is_empty() { + out = "Could not read memory regions from probe.".to_string(); + } + + Ok(out) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn static_map_nucleo() { + let tool = HardwareMemoryMapTool::new(vec!["nucleo-f401re".into()]); + assert!(tool.static_map_for_board("nucleo-f401re").is_some()); + assert!(tool + .static_map_for_board("nucleo-f401re") + .unwrap() + .contains("Flash")); + } + + #[test] + fn static_map_arduino() { + let tool = HardwareMemoryMapTool::new(vec!["arduino-uno".into()]); + assert!(tool.static_map_for_board("arduino-uno").is_some()); + } +} diff --git a/src/tools/hardware_memory_read.rs b/src/tools/hardware_memory_read.rs new file mode 100644 index 0000000..3232c78 --- /dev/null +++ b/src/tools/hardware_memory_read.rs @@ -0,0 +1,183 @@ +//! Hardware memory read tool — read actual memory/register values from Nucleo via probe-rs. +//! +//! Use when user asks to "read register values", "read memory at address", "dump lower memory", etc. +//! Requires probe feature and Nucleo connected via USB. + +use super::traits::{Tool, ToolResult}; +use async_trait::async_trait; +use serde_json::json; + +/// RAM base for Nucleo-F401RE (STM32F401) +const NUCLEO_RAM_BASE: u64 = 0x2000_0000; + +/// Tool: read memory at address from connected Nucleo via probe-rs. +pub struct HardwareMemoryReadTool { + boards: Vec, +} + +impl HardwareMemoryReadTool { + pub fn new(boards: Vec) -> Self { + Self { boards } + } + + fn chip_for_board(board: &str) -> Option<&'static str> { + match board { + "nucleo-f401re" => Some("STM32F401RETx"), + "nucleo-f411re" => Some("STM32F411RETx"), + _ => None, + } + } +} + +#[async_trait] +impl Tool for HardwareMemoryReadTool { + fn name(&self) -> &str { + "hardware_memory_read" + } + + fn description(&self) -> &str { + "Read actual memory/register values from Nucleo via USB. Use when: user asks to 'read register values', 'read memory at address', 'dump memory', 'lower memory 0-126', or 'give address and value'. Returns hex dump. Requires Nucleo connected via USB and probe feature. Params: address (hex, e.g. 0x20000000 for RAM start), length (bytes, default 128)." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "address": { + "type": "string", + "description": "Memory address in hex (e.g. 0x20000000 for RAM start). Default: 0x20000000 (RAM base)." + }, + "length": { + "type": "integer", + "description": "Number of bytes to read (default 128, max 256)." + }, + "board": { + "type": "string", + "description": "Board name (nucleo-f401re). Optional if only one configured." + } + } + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + if self.boards.is_empty() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some( + "No peripherals configured. Add nucleo-f401re to config.toml [peripherals.boards]." + .into(), + ), + }); + } + + let board = args + .get("board") + .and_then(|v| v.as_str()) + .map(String::from) + .or_else(|| self.boards.first().cloned()) + .unwrap_or_else(|| "nucleo-f401re".into()); + + let chip = Self::chip_for_board(&board); + if chip.is_none() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Memory read only supports nucleo-f401re, nucleo-f411re. Got: {}", + board + )), + }); + } + + let address_str = args + .get("address") + .and_then(|v| v.as_str()) + .unwrap_or("0x20000000"); + let _address = parse_hex_address(address_str).unwrap_or(NUCLEO_RAM_BASE); + + let requested_length = args.get("length").and_then(|v| v.as_u64()).unwrap_or(128); + let _length = usize::try_from(requested_length) + .unwrap_or(256) + .clamp(1, 256); + + #[cfg(feature = "probe")] + { + match probe_read_memory(chip.unwrap(), _address, _length) { + Ok(output) => { + return Ok(ToolResult { + success: true, + output, + error: None, + }); + } + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "probe-rs read failed: {}. Ensure Nucleo is connected via USB and built with --features probe.", + e + )), + }); + } + } + } + + #[cfg(not(feature = "probe"))] + { + Ok(ToolResult { + success: false, + output: String::new(), + error: Some( + "Memory read requires probe feature. Build with: cargo build --features hardware,probe" + .into(), + ), + }) + } + } +} + +fn parse_hex_address(s: &str) -> Option { + let s = s.trim().trim_start_matches("0x").trim_start_matches("0X"); + u64::from_str_radix(s, 16).ok() +} + +#[cfg(feature = "probe")] +fn probe_read_memory(chip: &str, address: u64, length: usize) -> anyhow::Result { + use probe_rs::MemoryInterface; + use probe_rs::Session; + use probe_rs::SessionConfig; + + let mut session = Session::auto_attach(chip, SessionConfig::default()) + .map_err(|e| anyhow::anyhow!("{}", e))?; + + let mut core = session.core(0)?; + let mut buf = vec![0u8; length]; + core.read_8(address, &mut buf) + .map_err(|e| anyhow::anyhow!("{}", e))?; + + // Format as hex dump: address | bytes (16 per line) + let mut out = format!("Memory read from 0x{:08X} ({} bytes):\n\n", address, length); + const COLS: usize = 16; + for (i, chunk) in buf.chunks(COLS).enumerate() { + let addr = address + (i * COLS) as u64; + let hex: String = chunk + .iter() + .map(|b| format!("{:02X}", b)) + .collect::>() + .join(" "); + let ascii: String = chunk + .iter() + .map(|&b| { + if b.is_ascii_graphic() || b == b' ' { + b as char + } else { + '.' + } + }) + .collect(); + out.push_str(&format!("0x{:08X} {:48} {}\n", addr, hex, ascii)); + } + Ok(out) +} diff --git a/src/tools/http_request.rs b/src/tools/http_request.rs new file mode 100644 index 0000000..1d00253 --- /dev/null +++ b/src/tools/http_request.rs @@ -0,0 +1,802 @@ +use super::traits::{Tool, ToolResult}; +use crate::security::SecurityPolicy; +use async_trait::async_trait; +use serde_json::json; +use std::sync::Arc; +use std::time::Duration; + +/// HTTP request tool for API interactions. +/// Supports GET, POST, PUT, DELETE methods with configurable security. +pub struct HttpRequestTool { + security: Arc, + allowed_domains: Vec, + max_response_size: usize, + timeout_secs: u64, +} + +impl HttpRequestTool { + pub fn new( + security: Arc, + allowed_domains: Vec, + max_response_size: usize, + timeout_secs: u64, + ) -> Self { + Self { + security, + allowed_domains: normalize_allowed_domains(allowed_domains), + max_response_size, + timeout_secs, + } + } + + fn validate_url(&self, raw_url: &str) -> anyhow::Result { + let url = raw_url.trim(); + + if url.is_empty() { + anyhow::bail!("URL cannot be empty"); + } + + if url.chars().any(char::is_whitespace) { + anyhow::bail!("URL cannot contain whitespace"); + } + + if !url.starts_with("http://") && !url.starts_with("https://") { + anyhow::bail!("Only http:// and https:// URLs are allowed"); + } + + if self.allowed_domains.is_empty() { + anyhow::bail!( + "HTTP request tool is enabled but no allowed_domains are configured. Add [http_request].allowed_domains in config.toml" + ); + } + + let host = extract_host(url)?; + + if is_private_or_local_host(&host) { + anyhow::bail!("Blocked local/private host: {host}"); + } + + if !host_matches_allowlist(&host, &self.allowed_domains) { + anyhow::bail!("Host '{host}' is not in http_request.allowed_domains"); + } + + Ok(url.to_string()) + } + + fn validate_method(&self, method: &str) -> anyhow::Result { + match method.to_uppercase().as_str() { + "GET" => Ok(reqwest::Method::GET), + "POST" => Ok(reqwest::Method::POST), + "PUT" => Ok(reqwest::Method::PUT), + "DELETE" => Ok(reqwest::Method::DELETE), + "PATCH" => Ok(reqwest::Method::PATCH), + "HEAD" => Ok(reqwest::Method::HEAD), + "OPTIONS" => Ok(reqwest::Method::OPTIONS), + _ => anyhow::bail!("Unsupported HTTP method: {method}. Supported: GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS"), + } + } + + fn parse_headers(&self, headers: &serde_json::Value) -> Vec<(String, String)> { + let mut result = Vec::new(); + if let Some(obj) = headers.as_object() { + for (key, value) in obj { + if let Some(str_val) = value.as_str() { + result.push((key.clone(), str_val.to_string())); + } + } + } + result + } + + fn redact_headers_for_display(headers: &[(String, String)]) -> Vec<(String, String)> { + headers + .iter() + .map(|(key, value)| { + let lower = key.to_lowercase(); + let is_sensitive = lower.contains("authorization") + || lower.contains("api-key") + || lower.contains("apikey") + || lower.contains("token") + || lower.contains("secret"); + if is_sensitive { + (key.clone(), "***REDACTED***".into()) + } else { + (key.clone(), value.clone()) + } + }) + .collect() + } + + async fn execute_request( + &self, + url: &str, + method: reqwest::Method, + headers: Vec<(String, String)>, + body: Option<&str>, + ) -> anyhow::Result { + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(self.timeout_secs)) + .build()?; + + let mut request = client.request(method, url); + + for (key, value) in headers { + request = request.header(&key, &value); + } + + if let Some(body_str) = body { + request = request.body(body_str.to_string()); + } + + Ok(request.send().await?) + } + + fn truncate_response(&self, text: &str) -> String { + if text.len() > self.max_response_size { + let mut truncated = text + .chars() + .take(self.max_response_size) + .collect::(); + truncated.push_str("\n\n... [Response truncated due to size limit] ..."); + truncated + } else { + text.to_string() + } + } +} + +#[async_trait] +impl Tool for HttpRequestTool { + fn name(&self) -> &str { + "http_request" + } + + fn description(&self) -> &str { + "Make HTTP requests to external APIs. Supports GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS methods. \ + Security constraints: allowlist-only domains, no local/private hosts, configurable timeout and response size limits." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "HTTP or HTTPS URL to request" + }, + "method": { + "type": "string", + "description": "HTTP method (GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS)", + "default": "GET" + }, + "headers": { + "type": "object", + "description": "Optional HTTP headers as key-value pairs (e.g., {\"Authorization\": \"Bearer token\", \"Content-Type\": \"application/json\"})", + "default": {} + }, + "body": { + "type": "string", + "description": "Optional request body (for POST, PUT, PATCH requests)" + } + }, + "required": ["url"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let url = args + .get("url") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'url' parameter"))?; + + let method_str = args.get("method").and_then(|v| v.as_str()).unwrap_or("GET"); + let headers_val = args.get("headers").cloned().unwrap_or(json!({})); + let body = args.get("body").and_then(|v| v.as_str()); + + if !self.security.can_act() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Action blocked: autonomy is read-only".into()), + }); + } + + if !self.security.record_action() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Action blocked: rate limit exceeded".into()), + }); + } + + let url = match self.validate_url(url) { + Ok(v) => v, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e.to_string()), + }) + } + }; + + let method = match self.validate_method(method_str) { + Ok(m) => m, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e.to_string()), + }) + } + }; + + let request_headers = self.parse_headers(&headers_val); + + match self + .execute_request(&url, method, request_headers, body) + .await + { + Ok(response) => { + let status = response.status(); + let status_code = status.as_u16(); + + // Get response headers (redact sensitive ones) + let response_headers = response.headers().iter(); + let headers_text = response_headers + .map(|(k, _)| { + let is_sensitive = k.as_str().to_lowercase().contains("set-cookie"); + if is_sensitive { + format!("{}: ***REDACTED***", k.as_str()) + } else { + format!("{}: {:?}", k.as_str(), k.as_str()) + } + }) + .collect::>() + .join(", "); + + // Get response body with size limit + let response_text = match response.text().await { + Ok(text) => self.truncate_response(&text), + Err(e) => format!("[Failed to read response body: {e}]"), + }; + + let output = format!( + "Status: {} {}\nResponse Headers: {}\n\nResponse Body:\n{}", + status_code, + status.canonical_reason().unwrap_or("Unknown"), + headers_text, + response_text + ); + + Ok(ToolResult { + success: status.is_success(), + output, + error: if status.is_client_error() || status.is_server_error() { + Some(format!("HTTP {}", status_code)) + } else { + None + }, + }) + } + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("HTTP request failed: {e}")), + }), + } + } +} + +// Helper functions similar to browser_open.rs + +fn normalize_allowed_domains(domains: Vec) -> Vec { + let mut normalized = domains + .into_iter() + .filter_map(|d| normalize_domain(&d)) + .collect::>(); + normalized.sort_unstable(); + normalized.dedup(); + normalized +} + +fn normalize_domain(raw: &str) -> Option { + let mut d = raw.trim().to_lowercase(); + if d.is_empty() { + return None; + } + + if let Some(stripped) = d.strip_prefix("https://") { + d = stripped.to_string(); + } else if let Some(stripped) = d.strip_prefix("http://") { + d = stripped.to_string(); + } + + if let Some((host, _)) = d.split_once('/') { + d = host.to_string(); + } + + d = d.trim_start_matches('.').trim_end_matches('.').to_string(); + + if let Some((host, _)) = d.split_once(':') { + d = host.to_string(); + } + + if d.is_empty() || d.chars().any(char::is_whitespace) { + return None; + } + + Some(d) +} + +fn extract_host(url: &str) -> anyhow::Result { + let rest = url + .strip_prefix("http://") + .or_else(|| url.strip_prefix("https://")) + .ok_or_else(|| anyhow::anyhow!("Only http:// and https:// URLs are allowed"))?; + + let authority = rest + .split(['/', '?', '#']) + .next() + .ok_or_else(|| anyhow::anyhow!("Invalid URL"))?; + + if authority.is_empty() { + anyhow::bail!("URL must include a host"); + } + + if authority.contains('@') { + anyhow::bail!("URL userinfo is not allowed"); + } + + if authority.starts_with('[') { + anyhow::bail!("IPv6 hosts are not supported in http_request"); + } + + let host = authority + .split(':') + .next() + .unwrap_or_default() + .trim() + .trim_end_matches('.') + .to_lowercase(); + + if host.is_empty() { + anyhow::bail!("URL must include a valid host"); + } + + Ok(host) +} + +fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool { + allowed_domains.iter().any(|domain| { + host == domain + || host + .strip_suffix(domain) + .is_some_and(|prefix| prefix.ends_with('.')) + }) +} + +fn is_private_or_local_host(host: &str) -> bool { + // Strip brackets from IPv6 addresses like [::1] + let bare = host + .strip_prefix('[') + .and_then(|h| h.strip_suffix(']')) + .unwrap_or(host); + + let has_local_tld = bare + .rsplit('.') + .next() + .is_some_and(|label| label == "local"); + + if bare == "localhost" || bare.ends_with(".localhost") || has_local_tld { + return true; + } + + if let Ok(ip) = bare.parse::() { + return match ip { + std::net::IpAddr::V4(v4) => is_non_global_v4(v4), + std::net::IpAddr::V6(v6) => is_non_global_v6(v6), + }; + } + + false +} + +/// Returns true if the IPv4 address is not globally routable. +fn is_non_global_v4(v4: std::net::Ipv4Addr) -> bool { + let [a, b, c, _] = v4.octets(); + v4.is_loopback() // 127.0.0.0/8 + || v4.is_private() // 10/8, 172.16/12, 192.168/16 + || v4.is_link_local() // 169.254.0.0/16 + || v4.is_unspecified() // 0.0.0.0 + || v4.is_broadcast() // 255.255.255.255 + || v4.is_multicast() // 224.0.0.0/4 + || (a == 100 && (64..=127).contains(&b)) // Shared address space (RFC 6598) + || a >= 240 // Reserved (240.0.0.0/4, except broadcast) + || (a == 192 && b == 0 && (c == 0 || c == 2)) // IETF assignments + TEST-NET-1 + || (a == 198 && b == 51) // Documentation (198.51.100.0/24) + || (a == 203 && b == 0) // Documentation (203.0.113.0/24) + || (a == 198 && (18..=19).contains(&b)) // Benchmarking (198.18.0.0/15) +} + +/// Returns true if the IPv6 address is not globally routable. +fn is_non_global_v6(v6: std::net::Ipv6Addr) -> bool { + let segs = v6.segments(); + v6.is_loopback() // ::1 + || v6.is_unspecified() // :: + || v6.is_multicast() // ff00::/8 + || (segs[0] & 0xfe00) == 0xfc00 // Unique-local (fc00::/7) + || (segs[0] & 0xffc0) == 0xfe80 // Link-local (fe80::/10) + || (segs[0] == 0x2001 && segs[1] == 0x0db8) // Documentation (2001:db8::/32) + || v6.to_ipv4_mapped().is_some_and(is_non_global_v4) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::security::{AutonomyLevel, SecurityPolicy}; + + fn test_tool(allowed_domains: Vec<&str>) -> HttpRequestTool { + let security = Arc::new(SecurityPolicy { + autonomy: AutonomyLevel::Supervised, + ..SecurityPolicy::default() + }); + HttpRequestTool::new( + security, + allowed_domains.into_iter().map(String::from).collect(), + 1_000_000, + 30, + ) + } + + #[test] + fn normalize_domain_strips_scheme_path_and_case() { + let got = normalize_domain(" HTTPS://Docs.Example.com/path ").unwrap(); + assert_eq!(got, "docs.example.com"); + } + + #[test] + fn normalize_allowed_domains_deduplicates() { + let got = normalize_allowed_domains(vec![ + "example.com".into(), + "EXAMPLE.COM".into(), + "https://example.com/".into(), + ]); + assert_eq!(got, vec!["example.com".to_string()]); + } + + #[test] + fn validate_accepts_exact_domain() { + let tool = test_tool(vec!["example.com"]); + let got = tool.validate_url("https://example.com/docs").unwrap(); + assert_eq!(got, "https://example.com/docs"); + } + + #[test] + fn validate_accepts_http() { + let tool = test_tool(vec!["example.com"]); + assert!(tool.validate_url("http://example.com").is_ok()); + } + + #[test] + fn validate_accepts_subdomain() { + let tool = test_tool(vec!["example.com"]); + assert!(tool.validate_url("https://api.example.com/v1").is_ok()); + } + + #[test] + fn validate_rejects_allowlist_miss() { + let tool = test_tool(vec!["example.com"]); + let err = tool + .validate_url("https://google.com") + .unwrap_err() + .to_string(); + assert!(err.contains("allowed_domains")); + } + + #[test] + fn validate_rejects_localhost() { + let tool = test_tool(vec!["localhost"]); + let err = tool + .validate_url("https://localhost:8080") + .unwrap_err() + .to_string(); + assert!(err.contains("local/private")); + } + + #[test] + fn validate_rejects_private_ipv4() { + let tool = test_tool(vec!["192.168.1.5"]); + let err = tool + .validate_url("https://192.168.1.5") + .unwrap_err() + .to_string(); + assert!(err.contains("local/private")); + } + + #[test] + fn validate_rejects_whitespace() { + let tool = test_tool(vec!["example.com"]); + let err = tool + .validate_url("https://example.com/hello world") + .unwrap_err() + .to_string(); + assert!(err.contains("whitespace")); + } + + #[test] + fn validate_rejects_userinfo() { + let tool = test_tool(vec!["example.com"]); + let err = tool + .validate_url("https://user@example.com") + .unwrap_err() + .to_string(); + assert!(err.contains("userinfo")); + } + + #[test] + fn validate_requires_allowlist() { + let security = Arc::new(SecurityPolicy::default()); + let tool = HttpRequestTool::new(security, vec![], 1_000_000, 30); + let err = tool + .validate_url("https://example.com") + .unwrap_err() + .to_string(); + assert!(err.contains("allowed_domains")); + } + + #[test] + fn validate_accepts_valid_methods() { + let tool = test_tool(vec!["example.com"]); + assert!(tool.validate_method("GET").is_ok()); + assert!(tool.validate_method("POST").is_ok()); + assert!(tool.validate_method("PUT").is_ok()); + assert!(tool.validate_method("DELETE").is_ok()); + assert!(tool.validate_method("PATCH").is_ok()); + assert!(tool.validate_method("HEAD").is_ok()); + assert!(tool.validate_method("OPTIONS").is_ok()); + } + + #[test] + fn validate_rejects_invalid_method() { + let tool = test_tool(vec!["example.com"]); + let err = tool.validate_method("INVALID").unwrap_err().to_string(); + assert!(err.contains("Unsupported HTTP method")); + } + + #[test] + fn blocks_multicast_ipv4() { + assert!(is_private_or_local_host("224.0.0.1")); + assert!(is_private_or_local_host("239.255.255.255")); + } + + #[test] + fn blocks_broadcast() { + assert!(is_private_or_local_host("255.255.255.255")); + } + + #[test] + fn blocks_reserved_ipv4() { + assert!(is_private_or_local_host("240.0.0.1")); + assert!(is_private_or_local_host("250.1.2.3")); + } + + #[test] + fn blocks_documentation_ranges() { + assert!(is_private_or_local_host("192.0.2.1")); // TEST-NET-1 + assert!(is_private_or_local_host("198.51.100.1")); // TEST-NET-2 + assert!(is_private_or_local_host("203.0.113.1")); // TEST-NET-3 + } + + #[test] + fn blocks_benchmarking_range() { + assert!(is_private_or_local_host("198.18.0.1")); + assert!(is_private_or_local_host("198.19.255.255")); + } + + #[test] + fn blocks_ipv6_localhost() { + assert!(is_private_or_local_host("::1")); + assert!(is_private_or_local_host("[::1]")); + } + + #[test] + fn blocks_ipv6_multicast() { + assert!(is_private_or_local_host("ff02::1")); + } + + #[test] + fn blocks_ipv6_link_local() { + assert!(is_private_or_local_host("fe80::1")); + } + + #[test] + fn blocks_ipv6_unique_local() { + assert!(is_private_or_local_host("fd00::1")); + } + + #[test] + fn blocks_ipv4_mapped_ipv6() { + assert!(is_private_or_local_host("::ffff:127.0.0.1")); + assert!(is_private_or_local_host("::ffff:192.168.1.1")); + assert!(is_private_or_local_host("::ffff:10.0.0.1")); + } + + #[test] + fn allows_public_ipv4() { + assert!(!is_private_or_local_host("8.8.8.8")); + assert!(!is_private_or_local_host("1.1.1.1")); + assert!(!is_private_or_local_host("93.184.216.34")); + } + + #[test] + fn blocks_ipv6_documentation_range() { + assert!(is_private_or_local_host("2001:db8::1")); + } + + #[test] + fn allows_public_ipv6() { + assert!(!is_private_or_local_host("2607:f8b0:4004:800::200e")); + } + + #[test] + fn blocks_shared_address_space() { + assert!(is_private_or_local_host("100.64.0.1")); + assert!(is_private_or_local_host("100.127.255.255")); + assert!(!is_private_or_local_host("100.63.0.1")); // Just below range + assert!(!is_private_or_local_host("100.128.0.1")); // Just above range + } + + #[tokio::test] + async fn execute_blocks_readonly_mode() { + let security = Arc::new(SecurityPolicy { + autonomy: AutonomyLevel::ReadOnly, + ..SecurityPolicy::default() + }); + let tool = HttpRequestTool::new(security, vec!["example.com".into()], 1_000_000, 30); + let result = tool + .execute(json!({"url": "https://example.com"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("read-only")); + } + + #[tokio::test] + async fn execute_blocks_when_rate_limited() { + let security = Arc::new(SecurityPolicy { + max_actions_per_hour: 0, + ..SecurityPolicy::default() + }); + let tool = HttpRequestTool::new(security, vec!["example.com".into()], 1_000_000, 30); + let result = tool + .execute(json!({"url": "https://example.com"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("rate limit")); + } + + #[test] + fn truncate_response_within_limit() { + let tool = test_tool(vec!["example.com"]); + let text = "hello world"; + assert_eq!(tool.truncate_response(text), "hello world"); + } + + #[test] + fn truncate_response_over_limit() { + let tool = HttpRequestTool::new( + Arc::new(SecurityPolicy::default()), + vec!["example.com".into()], + 10, + 30, + ); + let text = "hello world this is long"; + let truncated = tool.truncate_response(text); + assert!(truncated.len() <= 10 + 60); // limit + message + assert!(truncated.contains("[Response truncated")); + } + + #[test] + fn parse_headers_preserves_original_values() { + let tool = test_tool(vec!["example.com"]); + let headers = json!({ + "Authorization": "Bearer secret", + "Content-Type": "application/json", + "X-API-Key": "my-key" + }); + let parsed = tool.parse_headers(&headers); + assert_eq!(parsed.len(), 3); + assert!(parsed + .iter() + .any(|(k, v)| k == "Authorization" && v == "Bearer secret")); + assert!(parsed + .iter() + .any(|(k, v)| k == "X-API-Key" && v == "my-key")); + assert!(parsed + .iter() + .any(|(k, v)| k == "Content-Type" && v == "application/json")); + } + + #[test] + fn redact_headers_for_display_redacts_sensitive() { + let headers = vec![ + ("Authorization".into(), "Bearer secret".into()), + ("Content-Type".into(), "application/json".into()), + ("X-API-Key".into(), "my-key".into()), + ("X-Secret-Token".into(), "tok-123".into()), + ]; + let redacted = HttpRequestTool::redact_headers_for_display(&headers); + assert_eq!(redacted.len(), 4); + assert!(redacted + .iter() + .any(|(k, v)| k == "Authorization" && v == "***REDACTED***")); + assert!(redacted + .iter() + .any(|(k, v)| k == "X-API-Key" && v == "***REDACTED***")); + assert!(redacted + .iter() + .any(|(k, v)| k == "X-Secret-Token" && v == "***REDACTED***")); + assert!(redacted + .iter() + .any(|(k, v)| k == "Content-Type" && v == "application/json")); + } + + #[test] + fn redact_headers_does_not_alter_original() { + let headers = vec![("Authorization".into(), "Bearer real-token".into())]; + let _ = HttpRequestTool::redact_headers_for_display(&headers); + assert_eq!(headers[0].1, "Bearer real-token"); + } + + // ── SSRF: alternate IP notation bypass defense-in-depth ───────── + // + // Rust's IpAddr::parse() rejects non-standard notations (octal, hex, + // decimal integer, zero-padded). These tests document that property + // so regressions are caught if the parsing strategy ever changes. + + #[test] + fn ssrf_octal_loopback_not_parsed_as_ip() { + // 0177.0.0.1 is octal for 127.0.0.1 in some languages, but + // Rust's IpAddr rejects it — it falls through as a hostname. + assert!(!is_private_or_local_host("0177.0.0.1")); + } + + #[test] + fn ssrf_hex_loopback_not_parsed_as_ip() { + // 0x7f000001 is hex for 127.0.0.1 in some languages. + assert!(!is_private_or_local_host("0x7f000001")); + } + + #[test] + fn ssrf_decimal_loopback_not_parsed_as_ip() { + // 2130706433 is decimal for 127.0.0.1 in some languages. + assert!(!is_private_or_local_host("2130706433")); + } + + #[test] + fn ssrf_zero_padded_loopback_not_parsed_as_ip() { + // 127.000.000.001 uses zero-padded octets. + assert!(!is_private_or_local_host("127.000.000.001")); + } + + #[test] + fn ssrf_alternate_notations_rejected_by_validate_url() { + // Even if is_private_or_local_host doesn't flag these, they + // fail the allowlist because they're treated as hostnames. + let tool = test_tool(vec!["example.com"]); + for notation in [ + "http://0177.0.0.1", + "http://0x7f000001", + "http://2130706433", + "http://127.000.000.001", + ] { + let err = tool.validate_url(notation).unwrap_err().to_string(); + assert!( + err.contains("allowed_domains"), + "Expected allowlist rejection for {notation}, got: {err}" + ); + } + } +} diff --git a/src/tools/image_info.rs b/src/tools/image_info.rs new file mode 100644 index 0000000..349f707 --- /dev/null +++ b/src/tools/image_info.rs @@ -0,0 +1,493 @@ +use super::traits::{Tool, ToolResult}; +use crate::security::SecurityPolicy; +use async_trait::async_trait; +use serde_json::json; +use std::fmt::Write; +use std::path::Path; +use std::sync::Arc; + +/// Maximum file size we will read and base64-encode (5 MB). +const MAX_IMAGE_BYTES: u64 = 5_242_880; + +/// Tool to read image metadata and optionally return base64-encoded data. +/// +/// Since providers are currently text-only, this tool extracts what it can +/// (file size, format, dimensions from header bytes) and provides base64 +/// data for future multimodal provider support. +pub struct ImageInfoTool { + security: Arc, +} + +impl ImageInfoTool { + pub fn new(security: Arc) -> Self { + Self { security } + } + + /// Detect image format from first few bytes (magic numbers). + fn detect_format(bytes: &[u8]) -> &'static str { + if bytes.len() < 4 { + return "unknown"; + } + if bytes.starts_with(b"\x89PNG") { + "png" + } else if bytes.starts_with(b"\xFF\xD8\xFF") { + "jpeg" + } else if bytes.starts_with(b"GIF8") { + "gif" + } else if bytes.starts_with(b"RIFF") && bytes.len() >= 12 && &bytes[8..12] == b"WEBP" { + "webp" + } else if bytes.starts_with(b"BM") { + "bmp" + } else { + "unknown" + } + } + + /// Try to extract dimensions from image header bytes. + /// Returns (width, height) if detectable. + fn extract_dimensions(bytes: &[u8], format: &str) -> Option<(u32, u32)> { + match format { + "png" => { + // PNG IHDR chunk: bytes 16-19 = width, 20-23 = height (big-endian) + if bytes.len() >= 24 { + let w = u32::from_be_bytes([bytes[16], bytes[17], bytes[18], bytes[19]]); + let h = u32::from_be_bytes([bytes[20], bytes[21], bytes[22], bytes[23]]); + Some((w, h)) + } else { + None + } + } + "gif" => { + // GIF: bytes 6-7 = width, 8-9 = height (little-endian) + if bytes.len() >= 10 { + let w = u32::from(u16::from_le_bytes([bytes[6], bytes[7]])); + let h = u32::from(u16::from_le_bytes([bytes[8], bytes[9]])); + Some((w, h)) + } else { + None + } + } + "bmp" => { + // BMP: bytes 18-21 = width, 22-25 = height (little-endian, signed) + if bytes.len() >= 26 { + let w = u32::from_le_bytes([bytes[18], bytes[19], bytes[20], bytes[21]]); + let h_raw = i32::from_le_bytes([bytes[22], bytes[23], bytes[24], bytes[25]]); + let h = h_raw.unsigned_abs(); + Some((w, h)) + } else { + None + } + } + "jpeg" => Self::jpeg_dimensions(bytes), + _ => None, + } + } + + /// Parse JPEG SOF markers to extract dimensions. + fn jpeg_dimensions(bytes: &[u8]) -> Option<(u32, u32)> { + let mut i = 2; // skip SOI marker + while i + 1 < bytes.len() { + if bytes[i] != 0xFF { + return None; + } + let marker = bytes[i + 1]; + i += 2; + + // SOF0..SOF3 markers contain dimensions + if (0xC0..=0xC3).contains(&marker) { + if i + 7 <= bytes.len() { + let h = u32::from(u16::from_be_bytes([bytes[i + 3], bytes[i + 4]])); + let w = u32::from(u16::from_be_bytes([bytes[i + 5], bytes[i + 6]])); + return Some((w, h)); + } + return None; + } + + // Skip this segment + if i + 1 < bytes.len() { + let seg_len = u16::from_be_bytes([bytes[i], bytes[i + 1]]) as usize; + if seg_len < 2 { + return None; // Malformed segment (valid segments have length >= 2) + } + i += seg_len; + } else { + return None; + } + } + None + } +} + +#[async_trait] +impl Tool for ImageInfoTool { + fn name(&self) -> &str { + "image_info" + } + + fn description(&self) -> &str { + "Read image file metadata (format, dimensions, size) and optionally return base64-encoded data." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Path to the image file (absolute or relative to workspace)" + }, + "include_base64": { + "type": "boolean", + "description": "Include base64-encoded image data in output (default: false)" + } + }, + "required": ["path"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let path_str = args + .get("path") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'path' parameter"))?; + + let include_base64 = args + .get("include_base64") + .and_then(serde_json::Value::as_bool) + .unwrap_or(false); + + let path = Path::new(path_str); + + // Restrict reads to workspace directory to prevent arbitrary file exfiltration + if !self.security.is_path_allowed(path_str) { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Path not allowed: {path_str} (must be within workspace)" + )), + }); + } + + if !path.exists() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("File not found: {path_str}")), + }); + } + + let metadata = tokio::fs::metadata(path) + .await + .map_err(|e| anyhow::anyhow!("Failed to read file metadata: {e}"))?; + + let file_size = metadata.len(); + + if file_size > MAX_IMAGE_BYTES { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Image too large: {file_size} bytes (max {MAX_IMAGE_BYTES} bytes)" + )), + }); + } + + let bytes = tokio::fs::read(path) + .await + .map_err(|e| anyhow::anyhow!("Failed to read image file: {e}"))?; + + let format = Self::detect_format(&bytes); + let dimensions = Self::extract_dimensions(&bytes, format); + + let mut output = format!("File: {path_str}\nFormat: {format}\nSize: {file_size} bytes"); + + if let Some((w, h)) = dimensions { + let _ = write!(output, "\nDimensions: {w}x{h}"); + } + + if include_base64 { + use base64::Engine; + let encoded = base64::engine::general_purpose::STANDARD.encode(&bytes); + let mime = match format { + "png" => "image/png", + "jpeg" => "image/jpeg", + "gif" => "image/gif", + "webp" => "image/webp", + "bmp" => "image/bmp", + _ => "application/octet-stream", + }; + let _ = write!(output, "\ndata:{mime};base64,{encoded}"); + } + + Ok(ToolResult { + success: true, + output, + error: None, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::security::{AutonomyLevel, SecurityPolicy}; + + fn test_security() -> Arc { + Arc::new(SecurityPolicy { + autonomy: AutonomyLevel::Full, + workspace_dir: std::env::temp_dir(), + workspace_only: false, + forbidden_paths: vec![], + ..SecurityPolicy::default() + }) + } + + #[test] + fn image_info_tool_name() { + let tool = ImageInfoTool::new(test_security()); + assert_eq!(tool.name(), "image_info"); + } + + #[test] + fn image_info_tool_description() { + let tool = ImageInfoTool::new(test_security()); + assert!(!tool.description().is_empty()); + assert!(tool.description().contains("image")); + } + + #[test] + fn image_info_tool_schema() { + let tool = ImageInfoTool::new(test_security()); + let schema = tool.parameters_schema(); + assert!(schema["properties"]["path"].is_object()); + assert!(schema["properties"]["include_base64"].is_object()); + let required = schema["required"].as_array().unwrap(); + assert!(required.contains(&json!("path"))); + } + + #[test] + fn image_info_tool_spec() { + let tool = ImageInfoTool::new(test_security()); + let spec = tool.spec(); + assert_eq!(spec.name, "image_info"); + assert!(spec.parameters.is_object()); + } + + // ── Format detection ──────────────────────────────────────── + + #[test] + fn detect_png() { + let bytes = b"\x89PNG\r\n\x1a\n"; + assert_eq!(ImageInfoTool::detect_format(bytes), "png"); + } + + #[test] + fn detect_jpeg() { + let bytes = b"\xFF\xD8\xFF\xE0"; + assert_eq!(ImageInfoTool::detect_format(bytes), "jpeg"); + } + + #[test] + fn detect_gif() { + let bytes = b"GIF89a"; + assert_eq!(ImageInfoTool::detect_format(bytes), "gif"); + } + + #[test] + fn detect_webp() { + let bytes = b"RIFF\x00\x00\x00\x00WEBP"; + assert_eq!(ImageInfoTool::detect_format(bytes), "webp"); + } + + #[test] + fn detect_bmp() { + let bytes = b"BM\x00\x00"; + assert_eq!(ImageInfoTool::detect_format(bytes), "bmp"); + } + + #[test] + fn detect_unknown_short() { + let bytes = b"\x00\x01"; + assert_eq!(ImageInfoTool::detect_format(bytes), "unknown"); + } + + #[test] + fn detect_unknown_garbage() { + let bytes = b"this is not an image"; + assert_eq!(ImageInfoTool::detect_format(bytes), "unknown"); + } + + // ── Dimension extraction ──────────────────────────────────── + + #[test] + fn png_dimensions() { + // Minimal PNG IHDR: 8-byte signature + 4-byte length + 4-byte IHDR + 4-byte width + 4-byte height + let mut bytes = vec![ + 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature + 0x00, 0x00, 0x00, 0x0D, // IHDR length + 0x49, 0x48, 0x44, 0x52, // "IHDR" + 0x00, 0x00, 0x03, 0x20, // width: 800 + 0x00, 0x00, 0x02, 0x58, // height: 600 + ]; + bytes.extend_from_slice(&[0u8; 10]); // padding + let dims = ImageInfoTool::extract_dimensions(&bytes, "png"); + assert_eq!(dims, Some((800, 600))); + } + + #[test] + fn gif_dimensions() { + let bytes = [ + 0x47, 0x49, 0x46, 0x38, 0x39, 0x61, // GIF89a + 0x40, 0x01, // width: 320 (LE) + 0xF0, 0x00, // height: 240 (LE) + ]; + let dims = ImageInfoTool::extract_dimensions(&bytes, "gif"); + assert_eq!(dims, Some((320, 240))); + } + + #[test] + fn bmp_dimensions() { + let mut bytes = vec![0u8; 26]; + bytes[0] = b'B'; + bytes[1] = b'M'; + // width at offset 18 (LE): 1024 + bytes[18] = 0x00; + bytes[19] = 0x04; + bytes[20] = 0x00; + bytes[21] = 0x00; + // height at offset 22 (LE): 768 + bytes[22] = 0x00; + bytes[23] = 0x03; + bytes[24] = 0x00; + bytes[25] = 0x00; + let dims = ImageInfoTool::extract_dimensions(&bytes, "bmp"); + assert_eq!(dims, Some((1024, 768))); + } + + #[test] + fn jpeg_dimensions() { + // Minimal JPEG-like byte sequence with SOF0 marker + let mut bytes: Vec = vec![ + 0xFF, 0xD8, // SOI + 0xFF, 0xE0, // APP0 marker + 0x00, 0x10, // APP0 length = 16 + ]; + bytes.extend_from_slice(&[0u8; 14]); // APP0 payload + bytes.extend_from_slice(&[ + 0xFF, 0xC0, // SOF0 marker + 0x00, 0x11, // SOF0 length + 0x08, // precision + 0x01, 0xE0, // height: 480 + 0x02, 0x80, // width: 640 + ]); + let dims = ImageInfoTool::extract_dimensions(&bytes, "jpeg"); + assert_eq!(dims, Some((640, 480))); + } + + #[test] + fn jpeg_malformed_zero_length_segment() { + // Zero-length segment should return None instead of looping forever + let bytes: Vec = vec![ + 0xFF, 0xD8, // SOI + 0xFF, 0xE0, // APP0 marker + 0x00, 0x00, // length = 0 (malformed) + ]; + let dims = ImageInfoTool::extract_dimensions(&bytes, "jpeg"); + assert!(dims.is_none()); + } + + #[test] + fn unknown_format_no_dimensions() { + let bytes = b"random data here"; + let dims = ImageInfoTool::extract_dimensions(bytes, "unknown"); + assert!(dims.is_none()); + } + + // ── Execute tests ─────────────────────────────────────────── + + #[tokio::test] + async fn execute_missing_path() { + let tool = ImageInfoTool::new(test_security()); + let result = tool.execute(json!({})).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn execute_nonexistent_file() { + let tool = ImageInfoTool::new(test_security()); + let result = tool + .execute(json!({"path": "/tmp/nonexistent_image_xyz.png"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.as_ref().unwrap().contains("not found")); + } + + #[tokio::test] + async fn execute_real_file() { + // Create a minimal valid PNG + let dir = std::env::temp_dir().join("zeroclaw_image_info_test"); + let _ = std::fs::create_dir_all(&dir); + let png_path = dir.join("test.png"); + + // Minimal 1x1 red PNG (67 bytes) + let png_bytes: Vec = vec![ + 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // signature + 0x00, 0x00, 0x00, 0x0D, // IHDR length + 0x49, 0x48, 0x44, 0x52, // IHDR + 0x00, 0x00, 0x00, 0x01, // width: 1 + 0x00, 0x00, 0x00, 0x01, // height: 1 + 0x08, 0x02, 0x00, 0x00, 0x00, // bit depth, color type, etc. + 0x90, 0x77, 0x53, 0xDE, // CRC + 0x00, 0x00, 0x00, 0x0C, // IDAT length + 0x49, 0x44, 0x41, 0x54, // IDAT + 0x08, 0xD7, 0x63, 0xF8, 0xCF, 0xC0, 0x00, 0x00, 0x00, 0x02, 0x00, 0x01, 0xE2, 0x21, + 0xBC, 0x33, // CRC + 0x00, 0x00, 0x00, 0x00, // IEND length + 0x49, 0x45, 0x4E, 0x44, // IEND + 0xAE, 0x42, 0x60, 0x82, // CRC + ]; + std::fs::write(&png_path, &png_bytes).unwrap(); + + let tool = ImageInfoTool::new(test_security()); + let result = tool + .execute(json!({"path": png_path.to_string_lossy()})) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("Format: png")); + assert!(result.output.contains("Dimensions: 1x1")); + assert!(!result.output.contains("data:")); + + // Clean up + let _ = std::fs::remove_dir_all(&dir); + } + + #[tokio::test] + async fn execute_with_base64() { + let dir = std::env::temp_dir().join("zeroclaw_image_info_b64"); + let _ = std::fs::create_dir_all(&dir); + let png_path = dir.join("test_b64.png"); + + // Minimal 1x1 PNG + let png_bytes: Vec = vec![ + 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, + 0x44, 0x52, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02, 0x00, 0x00, + 0x00, 0x90, 0x77, 0x53, 0xDE, 0x00, 0x00, 0x00, 0x0C, 0x49, 0x44, 0x41, 0x54, 0x08, + 0xD7, 0x63, 0xF8, 0xCF, 0xC0, 0x00, 0x00, 0x00, 0x02, 0x00, 0x01, 0xE2, 0x21, 0xBC, + 0x33, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4E, 0x44, 0xAE, 0x42, 0x60, 0x82, + ]; + std::fs::write(&png_path, &png_bytes).unwrap(); + + let tool = ImageInfoTool::new(test_security()); + let result = tool + .execute(json!({"path": png_path.to_string_lossy(), "include_base64": true})) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("data:image/png;base64,")); + + let _ = std::fs::remove_dir_all(&dir); + } +} diff --git a/src/tools/memory_forget.rs b/src/tools/memory_forget.rs index 16b2b8a..a53885e 100644 --- a/src/tools/memory_forget.rs +++ b/src/tools/memory_forget.rs @@ -87,7 +87,7 @@ mod tests { #[tokio::test] async fn forget_existing() { let (_tmp, mem) = test_mem(); - mem.store("temp", "temporary", MemoryCategory::Conversation) + mem.store("temp", "temporary", MemoryCategory::Conversation, None) .await .unwrap(); diff --git a/src/tools/memory_recall.rs b/src/tools/memory_recall.rs index ff1385a..fada306 100644 --- a/src/tools/memory_recall.rs +++ b/src/tools/memory_recall.rs @@ -55,7 +55,7 @@ impl Tool for MemoryRecallTool { .and_then(serde_json::Value::as_u64) .map_or(5, |v| v as usize); - match self.memory.recall(query, limit).await { + match self.memory.recall(query, limit, None).await { Ok(entries) if entries.is_empty() => Ok(ToolResult { success: true, output: "No memories found matching that query.".into(), @@ -112,10 +112,10 @@ mod tests { #[tokio::test] async fn recall_finds_match() { let (_tmp, mem) = seeded_mem(); - mem.store("lang", "User prefers Rust", MemoryCategory::Core) + mem.store("lang", "User prefers Rust", MemoryCategory::Core, None) .await .unwrap(); - mem.store("tz", "Timezone is EST", MemoryCategory::Core) + mem.store("tz", "Timezone is EST", MemoryCategory::Core, None) .await .unwrap(); @@ -134,6 +134,7 @@ mod tests { &format!("k{i}"), &format!("Rust fact {i}"), MemoryCategory::Core, + None, ) .await .unwrap(); diff --git a/src/tools/memory_store.rs b/src/tools/memory_store.rs index b90222c..d2aad40 100644 --- a/src/tools/memory_store.rs +++ b/src/tools/memory_store.rs @@ -64,7 +64,7 @@ impl Tool for MemoryStoreTool { _ => MemoryCategory::Core, }; - match self.memory.store(key, content, category).await { + match self.memory.store(key, content, category, None).await { Ok(()) => Ok(ToolResult { success: true, output: format!("Stored memory: {key}"), diff --git a/src/tools/mod.rs b/src/tools/mod.rs index e02154d..3c6309f 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -1,54 +1,155 @@ pub mod browser; pub mod browser_open; pub mod composio; +pub mod cron_add; +pub mod cron_list; +pub mod cron_remove; +pub mod cron_run; +pub mod cron_runs; +pub mod cron_update; +pub mod delegate; pub mod file_read; pub mod file_write; +pub mod git_operations; +pub mod hardware_board_info; +pub mod hardware_memory_map; +pub mod hardware_memory_read; +pub mod http_request; +pub mod image_info; pub mod memory_forget; pub mod memory_recall; pub mod memory_store; +pub mod pushover; +pub mod schedule; +pub mod schema; +pub mod screenshot; pub mod shell; pub mod traits; -pub use browser::BrowserTool; +pub use browser::{BrowserTool, ComputerUseConfig}; pub use browser_open::BrowserOpenTool; pub use composio::ComposioTool; +pub use cron_add::CronAddTool; +pub use cron_list::CronListTool; +pub use cron_remove::CronRemoveTool; +pub use cron_run::CronRunTool; +pub use cron_runs::CronRunsTool; +pub use cron_update::CronUpdateTool; +pub use delegate::DelegateTool; pub use file_read::FileReadTool; pub use file_write::FileWriteTool; +pub use git_operations::GitOperationsTool; +pub use hardware_board_info::HardwareBoardInfoTool; +pub use hardware_memory_map::HardwareMemoryMapTool; +pub use hardware_memory_read::HardwareMemoryReadTool; +pub use http_request::HttpRequestTool; +pub use image_info::ImageInfoTool; pub use memory_forget::MemoryForgetTool; pub use memory_recall::MemoryRecallTool; pub use memory_store::MemoryStoreTool; +pub use pushover::PushoverTool; +pub use schedule::ScheduleTool; +#[allow(unused_imports)] +pub use schema::{CleaningStrategy, SchemaCleanr}; +pub use screenshot::ScreenshotTool; pub use shell::ShellTool; pub use traits::Tool; #[allow(unused_imports)] pub use traits::{ToolResult, ToolSpec}; +use crate::config::{Config, DelegateAgentConfig}; use crate::memory::Memory; +use crate::runtime::{NativeRuntime, RuntimeAdapter}; use crate::security::SecurityPolicy; +use std::collections::HashMap; use std::sync::Arc; /// Create the default tool registry pub fn default_tools(security: Arc) -> Vec> { + default_tools_with_runtime(security, Arc::new(NativeRuntime::new())) +} + +/// Create the default tool registry with explicit runtime adapter. +pub fn default_tools_with_runtime( + security: Arc, + runtime: Arc, +) -> Vec> { vec![ - Box::new(ShellTool::new(security.clone())), + Box::new(ShellTool::new(security.clone(), runtime)), Box::new(FileReadTool::new(security.clone())), Box::new(FileWriteTool::new(security)), ] } /// Create full tool registry including memory tools and optional Composio +#[allow(clippy::implicit_hasher, clippy::too_many_arguments)] pub fn all_tools( + config: Arc, security: &Arc, memory: Arc, composio_key: Option<&str>, + composio_entity_id: Option<&str>, browser_config: &crate::config::BrowserConfig, + http_config: &crate::config::HttpRequestConfig, + workspace_dir: &std::path::Path, + agents: &HashMap, + fallback_api_key: Option<&str>, + root_config: &crate::config::Config, +) -> Vec> { + all_tools_with_runtime( + config, + security, + Arc::new(NativeRuntime::new()), + memory, + composio_key, + composio_entity_id, + browser_config, + http_config, + workspace_dir, + agents, + fallback_api_key, + root_config, + ) +} + +/// Create full tool registry including memory tools and optional Composio. +#[allow(clippy::implicit_hasher, clippy::too_many_arguments)] +pub fn all_tools_with_runtime( + config: Arc, + security: &Arc, + runtime: Arc, + memory: Arc, + composio_key: Option<&str>, + composio_entity_id: Option<&str>, + browser_config: &crate::config::BrowserConfig, + http_config: &crate::config::HttpRequestConfig, + workspace_dir: &std::path::Path, + agents: &HashMap, + fallback_api_key: Option<&str>, + root_config: &crate::config::Config, ) -> Vec> { let mut tools: Vec> = vec![ - Box::new(ShellTool::new(security.clone())), + Box::new(ShellTool::new(security.clone(), runtime)), Box::new(FileReadTool::new(security.clone())), Box::new(FileWriteTool::new(security.clone())), + Box::new(CronAddTool::new(config.clone(), security.clone())), + Box::new(CronListTool::new(config.clone())), + Box::new(CronRemoveTool::new(config.clone())), + Box::new(CronUpdateTool::new(config.clone(), security.clone())), + Box::new(CronRunTool::new(config.clone())), + Box::new(CronRunsTool::new(config.clone())), Box::new(MemoryStoreTool::new(memory.clone())), Box::new(MemoryRecallTool::new(memory.clone())), Box::new(MemoryForgetTool::new(memory)), + Box::new(ScheduleTool::new(security.clone(), root_config.clone())), + Box::new(GitOperationsTool::new( + security.clone(), + workspace_dir.to_path_buf(), + )), + Box::new(PushoverTool::new( + security.clone(), + workspace_dir.to_path_buf(), + )), ]; if browser_config.enabled { @@ -57,29 +158,79 @@ pub fn all_tools( security.clone(), browser_config.allowed_domains.clone(), ))); - // Add full browser automation tool (agent-browser) - tools.push(Box::new(BrowserTool::new( + // Add full browser automation tool (pluggable backend) + tools.push(Box::new(BrowserTool::new_with_backend( security.clone(), browser_config.allowed_domains.clone(), browser_config.session_name.clone(), + browser_config.backend.clone(), + browser_config.native_headless, + browser_config.native_webdriver_url.clone(), + browser_config.native_chrome_path.clone(), + ComputerUseConfig { + endpoint: browser_config.computer_use.endpoint.clone(), + api_key: browser_config.computer_use.api_key.clone(), + timeout_ms: browser_config.computer_use.timeout_ms, + allow_remote_endpoint: browser_config.computer_use.allow_remote_endpoint, + window_allowlist: browser_config.computer_use.window_allowlist.clone(), + max_coordinate_x: browser_config.computer_use.max_coordinate_x, + max_coordinate_y: browser_config.computer_use.max_coordinate_y, + }, ))); } + if http_config.enabled { + tools.push(Box::new(HttpRequestTool::new( + security.clone(), + http_config.allowed_domains.clone(), + http_config.max_response_size, + http_config.timeout_secs, + ))); + } + + // Vision tools are always available + tools.push(Box::new(ScreenshotTool::new(security.clone()))); + tools.push(Box::new(ImageInfoTool::new(security.clone()))); + if let Some(key) = composio_key { if !key.is_empty() { - tools.push(Box::new(ComposioTool::new(key))); + tools.push(Box::new(ComposioTool::new(key, composio_entity_id))); } } + // Add delegation tool when agents are configured + if !agents.is_empty() { + let delegate_agents: HashMap = agents + .iter() + .map(|(name, cfg)| (name.clone(), cfg.clone())) + .collect(); + let delegate_fallback_credential = fallback_api_key.and_then(|value| { + let trimmed_value = value.trim(); + (!trimmed_value.is_empty()).then(|| trimmed_value.to_owned()) + }); + tools.push(Box::new(DelegateTool::new( + delegate_agents, + delegate_fallback_credential, + ))); + } + tools } #[cfg(test)] mod tests { use super::*; - use crate::config::{BrowserConfig, MemoryConfig}; + use crate::config::{BrowserConfig, Config, MemoryConfig}; use tempfile::TempDir; + fn test_config(tmp: &TempDir) -> Config { + Config { + workspace_dir: tmp.path().join("workspace"), + config_path: tmp.path().join("config.toml"), + ..Config::default() + } + } + #[test] fn default_tools_has_three() { let security = Arc::new(SecurityPolicy::default()); @@ -102,11 +253,28 @@ mod tests { enabled: false, allowed_domains: vec!["example.com".into()], session_name: None, + ..BrowserConfig::default() }; + let http = crate::config::HttpRequestConfig::default(); + let cfg = test_config(&tmp); - let tools = all_tools(&security, mem, None, &browser); + let tools = all_tools( + Arc::new(Config::default()), + &security, + mem, + None, + None, + &browser, + &http, + tmp.path(), + &HashMap::new(), + None, + &cfg, + ); let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); assert!(!names.contains(&"browser_open")); + assert!(names.contains(&"schedule")); + assert!(names.contains(&"pushover")); } #[test] @@ -124,11 +292,27 @@ mod tests { enabled: true, allowed_domains: vec!["example.com".into()], session_name: None, + ..BrowserConfig::default() }; + let http = crate::config::HttpRequestConfig::default(); + let cfg = test_config(&tmp); - let tools = all_tools(&security, mem, None, &browser); + let tools = all_tools( + Arc::new(Config::default()), + &security, + mem, + None, + None, + &browser, + &http, + tmp.path(), + &HashMap::new(), + None, + &cfg, + ); let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); assert!(names.contains(&"browser_open")); + assert!(names.contains(&"pushover")); } #[test] @@ -224,4 +408,81 @@ mod tests { assert_eq!(parsed.name, "test"); assert_eq!(parsed.description, "A test tool"); } + + #[test] + fn all_tools_includes_delegate_when_agents_configured() { + let tmp = TempDir::new().unwrap(); + let security = Arc::new(SecurityPolicy::default()); + let mem_cfg = MemoryConfig { + backend: "markdown".into(), + ..MemoryConfig::default() + }; + let mem: Arc = + Arc::from(crate::memory::create_memory(&mem_cfg, tmp.path(), None).unwrap()); + + let browser = BrowserConfig::default(); + let http = crate::config::HttpRequestConfig::default(); + let cfg = test_config(&tmp); + + let mut agents = HashMap::new(); + agents.insert( + "researcher".to_string(), + DelegateAgentConfig { + provider: "ollama".to_string(), + model: "llama3".to_string(), + system_prompt: None, + api_key: None, + temperature: None, + max_depth: 3, + }, + ); + + let tools = all_tools( + Arc::new(Config::default()), + &security, + mem, + None, + None, + &browser, + &http, + tmp.path(), + &agents, + Some("delegate-test-credential"), + &cfg, + ); + let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); + assert!(names.contains(&"delegate")); + } + + #[test] + fn all_tools_excludes_delegate_when_no_agents() { + let tmp = TempDir::new().unwrap(); + let security = Arc::new(SecurityPolicy::default()); + let mem_cfg = MemoryConfig { + backend: "markdown".into(), + ..MemoryConfig::default() + }; + let mem: Arc = + Arc::from(crate::memory::create_memory(&mem_cfg, tmp.path(), None).unwrap()); + + let browser = BrowserConfig::default(); + let http = crate::config::HttpRequestConfig::default(); + let cfg = test_config(&tmp); + + let tools = all_tools( + Arc::new(Config::default()), + &security, + mem, + None, + None, + &browser, + &http, + tmp.path(), + &HashMap::new(), + None, + &cfg, + ); + let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); + assert!(!names.contains(&"delegate")); + } } diff --git a/src/tools/pushover.rs b/src/tools/pushover.rs new file mode 100644 index 0000000..ad1d385 --- /dev/null +++ b/src/tools/pushover.rs @@ -0,0 +1,442 @@ +use super::traits::{Tool, ToolResult}; +use crate::security::SecurityPolicy; +use async_trait::async_trait; +use reqwest::Client; +use serde_json::json; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Duration; + +const PUSHOVER_API_URL: &str = "https://api.pushover.net/1/messages.json"; +const PUSHOVER_REQUEST_TIMEOUT_SECS: u64 = 15; + +pub struct PushoverTool { + client: Client, + security: Arc, + workspace_dir: PathBuf, +} + +impl PushoverTool { + pub fn new(security: Arc, workspace_dir: PathBuf) -> Self { + let client = Client::builder() + .timeout(Duration::from_secs(PUSHOVER_REQUEST_TIMEOUT_SECS)) + .build() + .unwrap_or_else(|_| Client::new()); + + Self { + client, + security, + workspace_dir, + } + } + + fn parse_env_value(raw: &str) -> String { + let raw = raw.trim(); + + let unquoted = if raw.len() >= 2 + && ((raw.starts_with('"') && raw.ends_with('"')) + || (raw.starts_with('\'') && raw.ends_with('\''))) + { + &raw[1..raw.len() - 1] + } else { + raw + }; + + // Keep support for inline comments in unquoted values: + // KEY=value # comment + unquoted.split_once(" #").map_or_else( + || unquoted.trim().to_string(), + |(value, _)| value.trim().to_string(), + ) + } + + fn get_credentials(&self) -> anyhow::Result<(String, String)> { + let env_path = self.workspace_dir.join(".env"); + let content = std::fs::read_to_string(&env_path) + .map_err(|e| anyhow::anyhow!("Failed to read {}: {}", env_path.display(), e))?; + + let mut token = None; + let mut user_key = None; + + for line in content.lines() { + let line = line.trim(); + if line.starts_with('#') || line.is_empty() { + continue; + } + let line = line.strip_prefix("export ").map(str::trim).unwrap_or(line); + if let Some((key, value)) = line.split_once('=') { + let key = key.trim(); + let value = Self::parse_env_value(value); + + if key.eq_ignore_ascii_case("PUSHOVER_TOKEN") { + token = Some(value); + } else if key.eq_ignore_ascii_case("PUSHOVER_USER_KEY") { + user_key = Some(value); + } + } + } + + let token = token.ok_or_else(|| anyhow::anyhow!("PUSHOVER_TOKEN not found in .env"))?; + let user_key = + user_key.ok_or_else(|| anyhow::anyhow!("PUSHOVER_USER_KEY not found in .env"))?; + + Ok((token, user_key)) + } +} + +#[async_trait] +impl Tool for PushoverTool { + fn name(&self) -> &str { + "pushover" + } + + fn description(&self) -> &str { + "Send a Pushover notification to your device. Requires PUSHOVER_TOKEN and PUSHOVER_USER_KEY in .env file." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "The notification message to send" + }, + "title": { + "type": "string", + "description": "Optional notification title" + }, + "priority": { + "type": "integer", + "enum": [-2, -1, 0, 1, 2], + "description": "Message priority: -2 (lowest/silent), -1 (low/no sound), 0 (normal), 1 (high), 2 (emergency/repeating)" + }, + "sound": { + "type": "string", + "description": "Notification sound override (e.g., 'pushover', 'bike', 'bugle', 'cashregister', etc.)" + } + }, + "required": ["message"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + if !self.security.can_act() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Action blocked: autonomy is read-only".into()), + }); + } + + if !self.security.record_action() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Action blocked: rate limit exceeded".into()), + }); + } + + let message = args + .get("message") + .and_then(|v| v.as_str()) + .map(str::trim) + .filter(|v| !v.is_empty()) + .ok_or_else(|| anyhow::anyhow!("Missing 'message' parameter"))? + .to_string(); + + let title = args.get("title").and_then(|v| v.as_str()).map(String::from); + + let priority = match args.get("priority").and_then(|v| v.as_i64()) { + Some(value) if (-2..=2).contains(&value) => Some(value), + Some(value) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Invalid 'priority': {value}. Expected integer in range -2..=2" + )), + }) + } + None => None, + }; + + let sound = args.get("sound").and_then(|v| v.as_str()).map(String::from); + + let (token, user_key) = self.get_credentials()?; + + let mut form = reqwest::multipart::Form::new() + .text("token", token) + .text("user", user_key) + .text("message", message); + + if let Some(title) = title { + form = form.text("title", title); + } + + if let Some(priority) = priority { + form = form.text("priority", priority.to_string()); + } + + if let Some(sound) = sound { + form = form.text("sound", sound); + } + + let response = self + .client + .post(PUSHOVER_API_URL) + .multipart(form) + .send() + .await?; + + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + + if !status.is_success() { + return Ok(ToolResult { + success: false, + output: body, + error: Some(format!("Pushover API returned status {}", status)), + }); + } + + let api_status = serde_json::from_str::(&body) + .ok() + .and_then(|json| json.get("status").and_then(|value| value.as_i64())); + + if api_status == Some(1) { + Ok(ToolResult { + success: true, + output: format!( + "Pushover notification sent successfully. Response: {}", + body + ), + error: None, + }) + } else { + Ok(ToolResult { + success: false, + output: body, + error: Some("Pushover API returned an application-level error".into()), + }) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::security::AutonomyLevel; + use std::fs; + use tempfile::TempDir; + + fn test_security(level: AutonomyLevel, max_actions_per_hour: u32) -> Arc { + Arc::new(SecurityPolicy { + autonomy: level, + max_actions_per_hour, + workspace_dir: std::env::temp_dir(), + ..SecurityPolicy::default() + }) + } + + #[test] + fn pushover_tool_name() { + let tool = PushoverTool::new( + test_security(AutonomyLevel::Full, 100), + PathBuf::from("/tmp"), + ); + assert_eq!(tool.name(), "pushover"); + } + + #[test] + fn pushover_tool_description() { + let tool = PushoverTool::new( + test_security(AutonomyLevel::Full, 100), + PathBuf::from("/tmp"), + ); + assert!(!tool.description().is_empty()); + } + + #[test] + fn pushover_tool_has_parameters_schema() { + let tool = PushoverTool::new( + test_security(AutonomyLevel::Full, 100), + PathBuf::from("/tmp"), + ); + let schema = tool.parameters_schema(); + assert_eq!(schema["type"], "object"); + assert!(schema["properties"].get("message").is_some()); + } + + #[test] + fn pushover_tool_requires_message() { + let tool = PushoverTool::new( + test_security(AutonomyLevel::Full, 100), + PathBuf::from("/tmp"), + ); + let schema = tool.parameters_schema(); + let required = schema["required"].as_array().unwrap(); + assert!(required.contains(&serde_json::Value::String("message".to_string()))); + } + + #[test] + fn credentials_parsed_from_env_file() { + let tmp = TempDir::new().unwrap(); + let env_path = tmp.path().join(".env"); + fs::write( + &env_path, + "PUSHOVER_TOKEN=testtoken123\nPUSHOVER_USER_KEY=userkey456\n", + ) + .unwrap(); + + let tool = PushoverTool::new( + test_security(AutonomyLevel::Full, 100), + tmp.path().to_path_buf(), + ); + let result = tool.get_credentials(); + + assert!(result.is_ok()); + let (token, user_key) = result.unwrap(); + assert_eq!(token, "testtoken123"); + assert_eq!(user_key, "userkey456"); + } + + #[test] + fn credentials_fail_without_env_file() { + let tmp = TempDir::new().unwrap(); + let tool = PushoverTool::new( + test_security(AutonomyLevel::Full, 100), + tmp.path().to_path_buf(), + ); + let result = tool.get_credentials(); + + assert!(result.is_err()); + } + + #[test] + fn credentials_fail_without_token() { + let tmp = TempDir::new().unwrap(); + let env_path = tmp.path().join(".env"); + fs::write(&env_path, "PUSHOVER_USER_KEY=userkey456\n").unwrap(); + + let tool = PushoverTool::new( + test_security(AutonomyLevel::Full, 100), + tmp.path().to_path_buf(), + ); + let result = tool.get_credentials(); + + assert!(result.is_err()); + } + + #[test] + fn credentials_fail_without_user_key() { + let tmp = TempDir::new().unwrap(); + let env_path = tmp.path().join(".env"); + fs::write(&env_path, "PUSHOVER_TOKEN=testtoken123\n").unwrap(); + + let tool = PushoverTool::new( + test_security(AutonomyLevel::Full, 100), + tmp.path().to_path_buf(), + ); + let result = tool.get_credentials(); + + assert!(result.is_err()); + } + + #[test] + fn credentials_ignore_comments() { + let tmp = TempDir::new().unwrap(); + let env_path = tmp.path().join(".env"); + fs::write(&env_path, "# This is a comment\nPUSHOVER_TOKEN=realtoken\n# Another comment\nPUSHOVER_USER_KEY=realuser\n").unwrap(); + + let tool = PushoverTool::new( + test_security(AutonomyLevel::Full, 100), + tmp.path().to_path_buf(), + ); + let result = tool.get_credentials(); + + assert!(result.is_ok()); + let (token, user_key) = result.unwrap(); + assert_eq!(token, "realtoken"); + assert_eq!(user_key, "realuser"); + } + + #[test] + fn pushover_tool_supports_priority() { + let tool = PushoverTool::new( + test_security(AutonomyLevel::Full, 100), + PathBuf::from("/tmp"), + ); + let schema = tool.parameters_schema(); + assert!(schema["properties"].get("priority").is_some()); + } + + #[test] + fn pushover_tool_supports_sound() { + let tool = PushoverTool::new( + test_security(AutonomyLevel::Full, 100), + PathBuf::from("/tmp"), + ); + let schema = tool.parameters_schema(); + assert!(schema["properties"].get("sound").is_some()); + } + + #[test] + fn credentials_support_export_and_quoted_values() { + let tmp = TempDir::new().unwrap(); + let env_path = tmp.path().join(".env"); + fs::write( + &env_path, + "export PUSHOVER_TOKEN=\"quotedtoken\"\nPUSHOVER_USER_KEY='quoteduser'\n", + ) + .unwrap(); + + let tool = PushoverTool::new( + test_security(AutonomyLevel::Full, 100), + tmp.path().to_path_buf(), + ); + let result = tool.get_credentials(); + + assert!(result.is_ok()); + let (token, user_key) = result.unwrap(); + assert_eq!(token, "quotedtoken"); + assert_eq!(user_key, "quoteduser"); + } + + #[tokio::test] + async fn execute_blocks_readonly_mode() { + let tool = PushoverTool::new( + test_security(AutonomyLevel::ReadOnly, 100), + PathBuf::from("/tmp"), + ); + + let result = tool.execute(json!({"message": "hello"})).await.unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("read-only")); + } + + #[tokio::test] + async fn execute_blocks_rate_limit() { + let tool = PushoverTool::new(test_security(AutonomyLevel::Full, 0), PathBuf::from("/tmp")); + + let result = tool.execute(json!({"message": "hello"})).await.unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("rate limit")); + } + + #[tokio::test] + async fn execute_rejects_priority_out_of_range() { + let tool = PushoverTool::new( + test_security(AutonomyLevel::Full, 100), + PathBuf::from("/tmp"), + ); + + let result = tool + .execute(json!({"message": "hello", "priority": 5})) + .await + .unwrap(); + + assert!(!result.success); + assert!(result.error.unwrap().contains("-2..=2")); + } +} diff --git a/src/tools/schedule.rs b/src/tools/schedule.rs new file mode 100644 index 0000000..96c3023 --- /dev/null +++ b/src/tools/schedule.rs @@ -0,0 +1,524 @@ +use super::traits::{Tool, ToolResult}; +use crate::config::Config; +use crate::cron; +use crate::security::SecurityPolicy; +use anyhow::Result; +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use serde_json::json; +use std::sync::Arc; + +/// Tool that lets the agent manage recurring and one-shot scheduled tasks. +pub struct ScheduleTool { + security: Arc, + config: Config, +} + +impl ScheduleTool { + pub fn new(security: Arc, config: Config) -> Self { + Self { security, config } + } +} + +#[async_trait] +impl Tool for ScheduleTool { + fn name(&self) -> &str { + "schedule" + } + + fn description(&self) -> &str { + "Manage scheduled tasks. Actions: create/add/once/list/get/cancel/remove/pause/resume" + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": ["create", "add", "once", "list", "get", "cancel", "remove", "pause", "resume"], + "description": "Action to perform" + }, + "expression": { + "type": "string", + "description": "Cron expression for recurring tasks (e.g. '*/5 * * * *')." + }, + "delay": { + "type": "string", + "description": "Delay for one-shot tasks (e.g. '30m', '2h', '1d')." + }, + "run_at": { + "type": "string", + "description": "Absolute RFC3339 time for one-shot tasks (e.g. '2030-01-01T00:00:00Z')." + }, + "command": { + "type": "string", + "description": "Shell command to execute. Required for create/add/once." + }, + "id": { + "type": "string", + "description": "Task ID. Required for get/cancel/remove/pause/resume." + } + }, + "required": ["action"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> Result { + let action = args + .get("action") + .and_then(|value| value.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'action' parameter"))?; + + match action { + "list" => self.handle_list(), + "get" => { + let id = args + .get("id") + .and_then(|value| value.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'id' parameter for get action"))?; + self.handle_get(id) + } + "create" | "add" | "once" => { + if let Some(blocked) = self.enforce_mutation_allowed(action) { + return Ok(blocked); + } + self.handle_create_like(action, &args) + } + "cancel" | "remove" => { + if let Some(blocked) = self.enforce_mutation_allowed(action) { + return Ok(blocked); + } + let id = args + .get("id") + .and_then(|value| value.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'id' parameter for cancel action"))?; + Ok(self.handle_cancel(id)) + } + "pause" => { + if let Some(blocked) = self.enforce_mutation_allowed(action) { + return Ok(blocked); + } + let id = args + .get("id") + .and_then(|value| value.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'id' parameter for pause action"))?; + Ok(self.handle_pause_resume(id, true)) + } + "resume" => { + if let Some(blocked) = self.enforce_mutation_allowed(action) { + return Ok(blocked); + } + let id = args + .get("id") + .and_then(|value| value.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'id' parameter for resume action"))?; + Ok(self.handle_pause_resume(id, false)) + } + other => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Unknown action '{other}'. Use create/add/once/list/get/cancel/remove/pause/resume." + )), + }), + } + } +} + +impl ScheduleTool { + fn enforce_mutation_allowed(&self, action: &str) -> Option { + if !self.security.can_act() { + return Some(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Security policy: read-only mode, cannot perform '{action}'" + )), + }); + } + + if !self.security.record_action() { + return Some(ToolResult { + success: false, + output: String::new(), + error: Some("Rate limit exceeded: action budget exhausted".to_string()), + }); + } + + None + } + + fn handle_list(&self) -> Result { + let jobs = cron::list_jobs(&self.config)?; + if jobs.is_empty() { + return Ok(ToolResult { + success: true, + output: "No scheduled jobs.".to_string(), + error: None, + }); + } + + let mut lines = Vec::with_capacity(jobs.len()); + for job in jobs { + let paused = !job.enabled; + let one_shot = matches!(job.schedule, cron::Schedule::At { .. }); + let flags = match (paused, one_shot) { + (true, true) => " [disabled, one-shot]", + (true, false) => " [disabled]", + (false, true) => " [one-shot]", + (false, false) => "", + }; + let last_run = job + .last_run + .map_or_else(|| "never".to_string(), |value| value.to_rfc3339()); + let last_status = job.last_status.unwrap_or_else(|| "n/a".to_string()); + lines.push(format!( + "- {} | {} | next={} | last={} ({}){} | cmd: {}", + job.id, + job.expression, + job.next_run.to_rfc3339(), + last_run, + last_status, + flags, + job.command + )); + } + + Ok(ToolResult { + success: true, + output: format!("Scheduled jobs ({}):\n{}", lines.len(), lines.join("\n")), + error: None, + }) + } + + fn handle_get(&self, id: &str) -> Result { + match cron::get_job(&self.config, id) { + Ok(job) => { + let detail = json!({ + "id": job.id, + "expression": job.expression, + "command": job.command, + "next_run": job.next_run.to_rfc3339(), + "last_run": job.last_run.map(|value| value.to_rfc3339()), + "last_status": job.last_status, + "enabled": job.enabled, + "one_shot": matches!(job.schedule, cron::Schedule::At { .. }), + }); + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&detail)?, + error: None, + }) + } + Err(_) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Job '{id}' not found")), + }), + } + } + + fn handle_create_like(&self, action: &str, args: &serde_json::Value) -> Result { + let command = args + .get("command") + .and_then(|value| value.as_str()) + .filter(|value| !value.trim().is_empty()) + .ok_or_else(|| anyhow::anyhow!("Missing or empty 'command' parameter"))?; + + let expression = args.get("expression").and_then(|value| value.as_str()); + let delay = args.get("delay").and_then(|value| value.as_str()); + let run_at = args.get("run_at").and_then(|value| value.as_str()); + + match action { + "add" => { + if expression.is_none() || delay.is_some() || run_at.is_some() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("'add' requires 'expression' and forbids delay/run_at".into()), + }); + } + } + "once" => { + if expression.is_some() || (delay.is_none() && run_at.is_none()) { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("'once' requires exactly one of 'delay' or 'run_at'".into()), + }); + } + if delay.is_some() && run_at.is_some() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("'once' supports either delay or run_at, not both".into()), + }); + } + } + _ => { + let count = [expression.is_some(), delay.is_some(), run_at.is_some()] + .into_iter() + .filter(|value| *value) + .count(); + if count != 1 { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some( + "Exactly one of 'expression', 'delay', or 'run_at' must be provided" + .into(), + ), + }); + } + } + } + + if let Some(value) = expression { + let job = cron::add_job(&self.config, value, command)?; + return Ok(ToolResult { + success: true, + output: format!( + "Created recurring job {} (expr: {}, next: {}, cmd: {})", + job.id, + job.expression, + job.next_run.to_rfc3339(), + job.command + ), + error: None, + }); + } + + if let Some(value) = delay { + let job = cron::add_once(&self.config, value, command)?; + return Ok(ToolResult { + success: true, + output: format!( + "Created one-shot job {} (runs at: {}, cmd: {})", + job.id, + job.next_run.to_rfc3339(), + job.command + ), + error: None, + }); + } + + let run_at_raw = run_at.ok_or_else(|| anyhow::anyhow!("Missing scheduling parameters"))?; + let run_at_parsed: DateTime = DateTime::parse_from_rfc3339(run_at_raw) + .map_err(|error| anyhow::anyhow!("Invalid run_at timestamp: {error}"))? + .with_timezone(&Utc); + + let job = cron::add_once_at(&self.config, run_at_parsed, command)?; + Ok(ToolResult { + success: true, + output: format!( + "Created one-shot job {} (runs at: {}, cmd: {})", + job.id, + job.next_run.to_rfc3339(), + job.command + ), + error: None, + }) + } + + fn handle_cancel(&self, id: &str) -> ToolResult { + match cron::remove_job(&self.config, id) { + Ok(()) => ToolResult { + success: true, + output: format!("Cancelled job {id}"), + error: None, + }, + Err(error) => ToolResult { + success: false, + output: String::new(), + error: Some(error.to_string()), + }, + } + } + + fn handle_pause_resume(&self, id: &str, pause: bool) -> ToolResult { + let operation = if pause { + cron::pause_job(&self.config, id) + } else { + cron::resume_job(&self.config, id) + }; + + match operation { + Ok(_) => ToolResult { + success: true, + output: if pause { + format!("Paused job {id}") + } else { + format!("Resumed job {id}") + }, + error: None, + }, + Err(error) => ToolResult { + success: false, + output: String::new(), + error: Some(error.to_string()), + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::security::AutonomyLevel; + use tempfile::TempDir; + + fn test_setup() -> (TempDir, Config, Arc) { + let tmp = TempDir::new().unwrap(); + let config = Config { + workspace_dir: tmp.path().join("workspace"), + config_path: tmp.path().join("config.toml"), + ..Config::default() + }; + std::fs::create_dir_all(&config.workspace_dir).unwrap(); + let security = Arc::new(SecurityPolicy::from_config( + &config.autonomy, + &config.workspace_dir, + )); + (tmp, config, security) + } + + #[test] + fn tool_name_and_schema() { + let (_tmp, config, security) = test_setup(); + let tool = ScheduleTool::new(security, config); + assert_eq!(tool.name(), "schedule"); + let schema = tool.parameters_schema(); + assert!(schema["properties"]["action"].is_object()); + } + + #[tokio::test] + async fn list_empty() { + let (_tmp, config, security) = test_setup(); + let tool = ScheduleTool::new(security, config); + + let result = tool.execute(json!({"action": "list"})).await.unwrap(); + assert!(result.success); + assert!(result.output.contains("No scheduled jobs")); + } + + #[tokio::test] + async fn create_get_and_cancel_roundtrip() { + let (_tmp, config, security) = test_setup(); + let tool = ScheduleTool::new(security, config); + + let create = tool + .execute(json!({ + "action": "create", + "expression": "*/5 * * * *", + "command": "echo hello" + })) + .await + .unwrap(); + assert!(create.success); + assert!(create.output.contains("Created recurring job")); + + let list = tool.execute(json!({"action": "list"})).await.unwrap(); + assert!(list.success); + assert!(list.output.contains("echo hello")); + + let id = create.output.split_whitespace().nth(3).unwrap(); + + let get = tool + .execute(json!({"action": "get", "id": id})) + .await + .unwrap(); + assert!(get.success); + assert!(get.output.contains("echo hello")); + + let cancel = tool + .execute(json!({"action": "cancel", "id": id})) + .await + .unwrap(); + assert!(cancel.success); + } + + #[tokio::test] + async fn once_and_pause_resume_aliases_work() { + let (_tmp, config, security) = test_setup(); + let tool = ScheduleTool::new(security, config); + + let once = tool + .execute(json!({ + "action": "once", + "delay": "30m", + "command": "echo delayed" + })) + .await + .unwrap(); + assert!(once.success); + + let add = tool + .execute(json!({ + "action": "add", + "expression": "*/10 * * * *", + "command": "echo recurring" + })) + .await + .unwrap(); + assert!(add.success); + + let id = add.output.split_whitespace().nth(3).unwrap(); + let pause = tool + .execute(json!({"action": "pause", "id": id})) + .await + .unwrap(); + assert!(pause.success); + + let resume = tool + .execute(json!({"action": "resume", "id": id})) + .await + .unwrap(); + assert!(resume.success); + } + + #[tokio::test] + async fn readonly_blocks_mutating_actions() { + let tmp = TempDir::new().unwrap(); + let config = Config { + workspace_dir: tmp.path().join("workspace"), + config_path: tmp.path().join("config.toml"), + autonomy: crate::config::AutonomyConfig { + level: AutonomyLevel::ReadOnly, + ..Default::default() + }, + ..Config::default() + }; + std::fs::create_dir_all(&config.workspace_dir).unwrap(); + let security = Arc::new(SecurityPolicy::from_config( + &config.autonomy, + &config.workspace_dir, + )); + + let tool = ScheduleTool::new(security, config); + + let blocked = tool + .execute(json!({ + "action": "create", + "expression": "* * * * *", + "command": "echo blocked" + })) + .await + .unwrap(); + assert!(!blocked.success); + assert!(blocked.error.as_deref().unwrap().contains("read-only")); + + let list = tool.execute(json!({"action": "list"})).await.unwrap(); + assert!(list.success); + } + + #[tokio::test] + async fn unknown_action_returns_failure() { + let (_tmp, config, security) = test_setup(); + let tool = ScheduleTool::new(security, config); + + let result = tool.execute(json!({"action": "explode"})).await.unwrap(); + assert!(!result.success); + assert!(result.error.as_deref().unwrap().contains("Unknown action")); + } +} diff --git a/src/tools/schema.rs b/src/tools/schema.rs new file mode 100644 index 0000000..e651993 --- /dev/null +++ b/src/tools/schema.rs @@ -0,0 +1,838 @@ +//! JSON Schema cleaning and validation for LLM tool-calling compatibility. +//! +//! Different providers support different subsets of JSON Schema. This module +//! normalizes tool schemas to improve cross-provider compatibility while +//! preserving semantic intent. +//! +//! ## What this module does +//! +//! 1. Removes unsupported keywords per provider strategy +//! 2. Resolves local `$ref` entries from `$defs` and `definitions` +//! 3. Flattens literal `anyOf` / `oneOf` unions into `enum` +//! 4. Strips nullable variants from unions and `type` arrays +//! 5. Converts `const` to single-value `enum` +//! 6. Detects circular references and stops recursion safely +//! +//! # Example +//! +//! ```rust +//! use serde_json::json; +//! use zeroclaw::tools::schema::SchemaCleanr; +//! +//! let dirty_schema = json!({ +//! "type": "object", +//! "properties": { +//! "name": { +//! "type": "string", +//! "minLength": 1, // Gemini rejects this +//! "pattern": "^[a-z]+$" // Gemini rejects this +//! }, +//! "age": { +//! "$ref": "#/$defs/Age" // Needs resolution +//! } +//! }, +//! "$defs": { +//! "Age": { +//! "type": "integer", +//! "minimum": 0 // Gemini rejects this +//! } +//! } +//! }); +//! +//! let cleaned = SchemaCleanr::clean_for_gemini(dirty_schema); +//! +//! // Result: +//! // { +//! // "type": "object", +//! // "properties": { +//! // "name": { "type": "string" }, +//! // "age": { "type": "integer" } +//! // } +//! // } +//! ``` +//! +use serde_json::{json, Map, Value}; +use std::collections::{HashMap, HashSet}; + +/// Keywords that Gemini rejects for tool schemas. +pub const GEMINI_UNSUPPORTED_KEYWORDS: &[&str] = &[ + // Schema composition + "$ref", + "$schema", + "$id", + "$defs", + "definitions", + // Property constraints + "additionalProperties", + "patternProperties", + // String constraints + "minLength", + "maxLength", + "pattern", + "format", + // Number constraints + "minimum", + "maximum", + "multipleOf", + // Array constraints + "minItems", + "maxItems", + "uniqueItems", + // Object constraints + "minProperties", + "maxProperties", + // Non-standard + "examples", // OpenAPI keyword, not JSON Schema +]; + +/// Keywords that should be preserved during cleaning (metadata). +const SCHEMA_META_KEYS: &[&str] = &["description", "title", "default"]; + +/// Schema cleaning strategies for different LLM providers. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CleaningStrategy { + /// Gemini (Google AI / Vertex AI) - Most restrictive + Gemini, + /// Anthropic Claude - Moderately permissive + Anthropic, + /// OpenAI GPT - Most permissive + OpenAI, + /// Conservative: Remove only universally unsupported keywords + Conservative, +} + +impl CleaningStrategy { + /// Get the list of unsupported keywords for this strategy. + pub fn unsupported_keywords(self) -> &'static [&'static str] { + match self { + Self::Gemini => GEMINI_UNSUPPORTED_KEYWORDS, + Self::Anthropic => &["$ref", "$defs", "definitions"], // Anthropic doesn't resolve refs + Self::OpenAI => &[], // OpenAI is most permissive + Self::Conservative => &["$ref", "$defs", "definitions", "additionalProperties"], + } + } +} + +/// JSON Schema cleaner optimized for LLM tool calling. +pub struct SchemaCleanr; + +impl SchemaCleanr { + /// Clean schema for Gemini compatibility (strictest). + /// + /// This is the most aggressive cleaning strategy, removing all keywords + /// that Gemini's API rejects. + pub fn clean_for_gemini(schema: Value) -> Value { + Self::clean(schema, CleaningStrategy::Gemini) + } + + /// Clean schema for Anthropic compatibility. + pub fn clean_for_anthropic(schema: Value) -> Value { + Self::clean(schema, CleaningStrategy::Anthropic) + } + + /// Clean schema for OpenAI compatibility (most permissive). + pub fn clean_for_openai(schema: Value) -> Value { + Self::clean(schema, CleaningStrategy::OpenAI) + } + + /// Clean schema with specified strategy. + pub fn clean(schema: Value, strategy: CleaningStrategy) -> Value { + // Extract $defs for reference resolution + let defs = if let Some(obj) = schema.as_object() { + Self::extract_defs(obj) + } else { + HashMap::new() + }; + + Self::clean_with_defs(schema, &defs, strategy, &mut HashSet::new()) + } + + /// Validate that a schema is suitable for LLM tool calling. + /// + /// Returns an error if the schema is invalid or missing required fields. + pub fn validate(schema: &Value) -> anyhow::Result<()> { + let obj = schema + .as_object() + .ok_or_else(|| anyhow::anyhow!("Schema must be an object"))?; + + // Must have 'type' field + if !obj.contains_key("type") { + anyhow::bail!("Schema missing required 'type' field"); + } + + // If type is 'object', should have 'properties' + if let Some(Value::String(t)) = obj.get("type") { + if t == "object" && !obj.contains_key("properties") { + tracing::warn!("Object schema without 'properties' field may cause issues"); + } + } + + Ok(()) + } + + // -------------------------------------------------------------------- + // Internal implementation + // -------------------------------------------------------------------- + + /// Extract $defs and definitions into a flat map for reference resolution. + fn extract_defs(obj: &Map) -> HashMap { + let mut defs = HashMap::new(); + + // Extract from $defs (JSON Schema 2019-09+) + if let Some(Value::Object(defs_obj)) = obj.get("$defs") { + for (key, value) in defs_obj { + defs.insert(key.clone(), value.clone()); + } + } + + // Extract from definitions (JSON Schema draft-07) + if let Some(Value::Object(defs_obj)) = obj.get("definitions") { + for (key, value) in defs_obj { + defs.insert(key.clone(), value.clone()); + } + } + + defs + } + + /// Recursively clean a schema value. + fn clean_with_defs( + schema: Value, + defs: &HashMap, + strategy: CleaningStrategy, + ref_stack: &mut HashSet, + ) -> Value { + match schema { + Value::Object(obj) => Self::clean_object(obj, defs, strategy, ref_stack), + Value::Array(arr) => Value::Array( + arr.into_iter() + .map(|v| Self::clean_with_defs(v, defs, strategy, ref_stack)) + .collect(), + ), + other => other, + } + } + + /// Clean an object schema. + fn clean_object( + obj: Map, + defs: &HashMap, + strategy: CleaningStrategy, + ref_stack: &mut HashSet, + ) -> Value { + // Handle $ref resolution + if let Some(Value::String(ref_value)) = obj.get("$ref") { + return Self::resolve_ref(ref_value, &obj, defs, strategy, ref_stack); + } + + // Handle anyOf/oneOf simplification + if obj.contains_key("anyOf") || obj.contains_key("oneOf") { + if let Some(simplified) = Self::try_simplify_union(&obj, defs, strategy, ref_stack) { + return simplified; + } + } + + // Build cleaned object + let mut cleaned = Map::new(); + let unsupported: HashSet<&str> = strategy.unsupported_keywords().iter().copied().collect(); + let has_union = obj.contains_key("anyOf") || obj.contains_key("oneOf"); + + for (key, value) in obj { + // Skip unsupported keywords + if unsupported.contains(key.as_str()) { + continue; + } + + // Special handling for specific keys + match key.as_str() { + // Convert const to enum + "const" => { + cleaned.insert("enum".to_string(), json!([value])); + } + // Skip type if we have anyOf/oneOf (they define the type) + "type" if has_union => { + // Skip + } + // Handle type arrays (remove null) + "type" if matches!(value, Value::Array(_)) => { + let cleaned_value = Self::clean_type_array(value); + cleaned.insert(key, cleaned_value); + } + // Recursively clean nested schemas + "properties" => { + let cleaned_value = Self::clean_properties(value, defs, strategy, ref_stack); + cleaned.insert(key, cleaned_value); + } + "items" => { + let cleaned_value = Self::clean_with_defs(value, defs, strategy, ref_stack); + cleaned.insert(key, cleaned_value); + } + "anyOf" | "oneOf" | "allOf" => { + let cleaned_value = Self::clean_union(value, defs, strategy, ref_stack); + cleaned.insert(key, cleaned_value); + } + // Keep all other keys, cleaning nested objects/arrays recursively. + _ => { + let cleaned_value = match value { + Value::Object(_) | Value::Array(_) => { + Self::clean_with_defs(value, defs, strategy, ref_stack) + } + other => other, + }; + cleaned.insert(key, cleaned_value); + } + } + } + + Value::Object(cleaned) + } + + /// Resolve a $ref to its definition. + fn resolve_ref( + ref_value: &str, + obj: &Map, + defs: &HashMap, + strategy: CleaningStrategy, + ref_stack: &mut HashSet, + ) -> Value { + // Prevent circular references + if ref_stack.contains(ref_value) { + tracing::warn!("Circular $ref detected: {}", ref_value); + return Self::preserve_meta(obj, Value::Object(Map::new())); + } + + // Try to resolve local ref (#/$defs/Name or #/definitions/Name) + if let Some(def_name) = Self::parse_local_ref(ref_value) { + if let Some(definition) = defs.get(def_name.as_str()) { + ref_stack.insert(ref_value.to_string()); + let cleaned = Self::clean_with_defs(definition.clone(), defs, strategy, ref_stack); + ref_stack.remove(ref_value); + return Self::preserve_meta(obj, cleaned); + } + } + + // Can't resolve: return empty object with metadata + tracing::warn!("Cannot resolve $ref: {}", ref_value); + Self::preserve_meta(obj, Value::Object(Map::new())) + } + + /// Parse a local JSON Pointer ref (#/$defs/Name). + fn parse_local_ref(ref_value: &str) -> Option { + ref_value + .strip_prefix("#/$defs/") + .or_else(|| ref_value.strip_prefix("#/definitions/")) + .map(Self::decode_json_pointer) + } + + /// Decode JSON Pointer escaping (`~0` = `~`, `~1` = `/`). + fn decode_json_pointer(segment: &str) -> String { + if !segment.contains('~') { + return segment.to_string(); + } + + let mut decoded = String::with_capacity(segment.len()); + let mut chars = segment.chars().peekable(); + + while let Some(ch) = chars.next() { + if ch == '~' { + match chars.peek().copied() { + Some('0') => { + chars.next(); + decoded.push('~'); + } + Some('1') => { + chars.next(); + decoded.push('/'); + } + _ => decoded.push('~'), + } + } else { + decoded.push(ch); + } + } + + decoded + } + + /// Try to simplify anyOf/oneOf to a simpler form. + fn try_simplify_union( + obj: &Map, + defs: &HashMap, + strategy: CleaningStrategy, + ref_stack: &mut HashSet, + ) -> Option { + let union_key = if obj.contains_key("anyOf") { + "anyOf" + } else if obj.contains_key("oneOf") { + "oneOf" + } else { + return None; + }; + + let variants = obj.get(union_key)?.as_array()?; + + // Clean all variants first + let cleaned_variants: Vec = variants + .iter() + .map(|v| Self::clean_with_defs(v.clone(), defs, strategy, ref_stack)) + .collect(); + + // Strip null variants + let non_null: Vec = cleaned_variants + .into_iter() + .filter(|v| !Self::is_null_schema(v)) + .collect(); + + // If only one variant remains after stripping nulls, return it + if non_null.len() == 1 { + return Some(Self::preserve_meta(obj, non_null[0].clone())); + } + + // Try to flatten to enum if all variants are literals + if let Some(enum_value) = Self::try_flatten_literal_union(&non_null) { + return Some(Self::preserve_meta(obj, enum_value)); + } + + None + } + + /// Check if a schema represents null type. + fn is_null_schema(value: &Value) -> bool { + if let Some(obj) = value.as_object() { + // { const: null } + if let Some(Value::Null) = obj.get("const") { + return true; + } + // { enum: [null] } + if let Some(Value::Array(arr)) = obj.get("enum") { + if arr.len() == 1 && matches!(arr[0], Value::Null) { + return true; + } + } + // { type: "null" } + if let Some(Value::String(t)) = obj.get("type") { + if t == "null" { + return true; + } + } + } + false + } + + /// Try to flatten anyOf/oneOf with only literal values to enum. + /// + /// Example: `anyOf: [{const: "a"}, {const: "b"}]` -> `{type: "string", enum: ["a", "b"]}` + fn try_flatten_literal_union(variants: &[Value]) -> Option { + if variants.is_empty() { + return None; + } + + let mut all_values = Vec::new(); + let mut common_type: Option = None; + + for variant in variants { + let obj = variant.as_object()?; + + // Extract literal value from const or single-item enum + let literal_value = if let Some(const_val) = obj.get("const") { + const_val.clone() + } else if let Some(Value::Array(arr)) = obj.get("enum") { + if arr.len() == 1 { + arr[0].clone() + } else { + return None; + } + } else { + return None; + }; + + // Check type consistency + let variant_type = obj.get("type")?.as_str()?; + match &common_type { + None => common_type = Some(variant_type.to_string()), + Some(t) if t != variant_type => return None, + _ => {} + } + + all_values.push(literal_value); + } + + common_type.map(|t| { + json!({ + "type": t, + "enum": all_values + }) + }) + } + + /// Clean type array, removing null. + fn clean_type_array(value: Value) -> Value { + if let Value::Array(types) = value { + let non_null: Vec = types + .into_iter() + .filter(|v| v.as_str() != Some("null")) + .collect(); + + match non_null.len() { + 0 => Value::String("null".to_string()), + 1 => non_null + .into_iter() + .next() + .unwrap_or(Value::String("null".to_string())), + _ => Value::Array(non_null), + } + } else { + value + } + } + + /// Clean properties object. + fn clean_properties( + value: Value, + defs: &HashMap, + strategy: CleaningStrategy, + ref_stack: &mut HashSet, + ) -> Value { + if let Value::Object(props) = value { + let cleaned: Map = props + .into_iter() + .map(|(k, v)| (k, Self::clean_with_defs(v, defs, strategy, ref_stack))) + .collect(); + Value::Object(cleaned) + } else { + value + } + } + + /// Clean union (anyOf/oneOf/allOf). + fn clean_union( + value: Value, + defs: &HashMap, + strategy: CleaningStrategy, + ref_stack: &mut HashSet, + ) -> Value { + if let Value::Array(variants) = value { + let cleaned: Vec = variants + .into_iter() + .map(|v| Self::clean_with_defs(v, defs, strategy, ref_stack)) + .collect(); + Value::Array(cleaned) + } else { + value + } + } + + /// Preserve metadata (description, title, default) from source to target. + fn preserve_meta(source: &Map, mut target: Value) -> Value { + if let Value::Object(target_obj) = &mut target { + for &key in SCHEMA_META_KEYS { + if let Some(value) = source.get(key) { + target_obj.insert(key.to_string(), value.clone()); + } + } + } + target + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_remove_unsupported_keywords() { + let schema = json!({ + "type": "string", + "minLength": 1, + "maxLength": 100, + "pattern": "^[a-z]+$", + "description": "A lowercase string" + }); + + let cleaned = SchemaCleanr::clean_for_gemini(schema); + + assert_eq!(cleaned["type"], "string"); + assert_eq!(cleaned["description"], "A lowercase string"); + assert!(cleaned.get("minLength").is_none()); + assert!(cleaned.get("maxLength").is_none()); + assert!(cleaned.get("pattern").is_none()); + } + + #[test] + fn test_resolve_ref() { + let schema = json!({ + "type": "object", + "properties": { + "age": { + "$ref": "#/$defs/Age" + } + }, + "$defs": { + "Age": { + "type": "integer", + "minimum": 0 + } + } + }); + + let cleaned = SchemaCleanr::clean_for_gemini(schema); + + assert_eq!(cleaned["properties"]["age"]["type"], "integer"); + assert!(cleaned["properties"]["age"].get("minimum").is_none()); // Stripped by Gemini strategy + assert!(cleaned.get("$defs").is_none()); + } + + #[test] + fn test_flatten_literal_union() { + let schema = json!({ + "anyOf": [ + { "const": "admin", "type": "string" }, + { "const": "user", "type": "string" }, + { "const": "guest", "type": "string" } + ] + }); + + let cleaned = SchemaCleanr::clean_for_gemini(schema); + + assert_eq!(cleaned["type"], "string"); + assert!(cleaned["enum"].is_array()); + let enum_values = cleaned["enum"].as_array().unwrap(); + assert_eq!(enum_values.len(), 3); + assert!(enum_values.contains(&json!("admin"))); + assert!(enum_values.contains(&json!("user"))); + assert!(enum_values.contains(&json!("guest"))); + } + + #[test] + fn test_strip_null_from_union() { + let schema = json!({ + "oneOf": [ + { "type": "string" }, + { "type": "null" } + ] + }); + + let cleaned = SchemaCleanr::clean_for_gemini(schema); + + // Should simplify to just { type: "string" } + assert_eq!(cleaned["type"], "string"); + assert!(cleaned.get("oneOf").is_none()); + } + + #[test] + fn test_const_to_enum() { + let schema = json!({ + "const": "fixed_value", + "description": "A constant" + }); + + let cleaned = SchemaCleanr::clean_for_gemini(schema); + + assert_eq!(cleaned["enum"], json!(["fixed_value"])); + assert_eq!(cleaned["description"], "A constant"); + assert!(cleaned.get("const").is_none()); + } + + #[test] + fn test_preserve_metadata() { + let schema = json!({ + "$ref": "#/$defs/Name", + "description": "User's name", + "title": "Name Field", + "default": "Anonymous", + "$defs": { + "Name": { + "type": "string" + } + } + }); + + let cleaned = SchemaCleanr::clean_for_gemini(schema); + + assert_eq!(cleaned["type"], "string"); + assert_eq!(cleaned["description"], "User's name"); + assert_eq!(cleaned["title"], "Name Field"); + assert_eq!(cleaned["default"], "Anonymous"); + } + + #[test] + fn test_circular_ref_prevention() { + let schema = json!({ + "type": "object", + "properties": { + "parent": { + "$ref": "#/$defs/Node" + } + }, + "$defs": { + "Node": { + "type": "object", + "properties": { + "child": { + "$ref": "#/$defs/Node" + } + } + } + } + }); + + // Should not panic on circular reference + let cleaned = SchemaCleanr::clean_for_gemini(schema); + + assert_eq!(cleaned["properties"]["parent"]["type"], "object"); + // Circular reference should be broken + } + + #[test] + fn test_validate_schema() { + let valid = json!({ + "type": "object", + "properties": { + "name": { "type": "string" } + } + }); + + assert!(SchemaCleanr::validate(&valid).is_ok()); + + let invalid = json!({ + "properties": { + "name": { "type": "string" } + } + }); + + assert!(SchemaCleanr::validate(&invalid).is_err()); + } + + #[test] + fn test_strategy_differences() { + let schema = json!({ + "type": "string", + "minLength": 1, + "description": "A string field" + }); + + // Gemini: Most restrictive (removes minLength) + let gemini = SchemaCleanr::clean_for_gemini(schema.clone()); + assert!(gemini.get("minLength").is_none()); + assert_eq!(gemini["type"], "string"); + assert_eq!(gemini["description"], "A string field"); + + // OpenAI: Most permissive (keeps minLength) + let openai = SchemaCleanr::clean_for_openai(schema.clone()); + assert_eq!(openai["minLength"], 1); // OpenAI allows validation keywords + assert_eq!(openai["type"], "string"); + } + + #[test] + fn test_nested_properties() { + let schema = json!({ + "type": "object", + "properties": { + "user": { + "type": "object", + "properties": { + "name": { + "type": "string", + "minLength": 1 + } + }, + "additionalProperties": false + } + } + }); + + let cleaned = SchemaCleanr::clean_for_gemini(schema); + + assert!(cleaned["properties"]["user"]["properties"]["name"] + .get("minLength") + .is_none()); + assert!(cleaned["properties"]["user"] + .get("additionalProperties") + .is_none()); + } + + #[test] + fn test_type_array_null_removal() { + let schema = json!({ + "type": ["string", "null"] + }); + + let cleaned = SchemaCleanr::clean_for_gemini(schema); + + // Should simplify to just "string" + assert_eq!(cleaned["type"], "string"); + } + + #[test] + fn test_type_array_only_null_preserved() { + let schema = json!({ + "type": ["null"] + }); + + let cleaned = SchemaCleanr::clean_for_gemini(schema); + + assert_eq!(cleaned["type"], "null"); + } + + #[test] + fn test_ref_with_json_pointer_escape() { + let schema = json!({ + "$ref": "#/$defs/Foo~1Bar", + "$defs": { + "Foo/Bar": { + "type": "string" + } + } + }); + + let cleaned = SchemaCleanr::clean_for_gemini(schema); + + assert_eq!(cleaned["type"], "string"); + } + + #[test] + fn test_skip_type_when_non_simplifiable_union_exists() { + let schema = json!({ + "type": "object", + "oneOf": [ + { + "type": "object", + "properties": { + "a": { "type": "string" } + } + }, + { + "type": "object", + "properties": { + "b": { "type": "number" } + } + } + ] + }); + + let cleaned = SchemaCleanr::clean_for_gemini(schema); + + assert!(cleaned.get("type").is_none()); + assert!(cleaned.get("oneOf").is_some()); + } + + #[test] + fn test_clean_nested_unknown_schema_keyword() { + let schema = json!({ + "not": { + "$ref": "#/$defs/Age" + }, + "$defs": { + "Age": { + "type": "integer", + "minimum": 0 + } + } + }); + + let cleaned = SchemaCleanr::clean_for_gemini(schema); + + assert_eq!(cleaned["not"]["type"], "integer"); + assert!(cleaned["not"].get("minimum").is_none()); + } +} diff --git a/src/tools/screenshot.rs b/src/tools/screenshot.rs new file mode 100644 index 0000000..7581bc1 --- /dev/null +++ b/src/tools/screenshot.rs @@ -0,0 +1,300 @@ +use super::traits::{Tool, ToolResult}; +use crate::security::SecurityPolicy; +use async_trait::async_trait; +use serde_json::json; +use std::fmt::Write; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Duration; + +/// Maximum time to wait for a screenshot command to complete. +const SCREENSHOT_TIMEOUT_SECS: u64 = 15; +/// Maximum base64 payload size to return (2 MB of base64 ≈ 1.5 MB image). +const MAX_BASE64_BYTES: usize = 2_097_152; + +/// Tool for capturing screenshots using platform-native commands. +/// +/// macOS: `screencapture` +/// Linux: tries `gnome-screenshot`, `scrot`, `import` (`ImageMagick`) in order. +pub struct ScreenshotTool { + security: Arc, +} + +impl ScreenshotTool { + pub fn new(security: Arc) -> Self { + Self { security } + } + + /// Determine the screenshot command for the current platform. + fn screenshot_command(output_path: &str) -> Option> { + if cfg!(target_os = "macos") { + Some(vec![ + "screencapture".into(), + "-x".into(), // no sound + output_path.into(), + ]) + } else if cfg!(target_os = "linux") { + Some(vec![ + "sh".into(), + "-c".into(), + format!( + "if command -v gnome-screenshot >/dev/null 2>&1; then \ + gnome-screenshot -f '{output_path}'; \ + elif command -v scrot >/dev/null 2>&1; then \ + scrot '{output_path}'; \ + elif command -v import >/dev/null 2>&1; then \ + import -window root '{output_path}'; \ + else \ + echo 'NO_SCREENSHOT_TOOL' >&2; exit 1; \ + fi" + ), + ]) + } else { + None + } + } + + /// Execute the screenshot capture and return the result. + async fn capture(&self, args: serde_json::Value) -> anyhow::Result { + let timestamp = chrono::Utc::now().format("%Y%m%d_%H%M%S"); + let filename = args + .get("filename") + .and_then(|v| v.as_str()) + .map_or_else(|| format!("screenshot_{timestamp}.png"), String::from); + + // Sanitize filename to prevent path traversal + let safe_name = PathBuf::from(&filename).file_name().map_or_else( + || format!("screenshot_{timestamp}.png"), + |n| n.to_string_lossy().to_string(), + ); + + let output_path = self.security.workspace_dir.join(&safe_name); + let output_str = output_path.to_string_lossy().to_string(); + + let Some(mut cmd_args) = Self::screenshot_command(&output_str) else { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Screenshot not supported on this platform".into()), + }); + }; + + // macOS region flags + if cfg!(target_os = "macos") { + if let Some(region) = args.get("region").and_then(|v| v.as_str()) { + match region { + "selection" => cmd_args.insert(1, "-s".into()), + "window" => cmd_args.insert(1, "-w".into()), + _ => {} // ignore unknown regions + } + } + } + + let program = cmd_args.remove(0); + let result = tokio::time::timeout( + Duration::from_secs(SCREENSHOT_TIMEOUT_SECS), + tokio::process::Command::new(&program) + .args(&cmd_args) + .output(), + ) + .await; + + match result { + Ok(Ok(output)) => { + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + if stderr.contains("NO_SCREENSHOT_TOOL") { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some( + "No screenshot tool found. Install gnome-screenshot, scrot, or ImageMagick." + .into(), + ), + }); + } + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Screenshot command failed: {stderr}")), + }); + } + + Self::read_and_encode(&output_path).await + } + Ok(Err(e)) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to execute screenshot command: {e}")), + }), + Err(_) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Screenshot timed out after {SCREENSHOT_TIMEOUT_SECS}s" + )), + }), + } + } + + /// Read the screenshot file and return base64-encoded result. + async fn read_and_encode(output_path: &std::path::Path) -> anyhow::Result { + // Check file size before reading to prevent OOM on large screenshots + const MAX_RAW_BYTES: u64 = 1_572_864; // ~1.5 MB (base64 expands ~33%) + if let Ok(meta) = tokio::fs::metadata(output_path).await { + if meta.len() > MAX_RAW_BYTES { + return Ok(ToolResult { + success: true, + output: format!( + "Screenshot saved to: {}\nSize: {} bytes (too large to base64-encode inline)", + output_path.display(), + meta.len(), + ), + error: None, + }); + } + } + + match tokio::fs::read(output_path).await { + Ok(bytes) => { + use base64::Engine; + let size = bytes.len(); + let mut encoded = base64::engine::general_purpose::STANDARD.encode(&bytes); + let truncated = if encoded.len() > MAX_BASE64_BYTES { + encoded.truncate(encoded.floor_char_boundary(MAX_BASE64_BYTES)); + true + } else { + false + }; + + let mut output_msg = format!( + "Screenshot saved to: {}\nSize: {size} bytes\nBase64 length: {}", + output_path.display(), + encoded.len(), + ); + if truncated { + output_msg.push_str(" (truncated)"); + } + let mime = match output_path.extension().and_then(|e| e.to_str()) { + Some("jpg" | "jpeg") => "image/jpeg", + Some("bmp") => "image/bmp", + Some("gif") => "image/gif", + Some("webp") => "image/webp", + _ => "image/png", + }; + let _ = write!(output_msg, "\ndata:{mime};base64,{encoded}"); + + Ok(ToolResult { + success: true, + output: output_msg, + error: None, + }) + } + Err(e) => Ok(ToolResult { + success: false, + output: format!("Screenshot saved to: {}", output_path.display()), + error: Some(format!("Failed to read screenshot file: {e}")), + }), + } + } +} + +#[async_trait] +impl Tool for ScreenshotTool { + fn name(&self) -> &str { + "screenshot" + } + + fn description(&self) -> &str { + "Capture a screenshot of the current screen. Returns the file path and base64-encoded PNG data." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "filename": { + "type": "string", + "description": "Optional filename (default: screenshot_.png). Saved in workspace." + }, + "region": { + "type": "string", + "description": "Optional region for macOS: 'selection' for interactive crop, 'window' for front window. Ignored on Linux." + } + } + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + if !self.security.can_act() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Action blocked: autonomy is read-only".into()), + }); + } + self.capture(args).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::security::{AutonomyLevel, SecurityPolicy}; + + fn test_security() -> Arc { + Arc::new(SecurityPolicy { + autonomy: AutonomyLevel::Full, + workspace_dir: std::env::temp_dir(), + ..SecurityPolicy::default() + }) + } + + #[test] + fn screenshot_tool_name() { + let tool = ScreenshotTool::new(test_security()); + assert_eq!(tool.name(), "screenshot"); + } + + #[test] + fn screenshot_tool_description() { + let tool = ScreenshotTool::new(test_security()); + assert!(!tool.description().is_empty()); + assert!(tool.description().contains("screenshot")); + } + + #[test] + fn screenshot_tool_schema() { + let tool = ScreenshotTool::new(test_security()); + let schema = tool.parameters_schema(); + assert!(schema["properties"]["filename"].is_object()); + assert!(schema["properties"]["region"].is_object()); + } + + #[test] + fn screenshot_tool_spec() { + let tool = ScreenshotTool::new(test_security()); + let spec = tool.spec(); + assert_eq!(spec.name, "screenshot"); + assert!(spec.parameters.is_object()); + } + + #[test] + #[cfg(any(target_os = "macos", target_os = "linux"))] + fn screenshot_command_exists() { + let cmd = ScreenshotTool::screenshot_command("/tmp/test.png"); + assert!(cmd.is_some()); + let args = cmd.unwrap(); + assert!(!args.is_empty()); + } + + #[test] + fn screenshot_command_contains_output_path() { + let cmd = ScreenshotTool::screenshot_command("/tmp/my_screenshot.png").unwrap(); + let joined = cmd.join(" "); + assert!( + joined.contains("/tmp/my_screenshot.png"), + "Command should contain the output path" + ); + } +} diff --git a/src/tools/shell.rs b/src/tools/shell.rs index 92a5582..662d7ab 100644 --- a/src/tools/shell.rs +++ b/src/tools/shell.rs @@ -1,4 +1,5 @@ use super::traits::{Tool, ToolResult}; +use crate::runtime::RuntimeAdapter; use crate::security::SecurityPolicy; use async_trait::async_trait; use serde_json::json; @@ -9,15 +10,21 @@ use std::time::Duration; const SHELL_TIMEOUT_SECS: u64 = 60; /// Maximum output size in bytes (1MB). const MAX_OUTPUT_BYTES: usize = 1_048_576; +/// Environment variables safe to pass to shell commands. +/// Only functional variables are included — never API keys or secrets. +const SAFE_ENV_VARS: &[&str] = &[ + "PATH", "HOME", "TERM", "LANG", "LC_ALL", "LC_CTYPE", "USER", "SHELL", "TMPDIR", +]; /// Shell command execution tool with sandboxing pub struct ShellTool { security: Arc, + runtime: Arc, } impl ShellTool { - pub fn new(security: Arc) -> Self { - Self { security } + pub fn new(security: Arc, runtime: Arc) -> Self { + Self { security, runtime } } } @@ -38,6 +45,11 @@ impl Tool for ShellTool { "command": { "type": "string", "description": "The shell command to execute" + }, + "approved": { + "type": "boolean", + "description": "Set true to explicitly approve medium/high-risk commands in supervised mode", + "default": false } }, "required": ["command"] @@ -49,26 +61,64 @@ impl Tool for ShellTool { .get("command") .and_then(|v| v.as_str()) .ok_or_else(|| anyhow::anyhow!("Missing 'command' parameter"))?; + let approved = args + .get("approved") + .and_then(|v| v.as_bool()) + .unwrap_or(false); - // Security check: validate command against allowlist - if !self.security.is_command_allowed(command) { + if self.security.is_rate_limited() { return Ok(ToolResult { success: false, output: String::new(), - error: Some(format!("Command not allowed by security policy: {command}")), + error: Some("Rate limit exceeded: too many actions in the last hour".into()), }); } - // Execute with timeout to prevent hanging commands - let result = tokio::time::timeout( - Duration::from_secs(SHELL_TIMEOUT_SECS), - tokio::process::Command::new("sh") - .arg("-c") - .arg(command) - .current_dir(&self.security.workspace_dir) - .output(), - ) - .await; + match self.security.validate_command_execution(command, approved) { + Ok(_) => {} + Err(reason) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(reason), + }); + } + } + + if !self.security.record_action() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Rate limit exceeded: action budget exhausted".into()), + }); + } + + // Execute with timeout to prevent hanging commands. + // Clear the environment to prevent leaking API keys and other secrets + // (CWE-200), then re-add only safe, functional variables. + let mut cmd = match self + .runtime + .build_shell_command(command, &self.security.workspace_dir) + { + Ok(cmd) => cmd, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to build runtime command: {e}")), + }); + } + }; + cmd.env_clear(); + + for var in SAFE_ENV_VARS { + if let Ok(val) = std::env::var(var) { + cmd.env(var, val); + } + } + + let result = + tokio::time::timeout(Duration::from_secs(SHELL_TIMEOUT_SECS), cmd.output()).await; match result { Ok(Ok(output)) => { @@ -77,11 +127,11 @@ impl Tool for ShellTool { // Truncate output to prevent OOM if stdout.len() > MAX_OUTPUT_BYTES { - stdout.truncate(MAX_OUTPUT_BYTES); + stdout.truncate(stdout.floor_char_boundary(MAX_OUTPUT_BYTES)); stdout.push_str("\n... [output truncated at 1MB]"); } if stderr.len() > MAX_OUTPUT_BYTES { - stderr.truncate(MAX_OUTPUT_BYTES); + stderr.truncate(stderr.floor_char_boundary(MAX_OUTPUT_BYTES)); stderr.push_str("\n... [stderr truncated at 1MB]"); } @@ -114,6 +164,7 @@ impl Tool for ShellTool { #[cfg(test)] mod tests { use super::*; + use crate::runtime::{NativeRuntime, RuntimeAdapter}; use crate::security::{AutonomyLevel, SecurityPolicy}; fn test_security(autonomy: AutonomyLevel) -> Arc { @@ -124,32 +175,37 @@ mod tests { }) } + fn test_runtime() -> Arc { + Arc::new(NativeRuntime::new()) + } + #[test] fn shell_tool_name() { - let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); + let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime()); assert_eq!(tool.name(), "shell"); } #[test] fn shell_tool_description() { - let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); + let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime()); assert!(!tool.description().is_empty()); } #[test] fn shell_tool_schema_has_command() { - let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); + let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime()); let schema = tool.parameters_schema(); assert!(schema["properties"]["command"].is_object()); assert!(schema["required"] .as_array() .unwrap() .contains(&json!("command"))); + assert!(schema["properties"]["approved"].is_object()); } #[tokio::test] async fn shell_executes_allowed_command() { - let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); + let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime()); let result = tool .execute(json!({"command": "echo hello"})) .await @@ -161,15 +217,16 @@ mod tests { #[tokio::test] async fn shell_blocks_disallowed_command() { - let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); + let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime()); let result = tool.execute(json!({"command": "rm -rf /"})).await.unwrap(); assert!(!result.success); - assert!(result.error.as_ref().unwrap().contains("not allowed")); + let error = result.error.as_deref().unwrap_or(""); + assert!(error.contains("not allowed") || error.contains("high-risk")); } #[tokio::test] async fn shell_blocks_readonly() { - let tool = ShellTool::new(test_security(AutonomyLevel::ReadOnly)); + let tool = ShellTool::new(test_security(AutonomyLevel::ReadOnly), test_runtime()); let result = tool.execute(json!({"command": "ls"})).await.unwrap(); assert!(!result.success); assert!(result.error.as_ref().unwrap().contains("not allowed")); @@ -177,7 +234,7 @@ mod tests { #[tokio::test] async fn shell_missing_command_param() { - let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); + let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime()); let result = tool.execute(json!({})).await; assert!(result.is_err()); assert!(result.unwrap_err().to_string().contains("command")); @@ -185,18 +242,127 @@ mod tests { #[tokio::test] async fn shell_wrong_type_param() { - let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); + let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime()); let result = tool.execute(json!({"command": 123})).await; assert!(result.is_err()); } #[tokio::test] async fn shell_captures_exit_code() { - let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); + let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime()); let result = tool .execute(json!({"command": "ls /nonexistent_dir_xyz"})) .await .unwrap(); assert!(!result.success); } + + fn test_security_with_env_cmd() -> Arc { + Arc::new(SecurityPolicy { + autonomy: AutonomyLevel::Supervised, + workspace_dir: std::env::temp_dir(), + allowed_commands: vec!["env".into(), "echo".into()], + ..SecurityPolicy::default() + }) + } + + /// RAII guard that restores an environment variable to its original state on drop, + /// ensuring cleanup even if the test panics. + struct EnvGuard { + key: &'static str, + original: Option, + } + + impl EnvGuard { + fn set(key: &'static str, value: &str) -> Self { + let original = std::env::var(key).ok(); + std::env::set_var(key, value); + Self { key, original } + } + } + + impl Drop for EnvGuard { + fn drop(&mut self) { + match &self.original { + Some(val) => std::env::set_var(self.key, val), + None => std::env::remove_var(self.key), + } + } + } + + #[tokio::test(flavor = "current_thread")] + async fn shell_does_not_leak_api_key() { + let _g1 = EnvGuard::set("API_KEY", "sk-test-secret-12345"); + let _g2 = EnvGuard::set("ZEROCLAW_API_KEY", "sk-test-secret-67890"); + + let tool = ShellTool::new(test_security_with_env_cmd(), test_runtime()); + let result = tool.execute(json!({"command": "env"})).await.unwrap(); + assert!(result.success); + assert!( + !result.output.contains("sk-test-secret-12345"), + "API_KEY leaked to shell command output" + ); + assert!( + !result.output.contains("sk-test-secret-67890"), + "ZEROCLAW_API_KEY leaked to shell command output" + ); + } + + #[tokio::test] + async fn shell_preserves_path_and_home() { + let tool = ShellTool::new(test_security_with_env_cmd(), test_runtime()); + + let result = tool + .execute(json!({"command": "echo $HOME"})) + .await + .unwrap(); + assert!(result.success); + assert!( + !result.output.trim().is_empty(), + "HOME should be available in shell" + ); + + let result = tool + .execute(json!({"command": "echo $PATH"})) + .await + .unwrap(); + assert!(result.success); + assert!( + !result.output.trim().is_empty(), + "PATH should be available in shell" + ); + } + + #[tokio::test] + async fn shell_requires_approval_for_medium_risk_command() { + let security = Arc::new(SecurityPolicy { + autonomy: AutonomyLevel::Supervised, + allowed_commands: vec!["touch".into()], + workspace_dir: std::env::temp_dir(), + ..SecurityPolicy::default() + }); + + let tool = ShellTool::new(security.clone(), test_runtime()); + let denied = tool + .execute(json!({"command": "touch zeroclaw_shell_approval_test"})) + .await + .unwrap(); + assert!(!denied.success); + assert!(denied + .error + .as_deref() + .unwrap_or("") + .contains("explicit approval")); + + let allowed = tool + .execute(json!({ + "command": "touch zeroclaw_shell_approval_test", + "approved": true + })) + .await + .unwrap(); + assert!(allowed.success); + + let _ = std::fs::remove_file(std::env::temp_dir().join("zeroclaw_shell_approval_test")); + } } diff --git a/src/tools/traits.rs b/src/tools/traits.rs index 714e83b..0a12606 100644 --- a/src/tools/traits.rs +++ b/src/tools/traits.rs @@ -41,3 +41,81 @@ pub trait Tool: Send + Sync { } } } + +#[cfg(test)] +mod tests { + use super::*; + + struct DummyTool; + + #[async_trait] + impl Tool for DummyTool { + fn name(&self) -> &str { + "dummy_tool" + } + + fn description(&self) -> &str { + "A deterministic test tool" + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": { + "value": { "type": "string" } + } + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + Ok(ToolResult { + success: true, + output: args + .get("value") + .and_then(serde_json::Value::as_str) + .unwrap_or_default() + .to_string(), + error: None, + }) + } + } + + #[test] + fn spec_uses_tool_metadata_and_schema() { + let tool = DummyTool; + let spec = tool.spec(); + + assert_eq!(spec.name, "dummy_tool"); + assert_eq!(spec.description, "A deterministic test tool"); + assert_eq!(spec.parameters["type"], "object"); + assert_eq!(spec.parameters["properties"]["value"]["type"], "string"); + } + + #[tokio::test] + async fn execute_returns_expected_output() { + let tool = DummyTool; + let result = tool + .execute(serde_json::json!({ "value": "hello-tool" })) + .await + .unwrap(); + + assert!(result.success); + assert_eq!(result.output, "hello-tool"); + assert!(result.error.is_none()); + } + + #[test] + fn tool_result_serialization_roundtrip() { + let result = ToolResult { + success: false, + output: String::new(), + error: Some("boom".into()), + }; + + let json = serde_json::to_string(&result).unwrap(); + let parsed: ToolResult = serde_json::from_str(&json).unwrap(); + + assert!(!parsed.success); + assert_eq!(parsed.error.as_deref(), Some("boom")); + } +} diff --git a/src/tunnel/cloudflare.rs b/src/tunnel/cloudflare.rs index e387099..d92cbb7 100644 --- a/src/tunnel/cloudflare.rs +++ b/src/tunnel/cloudflare.rs @@ -109,3 +109,33 @@ impl Tunnel for CloudflareTunnel { .and_then(|g| g.as_ref().map(|tp| tp.public_url.clone())) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn constructor_stores_token() { + let tunnel = CloudflareTunnel::new("cf-token".into()); + assert_eq!(tunnel.token, "cf-token"); + } + + #[test] + fn public_url_is_none_before_start() { + let tunnel = CloudflareTunnel::new("cf-token".into()); + assert!(tunnel.public_url().is_none()); + } + + #[tokio::test] + async fn stop_without_started_process_is_ok() { + let tunnel = CloudflareTunnel::new("cf-token".into()); + let result = tunnel.stop().await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn health_check_is_false_before_start() { + let tunnel = CloudflareTunnel::new("cf-token".into()); + assert!(!tunnel.health_check().await); + } +} diff --git a/src/tunnel/custom.rs b/src/tunnel/custom.rs index c65ff32..ef962b4 100644 --- a/src/tunnel/custom.rs +++ b/src/tunnel/custom.rs @@ -143,3 +143,78 @@ impl Tunnel for CustomTunnel { .and_then(|g| g.as_ref().map(|tp| tp.public_url.clone())) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn start_with_empty_command_returns_error() { + let tunnel = CustomTunnel::new(" ".into(), None, None); + let result = tunnel.start("127.0.0.1", 8080).await; + + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("start_command is empty")); + } + + #[tokio::test] + async fn start_without_pattern_returns_local_url() { + let tunnel = CustomTunnel::new("sleep 1".into(), None, None); + + let url = tunnel.start("127.0.0.1", 4455).await.unwrap(); + assert_eq!(url, "http://127.0.0.1:4455"); + assert_eq!( + tunnel.public_url().as_deref(), + Some("http://127.0.0.1:4455") + ); + + tunnel.stop().await.unwrap(); + } + + #[tokio::test] + async fn start_with_pattern_extracts_url() { + let tunnel = CustomTunnel::new( + "echo https://public.example".into(), + None, + Some("public.example".into()), + ); + + let url = tunnel.start("localhost", 9999).await.unwrap(); + + assert_eq!(url, "https://public.example"); + assert_eq!( + tunnel.public_url().as_deref(), + Some("https://public.example") + ); + + tunnel.stop().await.unwrap(); + } + + #[tokio::test] + async fn start_replaces_host_and_port_placeholders() { + let tunnel = CustomTunnel::new( + "echo http://{host}:{port}".into(), + None, + Some("http://".into()), + ); + + let url = tunnel.start("10.1.2.3", 4321).await.unwrap(); + + assert_eq!(url, "http://10.1.2.3:4321"); + tunnel.stop().await.unwrap(); + } + + #[tokio::test] + async fn health_check_with_unreachable_health_url_returns_false() { + let tunnel = CustomTunnel::new( + "sleep 1".into(), + Some("http://127.0.0.1:9/healthz".into()), + None, + ); + + assert!(!tunnel.health_check().await); + } +} diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index 0682a1b..6a852d8 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -128,6 +128,7 @@ mod tests { use crate::config::schema::{ CloudflareTunnelConfig, CustomTunnelConfig, NgrokTunnelConfig, TunnelConfig, }; + use tokio::process::Command; /// Helper: assert `create_tunnel` returns an error containing `needle`. fn assert_tunnel_err(cfg: &TunnelConfig, needle: &str) { @@ -313,4 +314,62 @@ mod tests { assert_eq!(t.name(), "custom"); assert!(t.public_url().is_none()); } + + #[tokio::test] + async fn kill_shared_no_process_is_ok() { + let proc = new_shared_process(); + let result = kill_shared(&proc).await; + + assert!(result.is_ok()); + assert!(proc.lock().await.is_none()); + } + + #[tokio::test] + async fn kill_shared_terminates_and_clears_child() { + let proc = new_shared_process(); + + let child = Command::new("sleep") + .arg("30") + .stdout(std::process::Stdio::null()) + .stderr(std::process::Stdio::null()) + .spawn() + .expect("sleep should spawn for lifecycle test"); + + { + let mut guard = proc.lock().await; + *guard = Some(TunnelProcess { + child, + public_url: "https://example.test".into(), + }); + } + + kill_shared(&proc).await.unwrap(); + + let guard = proc.lock().await; + assert!(guard.is_none()); + } + + #[tokio::test] + async fn cloudflare_health_false_before_start() { + let tunnel = CloudflareTunnel::new("tok".into()); + assert!(!tunnel.health_check().await); + } + + #[tokio::test] + async fn ngrok_health_false_before_start() { + let tunnel = NgrokTunnel::new("tok".into(), None); + assert!(!tunnel.health_check().await); + } + + #[tokio::test] + async fn tailscale_health_false_before_start() { + let tunnel = TailscaleTunnel::new(false, None); + assert!(!tunnel.health_check().await); + } + + #[tokio::test] + async fn custom_health_false_before_start_without_health_url() { + let tunnel = CustomTunnel::new("echo hi".into(), None, Some("https://".into())); + assert!(!tunnel.health_check().await); + } } diff --git a/src/tunnel/ngrok.rs b/src/tunnel/ngrok.rs index e993e79..7d16a11 100644 --- a/src/tunnel/ngrok.rs +++ b/src/tunnel/ngrok.rs @@ -119,3 +119,33 @@ impl Tunnel for NgrokTunnel { .and_then(|g| g.as_ref().map(|tp| tp.public_url.clone())) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn constructor_stores_domain() { + let tunnel = NgrokTunnel::new("ngrok-token".into(), Some("my.ngrok.app".into())); + assert_eq!(tunnel.domain.as_deref(), Some("my.ngrok.app")); + } + + #[test] + fn public_url_is_none_before_start() { + let tunnel = NgrokTunnel::new("ngrok-token".into(), None); + assert!(tunnel.public_url().is_none()); + } + + #[tokio::test] + async fn stop_without_started_process_is_ok() { + let tunnel = NgrokTunnel::new("ngrok-token".into(), None); + let result = tunnel.stop().await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn health_check_is_false_before_start() { + let tunnel = NgrokTunnel::new("ngrok-token".into(), None); + assert!(!tunnel.health_check().await); + } +} diff --git a/src/tunnel/none.rs b/src/tunnel/none.rs index a8de838..dc7189a 100644 --- a/src/tunnel/none.rs +++ b/src/tunnel/none.rs @@ -26,3 +26,39 @@ impl Tunnel for NoneTunnel { None } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn name_is_none() { + let tunnel = NoneTunnel; + assert_eq!(tunnel.name(), "none"); + } + + #[tokio::test] + async fn start_returns_local_url() { + let tunnel = NoneTunnel; + let url = tunnel.start("127.0.0.1", 7788).await.unwrap(); + assert_eq!(url, "http://127.0.0.1:7788"); + } + + #[tokio::test] + async fn stop_is_noop_success() { + let tunnel = NoneTunnel; + assert!(tunnel.stop().await.is_ok()); + } + + #[tokio::test] + async fn health_check_is_always_true() { + let tunnel = NoneTunnel; + assert!(tunnel.health_check().await); + } + + #[test] + fn public_url_is_always_none() { + let tunnel = NoneTunnel; + assert!(tunnel.public_url().is_none()); + } +} diff --git a/src/tunnel/tailscale.rs b/src/tunnel/tailscale.rs index 4a69038..f983d8e 100644 --- a/src/tunnel/tailscale.rs +++ b/src/tunnel/tailscale.rs @@ -100,3 +100,34 @@ impl Tunnel for TailscaleTunnel { .and_then(|g| g.as_ref().map(|tp| tp.public_url.clone())) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn constructor_stores_hostname_and_mode() { + let tunnel = TailscaleTunnel::new(true, Some("myhost.tailnet.ts.net".into())); + assert!(tunnel.funnel); + assert_eq!(tunnel.hostname.as_deref(), Some("myhost.tailnet.ts.net")); + } + + #[test] + fn public_url_is_none_before_start() { + let tunnel = TailscaleTunnel::new(false, None); + assert!(tunnel.public_url().is_none()); + } + + #[tokio::test] + async fn health_check_is_false_before_start() { + let tunnel = TailscaleTunnel::new(false, None); + assert!(!tunnel.health_check().await); + } + + #[tokio::test] + async fn stop_without_started_process_is_ok() { + let tunnel = TailscaleTunnel::new(false, None); + let result = tunnel.stop().await; + assert!(result.is_ok()); + } +} diff --git a/src/util.rs b/src/util.rs new file mode 100644 index 0000000..9a218e7 --- /dev/null +++ b/src/util.rs @@ -0,0 +1,138 @@ +//! Utility functions for `ZeroClaw`. +//! +//! This module contains reusable helper functions used across the codebase. + +/// Truncate a string to at most `max_chars` characters, appending "..." if truncated. +/// +/// This function safely handles multi-byte UTF-8 characters (emoji, CJK, accented characters) +/// by using character boundaries instead of byte indices. +/// +/// # Arguments +/// * `s` - The string to truncate +/// * `max_chars` - Maximum number of characters to keep (excluding "...") +/// +/// # Returns +/// * Original string if length <= `max_chars` +/// * Truncated string with "..." appended if length > `max_chars` +/// +/// # Examples +/// ``` +/// use zeroclaw::util::truncate_with_ellipsis; +/// +/// // ASCII string - no truncation needed +/// assert_eq!(truncate_with_ellipsis("hello", 10), "hello"); +/// +/// // ASCII string - truncation needed +/// assert_eq!(truncate_with_ellipsis("hello world", 5), "hello..."); +/// +/// // Multi-byte UTF-8 (emoji) - safe truncation +/// assert_eq!(truncate_with_ellipsis("Hello 🦀 World", 8), "Hello 🦀..."); +/// assert_eq!(truncate_with_ellipsis("😀😀😀😀", 2), "😀😀..."); +/// +/// // Empty string +/// assert_eq!(truncate_with_ellipsis("", 10), ""); +/// ``` +pub fn truncate_with_ellipsis(s: &str, max_chars: usize) -> String { + match s.char_indices().nth(max_chars) { + Some((idx, _)) => { + let truncated = &s[..idx]; + // Trim trailing whitespace for cleaner output + format!("{}...", truncated.trim_end()) + } + None => s.to_string(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_truncate_ascii_no_truncation() { + // ASCII string shorter than limit - no change + assert_eq!(truncate_with_ellipsis("hello", 10), "hello"); + assert_eq!(truncate_with_ellipsis("hello world", 50), "hello world"); + } + + #[test] + fn test_truncate_ascii_with_truncation() { + // ASCII string longer than limit - truncates + assert_eq!(truncate_with_ellipsis("hello world", 5), "hello..."); + assert_eq!( + truncate_with_ellipsis("This is a long message", 10), + "This is a..." + ); + } + + #[test] + fn test_truncate_empty_string() { + assert_eq!(truncate_with_ellipsis("", 10), ""); + } + + #[test] + fn test_truncate_at_exact_boundary() { + // String exactly at boundary - no truncation + assert_eq!(truncate_with_ellipsis("hello", 5), "hello"); + } + + #[test] + fn test_truncate_emoji_single() { + // Single emoji (4 bytes) - should not panic + let s = "🦀"; + assert_eq!(truncate_with_ellipsis(s, 10), s); + assert_eq!(truncate_with_ellipsis(s, 1), s); + } + + #[test] + fn test_truncate_emoji_multiple() { + // Multiple emoji - safe truncation at character boundary + let s = "😀😀😀😀"; // 4 emoji, each 4 bytes = 16 bytes total + assert_eq!(truncate_with_ellipsis(s, 2), "😀😀..."); + assert_eq!(truncate_with_ellipsis(s, 3), "😀😀😀..."); + } + + #[test] + fn test_truncate_mixed_ascii_emoji() { + // Mixed ASCII and emoji + assert_eq!(truncate_with_ellipsis("Hello 🦀 World", 8), "Hello 🦀..."); + assert_eq!(truncate_with_ellipsis("Hi 😊", 10), "Hi 😊"); + } + + #[test] + fn test_truncate_cjk_characters() { + // CJK characters (Chinese - each is 3 bytes) + let s = "这是一个测试消息用来触发崩溃的中文"; // 21 characters + let result = truncate_with_ellipsis(s, 16); + assert!(result.ends_with("...")); + assert!(result.is_char_boundary(result.len() - 1)); + } + + #[test] + fn test_truncate_accented_characters() { + // Accented characters (2 bytes each in UTF-8) + let s = "café résumé naïve"; + assert_eq!(truncate_with_ellipsis(s, 10), "café résum..."); + } + + #[test] + fn test_truncate_unicode_edge_case() { + // Mix of 1-byte, 2-byte, 3-byte, and 4-byte characters + let s = "aé你好🦀"; // 1 + 1 + 2 + 2 + 4 bytes = 10 bytes, 5 chars + assert_eq!(truncate_with_ellipsis(s, 3), "aé你..."); + } + + #[test] + fn test_truncate_long_string() { + // Long ASCII string + let s = "a".repeat(200); + let result = truncate_with_ellipsis(&s, 50); + assert_eq!(result.len(), 53); // 50 + "..." + assert!(result.ends_with("...")); + } + + #[test] + fn test_truncate_zero_max_chars() { + // Edge case: max_chars = 0 + assert_eq!(truncate_with_ellipsis("hello", 0), "..."); + } +} diff --git a/test_helpers/generate_test_messages.py b/test_helpers/generate_test_messages.py new file mode 100755 index 0000000..17a59af --- /dev/null +++ b/test_helpers/generate_test_messages.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 +""" +Test message generator for Telegram integration testing. +Generates messages of various lengths for testing message splitting. +""" + +import sys + +def generate_short_message(): + """Generate a short message (< 100 chars)""" + return "Hello! This is a short test message." + +def generate_medium_message(): + """Generate a medium message (~ 1000 chars)""" + return "This is a medium-length test message. " * 25 + +def generate_long_message(): + """Generate a long message (~ 5000 chars, > 4096 limit)""" + return "This is a very long test message that will be split into multiple chunks. " * 70 + +def generate_exact_limit_message(): + """Generate a message exactly at 4096 char limit""" + base = "x" * 4096 + return base + +def generate_over_limit_message(): + """Generate a message just over the 4096 char limit""" + return "x" * 4200 + +def generate_multi_chunk_message(): + """Generate a message that requires 3+ chunks""" + return "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " * 250 + +def generate_newline_message(): + """Generate a message with many newlines (tests newline splitting)""" + return "Line of text\n" * 400 + +def generate_word_boundary_message(): + """Generate a message with clear word boundaries""" + return "word " * 1000 + +def print_message_info(message, name): + """Print information about a message""" + print(f"\n{'='*60}") + print(f"{name}") + print(f"{'='*60}") + print(f"Length: {len(message)} characters") + print(f"Will split: {'Yes' if len(message) > 4096 else 'No'}") + if len(message) > 4096: + chunks = (len(message) + 4095) // 4096 + print(f"Estimated chunks: {chunks}") + print(f"{'='*60}") + print(message[:200] + "..." if len(message) > 200 else message) + print(f"{'='*60}\n") + +def main(): + if len(sys.argv) > 1: + test_type = sys.argv[1].lower() + else: + print("Usage: python3 generate_test_messages.py [type]") + print("\nAvailable types:") + print(" short - Short message (< 100 chars)") + print(" medium - Medium message (~1000 chars)") + print(" long - Long message (~5000 chars, requires splitting)") + print(" exact - Exactly 4096 chars") + print(" over - Just over 4096 chars") + print(" multi - Very long (3+ chunks)") + print(" newline - Many newlines (tests line splitting)") + print(" word - Clear word boundaries") + print(" all - Show info for all types") + print("\nExample:") + print(" python3 generate_test_messages.py long") + sys.exit(1) + + messages = { + 'short': ('Short Message', generate_short_message()), + 'medium': ('Medium Message', generate_medium_message()), + 'long': ('Long Message', generate_long_message()), + 'exact': ('Exact Limit (4096)', generate_exact_limit_message()), + 'over': ('Just Over Limit', generate_over_limit_message()), + 'multi': ('Multi-Chunk Message', generate_multi_chunk_message()), + 'newline': ('Newline Test', generate_newline_message()), + 'word': ('Word Boundary Test', generate_word_boundary_message()), + } + + if test_type == 'all': + for name, msg in messages.values(): + print_message_info(msg, name) + elif test_type in messages: + name, msg = messages[test_type] + # Just print the message for piping to Telegram + print(msg) + else: + print(f"Error: Unknown type '{test_type}'") + print("Run without arguments to see available types.") + sys.exit(1) + +if __name__ == '__main__': + main() diff --git a/test_telegram_integration.sh b/test_telegram_integration.sh new file mode 100755 index 0000000..c0ce2b7 --- /dev/null +++ b/test_telegram_integration.sh @@ -0,0 +1,362 @@ +#!/bin/bash +# ZeroClaw Telegram Integration Test Suite +# Automated testing script for Telegram channel functionality + +set -e # Exit on error + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +# Test counters +TOTAL_TESTS=0 +PASSED_TESTS=0 +FAILED_TESTS=0 + +# Helper functions +print_header() { + echo -e "\n${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}" + echo -e "${BLUE}$1${NC}" + echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}\n" +} + +print_test() { + TOTAL_TESTS=$((TOTAL_TESTS + 1)) + echo -e "${YELLOW}Test $TOTAL_TESTS:${NC} $1" +} + +pass() { + PASSED_TESTS=$((PASSED_TESTS + 1)) + echo -e "${GREEN}✓ PASS:${NC} $1\n" +} + +fail() { + FAILED_TESTS=$((FAILED_TESTS + 1)) + echo -e "${RED}✗ FAIL:${NC} $1\n" +} + +warn() { + echo -e "${YELLOW}⚠ WARNING:${NC} $1\n" +} + +# Banner +clear +cat << "EOF" + ⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡ + + ███████╗███████╗██████╗ ██████╗ ██████╗██╗ █████╗ ██╗ ██╗ + ╚══███╔╝██╔════╝██╔══██╗██╔═══██╗██╔════╝██║ ██╔══██╗██║ ██║ + ███╔╝ █████╗ ██████╔╝██║ ██║██║ ██║ ███████║██║ █╗ ██║ + ███╔╝ ██╔══╝ ██╔══██╗██║ ██║██║ ██║ ██╔══██║██║███╗██║ + ███████╗███████╗██║ ██║╚██████╔╝╚██████╗███████╗██║ ██║╚███╔███╔╝ + ╚══════╝╚══════╝╚═╝ ╚═╝ ╚═════╝ ╚═════╝╚══════╝╚═╝ ╚═╝ ╚══╝╚══╝ + + 🧪 TELEGRAM INTEGRATION TEST SUITE 🧪 + + ⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡ +EOF + +echo -e "\n${BLUE}Started at:${NC} $(date)" +echo -e "${BLUE}Working directory:${NC} $(pwd)\n" + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# Phase 1: Code Quality Tests +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +print_header "Phase 1: Code Quality Tests" + +# Test 1: Cargo test compilation +print_test "Compiling test suite" +if cargo test --lib --no-run &>/dev/null; then + pass "Test suite compiles successfully" +else + fail "Test suite compilation failed" + exit 1 +fi + +# Test 2: Unit tests +print_test "Running Telegram unit tests" +TEST_OUTPUT=$(cargo test telegram --lib 2>&1) +if echo "$TEST_OUTPUT" | grep -q "test result: ok"; then + PASSED_COUNT=$(echo "$TEST_OUTPUT" | grep -oP '\d+(?= passed)' | head -1) + pass "All Telegram unit tests passed ($PASSED_COUNT tests)" +else + fail "Some unit tests failed" + echo "$TEST_OUTPUT" | grep "FAILED\|error" +fi + +# Test 3: Message splitting tests specifically +print_test "Verifying message splitting tests" +if cargo test telegram_split --lib --quiet 2>&1 | grep -q "8 passed"; then + pass "All 8 message splitting tests passed" +else + fail "Message splitting tests incomplete" +fi + +# Test 4: Clippy linting +print_test "Running Clippy lint checks" +if cargo clippy --all-targets --quiet 2>&1 | grep -qv "error:"; then + pass "No clippy errors found" +else + CLIPPY_ERRORS=$(cargo clippy --all-targets 2>&1 | grep "error:" | wc -l) + fail "Clippy found $CLIPPY_ERRORS error(s)" +fi + +# Test 5: Code formatting +print_test "Checking code formatting" +if cargo fmt --check &>/dev/null; then + pass "Code is properly formatted" +else + warn "Code formatting issues found (run 'cargo fmt' to fix)" +fi + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# Phase 2: Build Tests +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +print_header "Phase 2: Build Tests" + +# Test 6: Debug build +print_test "Debug build" +if cargo build --quiet 2>&1; then + pass "Debug build successful" +else + fail "Debug build failed" +fi + +# Test 7: Release build +print_test "Release build with optimizations" +START_TIME=$(date +%s) +if cargo build --release --quiet 2>&1; then + END_TIME=$(date +%s) + BUILD_TIME=$((END_TIME - START_TIME)) + pass "Release build successful (${BUILD_TIME}s)" +else + fail "Release build failed" +fi + +# Test 8: Binary size check +print_test "Binary size verification" +if [ -f "target/release/zeroclaw" ]; then + BINARY_SIZE=$(ls -lh target/release/zeroclaw | awk '{print $5}') + SIZE_BYTES=$(stat -f%z target/release/zeroclaw 2>/dev/null || stat -c%s target/release/zeroclaw) + SIZE_MB=$((SIZE_BYTES / 1024 / 1024)) + + if [ $SIZE_MB -le 10 ]; then + pass "Binary size is optimal: $BINARY_SIZE (${SIZE_MB}MB)" + else + warn "Binary size is larger than expected: $BINARY_SIZE (${SIZE_MB}MB)" + fi +else + fail "Release binary not found" +fi + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# Phase 3: Configuration Tests +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +print_header "Phase 3: Configuration Tests" + +# Test 9: Config file existence +print_test "Configuration file check" +CONFIG_PATH="$HOME/.zeroclaw/config.toml" +if [ -f "$CONFIG_PATH" ]; then + pass "Config file exists at $CONFIG_PATH" + + # Test 10: Telegram config + print_test "Telegram configuration check" + if grep -q "\[channels_config.telegram\]" "$CONFIG_PATH"; then + pass "Telegram configuration found" + + # Test 11: Bot token configured + print_test "Bot token validation" + if grep -q "bot_token = \"" "$CONFIG_PATH"; then + pass "Bot token is configured" + else + warn "Bot token not set - integration tests will be skipped" + fi + + # Test 12: Allowlist configured + print_test "User allowlist validation" + if grep -q "allowed_users = \[" "$CONFIG_PATH"; then + pass "User allowlist is configured" + else + warn "User allowlist not set" + fi + else + warn "Telegram not configured - run 'zeroclaw onboard' first" + fi +else + warn "No config file found - run 'zeroclaw onboard' first" +fi + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# Phase 4: Health Check Tests +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +print_header "Phase 4: Health Check Tests" + +# Test 13: Health check timeout +print_test "Health check timeout (should complete in <5s)" +START_TIME=$(date +%s) +HEALTH_OUTPUT=$(timeout 10 target/release/zeroclaw channel doctor 2>&1 || true) +END_TIME=$(date +%s) +HEALTH_TIME=$((END_TIME - START_TIME)) + +if [ $HEALTH_TIME -le 6 ]; then + pass "Health check completed in ${HEALTH_TIME}s (timeout fix working)" +else + warn "Health check took ${HEALTH_TIME}s (expected <5s)" +fi + +# Test 14: Telegram connectivity +print_test "Telegram API connectivity" +if echo "$HEALTH_OUTPUT" | grep -q "Telegram.*healthy"; then + pass "Telegram channel is healthy" +elif echo "$HEALTH_OUTPUT" | grep -q "Telegram.*unhealthy"; then + warn "Telegram channel is unhealthy - check bot token" +elif echo "$HEALTH_OUTPUT" | grep -q "Telegram.*timed out"; then + warn "Telegram health check timed out - network issue?" +else + warn "Could not determine Telegram health status" +fi + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# Phase 5: Feature Validation Tests +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +print_header "Phase 5: Feature Validation Tests" + +# Test 15: Message splitting function exists +print_test "Message splitting function implementation" +if grep -q "fn split_message_for_telegram" src/channels/telegram.rs; then + pass "Message splitting function implemented" +else + fail "Message splitting function not found" +fi + +# Test 16: Message length constant +print_test "Telegram message length constant" +if grep -q "const TELEGRAM_MAX_MESSAGE_LENGTH: usize = 4096" src/channels/telegram.rs; then + pass "TELEGRAM_MAX_MESSAGE_LENGTH constant defined correctly" +else + fail "Message length constant missing or incorrect" +fi + +# Test 17: Timeout implementation +print_test "Health check timeout implementation" +if grep -q "tokio::time::timeout" src/channels/telegram.rs; then + pass "Timeout mechanism implemented in health_check" +else + fail "Timeout not implemented in health_check" +fi + +# Test 18: chat_id validation +print_test "chat_id validation implementation" +if grep -q "let Some(chat_id) = chat_id else" src/channels/telegram.rs; then + pass "chat_id validation implemented" +else + fail "chat_id validation missing" +fi + +# Test 19: Duration import +print_test "std::time::Duration import" +if grep -q "use std::time::Duration" src/channels/telegram.rs; then + pass "Duration import added" +else + fail "Duration import missing" +fi + +# Test 20: Continuation markers +print_test "Multi-part message markers" +if grep -q "(continues...)" src/channels/telegram.rs && grep -q "(continued)" src/channels/telegram.rs; then + pass "Continuation markers implemented for split messages" +else + fail "Continuation markers missing" +fi + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# Phase 6: Integration Test Preparation +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +print_header "Phase 6: Manual Integration Tests" + +echo -e "${BLUE}The following tests require manual interaction:${NC}\n" + +cat << 'EOF' +📱 Manual Test Checklist: + +1. [ ] Start the channel: + zeroclaw channel start + +2. [ ] Send a short message to your bot in Telegram: + "Hello bot!" + ✓ Verify: Bot responds within 3 seconds + +3. [ ] Send a long message (>4096 characters): + python3 -c 'print("test " * 1000)' + ✓ Verify: Message is split into chunks + ✓ Verify: Chunks have (continues...) and (continued) markers + ✓ Verify: All chunks arrive in order + +4. [ ] Test unauthorized access: + - Edit config: allowed_users = ["999999999"] + - Send a message + ✓ Verify: Warning log appears + ✓ Verify: Message is ignored + - Restore correct user ID + +5. [ ] Test rapid messages (10 messages in 5 seconds): + ✓ Verify: All messages are processed + ✓ Verify: No rate limit errors + ✓ Verify: Responses have delays + +6. [ ] Check logs for errors: + RUST_LOG=debug zeroclaw channel start + ✓ Verify: No unexpected errors + ✓ Verify: "missing chat_id" appears for malformed messages + ✓ Verify: Health check logs show "timed out" if needed + +EOF + +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# Test Summary +# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +print_header "Test Summary" + +echo -e "${BLUE}Total Tests:${NC} $TOTAL_TESTS" +echo -e "${GREEN}Passed:${NC} $PASSED_TESTS" +echo -e "${RED}Failed:${NC} $FAILED_TESTS" +echo -e "${YELLOW}Warnings:${NC} $((TOTAL_TESTS - PASSED_TESTS - FAILED_TESTS))" + +PASS_RATE=$((PASSED_TESTS * 100 / TOTAL_TESTS)) +echo -e "\n${BLUE}Pass Rate:${NC} ${PASS_RATE}%" + +if [ $FAILED_TESTS -eq 0 ]; then + echo -e "\n${GREEN}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}" + echo -e "${GREEN}✓ ALL AUTOMATED TESTS PASSED! 🎉${NC}" + echo -e "${GREEN}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}\n" + + echo -e "${BLUE}Next Steps:${NC}" + echo -e "1. Run manual integration tests (see checklist above)" + echo -e "2. Deploy to production when ready" + echo -e "3. Monitor logs for issues\n" + + exit 0 +else + echo -e "\n${RED}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}" + echo -e "${RED}✗ SOME TESTS FAILED${NC}" + echo -e "${RED}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}\n" + + echo -e "${BLUE}Troubleshooting:${NC}" + echo -e "1. Review failed tests above" + echo -e "2. Run: cargo test telegram --lib -- --nocapture" + echo -e "3. Check: cargo clippy --all-targets" + echo -e "4. Fix issues and re-run this script\n" + + exit 1 +fi diff --git a/tests/memory_comparison.rs b/tests/memory_comparison.rs index 8e0f4d6..2523829 100644 --- a/tests/memory_comparison.rs +++ b/tests/memory_comparison.rs @@ -36,6 +36,7 @@ async fn compare_store_speed() { &format!("key_{i}"), &format!("Memory entry number {i} about Rust programming"), MemoryCategory::Core, + None, ) .await .unwrap(); @@ -49,6 +50,7 @@ async fn compare_store_speed() { &format!("key_{i}"), &format!("Memory entry number {i} about Rust programming"), MemoryCategory::Core, + None, ) .await .unwrap(); @@ -127,8 +129,8 @@ async fn compare_recall_quality() { ]; for (key, content, cat) in &entries { - sq.store(key, content, cat.clone()).await.unwrap(); - md.store(key, content, cat.clone()).await.unwrap(); + sq.store(key, content, cat.clone(), None).await.unwrap(); + md.store(key, content, cat.clone(), None).await.unwrap(); } // Test queries and compare results @@ -145,8 +147,8 @@ async fn compare_recall_quality() { println!("RECALL QUALITY (10 entries seeded):\n"); for (query, desc) in &queries { - let sq_results = sq.recall(query, 10).await.unwrap(); - let md_results = md.recall(query, 10).await.unwrap(); + let sq_results = sq.recall(query, 10, None).await.unwrap(); + let md_results = md.recall(query, 10, None).await.unwrap(); println!(" Query: \"{query}\" — {desc}"); println!(" SQLite: {} results", sq_results.len()); @@ -190,21 +192,21 @@ async fn compare_recall_speed() { } else { format!("TypeScript powers modern web apps, entry {i}") }; - sq.store(&format!("e{i}"), &content, MemoryCategory::Core) + sq.store(&format!("e{i}"), &content, MemoryCategory::Core, None) .await .unwrap(); - md.store(&format!("e{i}"), &content, MemoryCategory::Daily) + md.store(&format!("e{i}"), &content, MemoryCategory::Daily, None) .await .unwrap(); } // Benchmark recall let start = Instant::now(); - let sq_results = sq.recall("Rust systems", 10).await.unwrap(); + let sq_results = sq.recall("Rust systems", 10, None).await.unwrap(); let sq_dur = start.elapsed(); let start = Instant::now(); - let md_results = md.recall("Rust systems", 10).await.unwrap(); + let md_results = md.recall("Rust systems", 10, None).await.unwrap(); let md_dur = start.elapsed(); println!("\n============================================================"); @@ -227,15 +229,25 @@ async fn compare_persistence() { // Store in both, then drop and re-open { let sq = sqlite_backend(tmp_sq.path()); - sq.store("persist_test", "I should survive", MemoryCategory::Core) - .await - .unwrap(); + sq.store( + "persist_test", + "I should survive", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); } { let md = markdown_backend(tmp_md.path()); - md.store("persist_test", "I should survive", MemoryCategory::Core) - .await - .unwrap(); + md.store( + "persist_test", + "I should survive", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); } // Re-open @@ -282,17 +294,17 @@ async fn compare_upsert() { let md = markdown_backend(tmp_md.path()); // Store twice with same key, different content - sq.store("pref", "likes Rust", MemoryCategory::Core) + sq.store("pref", "likes Rust", MemoryCategory::Core, None) .await .unwrap(); - sq.store("pref", "loves Rust", MemoryCategory::Core) + sq.store("pref", "loves Rust", MemoryCategory::Core, None) .await .unwrap(); - md.store("pref", "likes Rust", MemoryCategory::Core) + md.store("pref", "likes Rust", MemoryCategory::Core, None) .await .unwrap(); - md.store("pref", "loves Rust", MemoryCategory::Core) + md.store("pref", "loves Rust", MemoryCategory::Core, None) .await .unwrap(); @@ -300,7 +312,7 @@ async fn compare_upsert() { let md_count = md.count().await.unwrap(); let sq_entry = sq.get("pref").await.unwrap(); - let md_results = md.recall("loves Rust", 5).await.unwrap(); + let md_results = md.recall("loves Rust", 5, None).await.unwrap(); println!("\n============================================================"); println!("UPSERT (store same key twice):"); @@ -328,10 +340,10 @@ async fn compare_forget() { let sq = sqlite_backend(tmp_sq.path()); let md = markdown_backend(tmp_md.path()); - sq.store("secret", "API key: sk-1234", MemoryCategory::Core) + sq.store("secret", "API key: sk-1234", MemoryCategory::Core, None) .await .unwrap(); - md.store("secret", "API key: sk-1234", MemoryCategory::Core) + md.store("secret", "API key: sk-1234", MemoryCategory::Core, None) .await .unwrap(); @@ -372,37 +384,40 @@ async fn compare_category_filter() { let md = markdown_backend(tmp_md.path()); // Mix of categories - sq.store("a", "core fact 1", MemoryCategory::Core) + sq.store("a", "core fact 1", MemoryCategory::Core, None) .await .unwrap(); - sq.store("b", "core fact 2", MemoryCategory::Core) + sq.store("b", "core fact 2", MemoryCategory::Core, None) .await .unwrap(); - sq.store("c", "daily note", MemoryCategory::Daily) + sq.store("c", "daily note", MemoryCategory::Daily, None) .await .unwrap(); - sq.store("d", "convo msg", MemoryCategory::Conversation) + sq.store("d", "convo msg", MemoryCategory::Conversation, None) .await .unwrap(); - md.store("a", "core fact 1", MemoryCategory::Core) + md.store("a", "core fact 1", MemoryCategory::Core, None) .await .unwrap(); - md.store("b", "core fact 2", MemoryCategory::Core) + md.store("b", "core fact 2", MemoryCategory::Core, None) .await .unwrap(); - md.store("c", "daily note", MemoryCategory::Daily) + md.store("c", "daily note", MemoryCategory::Daily, None) .await .unwrap(); - let sq_core = sq.list(Some(&MemoryCategory::Core)).await.unwrap(); - let sq_daily = sq.list(Some(&MemoryCategory::Daily)).await.unwrap(); - let sq_conv = sq.list(Some(&MemoryCategory::Conversation)).await.unwrap(); - let sq_all = sq.list(None).await.unwrap(); + let sq_core = sq.list(Some(&MemoryCategory::Core), None).await.unwrap(); + let sq_daily = sq.list(Some(&MemoryCategory::Daily), None).await.unwrap(); + let sq_conv = sq + .list(Some(&MemoryCategory::Conversation), None) + .await + .unwrap(); + let sq_all = sq.list(None, None).await.unwrap(); - let md_core = md.list(Some(&MemoryCategory::Core)).await.unwrap(); - let md_daily = md.list(Some(&MemoryCategory::Daily)).await.unwrap(); - let md_all = md.list(None).await.unwrap(); + let md_core = md.list(Some(&MemoryCategory::Core), None).await.unwrap(); + let md_daily = md.list(Some(&MemoryCategory::Daily), None).await.unwrap(); + let md_all = md.list(None, None).await.unwrap(); println!("\n============================================================"); println!("CATEGORY FILTERING:"); diff --git a/tests/whatsapp_webhook_security.rs b/tests/whatsapp_webhook_security.rs new file mode 100644 index 0000000..3196d1e --- /dev/null +++ b/tests/whatsapp_webhook_security.rs @@ -0,0 +1,133 @@ +//! Integration tests for WhatsApp webhook signature verification. +//! +//! These tests validate that: +//! 1. Webhooks with valid signatures are accepted +//! 2. Webhooks with invalid signatures are rejected +//! 3. Webhooks with missing signatures are rejected +//! 4. Webhooks are rejected even if JSON is valid but signature is bad + +use hmac::{Hmac, Mac}; +use sha2::Sha256; + +/// Compute valid HMAC-SHA256 signature for a webhook payload +fn compute_signature(app_secret: &str, body: &[u8]) -> String { + let mut mac = Hmac::::new_from_slice(app_secret.as_bytes()).unwrap(); + mac.update(body); + let result = mac.finalize(); + format!("sha256={}", hex::encode(result.into_bytes())) +} + +#[test] +fn whatsapp_signature_rejects_missing_sha256_prefix() { + let secret = "test_app_secret"; + let body = b"test payload"; + let bad_sig = "abc123"; // Missing sha256= prefix + + assert!(!zeroclaw::gateway::verify_whatsapp_signature( + secret, body, bad_sig + )); +} + +#[test] +fn whatsapp_signature_rejects_invalid_hex() { + let secret = "test_app_secret"; + let body = b"test payload"; + let bad_sig = "sha256=not-valid-hex!!"; + + assert!(!zeroclaw::gateway::verify_whatsapp_signature( + secret, body, bad_sig + )); +} + +#[test] +fn whatsapp_signature_rejects_wrong_signature() { + let secret = "test_app_secret"; + let body = b"test payload"; + let bad_sig = "sha256=00112233445566778899aabbccddeeff"; + + assert!(!zeroclaw::gateway::verify_whatsapp_signature( + secret, body, bad_sig + )); +} + +#[test] +fn whatsapp_signature_accepts_valid_signature() { + let secret = "test_app_secret"; + let body = b"test payload"; + let valid_sig = compute_signature(secret, body); + + assert!(zeroclaw::gateway::verify_whatsapp_signature( + secret, body, &valid_sig + )); +} + +#[test] +fn whatsapp_signature_rejects_tampered_body() { + let secret = "test_app_secret"; + let original_body = b"original message"; + let tampered_body = b"tampered message"; + + // Compute signature for original body + let sig = compute_signature(secret, original_body); + + // Tampered body should be rejected even with valid-looking signature + assert!(!zeroclaw::gateway::verify_whatsapp_signature( + secret, + tampered_body, + &sig + )); +} + +#[test] +fn whatsapp_signature_rejects_wrong_secret() { + let correct_secret = "correct_secret"; + let wrong_secret = "wrong_secret"; + let body = b"test payload"; + + // Compute signature with correct secret + let sig = compute_signature(correct_secret, body); + + // Wrong secret should reject the signature + assert!(!zeroclaw::gateway::verify_whatsapp_signature( + wrong_secret, + body, + &sig + )); +} + +#[test] +fn whatsapp_signature_rejects_empty_signature() { + let secret = "test_app_secret"; + let body = b"test payload"; + + assert!(!zeroclaw::gateway::verify_whatsapp_signature( + secret, body, "" + )); +} + +#[test] +fn whatsapp_signature_different_secrets_produce_different_sigs() { + let secret1 = "secret_one"; + let secret2 = "secret_two"; + let body = b"same payload"; + + let sig1 = compute_signature(secret1, body); + let sig2 = compute_signature(secret2, body); + + // Different secrets should produce different signatures + assert_ne!(sig1, sig2); + + // Each signature should only verify with its own secret + assert!(zeroclaw::gateway::verify_whatsapp_signature( + secret1, body, &sig1 + )); + assert!(!zeroclaw::gateway::verify_whatsapp_signature( + secret2, body, &sig1 + )); + assert!(zeroclaw::gateway::verify_whatsapp_signature( + secret2, body, &sig2 + )); + assert!(!zeroclaw::gateway::verify_whatsapp_signature( + secret1, body, &sig2 + )); +} diff --git a/zero-claw.jpeg b/zero-claw.jpeg new file mode 100644 index 0000000..b76a094 Binary files /dev/null and b/zero-claw.jpeg differ