feat: initial release — ZeroClaw v0.1.0
- 22 AI providers (OpenRouter, Anthropic, OpenAI, Mistral, etc.) - 7 channels (CLI, Telegram, Discord, Slack, iMessage, Matrix, Webhook) - 5-step onboarding wizard with Project Context personalization - OpenClaw-aligned system prompt (SOUL.md, IDENTITY.md, USER.md, AGENTS.md, etc.) - SQLite memory backend with auto-save - Skills system with on-demand loading - Security: autonomy levels, command allowlists, cost limits - 532 tests passing, 0 clippy warnings
This commit is contained in:
commit
05cb353f7f
71 changed files with 15757 additions and 0 deletions
65
.github/workflows/ci.yml
vendored
Normal file
65
.github/workflows/ci.yml
vendored
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main, develop]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Test
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
components: rustfmt, clippy
|
||||
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
|
||||
- name: Check formatting
|
||||
run: cargo fmt -- --check
|
||||
|
||||
- name: Run clippy
|
||||
run: cargo clippy -- -D warnings
|
||||
|
||||
- name: Run tests
|
||||
run: cargo test --verbose
|
||||
|
||||
build:
|
||||
name: Build
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- os: ubuntu-latest
|
||||
target: x86_64-unknown-linux-gnu
|
||||
- os: macos-latest
|
||||
target: x86_64-apple-darwin
|
||||
- os: macos-latest
|
||||
target: aarch64-apple-darwin
|
||||
- os: windows-latest
|
||||
target: x86_64-pc-windows-msvc
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
targets: ${{ matrix.target }}
|
||||
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
|
||||
- name: Build release
|
||||
run: cargo build --release --target ${{ matrix.target }}
|
||||
|
||||
- name: Upload artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: zeroclaw-${{ matrix.target }}
|
||||
path: target/${{ matrix.target }}/release/zeroclaw*
|
||||
90
.github/workflows/release.yml
vendored
Normal file
90
.github/workflows/release.yml
vendored
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
name: Release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags: ["v*"]
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
build-release:
|
||||
name: Build ${{ matrix.target }}
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- os: ubuntu-latest
|
||||
target: x86_64-unknown-linux-gnu
|
||||
artifact: zeroclaw
|
||||
- os: macos-latest
|
||||
target: x86_64-apple-darwin
|
||||
artifact: zeroclaw
|
||||
- os: macos-latest
|
||||
target: aarch64-apple-darwin
|
||||
artifact: zeroclaw
|
||||
- os: windows-latest
|
||||
target: x86_64-pc-windows-msvc
|
||||
artifact: zeroclaw.exe
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
targets: ${{ matrix.target }}
|
||||
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
|
||||
- name: Build release
|
||||
run: cargo build --release --target ${{ matrix.target }}
|
||||
|
||||
- name: Check binary size (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
run: |
|
||||
SIZE=$(stat -f%z target/${{ matrix.target }}/release/${{ matrix.artifact }} 2>/dev/null || stat -c%s target/${{ matrix.target }}/release/${{ matrix.artifact }})
|
||||
echo "Binary size: $((SIZE / 1024 / 1024))MB ($SIZE bytes)"
|
||||
if [ "$SIZE" -gt 5242880 ]; then
|
||||
echo "::warning::Binary exceeds 5MB target"
|
||||
fi
|
||||
|
||||
- name: Package (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
run: |
|
||||
cd target/${{ matrix.target }}/release
|
||||
tar czf ../../../zeroclaw-${{ matrix.target }}.tar.gz ${{ matrix.artifact }}
|
||||
|
||||
- name: Package (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
run: |
|
||||
cd target/${{ matrix.target }}/release
|
||||
7z a ../../../zeroclaw-${{ matrix.target }}.zip ${{ matrix.artifact }}
|
||||
|
||||
- name: Upload artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: zeroclaw-${{ matrix.target }}
|
||||
path: zeroclaw-${{ matrix.target }}.*
|
||||
|
||||
publish:
|
||||
name: Publish Release
|
||||
needs: build-release
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Download all artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: artifacts
|
||||
|
||||
- name: Create GitHub Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
generate_release_notes: true
|
||||
files: artifacts/**/*
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
37
.github/workflows/security.yml
vendored
Normal file
37
.github/workflows/security.yml
vendored
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
name: Security Audit
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
schedule:
|
||||
- cron: "0 6 * * 1" # Weekly on Monday 6am UTC
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
audit:
|
||||
name: Security Audit
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
|
||||
- name: Install cargo-audit
|
||||
run: cargo install cargo-audit
|
||||
|
||||
- name: Run cargo-audit
|
||||
run: cargo audit
|
||||
|
||||
deny:
|
||||
name: License & Supply Chain
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: EmbarkStudios/cargo-deny-action@v2
|
||||
with:
|
||||
command: check advisories licenses sources
|
||||
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
/target
|
||||
*.db
|
||||
*.db-journal
|
||||
33
CHANGELOG.md
Normal file
33
CHANGELOG.md
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
# Changelog
|
||||
|
||||
All notable changes to ZeroClaw will be documented in this file.
|
||||
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [0.1.0] - 2025-02-13
|
||||
|
||||
### Added
|
||||
- **Core Architecture**: Trait-based pluggable system for Provider, Channel, Observer, RuntimeAdapter, Tool
|
||||
- **Provider**: OpenRouter implementation (access Claude, GPT-4, Llama, Gemini via single API)
|
||||
- **Channels**: CLI channel with interactive and single-message modes
|
||||
- **Observability**: NoopObserver (zero overhead), LogObserver (tracing), MultiObserver (fan-out)
|
||||
- **Security**: Workspace sandboxing, command allowlisting, path traversal blocking, autonomy levels (ReadOnly/Supervised/Full), rate limiting
|
||||
- **Tools**: Shell (sandboxed), FileRead (path-checked), FileWrite (path-checked)
|
||||
- **Memory (Brain)**: SQLite persistent backend (searchable, survives restarts), Markdown backend (plain files, human-readable)
|
||||
- **Heartbeat Engine**: Periodic task execution from HEARTBEAT.md
|
||||
- **Runtime**: Native adapter for Mac/Linux/Raspberry Pi
|
||||
- **Config**: TOML-based configuration with sensible defaults
|
||||
- **Onboarding**: Interactive CLI wizard with workspace scaffolding
|
||||
- **CLI Commands**: agent, gateway, status, cron, channel, tools, onboard
|
||||
- **CI/CD**: GitHub Actions with cross-platform builds (Linux, macOS Intel/ARM, Windows)
|
||||
- **Tests**: 159 inline tests covering all modules and edge cases
|
||||
- **Binary**: 3.1MB optimized release build (includes bundled SQLite)
|
||||
|
||||
### Security
|
||||
- Path traversal attack prevention
|
||||
- Command injection blocking
|
||||
- Workspace escape prevention
|
||||
- Forbidden system path protection (`/etc`, `/root`, `~/.ssh`)
|
||||
|
||||
[0.1.0]: https://github.com/theonlyhennygod/zeroclaw/releases/tag/v0.1.0
|
||||
209
CONTRIBUTING.md
Normal file
209
CONTRIBUTING.md
Normal file
|
|
@ -0,0 +1,209 @@
|
|||
# Contributing to ZeroClaw
|
||||
|
||||
Thanks for your interest in contributing to ZeroClaw! This guide will help you get started.
|
||||
|
||||
## Development Setup
|
||||
|
||||
```bash
|
||||
# Clone the repo
|
||||
git clone https://github.com/theonlyhennygod/zeroclaw.git
|
||||
cd zeroclaw
|
||||
|
||||
# Build
|
||||
cargo build
|
||||
|
||||
# Run tests (180 tests, all must pass)
|
||||
cargo test
|
||||
|
||||
# Format & lint (must pass before PR)
|
||||
cargo fmt && cargo clippy -- -D warnings
|
||||
|
||||
# Release build (~3.1MB)
|
||||
cargo build --release
|
||||
```
|
||||
|
||||
## Architecture: Trait-Based Pluggability
|
||||
|
||||
ZeroClaw's architecture is built on **traits** — every subsystem is swappable. This means contributing a new integration is as simple as implementing a trait and registering it in the factory function.
|
||||
|
||||
```
|
||||
src/
|
||||
├── providers/ # LLM backends → Provider trait
|
||||
├── channels/ # Messaging → Channel trait
|
||||
├── observability/ # Metrics/logging → Observer trait
|
||||
├── runtime/ # Platform adapters → RuntimeAdapter trait
|
||||
├── tools/ # Agent tools → Tool trait
|
||||
├── memory/ # Persistence/brain → Memory trait
|
||||
└── security/ # Sandboxing → SecurityPolicy
|
||||
```
|
||||
|
||||
## How to Add a New Provider
|
||||
|
||||
Create `src/providers/your_provider.rs`:
|
||||
|
||||
```rust
|
||||
use async_trait::async_trait;
|
||||
use anyhow::Result;
|
||||
use crate::providers::traits::Provider;
|
||||
|
||||
pub struct YourProvider {
|
||||
api_key: String,
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl YourProvider {
|
||||
pub fn new(api_key: Option<&str>) -> Self {
|
||||
Self {
|
||||
api_key: api_key.unwrap_or_default().to_string(),
|
||||
client: reqwest::Client::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for YourProvider {
|
||||
async fn chat(&self, message: &str, model: &str, temperature: f64) -> Result<String> {
|
||||
// Your API call here
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Then register it in `src/providers/mod.rs`:
|
||||
|
||||
```rust
|
||||
"your_provider" => Ok(Box::new(your_provider::YourProvider::new(api_key))),
|
||||
```
|
||||
|
||||
## How to Add a New Channel
|
||||
|
||||
Create `src/channels/your_channel.rs`:
|
||||
|
||||
```rust
|
||||
use async_trait::async_trait;
|
||||
use anyhow::Result;
|
||||
use tokio::sync::mpsc;
|
||||
use crate::channels::traits::{Channel, ChannelMessage};
|
||||
|
||||
pub struct YourChannel { /* config fields */ }
|
||||
|
||||
#[async_trait]
|
||||
impl Channel for YourChannel {
|
||||
fn name(&self) -> &str { "your_channel" }
|
||||
|
||||
async fn send(&self, message: &str, recipient: &str) -> Result<()> {
|
||||
// Send message via your platform
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn listen(&self, tx: mpsc::Sender<ChannelMessage>) -> Result<()> {
|
||||
// Listen for incoming messages, forward to tx
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool { true }
|
||||
}
|
||||
```
|
||||
|
||||
## How to Add a New Observer
|
||||
|
||||
Create `src/observability/your_observer.rs`:
|
||||
|
||||
```rust
|
||||
use crate::observability::traits::{Observer, ObserverEvent, ObserverMetric};
|
||||
|
||||
pub struct YourObserver { /* client, config, etc. */ }
|
||||
|
||||
impl Observer for YourObserver {
|
||||
fn record_event(&self, event: &ObserverEvent) {
|
||||
// Push event to your backend
|
||||
}
|
||||
|
||||
fn record_metric(&self, metric: &ObserverMetric) {
|
||||
// Push metric to your backend
|
||||
}
|
||||
|
||||
fn name(&self) -> &str { "your_observer" }
|
||||
}
|
||||
```
|
||||
|
||||
## How to Add a New Tool
|
||||
|
||||
Create `src/tools/your_tool.rs`:
|
||||
|
||||
```rust
|
||||
use async_trait::async_trait;
|
||||
use anyhow::Result;
|
||||
use serde_json::{json, Value};
|
||||
use crate::tools::traits::{Tool, ToolResult};
|
||||
|
||||
pub struct YourTool { /* security policy, config, etc. */ }
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for YourTool {
|
||||
fn name(&self) -> &str { "your_tool" }
|
||||
|
||||
fn description(&self) -> &str { "Does something useful" }
|
||||
|
||||
fn parameters_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input": { "type": "string", "description": "The input" }
|
||||
},
|
||||
"required": ["input"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: Value) -> Result<ToolResult> {
|
||||
let input = args["input"].as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'input'"))?;
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Processed: {input}"),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Pull Request Checklist
|
||||
|
||||
- [ ] `cargo fmt` — code is formatted
|
||||
- [ ] `cargo clippy -- -D warnings` — no warnings
|
||||
- [ ] `cargo test` — all 129+ tests pass
|
||||
- [ ] New code has inline `#[cfg(test)]` tests
|
||||
- [ ] No new dependencies unless absolutely necessary (we optimize for binary size)
|
||||
- [ ] README updated if adding user-facing features
|
||||
- [ ] Follows existing code patterns and conventions
|
||||
|
||||
## Commit Convention
|
||||
|
||||
We use [Conventional Commits](https://www.conventionalcommits.org/):
|
||||
|
||||
```
|
||||
feat: add Anthropic provider
|
||||
fix: path traversal edge case with symlinks
|
||||
docs: update contributing guide
|
||||
test: add heartbeat unicode parsing tests
|
||||
refactor: extract common security checks
|
||||
chore: bump tokio to 1.43
|
||||
```
|
||||
|
||||
## Code Style
|
||||
|
||||
- **Minimal dependencies** — every crate adds to binary size
|
||||
- **Inline tests** — `#[cfg(test)] mod tests {}` at the bottom of each file
|
||||
- **Trait-first** — define the trait, then implement
|
||||
- **Security by default** — sandbox everything, allowlist, never blocklist
|
||||
- **No unwrap in production code** — use `?`, `anyhow`, or `thiserror`
|
||||
|
||||
## Reporting Issues
|
||||
|
||||
- **Bugs**: Include OS, Rust version, steps to reproduce, expected vs actual
|
||||
- **Features**: Describe the use case, propose which trait to extend
|
||||
- **Security**: See [SECURITY.md](SECURITY.md) for responsible disclosure
|
||||
|
||||
## License
|
||||
|
||||
By contributing, you agree that your contributions will be licensed under the MIT License.
|
||||
2392
Cargo.lock
generated
Normal file
2392
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
76
Cargo.toml
Normal file
76
Cargo.toml
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
[package]
|
||||
name = "zeroclaw"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
authors = ["theonlyhennygod"]
|
||||
license = "MIT"
|
||||
description = "Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant."
|
||||
repository = "https://github.com/theonlyhennygod/zeroclaw"
|
||||
readme = "README.md"
|
||||
keywords = ["ai", "agent", "cli", "assistant", "chatbot"]
|
||||
categories = ["command-line-utilities", "api-bindings"]
|
||||
|
||||
[dependencies]
|
||||
# CLI - minimal and fast
|
||||
clap = { version = "4.5", features = ["derive"] }
|
||||
|
||||
# Async runtime - feature-optimized for size
|
||||
tokio = { version = "1.42", default-features = false, features = ["rt-multi-thread", "macros", "time", "net", "io-util", "sync", "process", "io-std", "fs"] }
|
||||
|
||||
# HTTP client - minimal features
|
||||
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls", "blocking"] }
|
||||
|
||||
# Serialization
|
||||
serde = { version = "1.0", default-features = false, features = ["derive"] }
|
||||
serde_json = { version = "1.0", default-features = false, features = ["std"] }
|
||||
|
||||
# Config
|
||||
directories = "5.0"
|
||||
toml = "0.8"
|
||||
shellexpand = "3.1"
|
||||
|
||||
# Logging - minimal
|
||||
tracing = { version = "0.1", default-features = false }
|
||||
tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "ansi"] }
|
||||
|
||||
# Error handling
|
||||
anyhow = "1.0"
|
||||
thiserror = "2.0"
|
||||
|
||||
# UUID generation
|
||||
uuid = { version = "1.11", default-features = false, features = ["v4", "std"] }
|
||||
|
||||
# Async traits
|
||||
async-trait = "0.1"
|
||||
|
||||
# Memory / persistence
|
||||
rusqlite = { version = "0.32", features = ["bundled"] }
|
||||
chrono = { version = "0.4", default-features = false, features = ["clock", "std"] }
|
||||
|
||||
# Interactive CLI prompts
|
||||
dialoguer = { version = "0.11", features = ["fuzzy-select"] }
|
||||
console = "0.15"
|
||||
|
||||
# Discord WebSocket gateway
|
||||
tokio-tungstenite = { version = "0.24", features = ["rustls-tls-webpki-roots"] }
|
||||
futures-util = { version = "0.3", default-features = false, features = ["sink"] }
|
||||
hostname = "0.4.2"
|
||||
|
||||
[profile.release]
|
||||
opt-level = "z" # Optimize for size
|
||||
lto = true # Link-time optimization
|
||||
codegen-units = 1 # Better optimization
|
||||
strip = true # Remove debug symbols
|
||||
panic = "abort" # Reduce binary size
|
||||
|
||||
[profile.dist]
|
||||
inherits = "release"
|
||||
opt-level = "z"
|
||||
lto = "fat"
|
||||
codegen-units = 1
|
||||
strip = true
|
||||
panic = "abort"
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-test = "0.4"
|
||||
tempfile = "3.14"
|
||||
21
Dockerfile
Normal file
21
Dockerfile
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
# ── Stage 1: Build ────────────────────────────────────────────
|
||||
FROM rust:1.83-slim AS builder
|
||||
|
||||
WORKDIR /app
|
||||
COPY Cargo.toml Cargo.lock ./
|
||||
COPY src/ src/
|
||||
|
||||
RUN cargo build --release --locked && \
|
||||
strip target/release/zeroclaw
|
||||
|
||||
# ── Stage 2: Runtime (distroless — no shell, no OS, tiny) ────
|
||||
FROM gcr.io/distroless/cc-debian12
|
||||
|
||||
COPY --from=builder /app/target/release/zeroclaw /usr/local/bin/zeroclaw
|
||||
|
||||
# Default workspace
|
||||
VOLUME ["/workspace"]
|
||||
ENV ZEROCLAW_WORKSPACE=/workspace
|
||||
|
||||
ENTRYPOINT ["zeroclaw"]
|
||||
CMD ["gateway"]
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2025 theonlyhennygod
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
278
README.md
Normal file
278
README.md
Normal file
|
|
@ -0,0 +1,278 @@
|
|||
<p align="center">
|
||||
<img src="zeroclaw.png" alt="ZeroClaw" width="200" />
|
||||
</p>
|
||||
|
||||
<h1 align="center">ZeroClaw 🦀</h1>
|
||||
|
||||
<p align="center">
|
||||
<strong>Zero overhead. Zero compromise. 100% Rust. 100% Agnostic.</strong>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="LICENSE"><img src="https://img.shields.io/badge/license-MIT-blue.svg" alt="License: MIT" /></a>
|
||||
</p>
|
||||
|
||||
The fastest, smallest, fully autonomous AI assistant — deploy anywhere, swap anything.
|
||||
|
||||
```
|
||||
~3MB binary · <10ms startup · 502 tests · 22 providers · Pluggable everything
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
git clone https://github.com/theonlyhennygod/zeroclaw.git
|
||||
cd zeroclaw
|
||||
cargo build --release
|
||||
|
||||
# Initialize config + workspace
|
||||
cargo run --release -- onboard
|
||||
|
||||
# Set your API key
|
||||
export OPENROUTER_API_KEY="sk-..."
|
||||
|
||||
# Chat
|
||||
cargo run --release -- agent -m "Hello, ZeroClaw!"
|
||||
|
||||
# Interactive mode
|
||||
cargo run --release -- agent
|
||||
|
||||
# Check status
|
||||
cargo run --release -- status --verbose
|
||||
|
||||
# List tools (includes memory tools)
|
||||
cargo run --release -- tools list
|
||||
|
||||
# Test a tool directly
|
||||
cargo run --release -- tools test memory_store '{"key": "lang", "content": "User prefers Rust"}'
|
||||
cargo run --release -- tools test memory_recall '{"query": "Rust"}'
|
||||
```
|
||||
|
||||
> **Tip:** Run `cargo install --path .` to install `zeroclaw` globally, then use `zeroclaw` instead of `cargo run --release --`.
|
||||
|
||||
## Architecture
|
||||
|
||||
Every subsystem is a **trait** — swap implementations with a config change, zero code changes.
|
||||
|
||||
| Subsystem | Trait | Ships with | Extend |
|
||||
|-----------|-------|------------|--------|
|
||||
| **AI Models** | `Provider` | 22 providers (OpenRouter, Anthropic, OpenAI, Venice, Groq, Mistral, etc.) | Any OpenAI-compatible API |
|
||||
| **Channels** | `Channel` | CLI, Telegram, Discord, Slack, iMessage, Matrix, Webhook | Any messaging API |
|
||||
| **Memory** | `Memory` | SQLite (default), Markdown | Any persistence |
|
||||
| **Tools** | `Tool` | shell, file_read, file_write, memory_store, memory_recall, memory_forget | Any capability |
|
||||
| **Observability** | `Observer` | Noop, Log, Multi | Prometheus, OTel |
|
||||
| **Runtime** | `RuntimeAdapter` | Native (Mac/Linux/Pi) | Docker, WASM |
|
||||
| **Security** | `SecurityPolicy` | Sandbox + allowlists + rate limits | — |
|
||||
| **Heartbeat** | Engine | HEARTBEAT.md periodic tasks | — |
|
||||
|
||||
### Memory System
|
||||
|
||||
ZeroClaw has a built-in brain. The agent automatically:
|
||||
1. **Recalls** relevant memories before each prompt (context injection)
|
||||
2. **Saves** conversation turns to memory (auto-save)
|
||||
3. **Manages** its own memory via tools (store/recall/forget)
|
||||
|
||||
Two backends — **SQLite** (default, searchable, upsert, delete) and **Markdown** (human-readable, append-only, git-friendly). Switch with one config line.
|
||||
|
||||
### Security
|
||||
|
||||
- **Workspace sandboxing** — can't escape workspace directory
|
||||
- **Command allowlisting** — only approved shell commands
|
||||
- **Path traversal blocking** — `..` and absolute paths blocked
|
||||
- **Rate limiting** — max actions/hour, max cost/day
|
||||
- **Autonomy levels** — ReadOnly, Supervised, Full
|
||||
|
||||
## Configuration
|
||||
|
||||
Config: `~/.zeroclaw/config.toml` (created by `onboard`)
|
||||
|
||||
## Documentation Index
|
||||
|
||||
Fetch the complete documentation index at: https://docs.openclaw.ai/llms.txt
|
||||
Use this file to discover all available pages before exploring further.
|
||||
|
||||
## Token Use & Costs
|
||||
|
||||
ZeroClaw tracks **tokens**, not characters. Tokens are model-specific, but most
|
||||
OpenAI-style models average ~4 characters per token for English text.
|
||||
|
||||
### How the system prompt is built
|
||||
|
||||
ZeroClaw assembles its own system prompt on every run. It includes:
|
||||
|
||||
* Tool list + short descriptions
|
||||
* Skills list (only metadata; instructions are loaded on demand with `read`)
|
||||
* Self-update instructions
|
||||
* Workspace + bootstrap files (`AGENTS.md`, `SOUL.md`, `TOOLS.md`, `IDENTITY.md`, `USER.md`, `HEARTBEAT.md`, `BOOTSTRAP.md` when new, plus `MEMORY.md` and/or `memory.md` when present). Large files are truncated by `agents.defaults.bootstrapMaxChars` (default: 20000). `memory/*.md` files are on-demand via memory tools and are not auto-injected.
|
||||
* Time (UTC + user timezone)
|
||||
* Reply tags + heartbeat behavior
|
||||
* Runtime metadata (host/OS/model/thinking)
|
||||
|
||||
### What counts in the context window
|
||||
|
||||
Everything the model receives counts toward the context limit:
|
||||
|
||||
* System prompt (all sections listed above)
|
||||
* Conversation history (user + assistant messages)
|
||||
* Tool calls and tool results
|
||||
* Attachments/transcripts (images, audio, files)
|
||||
* Compaction summaries and pruning artifacts
|
||||
* Provider wrappers or safety headers (not visible, but still counted)
|
||||
|
||||
### How to see current token usage
|
||||
|
||||
Use these in chat:
|
||||
|
||||
* `/status` → **emoji-rich status card** with the session model, context usage,
|
||||
last response input/output tokens, and **estimated cost** (API key only).
|
||||
* `/usage off|tokens|full` → appends a **per-response usage footer** to every reply.
|
||||
* Persists per session (stored as `responseUsage`).
|
||||
* OAuth auth **hides cost** (tokens only).
|
||||
* `/usage cost` → shows a local cost summary from ZeroClaw session logs.
|
||||
|
||||
Other surfaces:
|
||||
|
||||
* **TUI/Web TUI:** `/status` + `/usage` are supported.
|
||||
* **CLI:** `zeroclaw status --usage` and `zeroclaw channels list` show
|
||||
provider quota windows (not per-response costs).
|
||||
|
||||
### Cost estimation (when shown)
|
||||
|
||||
Costs are estimated from your model pricing config:
|
||||
|
||||
```
|
||||
models.providers.<provider>.models[].cost
|
||||
```
|
||||
|
||||
These are **USD per 1M tokens** for `input`, `output`, `cacheRead`, and
|
||||
`cacheWrite`. If pricing is missing, ZeroClaw shows tokens only. OAuth tokens
|
||||
never show dollar cost.
|
||||
|
||||
### Cache TTL and pruning impact
|
||||
|
||||
Provider prompt caching only applies within the cache TTL window. ZeroClaw can
|
||||
optionally run **cache-ttl pruning**: it prunes the session once the cache TTL
|
||||
has expired, then resets the cache window so subsequent requests can re-use the
|
||||
freshly cached context instead of re-caching the full history. This keeps cache
|
||||
write costs lower when a session goes idle past the TTL.
|
||||
|
||||
Configure it in Gateway configuration and see the behavior details in
|
||||
[Session pruning](/concepts/session-pruning).
|
||||
|
||||
Heartbeat can keep the cache **warm** across idle gaps. If your model cache TTL
|
||||
is `1h`, setting the heartbeat interval just under that (e.g., `55m`) can avoid
|
||||
re-caching the full prompt, reducing cache write costs.
|
||||
|
||||
For Anthropic API pricing, cache reads are significantly cheaper than input
|
||||
tokens, while cache writes are billed at a higher multiplier. See Anthropic's
|
||||
prompt caching pricing for the latest rates and TTL multipliers:
|
||||
[https://docs.anthropic.com/docs/build-with-claude/prompt-caching](https://docs.anthropic.com/docs/build-with-claude/prompt-caching)
|
||||
|
||||
#### Example: keep 1h cache warm with heartbeat
|
||||
|
||||
```yaml
|
||||
agents:
|
||||
defaults:
|
||||
model:
|
||||
primary: "anthropic/claude-opus-4-6"
|
||||
models:
|
||||
"anthropic/claude-opus-4-6":
|
||||
params:
|
||||
cacheRetention: "long"
|
||||
heartbeat:
|
||||
every: "55m"
|
||||
```
|
||||
|
||||
### Tips for reducing token pressure
|
||||
|
||||
* Use `/compact` to summarize long sessions.
|
||||
* Trim large tool outputs in your workflows.
|
||||
* Keep skill descriptions short (skill list is injected into the prompt).
|
||||
* Prefer smaller models for verbose, exploratory work.
|
||||
|
||||
```toml
|
||||
api_key = "sk-..."
|
||||
default_provider = "openrouter"
|
||||
default_model = "anthropic/claude-sonnet-4-20250514"
|
||||
default_temperature = 0.7
|
||||
|
||||
[memory]
|
||||
backend = "sqlite" # "sqlite", "markdown", "none"
|
||||
auto_save = true
|
||||
|
||||
[autonomy]
|
||||
level = "supervised" # "readonly", "supervised", "full"
|
||||
workspace_only = true
|
||||
allowed_commands = ["git", "npm", "cargo", "ls", "cat", "grep"]
|
||||
|
||||
[heartbeat]
|
||||
enabled = false
|
||||
interval_minutes = 30
|
||||
```
|
||||
|
||||
## Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `onboard` | Initialize workspace and config |
|
||||
| `agent -m "..."` | Single message mode |
|
||||
| `agent` | Interactive chat mode |
|
||||
| `status -v` | Show full system status |
|
||||
| `tools list` | List all 6 tools |
|
||||
| `tools test <name> <json>` | Test a tool directly |
|
||||
| `gateway` | Start webhook/WebSocket server |
|
||||
|
||||
## Development
|
||||
|
||||
```bash
|
||||
cargo build # Dev build
|
||||
cargo build --release # Release build (~3MB)
|
||||
cargo test # 502 tests
|
||||
cargo clippy # Lint (0 warnings)
|
||||
|
||||
# Run the SQLite vs Markdown benchmark
|
||||
cargo test --test memory_comparison -- --nocapture
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
src/
|
||||
├── main.rs # CLI (clap)
|
||||
├── lib.rs # Library exports
|
||||
├── agent/ # Agent loop + context injection
|
||||
├── channels/ # Channel trait + CLI
|
||||
├── config/ # TOML config schema
|
||||
├── cron/ # Scheduled tasks
|
||||
├── heartbeat/ # HEARTBEAT.md engine
|
||||
├── memory/ # Memory trait + SQLite + Markdown
|
||||
├── observability/ # Observer trait + Noop/Log/Multi
|
||||
├── providers/ # Provider trait + 22 providers
|
||||
├── runtime/ # RuntimeAdapter trait + Native
|
||||
├── security/ # Sandbox + allowlists + autonomy
|
||||
└── tools/ # Tool trait + shell/file/memory tools
|
||||
examples/
|
||||
├── custom_provider.rs
|
||||
├── custom_channel.rs
|
||||
├── custom_tool.rs
|
||||
└── custom_memory.rs
|
||||
tests/
|
||||
└── memory_comparison.rs # SQLite vs Markdown benchmark
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
MIT — see [LICENSE](LICENSE)
|
||||
|
||||
## Contributing
|
||||
|
||||
See [CONTRIBUTING.md](CONTRIBUTING.md). Implement a trait, submit a PR:
|
||||
- New `Provider` → `src/providers/`
|
||||
- New `Channel` → `src/channels/`
|
||||
- New `Observer` → `src/observability/`
|
||||
- New `Tool` → `src/tools/`
|
||||
- New `Memory` → `src/memory/`
|
||||
|
||||
---
|
||||
|
||||
**ZeroClaw** — Zero overhead. Zero compromise. Deploy anywhere. Swap anything. 🦀
|
||||
63
SECURITY.md
Normal file
63
SECURITY.md
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
# Security Policy
|
||||
|
||||
## Supported Versions
|
||||
|
||||
| Version | Supported |
|
||||
| ------- | ------------------ |
|
||||
| 0.1.x | :white_check_mark: |
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
**Please do NOT open a public GitHub issue for security vulnerabilities.**
|
||||
|
||||
Instead, please report them responsibly:
|
||||
|
||||
1. **Email**: Send details to the maintainers via GitHub private vulnerability reporting
|
||||
2. **GitHub**: Use [GitHub Security Advisories](https://github.com/theonlyhennygod/zeroclaw/security/advisories/new)
|
||||
|
||||
### What to Include
|
||||
|
||||
- Description of the vulnerability
|
||||
- Steps to reproduce
|
||||
- Impact assessment
|
||||
- Suggested fix (if any)
|
||||
|
||||
### Response Timeline
|
||||
|
||||
- **Acknowledgment**: Within 48 hours
|
||||
- **Assessment**: Within 1 week
|
||||
- **Fix**: Within 2 weeks for critical issues
|
||||
|
||||
## Security Architecture
|
||||
|
||||
ZeroClaw implements defense-in-depth security:
|
||||
|
||||
### Autonomy Levels
|
||||
- **ReadOnly** — Agent can only read, no shell or write access
|
||||
- **Supervised** — Agent can act within allowlists (default)
|
||||
- **Full** — Agent has full access within workspace sandbox
|
||||
|
||||
### Sandboxing Layers
|
||||
1. **Workspace isolation** — All file operations confined to workspace directory
|
||||
2. **Path traversal blocking** — `..` sequences and absolute paths rejected
|
||||
3. **Command allowlisting** — Only explicitly approved commands can execute
|
||||
4. **Forbidden path list** — Critical system paths (`/etc`, `/root`, `~/.ssh`) always blocked
|
||||
5. **Rate limiting** — Max actions per hour and cost per day caps
|
||||
|
||||
### What We Protect Against
|
||||
- Path traversal attacks (`../../../etc/passwd`)
|
||||
- Command injection (`rm -rf /`, `curl | sh`)
|
||||
- Workspace escape via symlinks or absolute paths
|
||||
- Runaway cost from LLM API calls
|
||||
- Unauthorized shell command execution
|
||||
|
||||
## Security Testing
|
||||
|
||||
All security mechanisms are covered by automated tests (129 tests):
|
||||
|
||||
```bash
|
||||
cargo test -- security
|
||||
cargo test -- tools::shell
|
||||
cargo test -- tools::file_read
|
||||
cargo test -- tools::file_write
|
||||
```
|
||||
34
deny.toml
Normal file
34
deny.toml
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
# cargo-deny configuration
|
||||
# https://embarkstudios.github.io/cargo-deny/
|
||||
|
||||
[advisories]
|
||||
vulnerability = "deny"
|
||||
unmaintained = "warn"
|
||||
yanked = "warn"
|
||||
notice = "warn"
|
||||
|
||||
[licenses]
|
||||
unlicensed = "deny"
|
||||
allow = [
|
||||
"MIT",
|
||||
"Apache-2.0",
|
||||
"BSD-2-Clause",
|
||||
"BSD-3-Clause",
|
||||
"ISC",
|
||||
"Unicode-3.0",
|
||||
"Unicode-DFS-2016",
|
||||
"OpenSSL",
|
||||
"Zlib",
|
||||
"MPL-2.0",
|
||||
]
|
||||
copyleft = "deny"
|
||||
|
||||
[bans]
|
||||
multiple-versions = "warn"
|
||||
wildcards = "allow"
|
||||
|
||||
[sources]
|
||||
unknown-registry = "deny"
|
||||
unknown-git = "deny"
|
||||
allow-registry = ["https://github.com/rust-lang/crates.io-index"]
|
||||
allow-git = []
|
||||
124
examples/custom_channel.rs
Normal file
124
examples/custom_channel.rs
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
//! Example: Implementing a custom Channel for ZeroClaw
|
||||
//!
|
||||
//! Channels let ZeroClaw communicate through any messaging platform.
|
||||
//! Implement the Channel trait, register it, and the agent works everywhere.
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// Mirrors src/channels/traits.rs
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ChannelMessage {
|
||||
pub id: String,
|
||||
pub sender: String,
|
||||
pub content: String,
|
||||
pub channel: String,
|
||||
pub timestamp: u64,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Channel: Send + Sync {
|
||||
fn name(&self) -> &str;
|
||||
async fn send(&self, message: &str, recipient: &str) -> Result<()>;
|
||||
async fn listen(&self, tx: mpsc::Sender<ChannelMessage>) -> Result<()>;
|
||||
async fn health_check(&self) -> bool;
|
||||
}
|
||||
|
||||
/// Example: Telegram channel via Bot API
|
||||
pub struct TelegramChannel {
|
||||
bot_token: String,
|
||||
allowed_users: Vec<String>,
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl TelegramChannel {
|
||||
pub fn new(bot_token: &str, allowed_users: Vec<String>) -> Self {
|
||||
Self {
|
||||
bot_token: bot_token.to_string(),
|
||||
allowed_users,
|
||||
client: reqwest::Client::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn api_url(&self, method: &str) -> String {
|
||||
format!("https://api.telegram.org/bot{}/{method}", self.bot_token)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Channel for TelegramChannel {
|
||||
fn name(&self) -> &str {
|
||||
"telegram"
|
||||
}
|
||||
|
||||
async fn send(&self, message: &str, chat_id: &str) -> Result<()> {
|
||||
self.client
|
||||
.post(&self.api_url("sendMessage"))
|
||||
.json(&serde_json::json!({
|
||||
"chat_id": chat_id,
|
||||
"text": message,
|
||||
"parse_mode": "Markdown",
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn listen(&self, tx: mpsc::Sender<ChannelMessage>) -> Result<()> {
|
||||
let mut offset: i64 = 0;
|
||||
|
||||
loop {
|
||||
let resp = self
|
||||
.client
|
||||
.get(&self.api_url("getUpdates"))
|
||||
.query(&[("offset", offset.to_string()), ("timeout", "30".into())])
|
||||
.send()
|
||||
.await?
|
||||
.json::<serde_json::Value>()
|
||||
.await?;
|
||||
|
||||
if let Some(updates) = resp["result"].as_array() {
|
||||
for update in updates {
|
||||
if let Some(msg) = update.get("message") {
|
||||
let sender = msg["from"]["username"]
|
||||
.as_str()
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
|
||||
if !self.allowed_users.is_empty() && !self.allowed_users.contains(&sender) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let channel_msg = ChannelMessage {
|
||||
id: msg["message_id"].to_string(),
|
||||
sender,
|
||||
content: msg["text"].as_str().unwrap_or("").to_string(),
|
||||
channel: "telegram".into(),
|
||||
timestamp: msg["date"].as_u64().unwrap_or(0),
|
||||
};
|
||||
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
offset = update["update_id"].as_i64().unwrap_or(offset) + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
self.client
|
||||
.get(&self.api_url("getMe"))
|
||||
.send()
|
||||
.await
|
||||
.map(|r| r.status().is_success())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
println!("This is an example — see CONTRIBUTING.md for integration steps.");
|
||||
println!("Add your channel config to ChannelsConfig in src/config/schema.rs");
|
||||
}
|
||||
160
examples/custom_memory.rs
Normal file
160
examples/custom_memory.rs
Normal file
|
|
@ -0,0 +1,160 @@
|
|||
//! Example: Implementing a custom Memory backend for ZeroClaw
|
||||
//!
|
||||
//! This demonstrates how to create a Redis-backed memory backend.
|
||||
//! The Memory trait is async and pluggable — implement it for any storage.
|
||||
//!
|
||||
//! Run: cargo run --example custom_memory
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Mutex;
|
||||
|
||||
// ── Re-define the trait types (in your app, import from zeroclaw::memory) ──
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum MemoryCategory {
|
||||
Core,
|
||||
Daily,
|
||||
Conversation,
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MemoryEntry {
|
||||
pub id: String,
|
||||
pub key: String,
|
||||
pub content: String,
|
||||
pub category: MemoryCategory,
|
||||
pub timestamp: String,
|
||||
pub score: Option<f64>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Memory: Send + Sync {
|
||||
fn name(&self) -> &str;
|
||||
async fn store(&self, key: &str, content: &str, category: MemoryCategory)
|
||||
-> anyhow::Result<()>;
|
||||
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>>;
|
||||
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>>;
|
||||
async fn forget(&self, key: &str) -> anyhow::Result<bool>;
|
||||
async fn count(&self) -> anyhow::Result<usize>;
|
||||
}
|
||||
|
||||
// ── Your custom implementation ─────────────────────────────────────
|
||||
|
||||
/// In-memory HashMap backend (great for testing or ephemeral sessions)
|
||||
pub struct InMemoryBackend {
|
||||
store: Mutex<HashMap<String, MemoryEntry>>,
|
||||
}
|
||||
|
||||
impl InMemoryBackend {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
store: Mutex::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Memory for InMemoryBackend {
|
||||
fn name(&self) -> &str {
|
||||
"in-memory"
|
||||
}
|
||||
|
||||
async fn store(
|
||||
&self,
|
||||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
) -> anyhow::Result<()> {
|
||||
let entry = MemoryEntry {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
key: key.to_string(),
|
||||
content: content.to_string(),
|
||||
category,
|
||||
timestamp: chrono::Local::now().to_rfc3339(),
|
||||
score: None,
|
||||
};
|
||||
self.store
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("{e}"))?
|
||||
.insert(key.to_string(), entry);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let store = self.store.lock().map_err(|e| anyhow::anyhow!("{e}"))?;
|
||||
let query_lower = query.to_lowercase();
|
||||
|
||||
let mut results: Vec<MemoryEntry> = store
|
||||
.values()
|
||||
.filter(|e| e.content.to_lowercase().contains(&query_lower))
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
results.truncate(limit);
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
|
||||
let store = self.store.lock().map_err(|e| anyhow::anyhow!("{e}"))?;
|
||||
Ok(store.get(key).cloned())
|
||||
}
|
||||
|
||||
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
|
||||
let mut store = self.store.lock().map_err(|e| anyhow::anyhow!("{e}"))?;
|
||||
Ok(store.remove(key).is_some())
|
||||
}
|
||||
|
||||
async fn count(&self) -> anyhow::Result<usize> {
|
||||
let store = self.store.lock().map_err(|e| anyhow::anyhow!("{e}"))?;
|
||||
Ok(store.len())
|
||||
}
|
||||
}
|
||||
|
||||
// ── Demo usage ─────────────────────────────────────────────────────
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let brain = InMemoryBackend::new();
|
||||
|
||||
println!("🧠 ZeroClaw Memory Demo — InMemoryBackend\n");
|
||||
|
||||
// Store some memories
|
||||
brain
|
||||
.store("user_lang", "User prefers Rust", MemoryCategory::Core)
|
||||
.await?;
|
||||
brain
|
||||
.store("user_tz", "Timezone is EST", MemoryCategory::Core)
|
||||
.await?;
|
||||
brain
|
||||
.store(
|
||||
"today_note",
|
||||
"Completed memory system implementation",
|
||||
MemoryCategory::Daily,
|
||||
)
|
||||
.await?;
|
||||
|
||||
println!("Stored {} memories", brain.count().await?);
|
||||
|
||||
// Recall by keyword
|
||||
let results = brain.recall("Rust", 5).await?;
|
||||
println!("\nRecall 'Rust' → {} results:", results.len());
|
||||
for entry in &results {
|
||||
println!(" [{:?}] {}: {}", entry.category, entry.key, entry.content);
|
||||
}
|
||||
|
||||
// Get by key
|
||||
if let Some(entry) = brain.get("user_tz").await? {
|
||||
println!("\nGet 'user_tz' → {}", entry.content);
|
||||
}
|
||||
|
||||
// Forget
|
||||
let removed = brain.forget("user_tz").await?;
|
||||
println!("Forget 'user_tz' → removed: {removed}");
|
||||
println!("Remaining: {} memories", brain.count().await?);
|
||||
|
||||
println!("\n✅ Memory backend works! Implement the Memory trait for any storage.");
|
||||
Ok(())
|
||||
}
|
||||
65
examples/custom_provider.rs
Normal file
65
examples/custom_provider.rs
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
//! Example: Implementing a custom Provider for ZeroClaw
|
||||
//!
|
||||
//! This shows how to add a new LLM backend in ~30 lines of code.
|
||||
//! Copy this file, modify the API call, and register in `src/providers/mod.rs`.
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
|
||||
// In a real implementation, you'd import from the crate:
|
||||
// use zeroclaw::providers::traits::Provider;
|
||||
|
||||
/// Minimal Provider trait (mirrors src/providers/traits.rs)
|
||||
#[async_trait]
|
||||
pub trait Provider: Send + Sync {
|
||||
async fn chat(&self, message: &str, model: &str, temperature: f64) -> Result<String>;
|
||||
}
|
||||
|
||||
/// Example: Ollama local provider
|
||||
pub struct OllamaProvider {
|
||||
base_url: String,
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl OllamaProvider {
|
||||
pub fn new(base_url: Option<&str>) -> Self {
|
||||
Self {
|
||||
base_url: base_url.unwrap_or("http://localhost:11434").to_string(),
|
||||
client: reqwest::Client::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for OllamaProvider {
|
||||
async fn chat(&self, message: &str, model: &str, temperature: f64) -> Result<String> {
|
||||
let url = format!("{}/api/generate", self.base_url);
|
||||
|
||||
let body = serde_json::json!({
|
||||
"model": model,
|
||||
"prompt": message,
|
||||
"temperature": temperature,
|
||||
"stream": false,
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?
|
||||
.json::<serde_json::Value>()
|
||||
.await?;
|
||||
|
||||
resp["response"]
|
||||
.as_str()
|
||||
.map(|s| s.to_string())
|
||||
.ok_or_else(|| anyhow::anyhow!("No response field in Ollama reply"))
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
println!("This is an example — see CONTRIBUTING.md for integration steps.");
|
||||
println!("Register your provider in src/providers/mod.rs:");
|
||||
println!(" \"ollama\" => Ok(Box::new(ollama::OllamaProvider::new(None))),");
|
||||
}
|
||||
76
examples/custom_tool.rs
Normal file
76
examples/custom_tool.rs
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
//! Example: Implementing a custom Tool for ZeroClaw
|
||||
//!
|
||||
//! This shows how to add a new tool the agent can use.
|
||||
//! Tools are the agent's hands — they let it interact with the world.
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
/// Mirrors src/tools/traits.rs
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct ToolResult {
|
||||
pub success: bool,
|
||||
pub output: String,
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Tool: Send + Sync {
|
||||
fn name(&self) -> &str;
|
||||
fn description(&self) -> &str;
|
||||
fn parameters_schema(&self) -> Value;
|
||||
async fn execute(&self, args: Value) -> Result<ToolResult>;
|
||||
}
|
||||
|
||||
/// Example: A tool that fetches a URL and returns the status code
|
||||
pub struct HttpGetTool;
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for HttpGetTool {
|
||||
fn name(&self) -> &str {
|
||||
"http_get"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Fetch a URL and return the HTTP status code and content length"
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": { "type": "string", "description": "URL to fetch" }
|
||||
},
|
||||
"required": ["url"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: Value) -> Result<ToolResult> {
|
||||
let url = args["url"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'url' parameter"))?;
|
||||
|
||||
match reqwest::get(url).await {
|
||||
Ok(resp) => {
|
||||
let status = resp.status().as_u16();
|
||||
let len = resp.content_length().unwrap_or(0);
|
||||
Ok(ToolResult {
|
||||
success: status < 400,
|
||||
output: format!("HTTP {status} — {len} bytes"),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Request failed: {e}")),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
println!("This is an example — see CONTRIBUTING.md for integration steps.");
|
||||
println!("Register your tool in src/tools/mod.rs default_tools()");
|
||||
}
|
||||
182
src/agent/loop_.rs
Normal file
182
src/agent/loop_.rs
Normal file
|
|
@ -0,0 +1,182 @@
|
|||
use crate::config::Config;
|
||||
use crate::memory::{self, Memory, MemoryCategory};
|
||||
use crate::observability::{self, Observer, ObserverEvent};
|
||||
use crate::providers::{self, Provider};
|
||||
use crate::runtime;
|
||||
use crate::security::SecurityPolicy;
|
||||
use crate::tools;
|
||||
use anyhow::Result;
|
||||
use std::fmt::Write;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
/// Build context preamble by searching memory for relevant entries
|
||||
async fn build_context(mem: &dyn Memory, user_msg: &str) -> String {
|
||||
let mut context = String::new();
|
||||
|
||||
// Pull relevant memories for this message
|
||||
if let Ok(entries) = mem.recall(user_msg, 5).await {
|
||||
if !entries.is_empty() {
|
||||
context.push_str("[Memory context]\n");
|
||||
for entry in &entries {
|
||||
let _ = writeln!(context, "- {}: {}", entry.key, entry.content);
|
||||
}
|
||||
context.push('\n');
|
||||
}
|
||||
}
|
||||
|
||||
context
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn run(
|
||||
config: Config,
|
||||
message: Option<String>,
|
||||
provider_override: Option<String>,
|
||||
model_override: Option<String>,
|
||||
temperature: f64,
|
||||
) -> Result<()> {
|
||||
// ── Wire up agnostic subsystems ──────────────────────────────
|
||||
let observer: Arc<dyn Observer> =
|
||||
Arc::from(observability::create_observer(&config.observability));
|
||||
let _runtime = runtime::create_runtime(&config.runtime);
|
||||
let security = Arc::new(SecurityPolicy::from_config(
|
||||
&config.autonomy,
|
||||
&config.workspace_dir,
|
||||
));
|
||||
|
||||
// ── Memory (the brain) ────────────────────────────────────────
|
||||
let mem: Arc<dyn Memory> =
|
||||
Arc::from(memory::create_memory(&config.memory, &config.workspace_dir)?);
|
||||
tracing::info!(backend = mem.name(), "Memory initialized");
|
||||
|
||||
// ── Tools (including memory tools) ────────────────────────────
|
||||
let _tools = tools::all_tools(security, mem.clone());
|
||||
|
||||
// ── Resolve provider ─────────────────────────────────────────
|
||||
let provider_name = provider_override
|
||||
.as_deref()
|
||||
.or(config.default_provider.as_deref())
|
||||
.unwrap_or("openrouter");
|
||||
|
||||
let model_name = model_override
|
||||
.as_deref()
|
||||
.or(config.default_model.as_deref())
|
||||
.unwrap_or("anthropic/claude-sonnet-4-20250514");
|
||||
|
||||
let provider: Box<dyn Provider> =
|
||||
providers::create_provider(provider_name, config.api_key.as_deref())?;
|
||||
|
||||
observer.record_event(&ObserverEvent::AgentStart {
|
||||
provider: provider_name.to_string(),
|
||||
model: model_name.to_string(),
|
||||
});
|
||||
|
||||
// ── Build system prompt from workspace MD files (OpenClaw framework) ──
|
||||
let skills = crate::skills::load_skills(&config.workspace_dir);
|
||||
let tool_descs: Vec<(&str, &str)> = vec![
|
||||
("shell", "Execute terminal commands"),
|
||||
("file_read", "Read file contents"),
|
||||
("file_write", "Write file contents"),
|
||||
("memory_store", "Save to memory"),
|
||||
("memory_recall", "Search memory"),
|
||||
("memory_forget", "Delete a memory entry"),
|
||||
];
|
||||
let system_prompt = crate::channels::build_system_prompt(
|
||||
&config.workspace_dir,
|
||||
model_name,
|
||||
&tool_descs,
|
||||
&skills,
|
||||
);
|
||||
|
||||
// ── Execute ──────────────────────────────────────────────────
|
||||
let start = Instant::now();
|
||||
|
||||
if let Some(msg) = message {
|
||||
// Auto-save user message to memory
|
||||
if config.memory.auto_save {
|
||||
let _ = mem
|
||||
.store("user_msg", &msg, MemoryCategory::Conversation)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Inject memory context into user message
|
||||
let context = build_context(mem.as_ref(), &msg).await;
|
||||
let enriched = if context.is_empty() {
|
||||
msg.clone()
|
||||
} else {
|
||||
format!("{context}{msg}")
|
||||
};
|
||||
|
||||
let response = provider
|
||||
.chat_with_system(Some(&system_prompt), &enriched, model_name, temperature)
|
||||
.await?;
|
||||
println!("{response}");
|
||||
|
||||
// Auto-save assistant response to daily log
|
||||
if config.memory.auto_save {
|
||||
let summary = if response.len() > 100 {
|
||||
format!("{}...", &response[..100])
|
||||
} else {
|
||||
response.clone()
|
||||
};
|
||||
let _ = mem
|
||||
.store("assistant_resp", &summary, MemoryCategory::Daily)
|
||||
.await;
|
||||
}
|
||||
} else {
|
||||
println!("🦀 ZeroClaw Interactive Mode");
|
||||
println!("Type /quit to exit.\n");
|
||||
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel(32);
|
||||
let cli = crate::channels::CliChannel::new();
|
||||
|
||||
// Spawn listener
|
||||
let listen_handle = tokio::spawn(async move {
|
||||
let _ = crate::channels::Channel::listen(&cli, tx).await;
|
||||
});
|
||||
|
||||
while let Some(msg) = rx.recv().await {
|
||||
// Auto-save conversation turns
|
||||
if config.memory.auto_save {
|
||||
let _ = mem
|
||||
.store("user_msg", &msg.content, MemoryCategory::Conversation)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Inject memory context into user message
|
||||
let context = build_context(mem.as_ref(), &msg.content).await;
|
||||
let enriched = if context.is_empty() {
|
||||
msg.content.clone()
|
||||
} else {
|
||||
format!("{context}{}", msg.content)
|
||||
};
|
||||
|
||||
let response = provider
|
||||
.chat_with_system(Some(&system_prompt), &enriched, model_name, temperature)
|
||||
.await?;
|
||||
println!("\n{response}\n");
|
||||
|
||||
if config.memory.auto_save {
|
||||
let summary = if response.len() > 100 {
|
||||
format!("{}...", &response[..100])
|
||||
} else {
|
||||
response.clone()
|
||||
};
|
||||
let _ = mem
|
||||
.store("assistant_resp", &summary, MemoryCategory::Daily)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
listen_handle.abort();
|
||||
}
|
||||
|
||||
let duration = start.elapsed();
|
||||
observer.record_event(&ObserverEvent::AgentEnd {
|
||||
duration,
|
||||
tokens_used: None,
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
3
src/agent/mod.rs
Normal file
3
src/agent/mod.rs
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
pub mod loop_;
|
||||
|
||||
pub use loop_::run;
|
||||
117
src/channels/cli.rs
Normal file
117
src/channels/cli.rs
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
use super::traits::{Channel, ChannelMessage};
|
||||
use async_trait::async_trait;
|
||||
use tokio::io::{self, AsyncBufReadExt, BufReader};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// CLI channel — stdin/stdout, always available, zero deps
|
||||
pub struct CliChannel;
|
||||
|
||||
impl CliChannel {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Channel for CliChannel {
|
||||
fn name(&self) -> &str {
|
||||
"cli"
|
||||
}
|
||||
|
||||
async fn send(&self, message: &str, _recipient: &str) -> anyhow::Result<()> {
|
||||
println!("{message}");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
|
||||
let stdin = io::stdin();
|
||||
let reader = BufReader::new(stdin);
|
||||
let mut lines = reader.lines();
|
||||
|
||||
while let Ok(Some(line)) = lines.next_line().await {
|
||||
let line = line.trim().to_string();
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if line == "/quit" || line == "/exit" {
|
||||
break;
|
||||
}
|
||||
|
||||
let msg = ChannelMessage {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
sender: "user".to_string(),
|
||||
content: line,
|
||||
channel: "cli".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
};
|
||||
|
||||
if tx.send(msg).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn cli_channel_name() {
|
||||
assert_eq!(CliChannel::new().name(), "cli");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cli_channel_send_does_not_panic() {
|
||||
let ch = CliChannel::new();
|
||||
let result = ch.send("hello", "user").await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cli_channel_send_empty_message() {
|
||||
let ch = CliChannel::new();
|
||||
let result = ch.send("", "").await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cli_channel_health_check() {
|
||||
let ch = CliChannel::new();
|
||||
assert!(ch.health_check().await);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn channel_message_struct() {
|
||||
let msg = ChannelMessage {
|
||||
id: "test-id".into(),
|
||||
sender: "user".into(),
|
||||
content: "hello".into(),
|
||||
channel: "cli".into(),
|
||||
timestamp: 1234567890,
|
||||
};
|
||||
assert_eq!(msg.id, "test-id");
|
||||
assert_eq!(msg.sender, "user");
|
||||
assert_eq!(msg.content, "hello");
|
||||
assert_eq!(msg.channel, "cli");
|
||||
assert_eq!(msg.timestamp, 1234567890);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn channel_message_clone() {
|
||||
let msg = ChannelMessage {
|
||||
id: "id".into(),
|
||||
sender: "s".into(),
|
||||
content: "c".into(),
|
||||
channel: "ch".into(),
|
||||
timestamp: 0,
|
||||
};
|
||||
let cloned = msg.clone();
|
||||
assert_eq!(cloned.id, msg.id);
|
||||
assert_eq!(cloned.content, msg.content);
|
||||
}
|
||||
}
|
||||
271
src/channels/discord.rs
Normal file
271
src/channels/discord.rs
Normal file
|
|
@ -0,0 +1,271 @@
|
|||
use super::traits::{Channel, ChannelMessage};
|
||||
use async_trait::async_trait;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use serde_json::json;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Discord channel — connects via Gateway WebSocket for real-time messages
|
||||
pub struct DiscordChannel {
|
||||
bot_token: String,
|
||||
guild_id: Option<String>,
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl DiscordChannel {
|
||||
pub fn new(bot_token: String, guild_id: Option<String>) -> Self {
|
||||
Self {
|
||||
bot_token,
|
||||
guild_id,
|
||||
client: reqwest::Client::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn bot_user_id_from_token(token: &str) -> Option<String> {
|
||||
// Discord bot tokens are base64(bot_user_id).timestamp.hmac
|
||||
let part = token.split('.').next()?;
|
||||
base64_decode(part)
|
||||
}
|
||||
}
|
||||
|
||||
const BASE64_ALPHABET: &[u8] =
|
||||
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
||||
|
||||
/// Minimal base64 decode (no extra dep) — only needs to decode the user ID portion
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
fn base64_decode(input: &str) -> Option<String> {
|
||||
let padded = match input.len() % 4 {
|
||||
2 => format!("{input}=="),
|
||||
3 => format!("{input}="),
|
||||
_ => input.to_string(),
|
||||
};
|
||||
|
||||
let mut bytes = Vec::new();
|
||||
let chars: Vec<u8> = padded.bytes().collect();
|
||||
|
||||
for chunk in chars.chunks(4) {
|
||||
if chunk.len() < 4 {
|
||||
break;
|
||||
}
|
||||
|
||||
let mut v = [0usize; 4];
|
||||
for (i, &b) in chunk.iter().enumerate() {
|
||||
if b == b'=' {
|
||||
v[i] = 0;
|
||||
} else {
|
||||
v[i] = BASE64_ALPHABET.iter().position(|&a| a == b)?;
|
||||
}
|
||||
}
|
||||
|
||||
bytes.push(((v[0] << 2) | (v[1] >> 4)) as u8);
|
||||
if chunk[2] != b'=' {
|
||||
bytes.push((((v[1] & 0xF) << 4) | (v[2] >> 2)) as u8);
|
||||
}
|
||||
if chunk[3] != b'=' {
|
||||
bytes.push((((v[2] & 0x3) << 6) | v[3]) as u8);
|
||||
}
|
||||
}
|
||||
|
||||
String::from_utf8(bytes).ok()
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Channel for DiscordChannel {
|
||||
fn name(&self) -> &str {
|
||||
"discord"
|
||||
}
|
||||
|
||||
async fn send(&self, message: &str, channel_id: &str) -> anyhow::Result<()> {
|
||||
let url = format!("https://discord.com/api/v10/channels/{channel_id}/messages");
|
||||
let body = json!({ "content": message });
|
||||
|
||||
self.client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bot {}", self.bot_token))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
|
||||
let bot_user_id = Self::bot_user_id_from_token(&self.bot_token).unwrap_or_default();
|
||||
|
||||
// Get Gateway URL
|
||||
let gw_resp: serde_json::Value = self
|
||||
.client
|
||||
.get("https://discord.com/api/v10/gateway/bot")
|
||||
.header("Authorization", format!("Bot {}", self.bot_token))
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
|
||||
let gw_url = gw_resp
|
||||
.get("url")
|
||||
.and_then(|u| u.as_str())
|
||||
.unwrap_or("wss://gateway.discord.gg");
|
||||
|
||||
let ws_url = format!("{gw_url}/?v=10&encoding=json");
|
||||
tracing::info!("Discord: connecting to gateway...");
|
||||
|
||||
let (ws_stream, _) = tokio_tungstenite::connect_async(&ws_url).await?;
|
||||
let (mut write, mut read) = ws_stream.split();
|
||||
|
||||
// Read Hello (opcode 10)
|
||||
let hello = read.next().await.ok_or(anyhow::anyhow!("No hello"))??;
|
||||
let hello_data: serde_json::Value = serde_json::from_str(&hello.to_string())?;
|
||||
let heartbeat_interval = hello_data
|
||||
.get("d")
|
||||
.and_then(|d| d.get("heartbeat_interval"))
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.unwrap_or(41250);
|
||||
|
||||
// Send Identify (opcode 2)
|
||||
let identify = json!({
|
||||
"op": 2,
|
||||
"d": {
|
||||
"token": self.bot_token,
|
||||
"intents": 33281, // GUILDS | GUILD_MESSAGES | MESSAGE_CONTENT | DIRECT_MESSAGES
|
||||
"properties": {
|
||||
"os": "linux",
|
||||
"browser": "zeroclaw",
|
||||
"device": "zeroclaw"
|
||||
}
|
||||
}
|
||||
});
|
||||
write.send(Message::Text(identify.to_string())).await?;
|
||||
|
||||
tracing::info!("Discord: connected and identified");
|
||||
|
||||
// Spawn heartbeat task
|
||||
let (hb_tx, mut hb_rx) = tokio::sync::mpsc::channel::<()>(1);
|
||||
let hb_interval = heartbeat_interval;
|
||||
tokio::spawn(async move {
|
||||
let mut interval =
|
||||
tokio::time::interval(std::time::Duration::from_millis(hb_interval));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
if hb_tx.send(()).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let guild_filter = self.guild_id.clone();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = hb_rx.recv() => {
|
||||
let hb = json!({"op": 1, "d": null});
|
||||
if write.send(Message::Text(hb.to_string())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
msg = read.next() => {
|
||||
let msg = match msg {
|
||||
Some(Ok(Message::Text(t))) => t,
|
||||
Some(Ok(Message::Close(_))) | None => break,
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
let event: serde_json::Value = match serde_json::from_str(&msg) {
|
||||
Ok(e) => e,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
// Only handle MESSAGE_CREATE (opcode 0, type "MESSAGE_CREATE")
|
||||
let event_type = event.get("t").and_then(|t| t.as_str()).unwrap_or("");
|
||||
if event_type != "MESSAGE_CREATE" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some(d) = event.get("d") else {
|
||||
continue;
|
||||
};
|
||||
|
||||
// Skip messages from the bot itself
|
||||
let author_id = d.get("author").and_then(|a| a.get("id")).and_then(|i| i.as_str()).unwrap_or("");
|
||||
if author_id == bot_user_id {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip bot messages
|
||||
if d.get("author").and_then(|a| a.get("bot")).and_then(serde_json::Value::as_bool).unwrap_or(false) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Guild filter
|
||||
if let Some(ref gid) = guild_filter {
|
||||
let msg_guild = d.get("guild_id").and_then(serde_json::Value::as_str).unwrap_or("");
|
||||
if msg_guild != gid {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
let content = d.get("content").and_then(|c| c.as_str()).unwrap_or("");
|
||||
if content.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let channel_id = d.get("channel_id").and_then(|c| c.as_str()).unwrap_or("").to_string();
|
||||
|
||||
let channel_msg = ChannelMessage {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
sender: channel_id,
|
||||
content: content.to_string(),
|
||||
channel: "discord".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
};
|
||||
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
self.client
|
||||
.get("https://discord.com/api/v10/users/@me")
|
||||
.header("Authorization", format!("Bot {}", self.bot_token))
|
||||
.send()
|
||||
.await
|
||||
.map(|r| r.status().is_success())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn discord_channel_name() {
|
||||
let ch = DiscordChannel::new("fake".into(), None);
|
||||
assert_eq!(ch.name(), "discord");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn base64_decode_bot_id() {
|
||||
// "MTIzNDU2" decodes to "123456"
|
||||
let decoded = base64_decode("MTIzNDU2");
|
||||
assert_eq!(decoded, Some("123456".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bot_user_id_extraction() {
|
||||
// Token format: base64(user_id).timestamp.hmac
|
||||
let token = "MTIzNDU2.fake.hmac";
|
||||
let id = DiscordChannel::bot_user_id_from_token(token);
|
||||
assert_eq!(id, Some("123456".to_string()));
|
||||
}
|
||||
}
|
||||
265
src/channels/imessage.rs
Normal file
265
src/channels/imessage.rs
Normal file
|
|
@ -0,0 +1,265 @@
|
|||
use crate::channels::traits::{Channel, ChannelMessage};
|
||||
use async_trait::async_trait;
|
||||
use directories::UserDirs;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// iMessage channel using macOS `AppleScript` bridge.
|
||||
/// Polls the Messages database for new messages and sends replies via `osascript`.
|
||||
#[derive(Clone)]
|
||||
pub struct IMessageChannel {
|
||||
allowed_contacts: Vec<String>,
|
||||
poll_interval_secs: u64,
|
||||
}
|
||||
|
||||
impl IMessageChannel {
|
||||
pub fn new(allowed_contacts: Vec<String>) -> Self {
|
||||
Self {
|
||||
allowed_contacts,
|
||||
poll_interval_secs: 3,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_contact_allowed(&self, sender: &str) -> bool {
|
||||
if self.allowed_contacts.iter().any(|u| u == "*") {
|
||||
return true;
|
||||
}
|
||||
self.allowed_contacts.iter().any(|u| {
|
||||
u.eq_ignore_ascii_case(sender)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Channel for IMessageChannel {
|
||||
fn name(&self) -> &str {
|
||||
"imessage"
|
||||
}
|
||||
|
||||
async fn send(&self, message: &str, target: &str) -> anyhow::Result<()> {
|
||||
let escaped_msg = message.replace('\\', "\\\\").replace('"', "\\\"");
|
||||
let script = format!(
|
||||
r#"tell application "Messages"
|
||||
set targetService to 1st account whose service type = iMessage
|
||||
set targetBuddy to participant "{target}" of targetService
|
||||
send "{escaped_msg}" to targetBuddy
|
||||
end tell"#
|
||||
);
|
||||
|
||||
let output = tokio::process::Command::new("osascript")
|
||||
.arg("-e")
|
||||
.arg(&script)
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
anyhow::bail!("iMessage send failed: {stderr}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn listen(&self, tx: mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
|
||||
tracing::info!("iMessage channel listening (AppleScript bridge)...");
|
||||
|
||||
// Query the Messages SQLite database for new messages
|
||||
// The database is at ~/Library/Messages/chat.db
|
||||
let db_path = UserDirs::new()
|
||||
.map(|u| u.home_dir().join("Library/Messages/chat.db"))
|
||||
.ok_or_else(|| anyhow::anyhow!("Cannot find home directory"))?;
|
||||
|
||||
if !db_path.exists() {
|
||||
anyhow::bail!(
|
||||
"Messages database not found at {}. Ensure Messages.app is set up and Full Disk Access is granted.",
|
||||
db_path.display()
|
||||
);
|
||||
}
|
||||
|
||||
// Track the last ROWID we've seen
|
||||
let mut last_rowid = get_max_rowid(&db_path).await.unwrap_or(0);
|
||||
|
||||
loop {
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(self.poll_interval_secs)).await;
|
||||
|
||||
let new_messages = fetch_new_messages(&db_path, last_rowid).await;
|
||||
|
||||
match new_messages {
|
||||
Ok(messages) => {
|
||||
for (rowid, sender, text) in messages {
|
||||
if rowid > last_rowid {
|
||||
last_rowid = rowid;
|
||||
}
|
||||
|
||||
if !self.is_contact_allowed(&sender) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if text.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let msg = ChannelMessage {
|
||||
id: rowid.to_string(),
|
||||
sender: sender.clone(),
|
||||
content: text,
|
||||
channel: "imessage".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
};
|
||||
|
||||
if tx.send(msg).await.is_err() {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("iMessage poll error: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
if !cfg!(target_os = "macos") {
|
||||
return false;
|
||||
}
|
||||
|
||||
let db_path = UserDirs::new()
|
||||
.map(|u| u.home_dir().join("Library/Messages/chat.db"))
|
||||
.unwrap_or_default();
|
||||
|
||||
db_path.exists()
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the current max ROWID from the messages table
|
||||
async fn get_max_rowid(db_path: &std::path::Path) -> anyhow::Result<i64> {
|
||||
let output = tokio::process::Command::new("sqlite3")
|
||||
.arg(db_path)
|
||||
.arg("SELECT MAX(ROWID) FROM message WHERE is_from_me = 0;")
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let rowid = stdout.trim().parse::<i64>().unwrap_or(0);
|
||||
Ok(rowid)
|
||||
}
|
||||
|
||||
/// Fetch messages newer than `since_rowid`
|
||||
async fn fetch_new_messages(
|
||||
db_path: &std::path::Path,
|
||||
since_rowid: i64,
|
||||
) -> anyhow::Result<Vec<(i64, String, String)>> {
|
||||
let query = format!(
|
||||
"SELECT m.ROWID, h.id, m.text \
|
||||
FROM message m \
|
||||
JOIN handle h ON m.handle_id = h.ROWID \
|
||||
WHERE m.ROWID > {since_rowid} \
|
||||
AND m.is_from_me = 0 \
|
||||
AND m.text IS NOT NULL \
|
||||
ORDER BY m.ROWID ASC \
|
||||
LIMIT 20;"
|
||||
);
|
||||
|
||||
let output = tokio::process::Command::new("sqlite3")
|
||||
.arg("-separator")
|
||||
.arg("|")
|
||||
.arg(db_path)
|
||||
.arg(&query)
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
anyhow::bail!("sqlite3 query failed: {stderr}");
|
||||
}
|
||||
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let mut results = Vec::new();
|
||||
|
||||
for line in stdout.lines() {
|
||||
let parts: Vec<&str> = line.splitn(3, '|').collect();
|
||||
if parts.len() == 3 {
|
||||
if let Ok(rowid) = parts[0].parse::<i64>() {
|
||||
results.push((rowid, parts[1].to_string(), parts[2].to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn creates_with_contacts() {
|
||||
let ch = IMessageChannel::new(vec!["+1234567890".into()]);
|
||||
assert_eq!(ch.allowed_contacts.len(), 1);
|
||||
assert_eq!(ch.poll_interval_secs, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_with_empty_contacts() {
|
||||
let ch = IMessageChannel::new(vec![]);
|
||||
assert!(ch.allowed_contacts.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wildcard_allows_anyone() {
|
||||
let ch = IMessageChannel::new(vec!["*".into()]);
|
||||
assert!(ch.is_contact_allowed("+1234567890"));
|
||||
assert!(ch.is_contact_allowed("random@icloud.com"));
|
||||
assert!(ch.is_contact_allowed(""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn specific_contact_allowed() {
|
||||
let ch = IMessageChannel::new(vec!["+1234567890".into(), "user@icloud.com".into()]);
|
||||
assert!(ch.is_contact_allowed("+1234567890"));
|
||||
assert!(ch.is_contact_allowed("user@icloud.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unknown_contact_denied() {
|
||||
let ch = IMessageChannel::new(vec!["+1234567890".into()]);
|
||||
assert!(!ch.is_contact_allowed("+9999999999"));
|
||||
assert!(!ch.is_contact_allowed("hacker@evil.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn contact_case_insensitive() {
|
||||
let ch = IMessageChannel::new(vec!["User@iCloud.com".into()]);
|
||||
assert!(ch.is_contact_allowed("user@icloud.com"));
|
||||
assert!(ch.is_contact_allowed("USER@ICLOUD.COM"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_allowlist_denies_all() {
|
||||
let ch = IMessageChannel::new(vec![]);
|
||||
assert!(!ch.is_contact_allowed("+1234567890"));
|
||||
assert!(!ch.is_contact_allowed("anyone"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn name_returns_imessage() {
|
||||
let ch = IMessageChannel::new(vec![]);
|
||||
assert_eq!(ch.name(), "imessage");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wildcard_among_others_still_allows_all() {
|
||||
let ch = IMessageChannel::new(vec!["+111".into(), "*".into(), "+222".into()]);
|
||||
assert!(ch.is_contact_allowed("totally-unknown"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn contact_with_spaces_exact_match() {
|
||||
let ch = IMessageChannel::new(vec![" spaced ".into()]);
|
||||
assert!(ch.is_contact_allowed(" spaced "));
|
||||
assert!(!ch.is_contact_allowed("spaced"));
|
||||
}
|
||||
}
|
||||
467
src/channels/matrix.rs
Normal file
467
src/channels/matrix.rs
Normal file
|
|
@ -0,0 +1,467 @@
|
|||
use crate::channels::traits::{Channel, ChannelMessage};
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::Deserialize;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// Matrix channel using the Client-Server API (no SDK needed).
|
||||
/// Connects to any Matrix homeserver (Element, Synapse, etc.).
|
||||
#[derive(Clone)]
|
||||
pub struct MatrixChannel {
|
||||
homeserver: String,
|
||||
access_token: String,
|
||||
room_id: String,
|
||||
allowed_users: Vec<String>,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SyncResponse {
|
||||
next_batch: String,
|
||||
#[serde(default)]
|
||||
rooms: Rooms,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Default)]
|
||||
struct Rooms {
|
||||
#[serde(default)]
|
||||
join: std::collections::HashMap<String, JoinedRoom>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct JoinedRoom {
|
||||
#[serde(default)]
|
||||
timeline: Timeline,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Default)]
|
||||
struct Timeline {
|
||||
#[serde(default)]
|
||||
events: Vec<TimelineEvent>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TimelineEvent {
|
||||
#[serde(rename = "type")]
|
||||
event_type: String,
|
||||
sender: String,
|
||||
#[serde(default)]
|
||||
content: EventContent,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Default)]
|
||||
struct EventContent {
|
||||
#[serde(default)]
|
||||
body: Option<String>,
|
||||
#[serde(default)]
|
||||
msgtype: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WhoAmIResponse {
|
||||
user_id: String,
|
||||
}
|
||||
|
||||
impl MatrixChannel {
|
||||
pub fn new(
|
||||
homeserver: String,
|
||||
access_token: String,
|
||||
room_id: String,
|
||||
allowed_users: Vec<String>,
|
||||
) -> Self {
|
||||
let homeserver = if homeserver.ends_with('/') {
|
||||
homeserver[..homeserver.len() - 1].to_string()
|
||||
} else {
|
||||
homeserver
|
||||
};
|
||||
Self {
|
||||
homeserver,
|
||||
access_token,
|
||||
room_id,
|
||||
allowed_users,
|
||||
client: Client::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_user_allowed(&self, sender: &str) -> bool {
|
||||
if self.allowed_users.iter().any(|u| u == "*") {
|
||||
return true;
|
||||
}
|
||||
self.allowed_users
|
||||
.iter()
|
||||
.any(|u| u.eq_ignore_ascii_case(sender))
|
||||
}
|
||||
|
||||
async fn get_my_user_id(&self) -> anyhow::Result<String> {
|
||||
let url = format!(
|
||||
"{}/_matrix/client/v3/account/whoami",
|
||||
self.homeserver
|
||||
);
|
||||
let resp = self
|
||||
.client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.access_token))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let err = resp.text().await?;
|
||||
anyhow::bail!("Matrix whoami failed: {err}");
|
||||
}
|
||||
|
||||
let who: WhoAmIResponse = resp.json().await?;
|
||||
Ok(who.user_id)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Channel for MatrixChannel {
|
||||
fn name(&self) -> &str {
|
||||
"matrix"
|
||||
}
|
||||
|
||||
async fn send(&self, message: &str, _target: &str) -> anyhow::Result<()> {
|
||||
let txn_id = format!("zc_{}", chrono::Utc::now().timestamp_millis());
|
||||
let url = format!(
|
||||
"{}/_matrix/client/v3/rooms/{}/send/m.room.message/{}",
|
||||
self.homeserver, self.room_id, txn_id
|
||||
);
|
||||
|
||||
let body = serde_json::json!({
|
||||
"msgtype": "m.text",
|
||||
"body": message
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.put(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.access_token))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let err = resp.text().await?;
|
||||
anyhow::bail!("Matrix send failed: {err}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn listen(&self, tx: mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
|
||||
tracing::info!("Matrix channel listening on room {}...", self.room_id);
|
||||
|
||||
let my_user_id = self.get_my_user_id().await?;
|
||||
|
||||
// Initial sync to get the since token
|
||||
let url = format!(
|
||||
"{}/_matrix/client/v3/sync?timeout=30000&filter={{\"room\":{{\"timeline\":{{\"limit\":1}}}}}}",
|
||||
self.homeserver
|
||||
);
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.access_token))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let err = resp.text().await?;
|
||||
anyhow::bail!("Matrix initial sync failed: {err}");
|
||||
}
|
||||
|
||||
let sync: SyncResponse = resp.json().await?;
|
||||
let mut since = sync.next_batch;
|
||||
|
||||
// Long-poll loop
|
||||
loop {
|
||||
let url = format!(
|
||||
"{}/_matrix/client/v3/sync?since={}&timeout=30000",
|
||||
self.homeserver, since
|
||||
);
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.access_token))
|
||||
.send()
|
||||
.await;
|
||||
|
||||
let resp = match resp {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::warn!("Matrix sync error: {e}, retrying...");
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if !resp.status().is_success() {
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
||||
continue;
|
||||
}
|
||||
|
||||
let sync: SyncResponse = resp.json().await?;
|
||||
since = sync.next_batch;
|
||||
|
||||
// Process events from our room
|
||||
if let Some(room) = sync.rooms.join.get(&self.room_id) {
|
||||
for event in &room.timeline.events {
|
||||
// Skip our own messages
|
||||
if event.sender == my_user_id {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Only process text messages
|
||||
if event.event_type != "m.room.message" {
|
||||
continue;
|
||||
}
|
||||
|
||||
if event.content.msgtype.as_deref() != Some("m.text") {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some(ref body) = event.content.body else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if !self.is_user_allowed(&event.sender) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let msg = ChannelMessage {
|
||||
id: format!("mx_{}", chrono::Utc::now().timestamp_millis()),
|
||||
sender: event.sender.clone(),
|
||||
content: body.clone(),
|
||||
channel: "matrix".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
};
|
||||
|
||||
if tx.send(msg).await.is_err() {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
let url = format!(
|
||||
"{}/_matrix/client/v3/account/whoami",
|
||||
self.homeserver
|
||||
);
|
||||
let Ok(resp) = self
|
||||
.client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.access_token))
|
||||
.send()
|
||||
.await
|
||||
else {
|
||||
return false;
|
||||
};
|
||||
|
||||
resp.status().is_success()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_channel() -> MatrixChannel {
|
||||
MatrixChannel::new(
|
||||
"https://matrix.org".to_string(),
|
||||
"syt_test_token".to_string(),
|
||||
"!room:matrix.org".to_string(),
|
||||
vec!["@user:matrix.org".to_string()],
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_with_correct_fields() {
|
||||
let ch = make_channel();
|
||||
assert_eq!(ch.homeserver, "https://matrix.org");
|
||||
assert_eq!(ch.access_token, "syt_test_token");
|
||||
assert_eq!(ch.room_id, "!room:matrix.org");
|
||||
assert_eq!(ch.allowed_users.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strips_trailing_slash() {
|
||||
let ch = MatrixChannel::new(
|
||||
"https://matrix.org/".to_string(),
|
||||
"tok".to_string(),
|
||||
"!r:m".to_string(),
|
||||
vec![],
|
||||
);
|
||||
assert_eq!(ch.homeserver, "https://matrix.org");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_trailing_slash_unchanged() {
|
||||
let ch = MatrixChannel::new(
|
||||
"https://matrix.org".to_string(),
|
||||
"tok".to_string(),
|
||||
"!r:m".to_string(),
|
||||
vec![],
|
||||
);
|
||||
assert_eq!(ch.homeserver, "https://matrix.org");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multiple_trailing_slashes_strips_one() {
|
||||
let ch = MatrixChannel::new(
|
||||
"https://matrix.org//".to_string(),
|
||||
"tok".to_string(),
|
||||
"!r:m".to_string(),
|
||||
vec![],
|
||||
);
|
||||
assert_eq!(ch.homeserver, "https://matrix.org/");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wildcard_allows_anyone() {
|
||||
let ch = MatrixChannel::new(
|
||||
"https://m.org".to_string(),
|
||||
"tok".to_string(),
|
||||
"!r:m".to_string(),
|
||||
vec!["*".to_string()],
|
||||
);
|
||||
assert!(ch.is_user_allowed("@anyone:matrix.org"));
|
||||
assert!(ch.is_user_allowed("@hacker:evil.org"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn specific_user_allowed() {
|
||||
let ch = make_channel();
|
||||
assert!(ch.is_user_allowed("@user:matrix.org"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unknown_user_denied() {
|
||||
let ch = make_channel();
|
||||
assert!(!ch.is_user_allowed("@stranger:matrix.org"));
|
||||
assert!(!ch.is_user_allowed("@evil:hacker.org"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn user_case_insensitive() {
|
||||
let ch = MatrixChannel::new(
|
||||
"https://m.org".to_string(),
|
||||
"tok".to_string(),
|
||||
"!r:m".to_string(),
|
||||
vec!["@User:Matrix.org".to_string()],
|
||||
);
|
||||
assert!(ch.is_user_allowed("@user:matrix.org"));
|
||||
assert!(ch.is_user_allowed("@USER:MATRIX.ORG"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_allowlist_denies_all() {
|
||||
let ch = MatrixChannel::new(
|
||||
"https://m.org".to_string(),
|
||||
"tok".to_string(),
|
||||
"!r:m".to_string(),
|
||||
vec![],
|
||||
);
|
||||
assert!(!ch.is_user_allowed("@anyone:matrix.org"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn name_returns_matrix() {
|
||||
let ch = make_channel();
|
||||
assert_eq!(ch.name(), "matrix");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_response_deserializes_empty() {
|
||||
let json = r#"{"next_batch":"s123","rooms":{"join":{}}}"#;
|
||||
let resp: SyncResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.next_batch, "s123");
|
||||
assert!(resp.rooms.join.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_response_deserializes_with_events() {
|
||||
let json = r#"{
|
||||
"next_batch": "s456",
|
||||
"rooms": {
|
||||
"join": {
|
||||
"!room:matrix.org": {
|
||||
"timeline": {
|
||||
"events": [
|
||||
{
|
||||
"type": "m.room.message",
|
||||
"sender": "@user:matrix.org",
|
||||
"content": {
|
||||
"msgtype": "m.text",
|
||||
"body": "Hello!"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}"#;
|
||||
let resp: SyncResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.next_batch, "s456");
|
||||
let room = resp.rooms.join.get("!room:matrix.org").unwrap();
|
||||
assert_eq!(room.timeline.events.len(), 1);
|
||||
assert_eq!(room.timeline.events[0].sender, "@user:matrix.org");
|
||||
assert_eq!(room.timeline.events[0].content.body.as_deref(), Some("Hello!"));
|
||||
assert_eq!(room.timeline.events[0].content.msgtype.as_deref(), Some("m.text"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_response_ignores_non_text_events() {
|
||||
let json = r#"{
|
||||
"next_batch": "s789",
|
||||
"rooms": {
|
||||
"join": {
|
||||
"!room:m": {
|
||||
"timeline": {
|
||||
"events": [
|
||||
{
|
||||
"type": "m.room.member",
|
||||
"sender": "@user:m",
|
||||
"content": {}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}"#;
|
||||
let resp: SyncResponse = serde_json::from_str(json).unwrap();
|
||||
let room = resp.rooms.join.get("!room:m").unwrap();
|
||||
assert_eq!(room.timeline.events[0].event_type, "m.room.member");
|
||||
assert!(room.timeline.events[0].content.body.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn whoami_response_deserializes() {
|
||||
let json = r#"{"user_id":"@bot:matrix.org"}"#;
|
||||
let resp: WhoAmIResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.user_id, "@bot:matrix.org");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_content_defaults() {
|
||||
let json = r#"{"type":"m.room.message","sender":"@u:m","content":{}}"#;
|
||||
let event: TimelineEvent = serde_json::from_str(json).unwrap();
|
||||
assert!(event.content.body.is_none());
|
||||
assert!(event.content.msgtype.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_response_missing_rooms_defaults() {
|
||||
let json = r#"{"next_batch":"s0"}"#;
|
||||
let resp: SyncResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.rooms.join.is_empty());
|
||||
}
|
||||
}
|
||||
550
src/channels/mod.rs
Normal file
550
src/channels/mod.rs
Normal file
|
|
@ -0,0 +1,550 @@
|
|||
pub mod cli;
|
||||
pub mod discord;
|
||||
pub mod imessage;
|
||||
pub mod matrix;
|
||||
pub mod slack;
|
||||
pub mod telegram;
|
||||
pub mod traits;
|
||||
|
||||
pub use cli::CliChannel;
|
||||
pub use discord::DiscordChannel;
|
||||
pub use imessage::IMessageChannel;
|
||||
pub use matrix::MatrixChannel;
|
||||
pub use slack::SlackChannel;
|
||||
pub use telegram::TelegramChannel;
|
||||
pub use traits::Channel;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::memory::{self, Memory};
|
||||
use crate::providers::{self, Provider};
|
||||
use anyhow::Result;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Maximum characters per injected workspace file (matches `OpenClaw` default).
|
||||
const BOOTSTRAP_MAX_CHARS: usize = 20_000;
|
||||
|
||||
/// Load workspace identity files and build a system prompt.
|
||||
///
|
||||
/// Follows the `OpenClaw` framework structure:
|
||||
/// 1. Tooling — tool list + descriptions
|
||||
/// 2. Safety — guardrail reminder
|
||||
/// 3. Skills — compact list with paths (loaded on-demand)
|
||||
/// 4. Workspace — working directory
|
||||
/// 5. Bootstrap files — AGENTS, SOUL, TOOLS, IDENTITY, USER, HEARTBEAT, BOOTSTRAP, MEMORY
|
||||
/// 6. Date & Time — timezone for cache stability
|
||||
/// 7. Runtime — host, OS, model
|
||||
///
|
||||
/// Daily memory files (`memory/*.md`) are NOT injected — they are accessed
|
||||
/// on-demand via `memory_recall` / `memory_search` tools.
|
||||
pub fn build_system_prompt(
|
||||
workspace_dir: &std::path::Path,
|
||||
model_name: &str,
|
||||
tools: &[(&str, &str)],
|
||||
skills: &[crate::skills::Skill],
|
||||
) -> String {
|
||||
use std::fmt::Write;
|
||||
let mut prompt = String::with_capacity(8192);
|
||||
|
||||
// ── 1. Tooling ──────────────────────────────────────────────
|
||||
if !tools.is_empty() {
|
||||
prompt.push_str("## Tools\n\n");
|
||||
prompt.push_str("You have access to the following tools:\n\n");
|
||||
for (name, desc) in tools {
|
||||
let _ = writeln!(prompt, "- **{name}**: {desc}");
|
||||
}
|
||||
prompt.push('\n');
|
||||
}
|
||||
|
||||
// ── 2. Safety ───────────────────────────────────────────────
|
||||
prompt.push_str("## Safety\n\n");
|
||||
prompt.push_str(
|
||||
"- Do not exfiltrate private data.\n\
|
||||
- Do not run destructive commands without asking.\n\
|
||||
- Do not bypass oversight or approval mechanisms.\n\
|
||||
- Prefer `trash` over `rm` (recoverable beats gone forever).\n\
|
||||
- When in doubt, ask before acting externally.\n\n",
|
||||
);
|
||||
|
||||
// ── 3. Skills (compact list — load on-demand) ───────────────
|
||||
if !skills.is_empty() {
|
||||
prompt.push_str("## Available Skills\n\n");
|
||||
prompt.push_str(
|
||||
"Skills are loaded on demand. Use `read` on the skill path to get full instructions.\n\n",
|
||||
);
|
||||
prompt.push_str("<available_skills>\n");
|
||||
for skill in skills {
|
||||
let _ = writeln!(prompt, " <skill>");
|
||||
let _ = writeln!(prompt, " <name>{}</name>", skill.name);
|
||||
let _ = writeln!(prompt, " <description>{}</description>", skill.description);
|
||||
let location = workspace_dir.join("skills").join(&skill.name).join("SKILL.md");
|
||||
let _ = writeln!(prompt, " <location>{}</location>", location.display());
|
||||
let _ = writeln!(prompt, " </skill>");
|
||||
}
|
||||
prompt.push_str("</available_skills>\n\n");
|
||||
}
|
||||
|
||||
// ── 4. Workspace ────────────────────────────────────────────
|
||||
let _ = writeln!(prompt, "## Workspace\n\nWorking directory: `{}`\n", workspace_dir.display());
|
||||
|
||||
// ── 5. Bootstrap files (injected into context) ──────────────
|
||||
prompt.push_str("## Project Context\n\n");
|
||||
prompt.push_str("The following workspace files define your identity, behavior, and context.\n\n");
|
||||
|
||||
let bootstrap_files = [
|
||||
"AGENTS.md",
|
||||
"SOUL.md",
|
||||
"TOOLS.md",
|
||||
"IDENTITY.md",
|
||||
"USER.md",
|
||||
"HEARTBEAT.md",
|
||||
];
|
||||
|
||||
for filename in &bootstrap_files {
|
||||
inject_workspace_file(&mut prompt, workspace_dir, filename);
|
||||
}
|
||||
|
||||
// BOOTSTRAP.md — only if it exists (first-run ritual)
|
||||
let bootstrap_path = workspace_dir.join("BOOTSTRAP.md");
|
||||
if bootstrap_path.exists() {
|
||||
inject_workspace_file(&mut prompt, workspace_dir, "BOOTSTRAP.md");
|
||||
}
|
||||
|
||||
// MEMORY.md — curated long-term memory (main session only)
|
||||
inject_workspace_file(&mut prompt, workspace_dir, "MEMORY.md");
|
||||
|
||||
// ── 6. Date & Time ──────────────────────────────────────────
|
||||
let now = chrono::Local::now();
|
||||
let tz = now.format("%Z").to_string();
|
||||
let _ = writeln!(prompt, "## Current Date & Time\n\nTimezone: {tz}\n");
|
||||
|
||||
// ── 7. Runtime ──────────────────────────────────────────────
|
||||
let host = hostname::get()
|
||||
.map_or_else(|_| "unknown".into(), |h| h.to_string_lossy().to_string());
|
||||
let _ = writeln!(
|
||||
prompt,
|
||||
"## Runtime\n\nHost: {host} | OS: {} | Model: {model_name}\n",
|
||||
std::env::consts::OS,
|
||||
);
|
||||
|
||||
if prompt.is_empty() {
|
||||
"You are ZeroClaw, a fast and efficient AI assistant built in Rust. Be helpful, concise, and direct.".to_string()
|
||||
} else {
|
||||
prompt
|
||||
}
|
||||
}
|
||||
|
||||
/// Inject a single workspace file into the prompt with truncation and missing-file markers.
|
||||
fn inject_workspace_file(prompt: &mut String, workspace_dir: &std::path::Path, filename: &str) {
|
||||
use std::fmt::Write;
|
||||
|
||||
let path = workspace_dir.join(filename);
|
||||
match std::fs::read_to_string(&path) {
|
||||
Ok(content) => {
|
||||
let trimmed = content.trim();
|
||||
if trimmed.is_empty() {
|
||||
return;
|
||||
}
|
||||
let _ = writeln!(prompt, "### {filename}\n");
|
||||
if trimmed.len() > BOOTSTRAP_MAX_CHARS {
|
||||
prompt.push_str(&trimmed[..BOOTSTRAP_MAX_CHARS]);
|
||||
let _ = writeln!(
|
||||
prompt,
|
||||
"\n\n[... truncated at {BOOTSTRAP_MAX_CHARS} chars — use `read` for full file]\n"
|
||||
);
|
||||
} else {
|
||||
prompt.push_str(trimmed);
|
||||
prompt.push_str("\n\n");
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// Missing-file marker (matches OpenClaw behavior)
|
||||
let _ = writeln!(prompt, "### {filename}\n\n[File not found: {filename}]\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn handle_command(command: super::ChannelCommands, config: &Config) -> Result<()> {
|
||||
match command {
|
||||
super::ChannelCommands::Start => {
|
||||
// Handled in main.rs (needs async), this is unreachable
|
||||
unreachable!("Start is handled in main.rs")
|
||||
}
|
||||
super::ChannelCommands::List => {
|
||||
println!("Channels:");
|
||||
println!(" ✅ CLI (always available)");
|
||||
for (name, configured) in [
|
||||
("Telegram", config.channels_config.telegram.is_some()),
|
||||
("Discord", config.channels_config.discord.is_some()),
|
||||
("Slack", config.channels_config.slack.is_some()),
|
||||
("Webhook", config.channels_config.webhook.is_some()),
|
||||
("iMessage", config.channels_config.imessage.is_some()),
|
||||
("Matrix", config.channels_config.matrix.is_some()),
|
||||
] {
|
||||
println!(
|
||||
" {} {name}",
|
||||
if configured { "✅" } else { "❌" }
|
||||
);
|
||||
}
|
||||
println!("\nTo start channels: zeroclaw channel start");
|
||||
println!("To configure: zeroclaw onboard");
|
||||
Ok(())
|
||||
}
|
||||
super::ChannelCommands::Add {
|
||||
channel_type,
|
||||
config: _,
|
||||
} => {
|
||||
anyhow::bail!("Channel type '{channel_type}' — use `zeroclaw onboard` to configure channels");
|
||||
}
|
||||
super::ChannelCommands::Remove { name } => {
|
||||
anyhow::bail!("Remove channel '{name}' — edit ~/.zeroclaw/config.toml directly");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Start all configured channels and route messages to the agent
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn start_channels(config: Config) -> Result<()> {
|
||||
let provider: Arc<dyn Provider> = Arc::from(providers::create_provider(
|
||||
config.default_provider.as_deref().unwrap_or("openrouter"),
|
||||
config.api_key.as_deref(),
|
||||
)?);
|
||||
let model = config
|
||||
.default_model
|
||||
.clone()
|
||||
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into());
|
||||
let temperature = config.default_temperature;
|
||||
let mem: Arc<dyn Memory> =
|
||||
Arc::from(memory::create_memory(&config.memory, &config.workspace_dir)?);
|
||||
|
||||
// Build system prompt from workspace identity files + skills
|
||||
let workspace = config.workspace_dir.clone();
|
||||
let skills = crate::skills::load_skills(&workspace);
|
||||
|
||||
// Collect tool descriptions for the prompt
|
||||
let tool_descs: Vec<(&str, &str)> = vec![
|
||||
("shell", "Execute terminal commands"),
|
||||
("file_read", "Read file contents"),
|
||||
("file_write", "Write file contents"),
|
||||
("memory_store", "Save to memory"),
|
||||
("memory_recall", "Search memory"),
|
||||
("memory_forget", "Delete a memory entry"),
|
||||
];
|
||||
|
||||
let system_prompt = build_system_prompt(&workspace, &model, &tool_descs, &skills);
|
||||
|
||||
if !skills.is_empty() {
|
||||
println!(" 🧩 Skills: {}", skills.iter().map(|s| s.name.as_str()).collect::<Vec<_>>().join(", "));
|
||||
}
|
||||
|
||||
// Collect active channels
|
||||
let mut channels: Vec<Arc<dyn Channel>> = Vec::new();
|
||||
|
||||
if let Some(ref tg) = config.channels_config.telegram {
|
||||
channels.push(Arc::new(TelegramChannel::new(
|
||||
tg.bot_token.clone(),
|
||||
tg.allowed_users.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
if let Some(ref dc) = config.channels_config.discord {
|
||||
channels.push(Arc::new(DiscordChannel::new(
|
||||
dc.bot_token.clone(),
|
||||
dc.guild_id.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
if let Some(ref sl) = config.channels_config.slack {
|
||||
channels.push(Arc::new(SlackChannel::new(
|
||||
sl.bot_token.clone(),
|
||||
sl.channel_id.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
if let Some(ref im) = config.channels_config.imessage {
|
||||
channels.push(Arc::new(IMessageChannel::new(
|
||||
im.allowed_contacts.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
if let Some(ref mx) = config.channels_config.matrix {
|
||||
channels.push(Arc::new(MatrixChannel::new(
|
||||
mx.homeserver.clone(),
|
||||
mx.access_token.clone(),
|
||||
mx.room_id.clone(),
|
||||
mx.allowed_users.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
if channels.is_empty() {
|
||||
println!("No channels configured. Run `zeroclaw onboard` to set up channels.");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
println!("🦀 ZeroClaw Channel Server");
|
||||
println!(" 🤖 Model: {model}");
|
||||
println!(" 🧠 Memory: {} (auto-save: {})", config.memory.backend, if config.memory.auto_save { "on" } else { "off" });
|
||||
println!(" 📡 Channels: {}", channels.iter().map(|c| c.name()).collect::<Vec<_>>().join(", "));
|
||||
println!();
|
||||
println!(" Listening for messages... (Ctrl+C to stop)");
|
||||
println!();
|
||||
|
||||
// Single message bus — all channels send messages here
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(100);
|
||||
|
||||
// Spawn a listener for each channel
|
||||
let mut handles = Vec::new();
|
||||
for ch in &channels {
|
||||
let ch = ch.clone();
|
||||
let tx = tx.clone();
|
||||
handles.push(tokio::spawn(async move {
|
||||
if let Err(e) = ch.listen(tx).await {
|
||||
tracing::error!("Channel {} error: {e}", ch.name());
|
||||
}
|
||||
}));
|
||||
}
|
||||
drop(tx); // Drop our copy so rx closes when all channels stop
|
||||
|
||||
// Process incoming messages — call the LLM and reply
|
||||
while let Some(msg) = rx.recv().await {
|
||||
println!(
|
||||
" 💬 [{}] from {}: {}",
|
||||
msg.channel,
|
||||
msg.sender,
|
||||
if msg.content.len() > 80 {
|
||||
format!("{}...", &msg.content[..80])
|
||||
} else {
|
||||
msg.content.clone()
|
||||
}
|
||||
);
|
||||
|
||||
// Auto-save to memory
|
||||
if config.memory.auto_save {
|
||||
let _ = mem
|
||||
.store(
|
||||
&format!("{}_{}", msg.channel, msg.sender),
|
||||
&msg.content,
|
||||
crate::memory::MemoryCategory::Conversation,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Call the LLM with system prompt (identity + soul + tools)
|
||||
match provider.chat_with_system(Some(&system_prompt), &msg.content, &model, temperature).await {
|
||||
Ok(response) => {
|
||||
println!(
|
||||
" 🤖 Reply: {}",
|
||||
if response.len() > 80 {
|
||||
format!("{}...", &response[..80])
|
||||
} else {
|
||||
response.clone()
|
||||
}
|
||||
);
|
||||
// Find the channel that sent this message and reply
|
||||
for ch in &channels {
|
||||
if ch.name() == msg.channel {
|
||||
if let Err(e) = ch.send(&response, &msg.sender).await {
|
||||
eprintln!(" ❌ Failed to reply on {}: {e}", ch.name());
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!(" ❌ LLM error: {e}");
|
||||
for ch in &channels {
|
||||
if ch.name() == msg.channel {
|
||||
let _ = ch
|
||||
.send(&format!("⚠️ Error: {e}"), &msg.sender)
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for all channel tasks
|
||||
for h in handles {
|
||||
let _ = h.await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn make_workspace() -> TempDir {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
// Create minimal workspace files
|
||||
std::fs::write(tmp.path().join("SOUL.md"), "# Soul\nBe helpful.").unwrap();
|
||||
std::fs::write(tmp.path().join("IDENTITY.md"), "# Identity\nName: ZeroClaw").unwrap();
|
||||
std::fs::write(tmp.path().join("USER.md"), "# User\nName: Test User").unwrap();
|
||||
std::fs::write(tmp.path().join("AGENTS.md"), "# Agents\nFollow instructions.").unwrap();
|
||||
std::fs::write(tmp.path().join("TOOLS.md"), "# Tools\nUse shell carefully.").unwrap();
|
||||
std::fs::write(tmp.path().join("HEARTBEAT.md"), "# Heartbeat\nCheck status.").unwrap();
|
||||
std::fs::write(tmp.path().join("MEMORY.md"), "# Memory\nUser likes Rust.").unwrap();
|
||||
tmp
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompt_contains_all_sections() {
|
||||
let ws = make_workspace();
|
||||
let tools = vec![("shell", "Run commands"), ("file_read", "Read files")];
|
||||
let prompt = build_system_prompt(ws.path(), "test-model", &tools, &[]);
|
||||
|
||||
// Section headers
|
||||
assert!(prompt.contains("## Tools"), "missing Tools section");
|
||||
assert!(prompt.contains("## Safety"), "missing Safety section");
|
||||
assert!(prompt.contains("## Workspace"), "missing Workspace section");
|
||||
assert!(prompt.contains("## Project Context"), "missing Project Context");
|
||||
assert!(prompt.contains("## Current Date & Time"), "missing Date/Time");
|
||||
assert!(prompt.contains("## Runtime"), "missing Runtime section");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompt_injects_tools() {
|
||||
let ws = make_workspace();
|
||||
let tools = vec![("shell", "Run commands"), ("memory_recall", "Search memory")];
|
||||
let prompt = build_system_prompt(ws.path(), "gpt-4o", &tools, &[]);
|
||||
|
||||
assert!(prompt.contains("**shell**"));
|
||||
assert!(prompt.contains("Run commands"));
|
||||
assert!(prompt.contains("**memory_recall**"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompt_injects_safety() {
|
||||
let ws = make_workspace();
|
||||
let prompt = build_system_prompt(ws.path(), "model", &[], &[]);
|
||||
|
||||
assert!(prompt.contains("Do not exfiltrate private data"));
|
||||
assert!(prompt.contains("Do not run destructive commands"));
|
||||
assert!(prompt.contains("Prefer `trash` over `rm`"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompt_injects_workspace_files() {
|
||||
let ws = make_workspace();
|
||||
let prompt = build_system_prompt(ws.path(), "model", &[], &[]);
|
||||
|
||||
assert!(prompt.contains("### SOUL.md"), "missing SOUL.md header");
|
||||
assert!(prompt.contains("Be helpful"), "missing SOUL content");
|
||||
assert!(prompt.contains("### IDENTITY.md"), "missing IDENTITY.md");
|
||||
assert!(prompt.contains("Name: ZeroClaw"), "missing IDENTITY content");
|
||||
assert!(prompt.contains("### USER.md"), "missing USER.md");
|
||||
assert!(prompt.contains("### AGENTS.md"), "missing AGENTS.md");
|
||||
assert!(prompt.contains("### TOOLS.md"), "missing TOOLS.md");
|
||||
assert!(prompt.contains("### HEARTBEAT.md"), "missing HEARTBEAT.md");
|
||||
assert!(prompt.contains("### MEMORY.md"), "missing MEMORY.md");
|
||||
assert!(prompt.contains("User likes Rust"), "missing MEMORY content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompt_missing_file_markers() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
// Empty workspace — no files at all
|
||||
let prompt = build_system_prompt(tmp.path(), "model", &[], &[]);
|
||||
|
||||
assert!(prompt.contains("[File not found: SOUL.md]"));
|
||||
assert!(prompt.contains("[File not found: AGENTS.md]"));
|
||||
assert!(prompt.contains("[File not found: IDENTITY.md]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompt_bootstrap_only_if_exists() {
|
||||
let ws = make_workspace();
|
||||
// No BOOTSTRAP.md — should not appear
|
||||
let prompt = build_system_prompt(ws.path(), "model", &[], &[]);
|
||||
assert!(!prompt.contains("### BOOTSTRAP.md"), "BOOTSTRAP.md should not appear when missing");
|
||||
|
||||
// Create BOOTSTRAP.md — should appear
|
||||
std::fs::write(ws.path().join("BOOTSTRAP.md"), "# Bootstrap\nFirst run.").unwrap();
|
||||
let prompt2 = build_system_prompt(ws.path(), "model", &[], &[]);
|
||||
assert!(prompt2.contains("### BOOTSTRAP.md"), "BOOTSTRAP.md should appear when present");
|
||||
assert!(prompt2.contains("First run"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompt_no_daily_memory_injection() {
|
||||
let ws = make_workspace();
|
||||
let memory_dir = ws.path().join("memory");
|
||||
std::fs::create_dir_all(&memory_dir).unwrap();
|
||||
let today = chrono::Local::now().format("%Y-%m-%d").to_string();
|
||||
std::fs::write(memory_dir.join(format!("{today}.md")), "# Daily\nSome note.").unwrap();
|
||||
|
||||
let prompt = build_system_prompt(ws.path(), "model", &[], &[]);
|
||||
|
||||
// Daily notes should NOT be in the system prompt (on-demand via tools)
|
||||
assert!(!prompt.contains("Daily Notes"), "daily notes should not be auto-injected");
|
||||
assert!(!prompt.contains("Some note"), "daily content should not be in prompt");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompt_runtime_metadata() {
|
||||
let ws = make_workspace();
|
||||
let prompt = build_system_prompt(ws.path(), "claude-sonnet-4", &[], &[]);
|
||||
|
||||
assert!(prompt.contains("Model: claude-sonnet-4"));
|
||||
assert!(prompt.contains(&format!("OS: {}", std::env::consts::OS)));
|
||||
assert!(prompt.contains("Host:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompt_skills_compact_list() {
|
||||
let ws = make_workspace();
|
||||
let skills = vec![crate::skills::Skill {
|
||||
name: "code-review".into(),
|
||||
description: "Review code for bugs".into(),
|
||||
version: "1.0.0".into(),
|
||||
author: None,
|
||||
tags: vec![],
|
||||
tools: vec![],
|
||||
prompts: vec!["Long prompt content that should NOT appear in system prompt".into()],
|
||||
}];
|
||||
|
||||
let prompt = build_system_prompt(ws.path(), "model", &[], &skills);
|
||||
|
||||
assert!(prompt.contains("<available_skills>"), "missing skills XML");
|
||||
assert!(prompt.contains("<name>code-review</name>"));
|
||||
assert!(prompt.contains("<description>Review code for bugs</description>"));
|
||||
assert!(prompt.contains("SKILL.md</location>"));
|
||||
assert!(prompt.contains("loaded on demand"), "should mention on-demand loading");
|
||||
// Full prompt content should NOT be dumped
|
||||
assert!(!prompt.contains("Long prompt content that should NOT appear"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompt_truncation() {
|
||||
let ws = make_workspace();
|
||||
// Write a file larger than BOOTSTRAP_MAX_CHARS
|
||||
let big_content = "x".repeat(BOOTSTRAP_MAX_CHARS + 1000);
|
||||
std::fs::write(ws.path().join("AGENTS.md"), &big_content).unwrap();
|
||||
|
||||
let prompt = build_system_prompt(ws.path(), "model", &[], &[]);
|
||||
|
||||
assert!(prompt.contains("truncated at"), "large files should be truncated");
|
||||
assert!(!prompt.contains(&big_content), "full content should not appear");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompt_empty_files_skipped() {
|
||||
let ws = make_workspace();
|
||||
std::fs::write(ws.path().join("TOOLS.md"), "").unwrap();
|
||||
|
||||
let prompt = build_system_prompt(ws.path(), "model", &[], &[]);
|
||||
|
||||
// Empty file should not produce a header
|
||||
assert!(!prompt.contains("### TOOLS.md"), "empty files should be skipped");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompt_workspace_path() {
|
||||
let ws = make_workspace();
|
||||
let prompt = build_system_prompt(ws.path(), "model", &[], &[]);
|
||||
|
||||
assert!(prompt.contains(&format!("Working directory: `{}`", ws.path().display())));
|
||||
}
|
||||
}
|
||||
174
src/channels/slack.rs
Normal file
174
src/channels/slack.rs
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
use super::traits::{Channel, ChannelMessage};
|
||||
use async_trait::async_trait;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Slack channel — polls conversations.history via Web API
|
||||
pub struct SlackChannel {
|
||||
bot_token: String,
|
||||
channel_id: Option<String>,
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl SlackChannel {
|
||||
pub fn new(bot_token: String, channel_id: Option<String>) -> Self {
|
||||
Self {
|
||||
bot_token,
|
||||
channel_id,
|
||||
client: reqwest::Client::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the bot's own user ID so we can ignore our own messages
|
||||
async fn get_bot_user_id(&self) -> Option<String> {
|
||||
let resp: serde_json::Value = self
|
||||
.client
|
||||
.get("https://slack.com/api/auth.test")
|
||||
.bearer_auth(&self.bot_token)
|
||||
.send()
|
||||
.await
|
||||
.ok()?
|
||||
.json()
|
||||
.await
|
||||
.ok()?;
|
||||
|
||||
resp.get("user_id")
|
||||
.and_then(|u| u.as_str())
|
||||
.map(String::from)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Channel for SlackChannel {
|
||||
fn name(&self) -> &str {
|
||||
"slack"
|
||||
}
|
||||
|
||||
async fn send(&self, message: &str, channel: &str) -> anyhow::Result<()> {
|
||||
let body = serde_json::json!({
|
||||
"channel": channel,
|
||||
"text": message
|
||||
});
|
||||
|
||||
self.client
|
||||
.post("https://slack.com/api/chat.postMessage")
|
||||
.bearer_auth(&self.bot_token)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
|
||||
let channel_id = self
|
||||
.channel_id
|
||||
.clone()
|
||||
.ok_or_else(|| anyhow::anyhow!("Slack channel_id required for listening"))?;
|
||||
|
||||
let bot_user_id = self.get_bot_user_id().await.unwrap_or_default();
|
||||
let mut last_ts = String::new();
|
||||
|
||||
tracing::info!("Slack channel listening on #{channel_id}...");
|
||||
|
||||
loop {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
|
||||
|
||||
let mut params = vec![
|
||||
("channel", channel_id.clone()),
|
||||
("limit", "10".to_string()),
|
||||
];
|
||||
if !last_ts.is_empty() {
|
||||
params.push(("oldest", last_ts.clone()));
|
||||
}
|
||||
|
||||
let resp = match self
|
||||
.client
|
||||
.get("https://slack.com/api/conversations.history")
|
||||
.bearer_auth(&self.bot_token)
|
||||
.query(¶ms)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::warn!("Slack poll error: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let data: serde_json::Value = match resp.json().await {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
tracing::warn!("Slack parse error: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(messages) = data.get("messages").and_then(|m| m.as_array()) {
|
||||
// Messages come newest-first, reverse to process oldest first
|
||||
for msg in messages.iter().rev() {
|
||||
let ts = msg.get("ts").and_then(|t| t.as_str()).unwrap_or("");
|
||||
let user = msg
|
||||
.get("user")
|
||||
.and_then(|u| u.as_str())
|
||||
.unwrap_or("unknown");
|
||||
let text = msg.get("text").and_then(|t| t.as_str()).unwrap_or("");
|
||||
|
||||
// Skip bot's own messages
|
||||
if user == bot_user_id {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip empty or already-seen
|
||||
if text.is_empty() || ts <= last_ts.as_str() {
|
||||
continue;
|
||||
}
|
||||
|
||||
last_ts = ts.to_string();
|
||||
|
||||
let channel_msg = ChannelMessage {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
sender: channel_id.clone(),
|
||||
content: text.to_string(),
|
||||
channel: "slack".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
};
|
||||
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
self.client
|
||||
.get("https://slack.com/api/auth.test")
|
||||
.bearer_auth(&self.bot_token)
|
||||
.send()
|
||||
.await
|
||||
.map(|r| r.status().is_success())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn slack_channel_name() {
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None);
|
||||
assert_eq!(ch.name(), "slack");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slack_channel_with_channel_id() {
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), Some("C12345".into()));
|
||||
assert_eq!(ch.channel_id, Some("C12345".to_string()));
|
||||
}
|
||||
}
|
||||
182
src/channels/telegram.rs
Normal file
182
src/channels/telegram.rs
Normal file
|
|
@ -0,0 +1,182 @@
|
|||
use super::traits::{Channel, ChannelMessage};
|
||||
use async_trait::async_trait;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Telegram channel — long-polls the Bot API for updates
|
||||
pub struct TelegramChannel {
|
||||
bot_token: String,
|
||||
allowed_users: Vec<String>,
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl TelegramChannel {
|
||||
pub fn new(bot_token: String, allowed_users: Vec<String>) -> Self {
|
||||
Self {
|
||||
bot_token,
|
||||
allowed_users,
|
||||
client: reqwest::Client::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn api_url(&self, method: &str) -> String {
|
||||
format!("https://api.telegram.org/bot{}/{method}", self.bot_token)
|
||||
}
|
||||
|
||||
fn is_user_allowed(&self, username: &str) -> bool {
|
||||
self.allowed_users.iter().any(|u| u == "*" || u == username)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Channel for TelegramChannel {
|
||||
fn name(&self) -> &str {
|
||||
"telegram"
|
||||
}
|
||||
|
||||
async fn send(&self, message: &str, chat_id: &str) -> anyhow::Result<()> {
|
||||
let body = serde_json::json!({
|
||||
"chat_id": chat_id,
|
||||
"text": message,
|
||||
"parse_mode": "Markdown"
|
||||
});
|
||||
|
||||
self.client
|
||||
.post(self.api_url("sendMessage"))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
|
||||
let mut offset: i64 = 0;
|
||||
|
||||
tracing::info!("Telegram channel listening for messages...");
|
||||
|
||||
loop {
|
||||
let url = self.api_url("getUpdates");
|
||||
let body = serde_json::json!({
|
||||
"offset": offset,
|
||||
"timeout": 30,
|
||||
"allowed_updates": ["message"]
|
||||
});
|
||||
|
||||
let resp = match self.client.post(&url).json(&body).send().await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::warn!("Telegram poll error: {e}");
|
||||
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let data: serde_json::Value = match resp.json().await {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
tracing::warn!("Telegram parse error: {e}");
|
||||
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(results) = data.get("result").and_then(serde_json::Value::as_array) {
|
||||
for update in results {
|
||||
// Advance offset past this update
|
||||
if let Some(uid) = update.get("update_id").and_then(serde_json::Value::as_i64) {
|
||||
offset = uid + 1;
|
||||
}
|
||||
|
||||
let Some(message) = update.get("message") else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let Some(text) = message.get("text").and_then(serde_json::Value::as_str) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let username = message
|
||||
.get("from")
|
||||
.and_then(|f| f.get("username"))
|
||||
.and_then(|u| u.as_str())
|
||||
.unwrap_or("unknown");
|
||||
|
||||
if !self.is_user_allowed(username) {
|
||||
tracing::warn!("Telegram: ignoring message from unauthorized user: {username}");
|
||||
continue;
|
||||
}
|
||||
|
||||
let chat_id = message
|
||||
.get("chat")
|
||||
.and_then(|c| c.get("id"))
|
||||
.and_then(serde_json::Value::as_i64)
|
||||
.map(|id| id.to_string())
|
||||
.unwrap_or_default();
|
||||
|
||||
let msg = ChannelMessage {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
sender: chat_id,
|
||||
content: text.to_string(),
|
||||
channel: "telegram".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
};
|
||||
|
||||
if tx.send(msg).await.is_err() {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
self.client
|
||||
.get(self.api_url("getMe"))
|
||||
.send()
|
||||
.await
|
||||
.map(|r| r.status().is_success())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn telegram_channel_name() {
|
||||
let ch = TelegramChannel::new("fake-token".into(), vec!["*".into()]);
|
||||
assert_eq!(ch.name(), "telegram");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn telegram_api_url() {
|
||||
let ch = TelegramChannel::new("123:ABC".into(), vec![]);
|
||||
assert_eq!(
|
||||
ch.api_url("getMe"),
|
||||
"https://api.telegram.org/bot123:ABC/getMe"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn telegram_user_allowed_wildcard() {
|
||||
let ch = TelegramChannel::new("t".into(), vec!["*".into()]);
|
||||
assert!(ch.is_user_allowed("anyone"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn telegram_user_allowed_specific() {
|
||||
let ch = TelegramChannel::new("t".into(), vec!["alice".into(), "bob".into()]);
|
||||
assert!(ch.is_user_allowed("alice"));
|
||||
assert!(!ch.is_user_allowed("eve"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn telegram_user_denied_empty() {
|
||||
let ch = TelegramChannel::new("t".into(), vec![]);
|
||||
assert!(!ch.is_user_allowed("anyone"));
|
||||
}
|
||||
}
|
||||
29
src/channels/traits.rs
Normal file
29
src/channels/traits.rs
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
use async_trait::async_trait;
|
||||
|
||||
/// A message received from or sent to a channel
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ChannelMessage {
|
||||
pub id: String,
|
||||
pub sender: String,
|
||||
pub content: String,
|
||||
pub channel: String,
|
||||
pub timestamp: u64,
|
||||
}
|
||||
|
||||
/// Core channel trait — implement for any messaging platform
|
||||
#[async_trait]
|
||||
pub trait Channel: Send + Sync {
|
||||
/// Human-readable channel name
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Send a message through this channel
|
||||
async fn send(&self, message: &str, recipient: &str) -> anyhow::Result<()>;
|
||||
|
||||
/// Start listening for incoming messages (long-running)
|
||||
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()>;
|
||||
|
||||
/// Check if channel is healthy
|
||||
async fn health_check(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
7
src/config/mod.rs
Normal file
7
src/config/mod.rs
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
pub mod schema;
|
||||
|
||||
pub use schema::{
|
||||
AutonomyConfig, ChannelsConfig, Config, DiscordConfig, HeartbeatConfig, IMessageConfig,
|
||||
MatrixConfig, MemoryConfig, ObservabilityConfig, RuntimeConfig, SlackConfig, TelegramConfig,
|
||||
WebhookConfig,
|
||||
};
|
||||
580
src/config/schema.rs
Normal file
580
src/config/schema.rs
Normal file
|
|
@ -0,0 +1,580 @@
|
|||
use crate::security::AutonomyLevel;
|
||||
use anyhow::{Context, Result};
|
||||
use directories::UserDirs;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
|
||||
// ── Top-level config ──────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Config {
|
||||
pub workspace_dir: PathBuf,
|
||||
pub config_path: PathBuf,
|
||||
pub api_key: Option<String>,
|
||||
pub default_provider: Option<String>,
|
||||
pub default_model: Option<String>,
|
||||
pub default_temperature: f64,
|
||||
|
||||
#[serde(default)]
|
||||
pub observability: ObservabilityConfig,
|
||||
|
||||
#[serde(default)]
|
||||
pub autonomy: AutonomyConfig,
|
||||
|
||||
#[serde(default)]
|
||||
pub runtime: RuntimeConfig,
|
||||
|
||||
#[serde(default)]
|
||||
pub heartbeat: HeartbeatConfig,
|
||||
|
||||
#[serde(default)]
|
||||
pub channels_config: ChannelsConfig,
|
||||
|
||||
#[serde(default)]
|
||||
pub memory: MemoryConfig,
|
||||
}
|
||||
|
||||
// ── Memory ───────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MemoryConfig {
|
||||
/// "sqlite" | "markdown" | "none"
|
||||
pub backend: String,
|
||||
/// Auto-save conversation context to memory
|
||||
pub auto_save: bool,
|
||||
}
|
||||
|
||||
impl Default for MemoryConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
backend: "sqlite".into(),
|
||||
auto_save: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Observability ─────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ObservabilityConfig {
|
||||
/// "none" | "log" | "prometheus" | "otel"
|
||||
pub backend: String,
|
||||
}
|
||||
|
||||
impl Default for ObservabilityConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
backend: "none".into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Autonomy / Security ──────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AutonomyConfig {
|
||||
pub level: AutonomyLevel,
|
||||
pub workspace_only: bool,
|
||||
pub allowed_commands: Vec<String>,
|
||||
pub forbidden_paths: Vec<String>,
|
||||
pub max_actions_per_hour: u32,
|
||||
pub max_cost_per_day_cents: u32,
|
||||
}
|
||||
|
||||
impl Default for AutonomyConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
level: AutonomyLevel::Supervised,
|
||||
workspace_only: true,
|
||||
allowed_commands: vec![
|
||||
"git".into(),
|
||||
"npm".into(),
|
||||
"cargo".into(),
|
||||
"ls".into(),
|
||||
"cat".into(),
|
||||
"grep".into(),
|
||||
"find".into(),
|
||||
"echo".into(),
|
||||
"pwd".into(),
|
||||
"wc".into(),
|
||||
"head".into(),
|
||||
"tail".into(),
|
||||
],
|
||||
forbidden_paths: vec![
|
||||
"/etc".into(),
|
||||
"/root".into(),
|
||||
"~/.ssh".into(),
|
||||
"~/.gnupg".into(),
|
||||
],
|
||||
max_actions_per_hour: 20,
|
||||
max_cost_per_day_cents: 500,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Runtime ──────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RuntimeConfig {
|
||||
/// "native" | "docker" | "cloudflare"
|
||||
pub kind: String,
|
||||
}
|
||||
|
||||
impl Default for RuntimeConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
kind: "native".into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Heartbeat ────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HeartbeatConfig {
|
||||
pub enabled: bool,
|
||||
pub interval_minutes: u32,
|
||||
}
|
||||
|
||||
impl Default for HeartbeatConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
interval_minutes: 30,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Channels ─────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChannelsConfig {
|
||||
pub cli: bool,
|
||||
pub telegram: Option<TelegramConfig>,
|
||||
pub discord: Option<DiscordConfig>,
|
||||
pub slack: Option<SlackConfig>,
|
||||
pub webhook: Option<WebhookConfig>,
|
||||
pub imessage: Option<IMessageConfig>,
|
||||
pub matrix: Option<MatrixConfig>,
|
||||
}
|
||||
|
||||
impl Default for ChannelsConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
cli: true,
|
||||
telegram: None,
|
||||
discord: None,
|
||||
slack: None,
|
||||
webhook: None,
|
||||
imessage: None,
|
||||
matrix: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TelegramConfig {
|
||||
pub bot_token: String,
|
||||
pub allowed_users: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DiscordConfig {
|
||||
pub bot_token: String,
|
||||
pub guild_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SlackConfig {
|
||||
pub bot_token: String,
|
||||
pub app_token: Option<String>,
|
||||
pub channel_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WebhookConfig {
|
||||
pub port: u16,
|
||||
pub secret: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct IMessageConfig {
|
||||
pub allowed_contacts: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MatrixConfig {
|
||||
pub homeserver: String,
|
||||
pub access_token: String,
|
||||
pub room_id: String,
|
||||
pub allowed_users: Vec<String>,
|
||||
}
|
||||
|
||||
// ── Config impl ──────────────────────────────────────────────────
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
let home =
|
||||
UserDirs::new().map_or_else(|| PathBuf::from("."), |u| u.home_dir().to_path_buf());
|
||||
let zeroclaw_dir = home.join(".zeroclaw");
|
||||
|
||||
Self {
|
||||
workspace_dir: zeroclaw_dir.join("workspace"),
|
||||
config_path: zeroclaw_dir.join("config.toml"),
|
||||
api_key: None,
|
||||
default_provider: Some("openrouter".to_string()),
|
||||
default_model: Some("anthropic/claude-sonnet-4-20250514".to_string()),
|
||||
default_temperature: 0.7,
|
||||
observability: ObservabilityConfig::default(),
|
||||
autonomy: AutonomyConfig::default(),
|
||||
runtime: RuntimeConfig::default(),
|
||||
heartbeat: HeartbeatConfig::default(),
|
||||
channels_config: ChannelsConfig::default(),
|
||||
memory: MemoryConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn load_or_init() -> Result<Self> {
|
||||
let home = UserDirs::new()
|
||||
.map(|u| u.home_dir().to_path_buf())
|
||||
.context("Could not find home directory")?;
|
||||
let zeroclaw_dir = home.join(".zeroclaw");
|
||||
let config_path = zeroclaw_dir.join("config.toml");
|
||||
|
||||
if !zeroclaw_dir.exists() {
|
||||
fs::create_dir_all(&zeroclaw_dir).context("Failed to create .zeroclaw directory")?;
|
||||
fs::create_dir_all(zeroclaw_dir.join("workspace"))
|
||||
.context("Failed to create workspace directory")?;
|
||||
}
|
||||
|
||||
if config_path.exists() {
|
||||
let contents =
|
||||
fs::read_to_string(&config_path).context("Failed to read config file")?;
|
||||
let config: Config =
|
||||
toml::from_str(&contents).context("Failed to parse config file")?;
|
||||
Ok(config)
|
||||
} else {
|
||||
let config = Config::default();
|
||||
config.save()?;
|
||||
Ok(config)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn save(&self) -> Result<()> {
|
||||
let toml_str = toml::to_string_pretty(self).context("Failed to serialize config")?;
|
||||
fs::write(&self.config_path, toml_str).context("Failed to write config file")?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::path::PathBuf;
|
||||
|
||||
// ── Defaults ─────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn config_default_has_sane_values() {
|
||||
let c = Config::default();
|
||||
assert_eq!(c.default_provider.as_deref(), Some("openrouter"));
|
||||
assert!(c.default_model.as_deref().unwrap().contains("claude"));
|
||||
assert!((c.default_temperature - 0.7).abs() < f64::EPSILON);
|
||||
assert!(c.api_key.is_none());
|
||||
assert!(c.workspace_dir.to_string_lossy().contains("workspace"));
|
||||
assert!(c.config_path.to_string_lossy().contains("config.toml"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn observability_config_default() {
|
||||
let o = ObservabilityConfig::default();
|
||||
assert_eq!(o.backend, "none");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn autonomy_config_default() {
|
||||
let a = AutonomyConfig::default();
|
||||
assert_eq!(a.level, AutonomyLevel::Supervised);
|
||||
assert!(a.workspace_only);
|
||||
assert!(a.allowed_commands.contains(&"git".to_string()));
|
||||
assert!(a.allowed_commands.contains(&"cargo".to_string()));
|
||||
assert!(a.forbidden_paths.contains(&"/etc".to_string()));
|
||||
assert_eq!(a.max_actions_per_hour, 20);
|
||||
assert_eq!(a.max_cost_per_day_cents, 500);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn runtime_config_default() {
|
||||
let r = RuntimeConfig::default();
|
||||
assert_eq!(r.kind, "native");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn heartbeat_config_default() {
|
||||
let h = HeartbeatConfig::default();
|
||||
assert!(!h.enabled);
|
||||
assert_eq!(h.interval_minutes, 30);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn channels_config_default() {
|
||||
let c = ChannelsConfig::default();
|
||||
assert!(c.cli);
|
||||
assert!(c.telegram.is_none());
|
||||
assert!(c.discord.is_none());
|
||||
}
|
||||
|
||||
// ── Serde round-trip ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn config_toml_roundtrip() {
|
||||
let config = Config {
|
||||
workspace_dir: PathBuf::from("/tmp/test/workspace"),
|
||||
config_path: PathBuf::from("/tmp/test/config.toml"),
|
||||
api_key: Some("sk-test-key".into()),
|
||||
default_provider: Some("openrouter".into()),
|
||||
default_model: Some("gpt-4o".into()),
|
||||
default_temperature: 0.5,
|
||||
observability: ObservabilityConfig {
|
||||
backend: "log".into(),
|
||||
},
|
||||
autonomy: AutonomyConfig {
|
||||
level: AutonomyLevel::Full,
|
||||
workspace_only: false,
|
||||
allowed_commands: vec!["docker".into()],
|
||||
forbidden_paths: vec!["/secret".into()],
|
||||
max_actions_per_hour: 50,
|
||||
max_cost_per_day_cents: 1000,
|
||||
},
|
||||
runtime: RuntimeConfig {
|
||||
kind: "docker".into(),
|
||||
},
|
||||
heartbeat: HeartbeatConfig {
|
||||
enabled: true,
|
||||
interval_minutes: 15,
|
||||
},
|
||||
channels_config: ChannelsConfig {
|
||||
cli: true,
|
||||
telegram: Some(TelegramConfig {
|
||||
bot_token: "123:ABC".into(),
|
||||
allowed_users: vec!["user1".into()],
|
||||
}),
|
||||
discord: None,
|
||||
slack: None,
|
||||
webhook: None,
|
||||
imessage: None,
|
||||
matrix: None,
|
||||
},
|
||||
memory: MemoryConfig::default(),
|
||||
};
|
||||
|
||||
let toml_str = toml::to_string_pretty(&config).unwrap();
|
||||
let parsed: Config = toml::from_str(&toml_str).unwrap();
|
||||
|
||||
assert_eq!(parsed.api_key, config.api_key);
|
||||
assert_eq!(parsed.default_provider, config.default_provider);
|
||||
assert_eq!(parsed.default_model, config.default_model);
|
||||
assert!((parsed.default_temperature - config.default_temperature).abs() < f64::EPSILON);
|
||||
assert_eq!(parsed.observability.backend, "log");
|
||||
assert_eq!(parsed.autonomy.level, AutonomyLevel::Full);
|
||||
assert!(!parsed.autonomy.workspace_only);
|
||||
assert_eq!(parsed.runtime.kind, "docker");
|
||||
assert!(parsed.heartbeat.enabled);
|
||||
assert_eq!(parsed.heartbeat.interval_minutes, 15);
|
||||
assert!(parsed.channels_config.telegram.is_some());
|
||||
assert_eq!(
|
||||
parsed.channels_config.telegram.unwrap().bot_token,
|
||||
"123:ABC"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_minimal_toml_uses_defaults() {
|
||||
let minimal = r#"
|
||||
workspace_dir = "/tmp/ws"
|
||||
config_path = "/tmp/config.toml"
|
||||
default_temperature = 0.7
|
||||
"#;
|
||||
let parsed: Config = toml::from_str(minimal).unwrap();
|
||||
assert!(parsed.api_key.is_none());
|
||||
assert!(parsed.default_provider.is_none());
|
||||
assert_eq!(parsed.observability.backend, "none");
|
||||
assert_eq!(parsed.autonomy.level, AutonomyLevel::Supervised);
|
||||
assert_eq!(parsed.runtime.kind, "native");
|
||||
assert!(!parsed.heartbeat.enabled);
|
||||
assert!(parsed.channels_config.cli);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_save_and_load_tmpdir() {
|
||||
let dir = std::env::temp_dir().join("zeroclaw_test_config");
|
||||
let _ = fs::remove_dir_all(&dir);
|
||||
fs::create_dir_all(&dir).unwrap();
|
||||
|
||||
let config_path = dir.join("config.toml");
|
||||
let config = Config {
|
||||
workspace_dir: dir.join("workspace"),
|
||||
config_path: config_path.clone(),
|
||||
api_key: Some("sk-roundtrip".into()),
|
||||
default_provider: Some("openrouter".into()),
|
||||
default_model: Some("test-model".into()),
|
||||
default_temperature: 0.9,
|
||||
observability: ObservabilityConfig::default(),
|
||||
autonomy: AutonomyConfig::default(),
|
||||
runtime: RuntimeConfig::default(),
|
||||
heartbeat: HeartbeatConfig::default(),
|
||||
channels_config: ChannelsConfig::default(),
|
||||
memory: MemoryConfig::default(),
|
||||
};
|
||||
|
||||
config.save().unwrap();
|
||||
assert!(config_path.exists());
|
||||
|
||||
let contents = fs::read_to_string(&config_path).unwrap();
|
||||
let loaded: Config = toml::from_str(&contents).unwrap();
|
||||
assert_eq!(loaded.api_key.as_deref(), Some("sk-roundtrip"));
|
||||
assert_eq!(loaded.default_model.as_deref(), Some("test-model"));
|
||||
assert!((loaded.default_temperature - 0.9).abs() < f64::EPSILON);
|
||||
|
||||
let _ = fs::remove_dir_all(&dir);
|
||||
}
|
||||
|
||||
// ── Telegram / Discord config ────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn telegram_config_serde() {
|
||||
let tc = TelegramConfig {
|
||||
bot_token: "123:XYZ".into(),
|
||||
allowed_users: vec!["alice".into(), "bob".into()],
|
||||
};
|
||||
let json = serde_json::to_string(&tc).unwrap();
|
||||
let parsed: TelegramConfig = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.bot_token, "123:XYZ");
|
||||
assert_eq!(parsed.allowed_users.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn discord_config_serde() {
|
||||
let dc = DiscordConfig {
|
||||
bot_token: "discord-token".into(),
|
||||
guild_id: Some("12345".into()),
|
||||
};
|
||||
let json = serde_json::to_string(&dc).unwrap();
|
||||
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.bot_token, "discord-token");
|
||||
assert_eq!(parsed.guild_id.as_deref(), Some("12345"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn discord_config_optional_guild() {
|
||||
let dc = DiscordConfig {
|
||||
bot_token: "tok".into(),
|
||||
guild_id: None,
|
||||
};
|
||||
let json = serde_json::to_string(&dc).unwrap();
|
||||
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
|
||||
assert!(parsed.guild_id.is_none());
|
||||
}
|
||||
|
||||
// ── iMessage / Matrix config ────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn imessage_config_serde() {
|
||||
let ic = IMessageConfig {
|
||||
allowed_contacts: vec!["+1234567890".into(), "user@icloud.com".into()],
|
||||
};
|
||||
let json = serde_json::to_string(&ic).unwrap();
|
||||
let parsed: IMessageConfig = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.allowed_contacts.len(), 2);
|
||||
assert_eq!(parsed.allowed_contacts[0], "+1234567890");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn imessage_config_empty_contacts() {
|
||||
let ic = IMessageConfig {
|
||||
allowed_contacts: vec![],
|
||||
};
|
||||
let json = serde_json::to_string(&ic).unwrap();
|
||||
let parsed: IMessageConfig = serde_json::from_str(&json).unwrap();
|
||||
assert!(parsed.allowed_contacts.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn imessage_config_wildcard() {
|
||||
let ic = IMessageConfig {
|
||||
allowed_contacts: vec!["*".into()],
|
||||
};
|
||||
let toml_str = toml::to_string(&ic).unwrap();
|
||||
let parsed: IMessageConfig = toml::from_str(&toml_str).unwrap();
|
||||
assert_eq!(parsed.allowed_contacts, vec!["*"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn matrix_config_serde() {
|
||||
let mc = MatrixConfig {
|
||||
homeserver: "https://matrix.org".into(),
|
||||
access_token: "syt_token_abc".into(),
|
||||
room_id: "!room123:matrix.org".into(),
|
||||
allowed_users: vec!["@user:matrix.org".into()],
|
||||
};
|
||||
let json = serde_json::to_string(&mc).unwrap();
|
||||
let parsed: MatrixConfig = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.homeserver, "https://matrix.org");
|
||||
assert_eq!(parsed.access_token, "syt_token_abc");
|
||||
assert_eq!(parsed.room_id, "!room123:matrix.org");
|
||||
assert_eq!(parsed.allowed_users.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn matrix_config_toml_roundtrip() {
|
||||
let mc = MatrixConfig {
|
||||
homeserver: "https://synapse.local:8448".into(),
|
||||
access_token: "tok".into(),
|
||||
room_id: "!abc:synapse.local".into(),
|
||||
allowed_users: vec!["@admin:synapse.local".into(), "*".into()],
|
||||
};
|
||||
let toml_str = toml::to_string(&mc).unwrap();
|
||||
let parsed: MatrixConfig = toml::from_str(&toml_str).unwrap();
|
||||
assert_eq!(parsed.homeserver, "https://synapse.local:8448");
|
||||
assert_eq!(parsed.allowed_users.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn channels_config_with_imessage_and_matrix() {
|
||||
let c = ChannelsConfig {
|
||||
cli: true,
|
||||
telegram: None,
|
||||
discord: None,
|
||||
slack: None,
|
||||
webhook: None,
|
||||
imessage: Some(IMessageConfig {
|
||||
allowed_contacts: vec!["+1".into()],
|
||||
}),
|
||||
matrix: Some(MatrixConfig {
|
||||
homeserver: "https://m.org".into(),
|
||||
access_token: "tok".into(),
|
||||
room_id: "!r:m".into(),
|
||||
allowed_users: vec!["@u:m".into()],
|
||||
}),
|
||||
};
|
||||
let toml_str = toml::to_string_pretty(&c).unwrap();
|
||||
let parsed: ChannelsConfig = toml::from_str(&toml_str).unwrap();
|
||||
assert!(parsed.imessage.is_some());
|
||||
assert!(parsed.matrix.is_some());
|
||||
assert_eq!(
|
||||
parsed.imessage.unwrap().allowed_contacts,
|
||||
vec!["+1"]
|
||||
);
|
||||
assert_eq!(parsed.matrix.unwrap().homeserver, "https://m.org");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn channels_config_default_has_no_imessage_matrix() {
|
||||
let c = ChannelsConfig::default();
|
||||
assert!(c.imessage.is_none());
|
||||
assert!(c.matrix.is_none());
|
||||
}
|
||||
}
|
||||
25
src/cron/mod.rs
Normal file
25
src/cron/mod.rs
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
use crate::config::Config;
|
||||
use anyhow::Result;
|
||||
|
||||
pub fn handle_command(command: super::CronCommands, _config: Config) -> Result<()> {
|
||||
match command {
|
||||
super::CronCommands::List => {
|
||||
println!("No scheduled tasks yet.");
|
||||
println!("\nUsage:");
|
||||
println!(" zeroclaw cron add '0 9 * * *' 'agent -m \"Good morning!\"'");
|
||||
Ok(())
|
||||
}
|
||||
super::CronCommands::Add {
|
||||
expression,
|
||||
command,
|
||||
} => {
|
||||
println!("Cron scheduling coming soon!");
|
||||
println!(" Expression: {expression}");
|
||||
println!(" Command: {command}");
|
||||
Ok(())
|
||||
}
|
||||
super::CronCommands::Remove { id } => {
|
||||
anyhow::bail!("Remove task '{id}' not yet implemented");
|
||||
}
|
||||
}
|
||||
}
|
||||
180
src/gateway/mod.rs
Normal file
180
src/gateway/mod.rs
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
use crate::config::Config;
|
||||
use crate::memory::{self, Memory, MemoryCategory};
|
||||
use crate::providers::{self, Provider};
|
||||
use anyhow::Result;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
/// Run a minimal HTTP gateway (webhook + health check)
|
||||
/// Zero new dependencies — uses raw TCP + tokio.
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
let addr = format!("{host}:{port}");
|
||||
let listener = TcpListener::bind(&addr).await?;
|
||||
|
||||
let provider: Arc<dyn Provider> = Arc::from(providers::create_provider(
|
||||
config.default_provider.as_deref().unwrap_or("openrouter"),
|
||||
config.api_key.as_deref(),
|
||||
)?);
|
||||
let model = config
|
||||
.default_model
|
||||
.clone()
|
||||
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into());
|
||||
let temperature = config.default_temperature;
|
||||
let mem: Arc<dyn Memory> =
|
||||
Arc::from(memory::create_memory(&config.memory, &config.workspace_dir)?);
|
||||
|
||||
println!("🦀 ZeroClaw Gateway listening on http://{addr}");
|
||||
println!(" POST /webhook — {{\"message\": \"your prompt\"}}");
|
||||
println!(" GET /health — health check");
|
||||
println!(" Press Ctrl+C to stop.\n");
|
||||
|
||||
loop {
|
||||
let (mut stream, peer) = listener.accept().await?;
|
||||
let provider = provider.clone();
|
||||
let model = model.clone();
|
||||
let mem = mem.clone();
|
||||
let auto_save = config.memory.auto_save;
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 8192];
|
||||
let n = match stream.read(&mut buf).await {
|
||||
Ok(n) if n > 0 => n,
|
||||
_ => return,
|
||||
};
|
||||
|
||||
let request = String::from_utf8_lossy(&buf[..n]);
|
||||
let first_line = request.lines().next().unwrap_or("");
|
||||
let parts: Vec<&str> = first_line.split_whitespace().collect();
|
||||
|
||||
if let [method, path, ..] = parts.as_slice() {
|
||||
tracing::info!("{peer} → {method} {path}");
|
||||
handle_request(&mut stream, method, path, &request, &provider, &model, temperature, &mem, auto_save).await;
|
||||
} else {
|
||||
let _ = send_response(&mut stream, 400, "Bad Request").await;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn handle_request(
|
||||
stream: &mut tokio::net::TcpStream,
|
||||
method: &str,
|
||||
path: &str,
|
||||
request: &str,
|
||||
provider: &Arc<dyn Provider>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
mem: &Arc<dyn Memory>,
|
||||
auto_save: bool,
|
||||
) {
|
||||
match (method, path) {
|
||||
("GET", "/health") => {
|
||||
let body = serde_json::json!({
|
||||
"status": "ok",
|
||||
"version": env!("CARGO_PKG_VERSION"),
|
||||
"memory": mem.name(),
|
||||
"memory_healthy": mem.health_check().await,
|
||||
});
|
||||
let _ = send_json(stream, 200, &body).await;
|
||||
}
|
||||
|
||||
("POST", "/webhook") => {
|
||||
handle_webhook(stream, request, provider, model, temperature, mem, auto_save).await;
|
||||
}
|
||||
|
||||
_ => {
|
||||
let body = serde_json::json!({
|
||||
"error": "Not found",
|
||||
"routes": ["GET /health", "POST /webhook"]
|
||||
});
|
||||
let _ = send_json(stream, 404, &body).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_webhook(
|
||||
stream: &mut tokio::net::TcpStream,
|
||||
request: &str,
|
||||
provider: &Arc<dyn Provider>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
mem: &Arc<dyn Memory>,
|
||||
auto_save: bool,
|
||||
) {
|
||||
let body_str = request
|
||||
.split("\r\n\r\n")
|
||||
.nth(1)
|
||||
.or_else(|| request.split("\n\n").nth(1))
|
||||
.unwrap_or("");
|
||||
|
||||
let Ok(parsed) = serde_json::from_str::<serde_json::Value>(body_str) else {
|
||||
let err = serde_json::json!({"error": "Invalid JSON. Expected: {\"message\": \"...\"}"});
|
||||
let _ = send_json(stream, 400, &err).await;
|
||||
return;
|
||||
};
|
||||
|
||||
let Some(message) = parsed.get("message").and_then(|v| v.as_str()) else {
|
||||
let err = serde_json::json!({"error": "Missing 'message' field in JSON"});
|
||||
let _ = send_json(stream, 400, &err).await;
|
||||
return;
|
||||
};
|
||||
|
||||
if auto_save {
|
||||
let _ = mem
|
||||
.store("webhook_msg", message, MemoryCategory::Conversation)
|
||||
.await;
|
||||
}
|
||||
|
||||
match provider.chat(message, model, temperature).await {
|
||||
Ok(response) => {
|
||||
let body = serde_json::json!({"response": response, "model": model});
|
||||
let _ = send_json(stream, 200, &body).await;
|
||||
}
|
||||
Err(e) => {
|
||||
let err = serde_json::json!({"error": format!("LLM error: {e}")});
|
||||
let _ = send_json(stream, 500, &err).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_response(
|
||||
stream: &mut tokio::net::TcpStream,
|
||||
status: u16,
|
||||
body: &str,
|
||||
) -> std::io::Result<()> {
|
||||
let reason = match status {
|
||||
200 => "OK",
|
||||
400 => "Bad Request",
|
||||
404 => "Not Found",
|
||||
500 => "Internal Server Error",
|
||||
_ => "Unknown",
|
||||
};
|
||||
let response = format!(
|
||||
"HTTP/1.1 {status} {reason}\r\nContent-Type: text/plain\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
|
||||
body.len()
|
||||
);
|
||||
stream.write_all(response.as_bytes()).await
|
||||
}
|
||||
|
||||
async fn send_json(
|
||||
stream: &mut tokio::net::TcpStream,
|
||||
status: u16,
|
||||
body: &serde_json::Value,
|
||||
) -> std::io::Result<()> {
|
||||
let reason = match status {
|
||||
200 => "OK",
|
||||
400 => "Bad Request",
|
||||
404 => "Not Found",
|
||||
500 => "Internal Server Error",
|
||||
_ => "Unknown",
|
||||
};
|
||||
let json = serde_json::to_string(body).unwrap_or_default();
|
||||
let response = format!(
|
||||
"HTTP/1.1 {status} {reason}\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{json}",
|
||||
json.len()
|
||||
);
|
||||
stream.write_all(response.as_bytes()).await
|
||||
}
|
||||
296
src/heartbeat/engine.rs
Normal file
296
src/heartbeat/engine.rs
Normal file
|
|
@ -0,0 +1,296 @@
|
|||
use crate::config::HeartbeatConfig;
|
||||
use crate::observability::{Observer, ObserverEvent};
|
||||
use anyhow::Result;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use tokio::time::{self, Duration};
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Heartbeat engine — reads HEARTBEAT.md and executes tasks periodically
|
||||
pub struct HeartbeatEngine {
|
||||
config: HeartbeatConfig,
|
||||
workspace_dir: std::path::PathBuf,
|
||||
observer: Arc<dyn Observer>,
|
||||
}
|
||||
|
||||
impl HeartbeatEngine {
|
||||
pub fn new(
|
||||
config: HeartbeatConfig,
|
||||
workspace_dir: std::path::PathBuf,
|
||||
observer: Arc<dyn Observer>,
|
||||
) -> Self {
|
||||
Self {
|
||||
config,
|
||||
workspace_dir,
|
||||
observer,
|
||||
}
|
||||
}
|
||||
|
||||
/// Start the heartbeat loop (runs until cancelled)
|
||||
pub async fn run(&self) -> Result<()> {
|
||||
if !self.config.enabled {
|
||||
info!("Heartbeat disabled");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let interval_mins = self.config.interval_minutes.max(5);
|
||||
info!("💓 Heartbeat started: every {} minutes", interval_mins);
|
||||
|
||||
let mut interval = time::interval(Duration::from_secs(u64::from(interval_mins) * 60));
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
self.observer.record_event(&ObserverEvent::HeartbeatTick);
|
||||
|
||||
match self.tick().await {
|
||||
Ok(tasks) => {
|
||||
if tasks > 0 {
|
||||
info!("💓 Heartbeat: processed {} tasks", tasks);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("💓 Heartbeat error: {}", e);
|
||||
self.observer.record_event(&ObserverEvent::Error {
|
||||
component: "heartbeat".into(),
|
||||
message: e.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Single heartbeat tick — read HEARTBEAT.md and return task count
|
||||
async fn tick(&self) -> Result<usize> {
|
||||
let heartbeat_path = self.workspace_dir.join("HEARTBEAT.md");
|
||||
|
||||
if !heartbeat_path.exists() {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
let content = tokio::fs::read_to_string(&heartbeat_path).await?;
|
||||
let tasks = Self::parse_tasks(&content);
|
||||
|
||||
Ok(tasks.len())
|
||||
}
|
||||
|
||||
/// Parse tasks from HEARTBEAT.md (lines starting with `- `)
|
||||
fn parse_tasks(content: &str) -> Vec<String> {
|
||||
content
|
||||
.lines()
|
||||
.filter_map(|line| {
|
||||
let trimmed = line.trim();
|
||||
trimmed.strip_prefix("- ").map(ToString::to_string)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Create a default HEARTBEAT.md if it doesn't exist
|
||||
pub async fn ensure_heartbeat_file(workspace_dir: &Path) -> Result<()> {
|
||||
let path = workspace_dir.join("HEARTBEAT.md");
|
||||
if !path.exists() {
|
||||
let default = "# Periodic Tasks\n\n\
|
||||
# Add tasks below (one per line, starting with `- `)\n\
|
||||
# The agent will check this file on each heartbeat tick.\n\
|
||||
#\n\
|
||||
# Examples:\n\
|
||||
# - Check my email for important messages\n\
|
||||
# - Review my calendar for upcoming events\n\
|
||||
# - Check the weather forecast\n";
|
||||
tokio::fs::write(&path, default).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_tasks_basic() {
|
||||
let content = "# Tasks\n\n- Check email\n- Review calendar\nNot a task\n- Third task";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
assert_eq!(tasks.len(), 3);
|
||||
assert_eq!(tasks[0], "Check email");
|
||||
assert_eq!(tasks[1], "Review calendar");
|
||||
assert_eq!(tasks[2], "Third task");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tasks_empty_content() {
|
||||
assert!(HeartbeatEngine::parse_tasks("").is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tasks_only_comments() {
|
||||
let tasks = HeartbeatEngine::parse_tasks("# No tasks here\n\nJust comments\n# Another");
|
||||
assert!(tasks.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tasks_with_leading_whitespace() {
|
||||
let content = " - Indented task\n\t- Tab indented";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
assert_eq!(tasks.len(), 2);
|
||||
assert_eq!(tasks[0], "Indented task");
|
||||
assert_eq!(tasks[1], "Tab indented");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tasks_dash_without_space_ignored() {
|
||||
let content = "- Real task\n-\n- Another";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
// "-" trimmed = "-", does NOT start with "- " => skipped
|
||||
// "- Real task" => "Real task"
|
||||
// "- Another" => "Another"
|
||||
assert_eq!(tasks.len(), 2);
|
||||
assert_eq!(tasks[0], "Real task");
|
||||
assert_eq!(tasks[1], "Another");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tasks_trailing_space_bullet_trimmed_to_dash() {
|
||||
// "- " trimmed becomes "-" (trim removes trailing space)
|
||||
// "-" does NOT start with "- " => skipped
|
||||
let content = "- ";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
assert_eq!(tasks.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tasks_bullet_with_content_after_spaces() {
|
||||
// "- hello " trimmed becomes "- hello" => starts_with "- " => "hello"
|
||||
let content = "- hello ";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
assert_eq!(tasks.len(), 1);
|
||||
assert_eq!(tasks[0], "hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tasks_unicode() {
|
||||
let content = "- Check email 📧\n- Review calendar 📅\n- 日本語タスク";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
assert_eq!(tasks.len(), 3);
|
||||
assert!(tasks[0].contains("📧"));
|
||||
assert!(tasks[2].contains("日本語"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tasks_mixed_markdown() {
|
||||
let content = "# Periodic Tasks\n\n## Quick\n- Task A\n\n## Long\n- Task B\n\n* Not a dash bullet\n1. Not numbered";
|
||||
let tasks = HeartbeatEngine::parse_tasks(content);
|
||||
assert_eq!(tasks.len(), 2);
|
||||
assert_eq!(tasks[0], "Task A");
|
||||
assert_eq!(tasks[1], "Task B");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tasks_single_task() {
|
||||
let tasks = HeartbeatEngine::parse_tasks("- Only one");
|
||||
assert_eq!(tasks.len(), 1);
|
||||
assert_eq!(tasks[0], "Only one");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tasks_many_tasks() {
|
||||
let content: String = (0..100).map(|i| format!("- Task {i}\n")).collect();
|
||||
let tasks = HeartbeatEngine::parse_tasks(&content);
|
||||
assert_eq!(tasks.len(), 100);
|
||||
assert_eq!(tasks[99], "Task 99");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ensure_heartbeat_file_creates_file() {
|
||||
let dir = std::env::temp_dir().join("zeroclaw_test_heartbeat");
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
tokio::fs::create_dir_all(&dir).await.unwrap();
|
||||
|
||||
HeartbeatEngine::ensure_heartbeat_file(&dir).await.unwrap();
|
||||
|
||||
let path = dir.join("HEARTBEAT.md");
|
||||
assert!(path.exists());
|
||||
let content = tokio::fs::read_to_string(&path).await.unwrap();
|
||||
assert!(content.contains("Periodic Tasks"));
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ensure_heartbeat_file_does_not_overwrite() {
|
||||
let dir = std::env::temp_dir().join("zeroclaw_test_heartbeat_no_overwrite");
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
tokio::fs::create_dir_all(&dir).await.unwrap();
|
||||
|
||||
let path = dir.join("HEARTBEAT.md");
|
||||
tokio::fs::write(&path, "- My custom task").await.unwrap();
|
||||
|
||||
HeartbeatEngine::ensure_heartbeat_file(&dir).await.unwrap();
|
||||
|
||||
let content = tokio::fs::read_to_string(&path).await.unwrap();
|
||||
assert_eq!(content, "- My custom task");
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tick_returns_zero_when_no_file() {
|
||||
let dir = std::env::temp_dir().join("zeroclaw_test_tick_no_file");
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
tokio::fs::create_dir_all(&dir).await.unwrap();
|
||||
|
||||
let observer: Arc<dyn Observer> = Arc::new(crate::observability::NoopObserver);
|
||||
let engine = HeartbeatEngine::new(
|
||||
HeartbeatConfig {
|
||||
enabled: true,
|
||||
interval_minutes: 30,
|
||||
},
|
||||
dir.clone(),
|
||||
observer,
|
||||
);
|
||||
let count = engine.tick().await.unwrap();
|
||||
assert_eq!(count, 0);
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tick_counts_tasks_from_file() {
|
||||
let dir = std::env::temp_dir().join("zeroclaw_test_tick_count");
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
tokio::fs::create_dir_all(&dir).await.unwrap();
|
||||
|
||||
tokio::fs::write(dir.join("HEARTBEAT.md"), "- A\n- B\n- C")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let observer: Arc<dyn Observer> = Arc::new(crate::observability::NoopObserver);
|
||||
let engine = HeartbeatEngine::new(
|
||||
HeartbeatConfig {
|
||||
enabled: true,
|
||||
interval_minutes: 30,
|
||||
},
|
||||
dir.clone(),
|
||||
observer,
|
||||
);
|
||||
let count = engine.tick().await.unwrap();
|
||||
assert_eq!(count, 3);
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_returns_immediately_when_disabled() {
|
||||
let observer: Arc<dyn Observer> = Arc::new(crate::observability::NoopObserver);
|
||||
let engine = HeartbeatEngine::new(
|
||||
HeartbeatConfig {
|
||||
enabled: false,
|
||||
interval_minutes: 30,
|
||||
},
|
||||
std::env::temp_dir(),
|
||||
observer,
|
||||
);
|
||||
// Should return Ok immediately, not loop forever
|
||||
let result = engine.run().await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
1
src/heartbeat/mod.rs
Normal file
1
src/heartbeat/mod.rs
Normal file
|
|
@ -0,0 +1 @@
|
|||
pub mod engine;
|
||||
234
src/integrations/mod.rs
Normal file
234
src/integrations/mod.rs
Normal file
|
|
@ -0,0 +1,234 @@
|
|||
pub mod registry;
|
||||
|
||||
use crate::config::Config;
|
||||
use anyhow::Result;
|
||||
|
||||
/// Integration status
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum IntegrationStatus {
|
||||
/// Fully implemented and ready to use
|
||||
Available,
|
||||
/// Configured and active
|
||||
Active,
|
||||
/// Planned but not yet implemented
|
||||
ComingSoon,
|
||||
}
|
||||
|
||||
/// Integration category
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum IntegrationCategory {
|
||||
Chat,
|
||||
AiModel,
|
||||
Productivity,
|
||||
MusicAudio,
|
||||
SmartHome,
|
||||
ToolsAutomation,
|
||||
MediaCreative,
|
||||
Social,
|
||||
Platform,
|
||||
}
|
||||
|
||||
impl IntegrationCategory {
|
||||
pub fn label(self) -> &'static str {
|
||||
match self {
|
||||
Self::Chat => "Chat Providers",
|
||||
Self::AiModel => "AI Models",
|
||||
Self::Productivity => "Productivity",
|
||||
Self::MusicAudio => "Music & Audio",
|
||||
Self::SmartHome => "Smart Home",
|
||||
Self::ToolsAutomation => "Tools & Automation",
|
||||
Self::MediaCreative => "Media & Creative",
|
||||
Self::Social => "Social",
|
||||
Self::Platform => "Platforms",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn all() -> &'static [Self] {
|
||||
&[
|
||||
Self::Chat,
|
||||
Self::AiModel,
|
||||
Self::Productivity,
|
||||
Self::MusicAudio,
|
||||
Self::SmartHome,
|
||||
Self::ToolsAutomation,
|
||||
Self::MediaCreative,
|
||||
Self::Social,
|
||||
Self::Platform,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
/// A registered integration
|
||||
pub struct IntegrationEntry {
|
||||
pub name: &'static str,
|
||||
pub description: &'static str,
|
||||
pub category: IntegrationCategory,
|
||||
pub status_fn: fn(&Config) -> IntegrationStatus,
|
||||
}
|
||||
|
||||
/// Handle the `integrations` CLI command
|
||||
pub fn handle_command(command: super::IntegrationCommands, config: &Config) -> Result<()> {
|
||||
match command {
|
||||
super::IntegrationCommands::List { category } => {
|
||||
list_integrations(config, category.as_deref())
|
||||
}
|
||||
super::IntegrationCommands::Info { name } => show_integration_info(config, &name),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::unnecessary_wraps)]
|
||||
fn list_integrations(config: &Config, filter_category: Option<&str>) -> Result<()> {
|
||||
let entries = registry::all_integrations();
|
||||
|
||||
let mut available = 0u32;
|
||||
let mut active = 0u32;
|
||||
let mut coming = 0u32;
|
||||
|
||||
for &cat in IntegrationCategory::all() {
|
||||
// Filter by category if specified
|
||||
if let Some(filter) = filter_category {
|
||||
let filter_lower = filter.to_lowercase();
|
||||
let cat_lower = cat.label().to_lowercase();
|
||||
if !cat_lower.contains(&filter_lower) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
let cat_entries: Vec<&IntegrationEntry> =
|
||||
entries.iter().filter(|e| e.category == cat).collect();
|
||||
|
||||
if cat_entries.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
println!("\n ⟩ {}", console::style(cat.label()).white().bold());
|
||||
|
||||
for entry in &cat_entries {
|
||||
let status = (entry.status_fn)(config);
|
||||
let (icon, label) = match status {
|
||||
IntegrationStatus::Active => {
|
||||
active += 1;
|
||||
("✅", console::style("active").green())
|
||||
}
|
||||
IntegrationStatus::Available => {
|
||||
available += 1;
|
||||
("⚪", console::style("available").dim())
|
||||
}
|
||||
IntegrationStatus::ComingSoon => {
|
||||
coming += 1;
|
||||
("🔜", console::style("coming soon").dim())
|
||||
}
|
||||
};
|
||||
println!(
|
||||
" {icon} {:<22} {:<30} {}",
|
||||
console::style(entry.name).white().bold(),
|
||||
entry.description,
|
||||
label
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let total = available + active + coming;
|
||||
println!();
|
||||
println!(" {total} integrations: {active} active, {available} available, {coming} coming soon");
|
||||
println!();
|
||||
println!(" Configure: zeroclaw onboard");
|
||||
println!(" Details: zeroclaw integrations info <name>");
|
||||
println!();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn show_integration_info(config: &Config, name: &str) -> Result<()> {
|
||||
let entries = registry::all_integrations();
|
||||
let name_lower = name.to_lowercase();
|
||||
|
||||
let Some(entry) = entries.iter().find(|e| e.name.to_lowercase() == name_lower) else {
|
||||
anyhow::bail!(
|
||||
"Unknown integration: {name}. Run `zeroclaw integrations list` to see all."
|
||||
);
|
||||
};
|
||||
|
||||
let status = (entry.status_fn)(config);
|
||||
let (icon, label) = match status {
|
||||
IntegrationStatus::Active => ("✅", "Active"),
|
||||
IntegrationStatus::Available => ("⚪", "Available"),
|
||||
IntegrationStatus::ComingSoon => ("🔜", "Coming Soon"),
|
||||
};
|
||||
|
||||
println!();
|
||||
println!(" {} {} — {}", icon, console::style(entry.name).white().bold(), entry.description);
|
||||
println!(" Category: {}", entry.category.label());
|
||||
println!(" Status: {label}");
|
||||
println!();
|
||||
|
||||
// Show setup hints based on integration
|
||||
match entry.name {
|
||||
"Telegram" => {
|
||||
println!(" Setup:");
|
||||
println!(" 1. Message @BotFather on Telegram");
|
||||
println!(" 2. Create a bot and copy the token");
|
||||
println!(" 3. Run: zeroclaw onboard");
|
||||
println!(" 4. Start: zeroclaw channel start");
|
||||
}
|
||||
"Discord" => {
|
||||
println!(" Setup:");
|
||||
println!(" 1. Go to https://discord.com/developers/applications");
|
||||
println!(" 2. Create app → Bot → Copy token");
|
||||
println!(" 3. Enable MESSAGE CONTENT intent");
|
||||
println!(" 4. Run: zeroclaw onboard");
|
||||
}
|
||||
"Slack" => {
|
||||
println!(" Setup:");
|
||||
println!(" 1. Go to https://api.slack.com/apps");
|
||||
println!(" 2. Create app → Bot Token Scopes → Install");
|
||||
println!(" 3. Run: zeroclaw onboard");
|
||||
}
|
||||
"OpenRouter" => {
|
||||
println!(" Setup:");
|
||||
println!(" 1. Get API key at https://openrouter.ai/keys");
|
||||
println!(" 2. Run: zeroclaw onboard");
|
||||
println!(" Access 200+ models with one key.");
|
||||
}
|
||||
"Ollama" => {
|
||||
println!(" Setup:");
|
||||
println!(" 1. Install: brew install ollama");
|
||||
println!(" 2. Pull a model: ollama pull llama3");
|
||||
println!(" 3. Set provider to 'ollama' in config.toml");
|
||||
}
|
||||
"iMessage" => {
|
||||
println!(" Setup (macOS only):");
|
||||
println!(" Uses AppleScript bridge to send/receive iMessages.");
|
||||
println!(" Requires Full Disk Access in System Settings → Privacy.");
|
||||
}
|
||||
"GitHub" => {
|
||||
println!(" Setup:");
|
||||
println!(" 1. Create a personal access token at https://github.com/settings/tokens");
|
||||
println!(" 2. Add to config: [integrations.github] token = \"ghp_...\"");
|
||||
}
|
||||
"Browser" => {
|
||||
println!(" Built-in:");
|
||||
println!(" ZeroClaw can control Chrome/Chromium for web tasks.");
|
||||
println!(" Uses headless browser automation.");
|
||||
}
|
||||
"Cron" => {
|
||||
println!(" Built-in:");
|
||||
println!(" Schedule tasks in ~/.zeroclaw/workspace/cron/");
|
||||
println!(" Run: zeroclaw cron list");
|
||||
}
|
||||
"Webhooks" => {
|
||||
println!(" Built-in:");
|
||||
println!(" HTTP endpoint for external triggers.");
|
||||
println!(" Run: zeroclaw gateway");
|
||||
}
|
||||
_ => {
|
||||
if status == IntegrationStatus::ComingSoon {
|
||||
println!(" This integration is planned. Stay tuned!");
|
||||
println!(" Track progress: https://github.com/theonlyhennygod/zeroclaw");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!();
|
||||
Ok(())
|
||||
}
|
||||
821
src/integrations/registry.rs
Normal file
821
src/integrations/registry.rs
Normal file
|
|
@ -0,0 +1,821 @@
|
|||
use super::{IntegrationCategory, IntegrationEntry, IntegrationStatus};
|
||||
|
||||
/// Returns the full catalog of integrations
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub fn all_integrations() -> Vec<IntegrationEntry> {
|
||||
vec![
|
||||
// ── Chat Providers ──────────────────────────────────────
|
||||
IntegrationEntry {
|
||||
name: "Telegram",
|
||||
description: "Bot API — long-polling",
|
||||
category: IntegrationCategory::Chat,
|
||||
status_fn: |c| {
|
||||
if c.channels_config.telegram.is_some() {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Discord",
|
||||
description: "Servers, channels & DMs",
|
||||
category: IntegrationCategory::Chat,
|
||||
status_fn: |c| {
|
||||
if c.channels_config.discord.is_some() {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Slack",
|
||||
description: "Workspace apps via Web API",
|
||||
category: IntegrationCategory::Chat,
|
||||
status_fn: |c| {
|
||||
if c.channels_config.slack.is_some() {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Webhooks",
|
||||
description: "HTTP endpoint for triggers",
|
||||
category: IntegrationCategory::Chat,
|
||||
status_fn: |c| {
|
||||
if c.channels_config.webhook.is_some() {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "WhatsApp",
|
||||
description: "QR pairing via web bridge",
|
||||
category: IntegrationCategory::Chat,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Signal",
|
||||
description: "Privacy-focused via signal-cli",
|
||||
category: IntegrationCategory::Chat,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "iMessage",
|
||||
description: "macOS AppleScript bridge",
|
||||
category: IntegrationCategory::Chat,
|
||||
status_fn: |c| {
|
||||
if c.channels_config.imessage.is_some() {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Microsoft Teams",
|
||||
description: "Enterprise chat support",
|
||||
category: IntegrationCategory::Chat,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Matrix",
|
||||
description: "Matrix protocol (Element)",
|
||||
category: IntegrationCategory::Chat,
|
||||
status_fn: |c| {
|
||||
if c.channels_config.matrix.is_some() {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Nostr",
|
||||
description: "Decentralized DMs (NIP-04)",
|
||||
category: IntegrationCategory::Chat,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "WebChat",
|
||||
description: "Browser-based chat UI",
|
||||
category: IntegrationCategory::Chat,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Nextcloud Talk",
|
||||
description: "Self-hosted Nextcloud chat",
|
||||
category: IntegrationCategory::Chat,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Zalo",
|
||||
description: "Zalo Bot API",
|
||||
category: IntegrationCategory::Chat,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
// ── AI Models ───────────────────────────────────────────
|
||||
IntegrationEntry {
|
||||
name: "OpenRouter",
|
||||
description: "200+ models, 1 API key",
|
||||
category: IntegrationCategory::AiModel,
|
||||
status_fn: |c| {
|
||||
if c.default_provider.as_deref() == Some("openrouter") && c.api_key.is_some() {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Anthropic",
|
||||
description: "Claude 3.5/4 Sonnet & Opus",
|
||||
category: IntegrationCategory::AiModel,
|
||||
status_fn: |c| {
|
||||
if c.default_provider.as_deref() == Some("anthropic") {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "OpenAI",
|
||||
description: "GPT-4o, GPT-5, o1",
|
||||
category: IntegrationCategory::AiModel,
|
||||
status_fn: |c| {
|
||||
if c.default_provider.as_deref() == Some("openai") {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Google",
|
||||
description: "Gemini 2.5 Pro/Flash",
|
||||
category: IntegrationCategory::AiModel,
|
||||
status_fn: |c| {
|
||||
if c.default_model.as_deref().is_some_and(|m| m.starts_with("google/")) {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "DeepSeek",
|
||||
description: "DeepSeek V3 & R1",
|
||||
category: IntegrationCategory::AiModel,
|
||||
status_fn: |c| {
|
||||
if c.default_model.as_deref().is_some_and(|m| m.starts_with("deepseek/")) {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "xAI",
|
||||
description: "Grok 3 & 4",
|
||||
category: IntegrationCategory::AiModel,
|
||||
status_fn: |c| {
|
||||
if c.default_model.as_deref().is_some_and(|m| m.starts_with("x-ai/")) {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Mistral",
|
||||
description: "Mistral Large & Codestral",
|
||||
category: IntegrationCategory::AiModel,
|
||||
status_fn: |c| {
|
||||
if c.default_model.as_deref().is_some_and(|m| m.starts_with("mistral")) {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Ollama",
|
||||
description: "Local models (Llama, etc.)",
|
||||
category: IntegrationCategory::AiModel,
|
||||
status_fn: |c| {
|
||||
if c.default_provider.as_deref() == Some("ollama") {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Perplexity",
|
||||
description: "Search-augmented AI",
|
||||
category: IntegrationCategory::AiModel,
|
||||
status_fn: |c| {
|
||||
if c.default_provider.as_deref() == Some("perplexity") {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Hugging Face",
|
||||
description: "Open-source models",
|
||||
category: IntegrationCategory::AiModel,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "LM Studio",
|
||||
description: "Local model server",
|
||||
category: IntegrationCategory::AiModel,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Venice",
|
||||
description: "Privacy-first inference (Llama, Opus)",
|
||||
category: IntegrationCategory::AiModel,
|
||||
status_fn: |c| {
|
||||
if c.default_provider.as_deref() == Some("venice") {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Vercel AI",
|
||||
description: "Vercel AI Gateway",
|
||||
category: IntegrationCategory::AiModel,
|
||||
status_fn: |c| {
|
||||
if c.default_provider.as_deref() == Some("vercel") {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Cloudflare AI",
|
||||
description: "Cloudflare AI Gateway",
|
||||
category: IntegrationCategory::AiModel,
|
||||
status_fn: |c| {
|
||||
if c.default_provider.as_deref() == Some("cloudflare") {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Moonshot",
|
||||
description: "Kimi & Kimi Coding",
|
||||
category: IntegrationCategory::AiModel,
|
||||
status_fn: |c| {
|
||||
if c.default_provider.as_deref() == Some("moonshot") {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Synthetic",
|
||||
description: "Synthetic AI models",
|
||||
category: IntegrationCategory::AiModel,
|
||||
status_fn: |c| {
|
||||
if c.default_provider.as_deref() == Some("synthetic") {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "OpenCode Zen",
|
||||
description: "Code-focused AI models",
|
||||
category: IntegrationCategory::AiModel,
|
||||
status_fn: |c| {
|
||||
if c.default_provider.as_deref() == Some("opencode") {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Z.AI",
|
||||
description: "Z.AI inference",
|
||||
category: IntegrationCategory::AiModel,
|
||||
status_fn: |c| {
|
||||
if c.default_provider.as_deref() == Some("zai") {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "GLM",
|
||||
description: "ChatGLM / Zhipu models",
|
||||
category: IntegrationCategory::AiModel,
|
||||
status_fn: |c| {
|
||||
if c.default_provider.as_deref() == Some("glm") {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "MiniMax",
|
||||
description: "MiniMax AI models",
|
||||
category: IntegrationCategory::AiModel,
|
||||
status_fn: |c| {
|
||||
if c.default_provider.as_deref() == Some("minimax") {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Amazon Bedrock",
|
||||
description: "AWS managed model access",
|
||||
category: IntegrationCategory::AiModel,
|
||||
status_fn: |c| {
|
||||
if c.default_provider.as_deref() == Some("bedrock") {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Qianfan",
|
||||
description: "Baidu AI models",
|
||||
category: IntegrationCategory::AiModel,
|
||||
status_fn: |c| {
|
||||
if c.default_provider.as_deref() == Some("qianfan") {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Groq",
|
||||
description: "Ultra-fast LPU inference",
|
||||
category: IntegrationCategory::AiModel,
|
||||
status_fn: |c| {
|
||||
if c.default_provider.as_deref() == Some("groq") {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Together AI",
|
||||
description: "Open-source model hosting",
|
||||
category: IntegrationCategory::AiModel,
|
||||
status_fn: |c| {
|
||||
if c.default_provider.as_deref() == Some("together") {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Fireworks AI",
|
||||
description: "Fast open-source inference",
|
||||
category: IntegrationCategory::AiModel,
|
||||
status_fn: |c| {
|
||||
if c.default_provider.as_deref() == Some("fireworks") {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Cohere",
|
||||
description: "Command R+ & embeddings",
|
||||
category: IntegrationCategory::AiModel,
|
||||
status_fn: |c| {
|
||||
if c.default_provider.as_deref() == Some("cohere") {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
// ── Productivity ────────────────────────────────────────
|
||||
IntegrationEntry {
|
||||
name: "GitHub",
|
||||
description: "Code, issues, PRs",
|
||||
category: IntegrationCategory::Productivity,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Notion",
|
||||
description: "Workspace & databases",
|
||||
category: IntegrationCategory::Productivity,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Apple Notes",
|
||||
description: "Native macOS/iOS notes",
|
||||
category: IntegrationCategory::Productivity,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Apple Reminders",
|
||||
description: "Task management",
|
||||
category: IntegrationCategory::Productivity,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Obsidian",
|
||||
description: "Knowledge graph notes",
|
||||
category: IntegrationCategory::Productivity,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Things 3",
|
||||
description: "GTD task manager",
|
||||
category: IntegrationCategory::Productivity,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Bear Notes",
|
||||
description: "Markdown notes",
|
||||
category: IntegrationCategory::Productivity,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Trello",
|
||||
description: "Kanban boards",
|
||||
category: IntegrationCategory::Productivity,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Linear",
|
||||
description: "Issue tracking",
|
||||
category: IntegrationCategory::Productivity,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
// ── Music & Audio ───────────────────────────────────────
|
||||
IntegrationEntry {
|
||||
name: "Spotify",
|
||||
description: "Music playback control",
|
||||
category: IntegrationCategory::MusicAudio,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Sonos",
|
||||
description: "Multi-room audio",
|
||||
category: IntegrationCategory::MusicAudio,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Shazam",
|
||||
description: "Song recognition",
|
||||
category: IntegrationCategory::MusicAudio,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
// ── Smart Home ──────────────────────────────────────────
|
||||
IntegrationEntry {
|
||||
name: "Home Assistant",
|
||||
description: "Home automation hub",
|
||||
category: IntegrationCategory::SmartHome,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Philips Hue",
|
||||
description: "Smart lighting",
|
||||
category: IntegrationCategory::SmartHome,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "8Sleep",
|
||||
description: "Smart mattress",
|
||||
category: IntegrationCategory::SmartHome,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
// ── Tools & Automation ──────────────────────────────────
|
||||
IntegrationEntry {
|
||||
name: "Browser",
|
||||
description: "Chrome/Chromium control",
|
||||
category: IntegrationCategory::ToolsAutomation,
|
||||
status_fn: |_| IntegrationStatus::Available,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Shell",
|
||||
description: "Terminal command execution",
|
||||
category: IntegrationCategory::ToolsAutomation,
|
||||
status_fn: |_| IntegrationStatus::Active,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "File System",
|
||||
description: "Read/write files",
|
||||
category: IntegrationCategory::ToolsAutomation,
|
||||
status_fn: |_| IntegrationStatus::Active,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Cron",
|
||||
description: "Scheduled tasks",
|
||||
category: IntegrationCategory::ToolsAutomation,
|
||||
status_fn: |_| IntegrationStatus::Available,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Voice",
|
||||
description: "Voice wake + talk mode",
|
||||
category: IntegrationCategory::ToolsAutomation,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Gmail",
|
||||
description: "Email triggers & send",
|
||||
category: IntegrationCategory::ToolsAutomation,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "1Password",
|
||||
description: "Secure credentials",
|
||||
category: IntegrationCategory::ToolsAutomation,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Weather",
|
||||
description: "Forecasts & conditions",
|
||||
category: IntegrationCategory::ToolsAutomation,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Canvas",
|
||||
description: "Visual workspace + A2UI",
|
||||
category: IntegrationCategory::ToolsAutomation,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
// ── Media & Creative ────────────────────────────────────
|
||||
IntegrationEntry {
|
||||
name: "Image Gen",
|
||||
description: "AI image generation",
|
||||
category: IntegrationCategory::MediaCreative,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "GIF Search",
|
||||
description: "Find the perfect GIF",
|
||||
category: IntegrationCategory::MediaCreative,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Screen Capture",
|
||||
description: "Screenshot & screen control",
|
||||
category: IntegrationCategory::MediaCreative,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Camera",
|
||||
description: "Photo/video capture",
|
||||
category: IntegrationCategory::MediaCreative,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
// ── Social ──────────────────────────────────────────────
|
||||
IntegrationEntry {
|
||||
name: "Twitter/X",
|
||||
description: "Tweet, reply, search",
|
||||
category: IntegrationCategory::Social,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Email",
|
||||
description: "Send & read emails",
|
||||
category: IntegrationCategory::Social,
|
||||
status_fn: |_| IntegrationStatus::ComingSoon,
|
||||
},
|
||||
// ── Platforms ───────────────────────────────────────────
|
||||
IntegrationEntry {
|
||||
name: "macOS",
|
||||
description: "Native support + AppleScript",
|
||||
category: IntegrationCategory::Platform,
|
||||
status_fn: |_| {
|
||||
if cfg!(target_os = "macos") {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Linux",
|
||||
description: "Native support",
|
||||
category: IntegrationCategory::Platform,
|
||||
status_fn: |_| {
|
||||
if cfg!(target_os = "linux") {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Windows",
|
||||
description: "WSL2 recommended",
|
||||
category: IntegrationCategory::Platform,
|
||||
status_fn: |_| IntegrationStatus::Available,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "iOS",
|
||||
description: "Chat via Telegram/Discord",
|
||||
category: IntegrationCategory::Platform,
|
||||
status_fn: |_| IntegrationStatus::Available,
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Android",
|
||||
description: "Chat via Telegram/Discord",
|
||||
category: IntegrationCategory::Platform,
|
||||
status_fn: |_| IntegrationStatus::Available,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::Config;
|
||||
use crate::config::schema::{
|
||||
ChannelsConfig, IMessageConfig, MatrixConfig, TelegramConfig,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn registry_has_entries() {
|
||||
let entries = all_integrations();
|
||||
assert!(entries.len() >= 50, "Expected 50+ integrations, got {}", entries.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn all_categories_represented() {
|
||||
let entries = all_integrations();
|
||||
for cat in IntegrationCategory::all() {
|
||||
let count = entries.iter().filter(|e| e.category == *cat).count();
|
||||
assert!(count > 0, "Category {:?} has no entries", cat);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn status_functions_dont_panic() {
|
||||
let config = Config::default();
|
||||
let entries = all_integrations();
|
||||
for entry in &entries {
|
||||
let _ = (entry.status_fn)(&config);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_duplicate_names() {
|
||||
let entries = all_integrations();
|
||||
let mut seen = std::collections::HashSet::new();
|
||||
for entry in &entries {
|
||||
assert!(
|
||||
seen.insert(entry.name),
|
||||
"Duplicate integration name: {}",
|
||||
entry.name
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_empty_names_or_descriptions() {
|
||||
let entries = all_integrations();
|
||||
for entry in &entries {
|
||||
assert!(!entry.name.is_empty(), "Found integration with empty name");
|
||||
assert!(
|
||||
!entry.description.is_empty(),
|
||||
"Integration '{}' has empty description",
|
||||
entry.name
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn telegram_active_when_configured() {
|
||||
let mut config = Config::default();
|
||||
config.channels_config.telegram = Some(TelegramConfig {
|
||||
bot_token: "123:ABC".into(),
|
||||
allowed_users: vec!["user".into()],
|
||||
});
|
||||
let entries = all_integrations();
|
||||
let tg = entries.iter().find(|e| e.name == "Telegram").unwrap();
|
||||
assert!(matches!((tg.status_fn)(&config), IntegrationStatus::Active));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn telegram_available_when_not_configured() {
|
||||
let config = Config::default();
|
||||
let entries = all_integrations();
|
||||
let tg = entries.iter().find(|e| e.name == "Telegram").unwrap();
|
||||
assert!(matches!((tg.status_fn)(&config), IntegrationStatus::Available));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn imessage_active_when_configured() {
|
||||
let mut config = Config::default();
|
||||
config.channels_config.imessage = Some(IMessageConfig {
|
||||
allowed_contacts: vec!["*".into()],
|
||||
});
|
||||
let entries = all_integrations();
|
||||
let im = entries.iter().find(|e| e.name == "iMessage").unwrap();
|
||||
assert!(matches!((im.status_fn)(&config), IntegrationStatus::Active));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn imessage_available_when_not_configured() {
|
||||
let config = Config::default();
|
||||
let entries = all_integrations();
|
||||
let im = entries.iter().find(|e| e.name == "iMessage").unwrap();
|
||||
assert!(matches!((im.status_fn)(&config), IntegrationStatus::Available));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn matrix_active_when_configured() {
|
||||
let mut config = Config::default();
|
||||
config.channels_config.matrix = Some(MatrixConfig {
|
||||
homeserver: "https://m.org".into(),
|
||||
access_token: "tok".into(),
|
||||
room_id: "!r:m".into(),
|
||||
allowed_users: vec![],
|
||||
});
|
||||
let entries = all_integrations();
|
||||
let mx = entries.iter().find(|e| e.name == "Matrix").unwrap();
|
||||
assert!(matches!((mx.status_fn)(&config), IntegrationStatus::Active));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn matrix_available_when_not_configured() {
|
||||
let config = Config::default();
|
||||
let entries = all_integrations();
|
||||
let mx = entries.iter().find(|e| e.name == "Matrix").unwrap();
|
||||
assert!(matches!((mx.status_fn)(&config), IntegrationStatus::Available));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn coming_soon_integrations_stay_coming_soon() {
|
||||
let config = Config::default();
|
||||
let entries = all_integrations();
|
||||
for name in ["WhatsApp", "Signal", "Nostr", "Spotify", "Home Assistant"] {
|
||||
let entry = entries.iter().find(|e| e.name == name).unwrap();
|
||||
assert!(
|
||||
matches!((entry.status_fn)(&config), IntegrationStatus::ComingSoon),
|
||||
"{name} should be ComingSoon"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn shell_and_filesystem_always_active() {
|
||||
let config = Config::default();
|
||||
let entries = all_integrations();
|
||||
for name in ["Shell", "File System"] {
|
||||
let entry = entries.iter().find(|e| e.name == name).unwrap();
|
||||
assert!(
|
||||
matches!((entry.status_fn)(&config), IntegrationStatus::Active),
|
||||
"{name} should always be Active"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn macos_active_on_macos() {
|
||||
let config = Config::default();
|
||||
let entries = all_integrations();
|
||||
let macos = entries.iter().find(|e| e.name == "macOS").unwrap();
|
||||
let status = (macos.status_fn)(&config);
|
||||
if cfg!(target_os = "macos") {
|
||||
assert!(matches!(status, IntegrationStatus::Active));
|
||||
} else {
|
||||
assert!(matches!(status, IntegrationStatus::Available));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn category_counts_reasonable() {
|
||||
let entries = all_integrations();
|
||||
let chat_count = entries.iter().filter(|e| e.category == IntegrationCategory::Chat).count();
|
||||
let ai_count = entries.iter().filter(|e| e.category == IntegrationCategory::AiModel).count();
|
||||
assert!(chat_count >= 5, "Expected 5+ chat integrations, got {chat_count}");
|
||||
assert!(ai_count >= 5, "Expected 5+ AI model integrations, got {ai_count}");
|
||||
}
|
||||
}
|
||||
20
src/lib.rs
Normal file
20
src/lib.rs
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
#![warn(clippy::all, clippy::pedantic)]
|
||||
#![allow(
|
||||
clippy::missing_errors_doc,
|
||||
clippy::missing_panics_doc,
|
||||
clippy::unnecessary_literal_bound,
|
||||
clippy::module_name_repetitions,
|
||||
clippy::struct_field_names,
|
||||
clippy::must_use_candidate,
|
||||
clippy::new_without_default,
|
||||
clippy::return_self_not_must_use,
|
||||
dead_code
|
||||
)]
|
||||
|
||||
pub mod config;
|
||||
pub mod heartbeat;
|
||||
pub mod memory;
|
||||
pub mod observability;
|
||||
pub mod providers;
|
||||
pub mod runtime;
|
||||
pub mod security;
|
||||
326
src/main.rs
Normal file
326
src/main.rs
Normal file
|
|
@ -0,0 +1,326 @@
|
|||
#![warn(clippy::all, clippy::pedantic)]
|
||||
#![allow(
|
||||
clippy::missing_errors_doc,
|
||||
clippy::missing_panics_doc,
|
||||
clippy::unnecessary_literal_bound,
|
||||
clippy::module_name_repetitions,
|
||||
clippy::struct_field_names,
|
||||
dead_code
|
||||
)]
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::{Parser, Subcommand};
|
||||
use tracing::{info, Level};
|
||||
use tracing_subscriber::FmtSubscriber;
|
||||
|
||||
mod agent;
|
||||
mod channels;
|
||||
mod config;
|
||||
mod cron;
|
||||
mod gateway;
|
||||
mod heartbeat;
|
||||
mod memory;
|
||||
mod observability;
|
||||
mod onboard;
|
||||
mod providers;
|
||||
mod runtime;
|
||||
mod security;
|
||||
mod integrations;
|
||||
mod skills;
|
||||
mod tools;
|
||||
|
||||
use config::Config;
|
||||
|
||||
/// `ZeroClaw` - Zero overhead. Zero compromise. 100% Rust.
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "zeroclaw")]
|
||||
#[command(author = "theonlyhennygod")]
|
||||
#[command(version = "0.1.0")]
|
||||
#[command(about = "The fastest, smallest AI assistant.", long_about = None)]
|
||||
struct Cli {
|
||||
#[command(subcommand)]
|
||||
command: Commands,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
enum Commands {
|
||||
/// Initialize your workspace and configuration
|
||||
Onboard,
|
||||
|
||||
/// Start the AI agent loop
|
||||
Agent {
|
||||
/// Single message mode (don't enter interactive mode)
|
||||
#[arg(short, long)]
|
||||
message: Option<String>,
|
||||
|
||||
/// Provider to use (openrouter, anthropic, openai)
|
||||
#[arg(short, long)]
|
||||
provider: Option<String>,
|
||||
|
||||
/// Model to use
|
||||
#[arg(short, long)]
|
||||
model: Option<String>,
|
||||
|
||||
/// Temperature (0.0 - 2.0)
|
||||
#[arg(short, long, default_value = "0.7")]
|
||||
temperature: f64,
|
||||
},
|
||||
|
||||
/// Start the gateway server (webhooks, websockets)
|
||||
Gateway {
|
||||
/// Port to listen on
|
||||
#[arg(short, long, default_value = "8080")]
|
||||
port: u16,
|
||||
|
||||
/// Host to bind to
|
||||
#[arg(short, long, default_value = "127.0.0.1")]
|
||||
host: String,
|
||||
},
|
||||
|
||||
/// Show system status
|
||||
Status {
|
||||
/// Show detailed status
|
||||
#[arg(short, long)]
|
||||
verbose: bool,
|
||||
},
|
||||
|
||||
/// Configure and manage scheduled tasks
|
||||
Cron {
|
||||
#[command(subcommand)]
|
||||
cron_command: CronCommands,
|
||||
},
|
||||
|
||||
/// Manage channels (telegram, discord, slack)
|
||||
Channel {
|
||||
#[command(subcommand)]
|
||||
channel_command: ChannelCommands,
|
||||
},
|
||||
|
||||
/// Tool utilities
|
||||
Tools {
|
||||
#[command(subcommand)]
|
||||
tool_command: ToolCommands,
|
||||
},
|
||||
|
||||
/// Browse 50+ integrations
|
||||
Integrations {
|
||||
#[command(subcommand)]
|
||||
integration_command: IntegrationCommands,
|
||||
},
|
||||
|
||||
/// Manage skills (user-defined capabilities)
|
||||
Skills {
|
||||
#[command(subcommand)]
|
||||
skill_command: SkillCommands,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
enum CronCommands {
|
||||
/// List all scheduled tasks
|
||||
List,
|
||||
/// Add a new scheduled task
|
||||
Add {
|
||||
/// Cron expression
|
||||
expression: String,
|
||||
/// Command to run
|
||||
command: String,
|
||||
},
|
||||
/// Remove a scheduled task
|
||||
Remove {
|
||||
/// Task ID
|
||||
id: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
enum ChannelCommands {
|
||||
/// List configured channels
|
||||
List,
|
||||
/// Start all configured channels (Telegram, Discord, Slack)
|
||||
Start,
|
||||
/// Add a new channel
|
||||
Add {
|
||||
/// Channel type
|
||||
channel_type: String,
|
||||
/// Configuration JSON
|
||||
config: String,
|
||||
},
|
||||
/// Remove a channel
|
||||
Remove {
|
||||
/// Channel name
|
||||
name: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
enum SkillCommands {
|
||||
/// List installed skills
|
||||
List,
|
||||
/// Install a skill from a GitHub URL or local path
|
||||
Install {
|
||||
/// GitHub URL or local path
|
||||
source: String,
|
||||
},
|
||||
/// Remove an installed skill
|
||||
Remove {
|
||||
/// Skill name
|
||||
name: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
enum IntegrationCommands {
|
||||
/// List all integrations and their status
|
||||
List {
|
||||
/// Filter by category (e.g. "chat", "ai", "productivity")
|
||||
#[arg(short, long)]
|
||||
category: Option<String>,
|
||||
},
|
||||
/// Show details about a specific integration
|
||||
Info {
|
||||
/// Integration name
|
||||
name: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
enum ToolCommands {
|
||||
/// List available tools
|
||||
List,
|
||||
/// Test a tool
|
||||
Test {
|
||||
/// Tool name
|
||||
tool: String,
|
||||
/// Tool arguments (JSON)
|
||||
args: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
#[allow(clippy::too_many_lines)]
|
||||
async fn main() -> Result<()> {
|
||||
let cli = Cli::parse();
|
||||
|
||||
// Initialize logging
|
||||
let subscriber = FmtSubscriber::builder()
|
||||
.with_max_level(Level::INFO)
|
||||
.finish();
|
||||
|
||||
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
|
||||
|
||||
// Onboard runs the interactive wizard — no existing config needed
|
||||
if matches!(cli.command, Commands::Onboard) {
|
||||
let config = onboard::run_wizard()?;
|
||||
// Auto-start channels if user said yes during wizard
|
||||
if std::env::var("ZEROCLAW_AUTOSTART_CHANNELS").as_deref() == Ok("1") {
|
||||
channels::start_channels(config).await?;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// All other commands need config loaded first
|
||||
let config = Config::load_or_init()?;
|
||||
|
||||
match cli.command {
|
||||
Commands::Onboard => unreachable!(),
|
||||
|
||||
Commands::Agent {
|
||||
message,
|
||||
provider,
|
||||
model,
|
||||
temperature,
|
||||
} => agent::run(config, message, provider, model, temperature).await,
|
||||
|
||||
Commands::Gateway { port, host } => {
|
||||
info!("🚀 Starting ZeroClaw Gateway on {host}:{port}");
|
||||
info!("POST http://{host}:{port}/webhook — send JSON messages");
|
||||
info!("GET http://{host}:{port}/health — health check");
|
||||
gateway::run_gateway(&host, port, config).await
|
||||
}
|
||||
|
||||
Commands::Status { verbose } => {
|
||||
println!("🦀 ZeroClaw Status");
|
||||
println!();
|
||||
println!("Version: {}", env!("CARGO_PKG_VERSION"));
|
||||
println!("Workspace: {}", config.workspace_dir.display());
|
||||
println!("Config: {}", config.config_path.display());
|
||||
println!();
|
||||
println!(
|
||||
"🤖 Provider: {}",
|
||||
config.default_provider.as_deref().unwrap_or("openrouter")
|
||||
);
|
||||
println!(
|
||||
" Model: {}",
|
||||
config.default_model.as_deref().unwrap_or("(default)")
|
||||
);
|
||||
println!("📊 Observability: {}", config.observability.backend);
|
||||
println!("🛡️ Autonomy: {:?}", config.autonomy.level);
|
||||
println!("⚙️ Runtime: {}", config.runtime.kind);
|
||||
println!(
|
||||
"💓 Heartbeat: {}",
|
||||
if config.heartbeat.enabled {
|
||||
format!("every {}min", config.heartbeat.interval_minutes)
|
||||
} else {
|
||||
"disabled".into()
|
||||
}
|
||||
);
|
||||
println!(
|
||||
"🧠 Memory: {} (auto-save: {})",
|
||||
config.memory.backend,
|
||||
if config.memory.auto_save { "on" } else { "off" }
|
||||
);
|
||||
|
||||
if verbose {
|
||||
println!();
|
||||
println!("Security:");
|
||||
println!(" Workspace only: {}", config.autonomy.workspace_only);
|
||||
println!(
|
||||
" Allowed commands: {}",
|
||||
config.autonomy.allowed_commands.join(", ")
|
||||
);
|
||||
println!(
|
||||
" Max actions/hour: {}",
|
||||
config.autonomy.max_actions_per_hour
|
||||
);
|
||||
println!(
|
||||
" Max cost/day: ${:.2}",
|
||||
f64::from(config.autonomy.max_cost_per_day_cents) / 100.0
|
||||
);
|
||||
println!();
|
||||
println!("Channels:");
|
||||
println!(" CLI: ✅ always");
|
||||
for (name, configured) in [
|
||||
("Telegram", config.channels_config.telegram.is_some()),
|
||||
("Discord", config.channels_config.discord.is_some()),
|
||||
("Slack", config.channels_config.slack.is_some()),
|
||||
("Webhook", config.channels_config.webhook.is_some()),
|
||||
] {
|
||||
println!(
|
||||
" {name:9} {}",
|
||||
if configured { "✅ configured" } else { "❌ not configured" }
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Commands::Cron { cron_command } => cron::handle_command(cron_command, config),
|
||||
|
||||
Commands::Channel { channel_command } => match channel_command {
|
||||
ChannelCommands::Start => channels::start_channels(config).await,
|
||||
other => channels::handle_command(other, &config),
|
||||
},
|
||||
|
||||
Commands::Tools { tool_command } => tools::handle_command(tool_command, config).await,
|
||||
|
||||
Commands::Integrations {
|
||||
integration_command,
|
||||
} => integrations::handle_command(integration_command, &config),
|
||||
|
||||
Commands::Skills { skill_command } => {
|
||||
skills::handle_command(skill_command, &config.workspace_dir)
|
||||
}
|
||||
}
|
||||
}
|
||||
344
src/memory/markdown.rs
Normal file
344
src/memory/markdown.rs
Normal file
|
|
@ -0,0 +1,344 @@
|
|||
use super::traits::{Memory, MemoryCategory, MemoryEntry};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Local;
|
||||
use std::path::{Path, PathBuf};
|
||||
use tokio::fs;
|
||||
|
||||
/// Markdown-based memory — plain files as source of truth
|
||||
///
|
||||
/// Layout:
|
||||
/// workspace/MEMORY.md — curated long-term memory (core)
|
||||
/// workspace/memory/YYYY-MM-DD.md — daily logs (append-only)
|
||||
pub struct MarkdownMemory {
|
||||
workspace_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl MarkdownMemory {
|
||||
pub fn new(workspace_dir: &Path) -> Self {
|
||||
Self {
|
||||
workspace_dir: workspace_dir.to_path_buf(),
|
||||
}
|
||||
}
|
||||
|
||||
fn memory_dir(&self) -> PathBuf {
|
||||
self.workspace_dir.join("memory")
|
||||
}
|
||||
|
||||
fn core_path(&self) -> PathBuf {
|
||||
self.workspace_dir.join("MEMORY.md")
|
||||
}
|
||||
|
||||
fn daily_path(&self) -> PathBuf {
|
||||
let date = Local::now().format("%Y-%m-%d").to_string();
|
||||
self.memory_dir().join(format!("{date}.md"))
|
||||
}
|
||||
|
||||
async fn ensure_dirs(&self) -> anyhow::Result<()> {
|
||||
fs::create_dir_all(self.memory_dir()).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn append_to_file(&self, path: &Path, content: &str) -> anyhow::Result<()> {
|
||||
self.ensure_dirs().await?;
|
||||
|
||||
let existing = if path.exists() {
|
||||
fs::read_to_string(path).await.unwrap_or_default()
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let updated = if existing.is_empty() {
|
||||
let header = if path == self.core_path() {
|
||||
"# Long-Term Memory\n\n"
|
||||
} else {
|
||||
let date = Local::now().format("%Y-%m-%d").to_string();
|
||||
&format!("# Daily Log — {date}\n\n")
|
||||
};
|
||||
format!("{header}{content}\n")
|
||||
} else {
|
||||
format!("{existing}\n{content}\n")
|
||||
};
|
||||
|
||||
fs::write(path, updated).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_entries_from_file(
|
||||
path: &Path,
|
||||
content: &str,
|
||||
category: &MemoryCategory,
|
||||
) -> Vec<MemoryEntry> {
|
||||
let filename = path
|
||||
.file_stem()
|
||||
.and_then(|s| s.to_str())
|
||||
.unwrap_or("unknown");
|
||||
|
||||
content
|
||||
.lines()
|
||||
.filter(|line| {
|
||||
let trimmed = line.trim();
|
||||
!trimmed.is_empty() && !trimmed.starts_with('#')
|
||||
})
|
||||
.enumerate()
|
||||
.map(|(i, line)| {
|
||||
let trimmed = line.trim();
|
||||
let clean = trimmed.strip_prefix("- ").unwrap_or(trimmed);
|
||||
MemoryEntry {
|
||||
id: format!("{filename}:{i}"),
|
||||
key: format!("{filename}:{i}"),
|
||||
content: clean.to_string(),
|
||||
category: category.clone(),
|
||||
timestamp: filename.to_string(),
|
||||
session_id: None,
|
||||
score: None,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn read_all_entries(&self) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let mut entries = Vec::new();
|
||||
|
||||
// Read MEMORY.md (core)
|
||||
let core_path = self.core_path();
|
||||
if core_path.exists() {
|
||||
let content = fs::read_to_string(&core_path).await?;
|
||||
entries.extend(Self::parse_entries_from_file(
|
||||
&core_path,
|
||||
&content,
|
||||
&MemoryCategory::Core,
|
||||
));
|
||||
}
|
||||
|
||||
// Read daily logs
|
||||
let mem_dir = self.memory_dir();
|
||||
if mem_dir.exists() {
|
||||
let mut dir = fs::read_dir(&mem_dir).await?;
|
||||
while let Some(entry) = dir.next_entry().await? {
|
||||
let path = entry.path();
|
||||
if path.extension().and_then(|e| e.to_str()) == Some("md") {
|
||||
let content = fs::read_to_string(&path).await?;
|
||||
entries.extend(Self::parse_entries_from_file(
|
||||
&path,
|
||||
&content,
|
||||
&MemoryCategory::Daily,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
entries.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
|
||||
Ok(entries)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Memory for MarkdownMemory {
|
||||
fn name(&self) -> &str {
|
||||
"markdown"
|
||||
}
|
||||
|
||||
async fn store(
|
||||
&self,
|
||||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
) -> anyhow::Result<()> {
|
||||
let entry = format!("- **{key}**: {content}");
|
||||
let path = match category {
|
||||
MemoryCategory::Core => self.core_path(),
|
||||
_ => self.daily_path(),
|
||||
};
|
||||
self.append_to_file(&path, &entry).await
|
||||
}
|
||||
|
||||
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let all = self.read_all_entries().await?;
|
||||
let query_lower = query.to_lowercase();
|
||||
let keywords: Vec<&str> = query_lower.split_whitespace().collect();
|
||||
|
||||
let mut scored: Vec<MemoryEntry> = all
|
||||
.into_iter()
|
||||
.filter_map(|mut entry| {
|
||||
let content_lower = entry.content.to_lowercase();
|
||||
let matched = keywords
|
||||
.iter()
|
||||
.filter(|kw| content_lower.contains(**kw))
|
||||
.count();
|
||||
if matched > 0 {
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
let score = matched as f64 / keywords.len() as f64;
|
||||
entry.score = Some(score);
|
||||
Some(entry)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
scored.sort_by(|a, b| {
|
||||
b.score
|
||||
.partial_cmp(&a.score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
scored.truncate(limit);
|
||||
Ok(scored)
|
||||
}
|
||||
|
||||
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
|
||||
let all = self.read_all_entries().await?;
|
||||
Ok(all
|
||||
.into_iter()
|
||||
.find(|e| e.key == key || e.content.contains(key)))
|
||||
}
|
||||
|
||||
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let all = self.read_all_entries().await?;
|
||||
match category {
|
||||
Some(cat) => Ok(all.into_iter().filter(|e| &e.category == cat).collect()),
|
||||
None => Ok(all),
|
||||
}
|
||||
}
|
||||
|
||||
async fn forget(&self, _key: &str) -> anyhow::Result<bool> {
|
||||
// Markdown memory is append-only by design (audit trail)
|
||||
// Return false to indicate the entry wasn't removed
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn count(&self) -> anyhow::Result<usize> {
|
||||
let all = self.read_all_entries().await?;
|
||||
Ok(all.len())
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
self.workspace_dir.exists()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::fs as sync_fs;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn temp_workspace() -> (TempDir, MarkdownMemory) {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mem = MarkdownMemory::new(tmp.path());
|
||||
(tmp, mem)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn markdown_name() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
assert_eq!(mem.name(), "markdown");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn markdown_health_check() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
assert!(mem.health_check().await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn markdown_store_core() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("pref", "User likes Rust", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
let content = sync_fs::read_to_string(mem.core_path()).unwrap();
|
||||
assert!(content.contains("User likes Rust"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn markdown_store_daily() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("note", "Finished tests", MemoryCategory::Daily)
|
||||
.await
|
||||
.unwrap();
|
||||
let path = mem.daily_path();
|
||||
let content = sync_fs::read_to_string(path).unwrap();
|
||||
assert!(content.contains("Finished tests"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn markdown_recall_keyword() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("a", "Rust is fast", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "Python is slow", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c", "Rust and safety", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("Rust", 10).await.unwrap();
|
||||
assert!(results.len() >= 2);
|
||||
assert!(results
|
||||
.iter()
|
||||
.all(|r| r.content.to_lowercase().contains("rust")));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn markdown_recall_no_match() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("a", "Rust is great", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("javascript", 10).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn markdown_count() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("a", "first", MemoryCategory::Core).await.unwrap();
|
||||
mem.store("b", "second", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
let count = mem.count().await.unwrap();
|
||||
assert!(count >= 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn markdown_list_by_category() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("a", "core fact", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "daily note", MemoryCategory::Daily)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let core = mem.list(Some(&MemoryCategory::Core)).await.unwrap();
|
||||
assert!(core.iter().all(|e| e.category == MemoryCategory::Core));
|
||||
|
||||
let daily = mem.list(Some(&MemoryCategory::Daily)).await.unwrap();
|
||||
assert!(daily.iter().all(|e| e.category == MemoryCategory::Daily));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn markdown_forget_is_noop() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("a", "permanent", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
let removed = mem.forget("a").await.unwrap();
|
||||
assert!(!removed, "Markdown memory is append-only");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn markdown_empty_recall() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
let results = mem.recall("anything", 10).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn markdown_empty_count() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
assert_eq!(mem.count().await.unwrap(), 0);
|
||||
}
|
||||
}
|
||||
77
src/memory/mod.rs
Normal file
77
src/memory/mod.rs
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
pub mod markdown;
|
||||
pub mod sqlite;
|
||||
pub mod traits;
|
||||
|
||||
pub use markdown::MarkdownMemory;
|
||||
pub use sqlite::SqliteMemory;
|
||||
pub use traits::Memory;
|
||||
#[allow(unused_imports)]
|
||||
pub use traits::{MemoryCategory, MemoryEntry};
|
||||
|
||||
use crate::config::MemoryConfig;
|
||||
use std::path::Path;
|
||||
|
||||
/// Factory: create the right memory backend from config
|
||||
pub fn create_memory(
|
||||
config: &MemoryConfig,
|
||||
workspace_dir: &Path,
|
||||
) -> anyhow::Result<Box<dyn Memory>> {
|
||||
match config.backend.as_str() {
|
||||
"sqlite" => Ok(Box::new(SqliteMemory::new(workspace_dir)?)),
|
||||
"markdown" | "none" => Ok(Box::new(MarkdownMemory::new(workspace_dir))),
|
||||
other => {
|
||||
tracing::warn!("Unknown memory backend '{other}', falling back to markdown");
|
||||
Ok(Box::new(MarkdownMemory::new(workspace_dir)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn factory_sqlite() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let cfg = MemoryConfig {
|
||||
backend: "sqlite".into(),
|
||||
auto_save: true,
|
||||
};
|
||||
let mem = create_memory(&cfg, tmp.path()).unwrap();
|
||||
assert_eq!(mem.name(), "sqlite");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_markdown() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let cfg = MemoryConfig {
|
||||
backend: "markdown".into(),
|
||||
auto_save: true,
|
||||
};
|
||||
let mem = create_memory(&cfg, tmp.path()).unwrap();
|
||||
assert_eq!(mem.name(), "markdown");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_none_falls_back_to_markdown() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let cfg = MemoryConfig {
|
||||
backend: "none".into(),
|
||||
auto_save: true,
|
||||
};
|
||||
let mem = create_memory(&cfg, tmp.path()).unwrap();
|
||||
assert_eq!(mem.name(), "markdown");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_unknown_falls_back_to_markdown() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let cfg = MemoryConfig {
|
||||
backend: "redis".into(),
|
||||
auto_save: true,
|
||||
};
|
||||
let mem = create_memory(&cfg, tmp.path()).unwrap();
|
||||
assert_eq!(mem.name(), "markdown");
|
||||
}
|
||||
}
|
||||
481
src/memory/sqlite.rs
Normal file
481
src/memory/sqlite.rs
Normal file
|
|
@ -0,0 +1,481 @@
|
|||
use super::traits::{Memory, MemoryCategory, MemoryEntry};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Local;
|
||||
use rusqlite::{params, Connection};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// SQLite-backed persistent memory — the brain
|
||||
///
|
||||
/// Stores memories in a local `SQLite` database with keyword search.
|
||||
/// Zero external dependencies, works offline, survives restarts.
|
||||
pub struct SqliteMemory {
|
||||
conn: Mutex<Connection>,
|
||||
db_path: PathBuf,
|
||||
}
|
||||
|
||||
impl SqliteMemory {
|
||||
pub fn new(workspace_dir: &Path) -> anyhow::Result<Self> {
|
||||
let db_path = workspace_dir.join("memory").join("brain.db");
|
||||
|
||||
if let Some(parent) = db_path.parent() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
let conn = Connection::open(&db_path)?;
|
||||
|
||||
conn.execute_batch(
|
||||
"CREATE TABLE IF NOT EXISTS memories (
|
||||
id TEXT PRIMARY KEY,
|
||||
key TEXT NOT NULL UNIQUE,
|
||||
content TEXT NOT NULL,
|
||||
category TEXT NOT NULL DEFAULT 'core',
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_category ON memories(category);
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_key ON memories(key);",
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
conn: Mutex::new(conn),
|
||||
db_path,
|
||||
})
|
||||
}
|
||||
|
||||
fn category_to_str(cat: &MemoryCategory) -> String {
|
||||
match cat {
|
||||
MemoryCategory::Core => "core".into(),
|
||||
MemoryCategory::Daily => "daily".into(),
|
||||
MemoryCategory::Conversation => "conversation".into(),
|
||||
MemoryCategory::Custom(name) => name.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn str_to_category(s: &str) -> MemoryCategory {
|
||||
match s {
|
||||
"core" => MemoryCategory::Core,
|
||||
"daily" => MemoryCategory::Daily,
|
||||
"conversation" => MemoryCategory::Conversation,
|
||||
other => MemoryCategory::Custom(other.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Memory for SqliteMemory {
|
||||
fn name(&self) -> &str {
|
||||
"sqlite"
|
||||
}
|
||||
|
||||
async fn store(
|
||||
&self,
|
||||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
) -> anyhow::Result<()> {
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
let now = Local::now().to_rfc3339();
|
||||
let cat = Self::category_to_str(&category);
|
||||
let id = Uuid::new_v4().to_string();
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO memories (id, key, content, category, created_at, updated_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6)
|
||||
ON CONFLICT(key) DO UPDATE SET
|
||||
content = excluded.content,
|
||||
category = excluded.category,
|
||||
updated_at = excluded.updated_at",
|
||||
params![id, key, content, cat, now, now],
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
|
||||
// Keyword search: split query into words, match any
|
||||
let keywords: Vec<String> = query.split_whitespace().map(|w| format!("%{w}%")).collect();
|
||||
|
||||
if keywords.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Build dynamic WHERE clause for keyword matching
|
||||
let conditions: Vec<String> = keywords
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, _)| format!("(content LIKE ?{} OR key LIKE ?{})", i * 2 + 1, i * 2 + 2))
|
||||
.collect();
|
||||
|
||||
let where_clause = conditions.join(" OR ");
|
||||
let sql = format!(
|
||||
"SELECT id, key, content, category, created_at FROM memories
|
||||
WHERE {where_clause}
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT ?{}",
|
||||
keywords.len() * 2 + 1
|
||||
);
|
||||
|
||||
let mut stmt = conn.prepare(&sql)?;
|
||||
|
||||
// Build params: each keyword appears twice (for content and key)
|
||||
let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
|
||||
for kw in &keywords {
|
||||
param_values.push(Box::new(kw.clone()));
|
||||
param_values.push(Box::new(kw.clone()));
|
||||
}
|
||||
#[allow(clippy::cast_possible_wrap)]
|
||||
param_values.push(Box::new(limit as i64));
|
||||
|
||||
let params_ref: Vec<&dyn rusqlite::types::ToSql> =
|
||||
param_values.iter().map(AsRef::as_ref).collect();
|
||||
|
||||
let rows = stmt.query_map(params_ref.as_slice(), |row| {
|
||||
Ok(MemoryEntry {
|
||||
id: row.get(0)?,
|
||||
key: row.get(1)?,
|
||||
content: row.get(2)?,
|
||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||
timestamp: row.get(4)?,
|
||||
session_id: None,
|
||||
score: Some(1.0),
|
||||
})
|
||||
})?;
|
||||
|
||||
let mut results = Vec::new();
|
||||
for row in rows {
|
||||
results.push(row?);
|
||||
}
|
||||
|
||||
// Score by keyword match count
|
||||
let query_lower = query.to_lowercase();
|
||||
let kw_list: Vec<&str> = query_lower.split_whitespace().collect();
|
||||
for entry in &mut results {
|
||||
let content_lower = entry.content.to_lowercase();
|
||||
let matched = kw_list
|
||||
.iter()
|
||||
.filter(|kw| content_lower.contains(**kw))
|
||||
.count();
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
{
|
||||
entry.score = Some(matched as f64 / kw_list.len().max(1) as f64);
|
||||
}
|
||||
}
|
||||
|
||||
results.sort_by(|a, b| {
|
||||
b.score
|
||||
.partial_cmp(&a.score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at FROM memories WHERE key = ?1",
|
||||
)?;
|
||||
|
||||
let mut rows = stmt.query_map(params![key], |row| {
|
||||
Ok(MemoryEntry {
|
||||
id: row.get(0)?,
|
||||
key: row.get(1)?,
|
||||
content: row.get(2)?,
|
||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||
timestamp: row.get(4)?,
|
||||
session_id: None,
|
||||
score: None,
|
||||
})
|
||||
})?;
|
||||
|
||||
match rows.next() {
|
||||
Some(Ok(entry)) => Ok(Some(entry)),
|
||||
_ => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
|
||||
let mut results = Vec::new();
|
||||
|
||||
let row_mapper = |row: &rusqlite::Row| -> rusqlite::Result<MemoryEntry> {
|
||||
Ok(MemoryEntry {
|
||||
id: row.get(0)?,
|
||||
key: row.get(1)?,
|
||||
content: row.get(2)?,
|
||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||
timestamp: row.get(4)?,
|
||||
session_id: None,
|
||||
score: None,
|
||||
})
|
||||
};
|
||||
|
||||
if let Some(cat) = category {
|
||||
let cat_str = Self::category_to_str(cat);
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at FROM memories
|
||||
WHERE category = ?1 ORDER BY updated_at DESC",
|
||||
)?;
|
||||
let rows = stmt.query_map(params![cat_str], row_mapper)?;
|
||||
for row in rows {
|
||||
results.push(row?);
|
||||
}
|
||||
} else {
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at FROM memories
|
||||
ORDER BY updated_at DESC",
|
||||
)?;
|
||||
let rows = stmt.query_map([], row_mapper)?;
|
||||
for row in rows {
|
||||
results.push(row?);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
let affected = conn.execute("DELETE FROM memories WHERE key = ?1", params![key])?;
|
||||
Ok(affected > 0)
|
||||
}
|
||||
|
||||
async fn count(&self) -> anyhow::Result<usize> {
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
let count: i64 = conn.query_row("SELECT COUNT(*) FROM memories", [], |row| row.get(0))?;
|
||||
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||
Ok(count as usize)
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
self.conn
|
||||
.lock()
|
||||
.map(|c| c.execute_batch("SELECT 1").is_ok())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn temp_sqlite() -> (TempDir, SqliteMemory) {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
(tmp, mem)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_name() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
assert_eq!(mem.name(), "sqlite");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_health() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
assert!(mem.health_check().await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_store_and_get() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("user_lang", "Prefers Rust", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let entry = mem.get("user_lang").await.unwrap();
|
||||
assert!(entry.is_some());
|
||||
let entry = entry.unwrap();
|
||||
assert_eq!(entry.key, "user_lang");
|
||||
assert_eq!(entry.content, "Prefers Rust");
|
||||
assert_eq!(entry.category, MemoryCategory::Core);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_store_upsert() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("pref", "likes Rust", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("pref", "loves Rust", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let entry = mem.get("pref").await.unwrap().unwrap();
|
||||
assert_eq!(entry.content, "loves Rust");
|
||||
assert_eq!(mem.count().await.unwrap(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_recall_keyword() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "Rust is fast and safe", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "Python is interpreted", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c", "Rust has zero-cost abstractions", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("Rust", 10).await.unwrap();
|
||||
assert_eq!(results.len(), 2);
|
||||
assert!(results
|
||||
.iter()
|
||||
.all(|r| r.content.to_lowercase().contains("rust")));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_recall_multi_keyword() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "Rust is fast", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "Rust is safe and fast", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("fast safe", 10).await.unwrap();
|
||||
assert!(!results.is_empty());
|
||||
// Entry with both keywords should score higher
|
||||
assert!(results[0].content.contains("safe") && results[0].content.contains("fast"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_recall_no_match() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "Rust rocks", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("javascript", 10).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_forget() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("temp", "temporary data", MemoryCategory::Conversation)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(mem.count().await.unwrap(), 1);
|
||||
|
||||
let removed = mem.forget("temp").await.unwrap();
|
||||
assert!(removed);
|
||||
assert_eq!(mem.count().await.unwrap(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_forget_nonexistent() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
let removed = mem.forget("nope").await.unwrap();
|
||||
assert!(!removed);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_list_all() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "one", MemoryCategory::Core).await.unwrap();
|
||||
mem.store("b", "two", MemoryCategory::Daily).await.unwrap();
|
||||
mem.store("c", "three", MemoryCategory::Conversation)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let all = mem.list(None).await.unwrap();
|
||||
assert_eq!(all.len(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_list_by_category() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "core1", MemoryCategory::Core).await.unwrap();
|
||||
mem.store("b", "core2", MemoryCategory::Core).await.unwrap();
|
||||
mem.store("c", "daily1", MemoryCategory::Daily)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let core = mem.list(Some(&MemoryCategory::Core)).await.unwrap();
|
||||
assert_eq!(core.len(), 2);
|
||||
|
||||
let daily = mem.list(Some(&MemoryCategory::Daily)).await.unwrap();
|
||||
assert_eq!(daily.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_count_empty() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
assert_eq!(mem.count().await.unwrap(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_get_nonexistent() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
assert!(mem.get("nope").await.unwrap().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_db_persists() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
|
||||
{
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
mem.store("persist", "I survive restarts", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Reopen
|
||||
let mem2 = SqliteMemory::new(tmp.path()).unwrap();
|
||||
let entry = mem2.get("persist").await.unwrap();
|
||||
assert!(entry.is_some());
|
||||
assert_eq!(entry.unwrap().content, "I survive restarts");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_category_roundtrip() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
let categories = vec![
|
||||
MemoryCategory::Core,
|
||||
MemoryCategory::Daily,
|
||||
MemoryCategory::Conversation,
|
||||
MemoryCategory::Custom("project".into()),
|
||||
];
|
||||
|
||||
for (i, cat) in categories.iter().enumerate() {
|
||||
mem.store(&format!("k{i}"), &format!("v{i}"), cat.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
for (i, cat) in categories.iter().enumerate() {
|
||||
let entry = mem.get(&format!("k{i}")).await.unwrap().unwrap();
|
||||
assert_eq!(&entry.category, cat);
|
||||
}
|
||||
}
|
||||
}
|
||||
68
src/memory/traits.rs
Normal file
68
src/memory/traits.rs
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// A single memory entry
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MemoryEntry {
|
||||
pub id: String,
|
||||
pub key: String,
|
||||
pub content: String,
|
||||
pub category: MemoryCategory,
|
||||
pub timestamp: String,
|
||||
pub session_id: Option<String>,
|
||||
pub score: Option<f64>,
|
||||
}
|
||||
|
||||
/// Memory categories for organization
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum MemoryCategory {
|
||||
/// Long-term facts, preferences, decisions
|
||||
Core,
|
||||
/// Daily session logs
|
||||
Daily,
|
||||
/// Conversation context
|
||||
Conversation,
|
||||
/// User-defined custom category
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for MemoryCategory {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Core => write!(f, "core"),
|
||||
Self::Daily => write!(f, "daily"),
|
||||
Self::Conversation => write!(f, "conversation"),
|
||||
Self::Custom(name) => write!(f, "{name}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Core memory trait — implement for any persistence backend
|
||||
#[async_trait]
|
||||
pub trait Memory: Send + Sync {
|
||||
/// Backend name
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Store a memory entry
|
||||
async fn store(&self, key: &str, content: &str, category: MemoryCategory)
|
||||
-> anyhow::Result<()>;
|
||||
|
||||
/// Recall memories matching a query (keyword search)
|
||||
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>>;
|
||||
|
||||
/// Get a specific memory by key
|
||||
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>>;
|
||||
|
||||
/// List all memory keys, optionally filtered by category
|
||||
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>>;
|
||||
|
||||
/// Remove a memory by key
|
||||
async fn forget(&self, key: &str) -> anyhow::Result<bool>;
|
||||
|
||||
/// Count total memories
|
||||
async fn count(&self) -> anyhow::Result<usize>;
|
||||
|
||||
/// Health check
|
||||
async fn health_check(&self) -> bool;
|
||||
}
|
||||
119
src/observability/log.rs
Normal file
119
src/observability/log.rs
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
use super::traits::{Observer, ObserverEvent, ObserverMetric};
|
||||
use tracing::info;
|
||||
|
||||
/// Log-based observer — uses tracing, zero external deps
|
||||
pub struct LogObserver;
|
||||
|
||||
impl LogObserver {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl Observer for LogObserver {
|
||||
fn record_event(&self, event: &ObserverEvent) {
|
||||
match event {
|
||||
ObserverEvent::AgentStart { provider, model } => {
|
||||
info!(provider = %provider, model = %model, "agent.start");
|
||||
}
|
||||
ObserverEvent::AgentEnd {
|
||||
duration,
|
||||
tokens_used,
|
||||
} => {
|
||||
let ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX);
|
||||
info!(duration_ms = ms, tokens = ?tokens_used, "agent.end");
|
||||
}
|
||||
ObserverEvent::ToolCall {
|
||||
tool,
|
||||
duration,
|
||||
success,
|
||||
} => {
|
||||
let ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX);
|
||||
info!(tool = %tool, duration_ms = ms, success = success, "tool.call");
|
||||
}
|
||||
ObserverEvent::ChannelMessage { channel, direction } => {
|
||||
info!(channel = %channel, direction = %direction, "channel.message");
|
||||
}
|
||||
ObserverEvent::HeartbeatTick => {
|
||||
info!("heartbeat.tick");
|
||||
}
|
||||
ObserverEvent::Error { component, message } => {
|
||||
info!(component = %component, error = %message, "error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn record_metric(&self, metric: &ObserverMetric) {
|
||||
match metric {
|
||||
ObserverMetric::RequestLatency(d) => {
|
||||
let ms = u64::try_from(d.as_millis()).unwrap_or(u64::MAX);
|
||||
info!(latency_ms = ms, "metric.request_latency");
|
||||
}
|
||||
ObserverMetric::TokensUsed(t) => {
|
||||
info!(tokens = t, "metric.tokens_used");
|
||||
}
|
||||
ObserverMetric::ActiveSessions(s) => {
|
||||
info!(sessions = s, "metric.active_sessions");
|
||||
}
|
||||
ObserverMetric::QueueDepth(d) => {
|
||||
info!(depth = d, "metric.queue_depth");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"log"
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::time::Duration;
|
||||
|
||||
#[test]
|
||||
fn log_observer_name() {
|
||||
assert_eq!(LogObserver::new().name(), "log");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn log_observer_all_events_no_panic() {
|
||||
let obs = LogObserver::new();
|
||||
obs.record_event(&ObserverEvent::AgentStart {
|
||||
provider: "openrouter".into(),
|
||||
model: "claude-sonnet".into(),
|
||||
});
|
||||
obs.record_event(&ObserverEvent::AgentEnd {
|
||||
duration: Duration::from_millis(500),
|
||||
tokens_used: Some(100),
|
||||
});
|
||||
obs.record_event(&ObserverEvent::AgentEnd {
|
||||
duration: Duration::ZERO,
|
||||
tokens_used: None,
|
||||
});
|
||||
obs.record_event(&ObserverEvent::ToolCall {
|
||||
tool: "shell".into(),
|
||||
duration: Duration::from_millis(10),
|
||||
success: false,
|
||||
});
|
||||
obs.record_event(&ObserverEvent::ChannelMessage {
|
||||
channel: "telegram".into(),
|
||||
direction: "outbound".into(),
|
||||
});
|
||||
obs.record_event(&ObserverEvent::HeartbeatTick);
|
||||
obs.record_event(&ObserverEvent::Error {
|
||||
component: "provider".into(),
|
||||
message: "timeout".into(),
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn log_observer_all_metrics_no_panic() {
|
||||
let obs = LogObserver::new();
|
||||
obs.record_metric(&ObserverMetric::RequestLatency(Duration::from_secs(2)));
|
||||
obs.record_metric(&ObserverMetric::TokensUsed(0));
|
||||
obs.record_metric(&ObserverMetric::TokensUsed(u64::MAX));
|
||||
obs.record_metric(&ObserverMetric::ActiveSessions(1));
|
||||
obs.record_metric(&ObserverMetric::QueueDepth(999));
|
||||
}
|
||||
}
|
||||
76
src/observability/mod.rs
Normal file
76
src/observability/mod.rs
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
pub mod log;
|
||||
pub mod multi;
|
||||
pub mod noop;
|
||||
pub mod traits;
|
||||
|
||||
pub use self::log::LogObserver;
|
||||
pub use noop::NoopObserver;
|
||||
pub use traits::{Observer, ObserverEvent};
|
||||
|
||||
use crate::config::ObservabilityConfig;
|
||||
|
||||
/// Factory: create the right observer from config
|
||||
pub fn create_observer(config: &ObservabilityConfig) -> Box<dyn Observer> {
|
||||
match config.backend.as_str() {
|
||||
"log" => Box::new(LogObserver::new()),
|
||||
"none" | "noop" => Box::new(NoopObserver),
|
||||
_ => {
|
||||
tracing::warn!(
|
||||
"Unknown observability backend '{}', falling back to noop",
|
||||
config.backend
|
||||
);
|
||||
Box::new(NoopObserver)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn factory_none_returns_noop() {
|
||||
let cfg = ObservabilityConfig {
|
||||
backend: "none".into(),
|
||||
};
|
||||
assert_eq!(create_observer(&cfg).name(), "noop");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_noop_returns_noop() {
|
||||
let cfg = ObservabilityConfig {
|
||||
backend: "noop".into(),
|
||||
};
|
||||
assert_eq!(create_observer(&cfg).name(), "noop");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_log_returns_log() {
|
||||
let cfg = ObservabilityConfig {
|
||||
backend: "log".into(),
|
||||
};
|
||||
assert_eq!(create_observer(&cfg).name(), "log");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_unknown_falls_back_to_noop() {
|
||||
let cfg = ObservabilityConfig {
|
||||
backend: "prometheus".into(),
|
||||
};
|
||||
assert_eq!(create_observer(&cfg).name(), "noop");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_empty_string_falls_back_to_noop() {
|
||||
let cfg = ObservabilityConfig { backend: "".into() };
|
||||
assert_eq!(create_observer(&cfg).name(), "noop");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_garbage_falls_back_to_noop() {
|
||||
let cfg = ObservabilityConfig {
|
||||
backend: "xyzzy_garbage_123".into(),
|
||||
};
|
||||
assert_eq!(create_observer(&cfg).name(), "noop");
|
||||
}
|
||||
}
|
||||
154
src/observability/multi.rs
Normal file
154
src/observability/multi.rs
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
use super::traits::{Observer, ObserverEvent, ObserverMetric};
|
||||
|
||||
/// Combine multiple observers — fan-out events to all backends
|
||||
pub struct MultiObserver {
|
||||
observers: Vec<Box<dyn Observer>>,
|
||||
}
|
||||
|
||||
impl MultiObserver {
|
||||
pub fn new(observers: Vec<Box<dyn Observer>>) -> Self {
|
||||
Self { observers }
|
||||
}
|
||||
}
|
||||
|
||||
impl Observer for MultiObserver {
|
||||
fn record_event(&self, event: &ObserverEvent) {
|
||||
for obs in &self.observers {
|
||||
obs.record_event(event);
|
||||
}
|
||||
}
|
||||
|
||||
fn record_metric(&self, metric: &ObserverMetric) {
|
||||
for obs in &self.observers {
|
||||
obs.record_metric(metric);
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&self) {
|
||||
for obs in &self.observers {
|
||||
obs.flush();
|
||||
}
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"multi"
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Test observer that counts calls
|
||||
struct CountingObserver {
|
||||
event_count: Arc<AtomicUsize>,
|
||||
metric_count: Arc<AtomicUsize>,
|
||||
flush_count: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl CountingObserver {
|
||||
fn new(
|
||||
event_count: Arc<AtomicUsize>,
|
||||
metric_count: Arc<AtomicUsize>,
|
||||
flush_count: Arc<AtomicUsize>,
|
||||
) -> Self {
|
||||
Self {
|
||||
event_count,
|
||||
metric_count,
|
||||
flush_count,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Observer for CountingObserver {
|
||||
fn record_event(&self, _event: &ObserverEvent) {
|
||||
self.event_count.fetch_add(1, Ordering::SeqCst);
|
||||
}
|
||||
fn record_metric(&self, _metric: &ObserverMetric) {
|
||||
self.metric_count.fetch_add(1, Ordering::SeqCst);
|
||||
}
|
||||
fn flush(&self) {
|
||||
self.flush_count.fetch_add(1, Ordering::SeqCst);
|
||||
}
|
||||
fn name(&self) -> &str {
|
||||
"counting"
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multi_name() {
|
||||
let m = MultiObserver::new(vec![]);
|
||||
assert_eq!(m.name(), "multi");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multi_empty_no_panic() {
|
||||
let m = MultiObserver::new(vec![]);
|
||||
m.record_event(&ObserverEvent::HeartbeatTick);
|
||||
m.record_metric(&ObserverMetric::TokensUsed(10));
|
||||
m.flush();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multi_fans_out_events() {
|
||||
let ec1 = Arc::new(AtomicUsize::new(0));
|
||||
let mc1 = Arc::new(AtomicUsize::new(0));
|
||||
let fc1 = Arc::new(AtomicUsize::new(0));
|
||||
let ec2 = Arc::new(AtomicUsize::new(0));
|
||||
let mc2 = Arc::new(AtomicUsize::new(0));
|
||||
let fc2 = Arc::new(AtomicUsize::new(0));
|
||||
|
||||
let m = MultiObserver::new(vec![
|
||||
Box::new(CountingObserver::new(ec1.clone(), mc1.clone(), fc1.clone())),
|
||||
Box::new(CountingObserver::new(ec2.clone(), mc2.clone(), fc2.clone())),
|
||||
]);
|
||||
|
||||
m.record_event(&ObserverEvent::HeartbeatTick);
|
||||
m.record_event(&ObserverEvent::HeartbeatTick);
|
||||
m.record_event(&ObserverEvent::HeartbeatTick);
|
||||
|
||||
assert_eq!(ec1.load(Ordering::SeqCst), 3);
|
||||
assert_eq!(ec2.load(Ordering::SeqCst), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multi_fans_out_metrics() {
|
||||
let ec1 = Arc::new(AtomicUsize::new(0));
|
||||
let mc1 = Arc::new(AtomicUsize::new(0));
|
||||
let fc1 = Arc::new(AtomicUsize::new(0));
|
||||
let ec2 = Arc::new(AtomicUsize::new(0));
|
||||
let mc2 = Arc::new(AtomicUsize::new(0));
|
||||
let fc2 = Arc::new(AtomicUsize::new(0));
|
||||
|
||||
let m = MultiObserver::new(vec![
|
||||
Box::new(CountingObserver::new(ec1.clone(), mc1.clone(), fc1.clone())),
|
||||
Box::new(CountingObserver::new(ec2.clone(), mc2.clone(), fc2.clone())),
|
||||
]);
|
||||
|
||||
m.record_metric(&ObserverMetric::TokensUsed(100));
|
||||
m.record_metric(&ObserverMetric::RequestLatency(Duration::from_millis(5)));
|
||||
|
||||
assert_eq!(mc1.load(Ordering::SeqCst), 2);
|
||||
assert_eq!(mc2.load(Ordering::SeqCst), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multi_fans_out_flush() {
|
||||
let ec = Arc::new(AtomicUsize::new(0));
|
||||
let mc = Arc::new(AtomicUsize::new(0));
|
||||
let fc1 = Arc::new(AtomicUsize::new(0));
|
||||
let fc2 = Arc::new(AtomicUsize::new(0));
|
||||
|
||||
let m = MultiObserver::new(vec![
|
||||
Box::new(CountingObserver::new(ec.clone(), mc.clone(), fc1.clone())),
|
||||
Box::new(CountingObserver::new(ec.clone(), mc.clone(), fc2.clone())),
|
||||
]);
|
||||
|
||||
m.flush();
|
||||
assert_eq!(fc1.load(Ordering::SeqCst), 1);
|
||||
assert_eq!(fc2.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
}
|
||||
72
src/observability/noop.rs
Normal file
72
src/observability/noop.rs
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
use super::traits::{Observer, ObserverEvent, ObserverMetric};
|
||||
|
||||
/// Zero-overhead observer — all methods compile to nothing
|
||||
pub struct NoopObserver;
|
||||
|
||||
impl Observer for NoopObserver {
|
||||
#[inline(always)]
|
||||
fn record_event(&self, _event: &ObserverEvent) {}
|
||||
|
||||
#[inline(always)]
|
||||
fn record_metric(&self, _metric: &ObserverMetric) {}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"noop"
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::time::Duration;
|
||||
|
||||
#[test]
|
||||
fn noop_name() {
|
||||
assert_eq!(NoopObserver.name(), "noop");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn noop_record_event_does_not_panic() {
|
||||
let obs = NoopObserver;
|
||||
obs.record_event(&ObserverEvent::HeartbeatTick);
|
||||
obs.record_event(&ObserverEvent::AgentStart {
|
||||
provider: "test".into(),
|
||||
model: "test".into(),
|
||||
});
|
||||
obs.record_event(&ObserverEvent::AgentEnd {
|
||||
duration: Duration::from_millis(100),
|
||||
tokens_used: Some(42),
|
||||
});
|
||||
obs.record_event(&ObserverEvent::AgentEnd {
|
||||
duration: Duration::ZERO,
|
||||
tokens_used: None,
|
||||
});
|
||||
obs.record_event(&ObserverEvent::ToolCall {
|
||||
tool: "shell".into(),
|
||||
duration: Duration::from_secs(1),
|
||||
success: true,
|
||||
});
|
||||
obs.record_event(&ObserverEvent::ChannelMessage {
|
||||
channel: "cli".into(),
|
||||
direction: "inbound".into(),
|
||||
});
|
||||
obs.record_event(&ObserverEvent::Error {
|
||||
component: "test".into(),
|
||||
message: "boom".into(),
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn noop_record_metric_does_not_panic() {
|
||||
let obs = NoopObserver;
|
||||
obs.record_metric(&ObserverMetric::RequestLatency(Duration::from_millis(50)));
|
||||
obs.record_metric(&ObserverMetric::TokensUsed(1000));
|
||||
obs.record_metric(&ObserverMetric::ActiveSessions(5));
|
||||
obs.record_metric(&ObserverMetric::QueueDepth(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn noop_flush_does_not_panic() {
|
||||
NoopObserver.flush();
|
||||
}
|
||||
}
|
||||
52
src/observability/traits.rs
Normal file
52
src/observability/traits.rs
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
use std::time::Duration;
|
||||
|
||||
/// Events the observer can record
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ObserverEvent {
|
||||
AgentStart {
|
||||
provider: String,
|
||||
model: String,
|
||||
},
|
||||
AgentEnd {
|
||||
duration: Duration,
|
||||
tokens_used: Option<u64>,
|
||||
},
|
||||
ToolCall {
|
||||
tool: String,
|
||||
duration: Duration,
|
||||
success: bool,
|
||||
},
|
||||
ChannelMessage {
|
||||
channel: String,
|
||||
direction: String,
|
||||
},
|
||||
HeartbeatTick,
|
||||
Error {
|
||||
component: String,
|
||||
message: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// Numeric metrics
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ObserverMetric {
|
||||
RequestLatency(Duration),
|
||||
TokensUsed(u64),
|
||||
ActiveSessions(u64),
|
||||
QueueDepth(u64),
|
||||
}
|
||||
|
||||
/// Core observability trait — implement for any backend
|
||||
pub trait Observer: Send + Sync {
|
||||
/// Record a discrete event
|
||||
fn record_event(&self, event: &ObserverEvent);
|
||||
|
||||
/// Record a numeric metric
|
||||
fn record_metric(&self, metric: &ObserverMetric);
|
||||
|
||||
/// Flush any buffered data (no-op for most backends)
|
||||
fn flush(&self) {}
|
||||
|
||||
/// Human-readable name of this observer
|
||||
fn name(&self) -> &str;
|
||||
}
|
||||
3
src/onboard/mod.rs
Normal file
3
src/onboard/mod.rs
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
pub mod wizard;
|
||||
|
||||
pub use wizard::run_wizard;
|
||||
1804
src/onboard/wizard.rs
Normal file
1804
src/onboard/wizard.rs
Normal file
File diff suppressed because it is too large
Load diff
212
src/providers/anthropic.rs
Normal file
212
src/providers/anthropic.rs
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
use crate::providers::traits::Provider;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub struct AnthropicProvider {
|
||||
api_key: Option<String>,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ChatRequest {
|
||||
model: String,
|
||||
max_tokens: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
system: Option<String>,
|
||||
messages: Vec<Message>,
|
||||
temperature: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct Message {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ChatResponse {
|
||||
content: Vec<ContentBlock>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ContentBlock {
|
||||
text: String,
|
||||
}
|
||||
|
||||
impl AnthropicProvider {
|
||||
pub fn new(api_key: Option<&str>) -> Self {
|
||||
Self {
|
||||
api_key: api_key.map(ToString::to_string),
|
||||
client: Client::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for AnthropicProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"Anthropic API key not set. Set ANTHROPIC_API_KEY or edit config.toml."
|
||||
)
|
||||
})?;
|
||||
|
||||
let request = ChatRequest {
|
||||
model: model.to_string(),
|
||||
max_tokens: 4096,
|
||||
system: system_prompt.map(ToString::to_string),
|
||||
messages: vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: message.to_string(),
|
||||
}],
|
||||
temperature,
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post("https://api.anthropic.com/v1/messages")
|
||||
.header("x-api-key", api_key)
|
||||
.header("anthropic-version", "2023-06-01")
|
||||
.header("content-type", "application/json")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let error = response.text().await?;
|
||||
anyhow::bail!("Anthropic API error: {error}");
|
||||
}
|
||||
|
||||
let chat_response: ChatResponse = response.json().await?;
|
||||
|
||||
chat_response
|
||||
.content
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| c.text)
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from Anthropic"))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn creates_with_key() {
|
||||
let p = AnthropicProvider::new(Some("sk-ant-test123"));
|
||||
assert!(p.api_key.is_some());
|
||||
assert_eq!(p.api_key.as_deref(), Some("sk-ant-test123"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_without_key() {
|
||||
let p = AnthropicProvider::new(None);
|
||||
assert!(p.api_key.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_with_empty_key() {
|
||||
let p = AnthropicProvider::new(Some(""));
|
||||
assert!(p.api_key.is_some());
|
||||
assert_eq!(p.api_key.as_deref(), Some(""));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_fails_without_key() {
|
||||
let p = AnthropicProvider::new(None);
|
||||
let result = p.chat_with_system(None, "hello", "claude-3-opus", 0.7).await;
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(err.contains("API key not set"), "Expected key error, got: {err}");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_with_system_fails_without_key() {
|
||||
let p = AnthropicProvider::new(None);
|
||||
let result = p
|
||||
.chat_with_system(Some("You are ZeroClaw"), "hello", "claude-3-opus", 0.7)
|
||||
.await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_request_serializes_without_system() {
|
||||
let req = ChatRequest {
|
||||
model: "claude-3-opus".to_string(),
|
||||
max_tokens: 4096,
|
||||
system: None,
|
||||
messages: vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: "hello".to_string(),
|
||||
}],
|
||||
temperature: 0.7,
|
||||
};
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(!json.contains("system"), "system field should be skipped when None");
|
||||
assert!(json.contains("claude-3-opus"));
|
||||
assert!(json.contains("hello"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_request_serializes_with_system() {
|
||||
let req = ChatRequest {
|
||||
model: "claude-3-opus".to_string(),
|
||||
max_tokens: 4096,
|
||||
system: Some("You are ZeroClaw".to_string()),
|
||||
messages: vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: "hello".to_string(),
|
||||
}],
|
||||
temperature: 0.7,
|
||||
};
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(json.contains("\"system\":\"You are ZeroClaw\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_response_deserializes() {
|
||||
let json = r#"{"content":[{"type":"text","text":"Hello there!"}]}"#;
|
||||
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.content.len(), 1);
|
||||
assert_eq!(resp.content[0].text, "Hello there!");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_response_empty_content() {
|
||||
let json = r#"{"content":[]}"#;
|
||||
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.content.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_response_multiple_blocks() {
|
||||
let json = r#"{"content":[{"type":"text","text":"First"},{"type":"text","text":"Second"}]}"#;
|
||||
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.content.len(), 2);
|
||||
assert_eq!(resp.content[0].text, "First");
|
||||
assert_eq!(resp.content[1].text, "Second");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn temperature_range_serializes() {
|
||||
for temp in [0.0, 0.5, 1.0, 2.0] {
|
||||
let req = ChatRequest {
|
||||
model: "claude-3-opus".to_string(),
|
||||
max_tokens: 4096,
|
||||
system: None,
|
||||
messages: vec![],
|
||||
temperature: temp,
|
||||
};
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(json.contains(&format!("{temp}")));
|
||||
}
|
||||
}
|
||||
}
|
||||
245
src/providers/compatible.rs
Normal file
245
src/providers/compatible.rs
Normal file
|
|
@ -0,0 +1,245 @@
|
|||
//! Generic OpenAI-compatible provider.
|
||||
//! Most LLM APIs follow the same `/v1/chat/completions` format.
|
||||
//! This module provides a single implementation that works for all of them.
|
||||
|
||||
use crate::providers::traits::Provider;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// A provider that speaks the OpenAI-compatible chat completions API.
|
||||
/// Used by: Venice, Vercel AI Gateway, Cloudflare AI Gateway, Moonshot,
|
||||
/// Synthetic, `OpenCode` Zen, `Z.AI`, `GLM`, `MiniMax`, Bedrock, Qianfan, Groq, Mistral, `xAI`, etc.
|
||||
pub struct OpenAiCompatibleProvider {
|
||||
pub(crate) name: String,
|
||||
pub(crate) base_url: String,
|
||||
pub(crate) api_key: Option<String>,
|
||||
pub(crate) auth_header: AuthStyle,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
/// How the provider expects the API key to be sent.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AuthStyle {
|
||||
/// `Authorization: Bearer <key>`
|
||||
Bearer,
|
||||
/// `x-api-key: <key>` (used by some Chinese providers)
|
||||
XApiKey,
|
||||
/// Custom header name
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
impl OpenAiCompatibleProvider {
|
||||
pub fn new(name: &str, base_url: &str, api_key: Option<&str>, auth_style: AuthStyle) -> Self {
|
||||
Self {
|
||||
name: name.to_string(),
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
api_key: api_key.map(ToString::to_string),
|
||||
auth_header: auth_style,
|
||||
client: Client::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ChatRequest {
|
||||
model: String,
|
||||
messages: Vec<Message>,
|
||||
temperature: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct Message {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ChatResponse {
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Choice {
|
||||
message: ResponseMessage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseMessage {
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for OpenAiCompatibleProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
|
||||
self.name
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut messages = Vec::new();
|
||||
|
||||
if let Some(sys) = system_prompt {
|
||||
messages.push(Message {
|
||||
role: "system".to_string(),
|
||||
content: sys.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
messages.push(Message {
|
||||
role: "user".to_string(),
|
||||
content: message.to_string(),
|
||||
});
|
||||
|
||||
let request = ChatRequest {
|
||||
model: model.to_string(),
|
||||
messages,
|
||||
temperature,
|
||||
};
|
||||
|
||||
let url = format!("{}/v1/chat/completions", self.base_url);
|
||||
|
||||
let mut req = self.client.post(&url).json(&request);
|
||||
|
||||
match &self.auth_header {
|
||||
AuthStyle::Bearer => {
|
||||
req = req.header("Authorization", format!("Bearer {api_key}"));
|
||||
}
|
||||
AuthStyle::XApiKey => {
|
||||
req = req.header("x-api-key", api_key.as_str());
|
||||
}
|
||||
AuthStyle::Custom(header) => {
|
||||
req = req.header(header.as_str(), api_key.as_str());
|
||||
}
|
||||
}
|
||||
|
||||
let response = req.send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let error = response.text().await?;
|
||||
anyhow::bail!("{} API error: {error}", self.name);
|
||||
}
|
||||
|
||||
let chat_response: ChatResponse = response.json().await?;
|
||||
|
||||
chat_response
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| c.message.content)
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_provider(name: &str, url: &str, key: Option<&str>) -> OpenAiCompatibleProvider {
|
||||
OpenAiCompatibleProvider::new(name, url, key, AuthStyle::Bearer)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_with_key() {
|
||||
let p = make_provider("venice", "https://api.venice.ai", Some("vn-key"));
|
||||
assert_eq!(p.name, "venice");
|
||||
assert_eq!(p.base_url, "https://api.venice.ai");
|
||||
assert_eq!(p.api_key.as_deref(), Some("vn-key"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_without_key() {
|
||||
let p = make_provider("test", "https://example.com", None);
|
||||
assert!(p.api_key.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strips_trailing_slash() {
|
||||
let p = make_provider("test", "https://example.com/", None);
|
||||
assert_eq!(p.base_url, "https://example.com");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_fails_without_key() {
|
||||
let p = make_provider("Venice", "https://api.venice.ai", None);
|
||||
let result = p.chat_with_system(None, "hello", "llama-3.3-70b", 0.7).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("Venice API key not set"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_serializes_correctly() {
|
||||
let req = ChatRequest {
|
||||
model: "llama-3.3-70b".to_string(),
|
||||
messages: vec![
|
||||
Message { role: "system".to_string(), content: "You are ZeroClaw".to_string() },
|
||||
Message { role: "user".to_string(), content: "hello".to_string() },
|
||||
],
|
||||
temperature: 0.7,
|
||||
};
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(json.contains("llama-3.3-70b"));
|
||||
assert!(json.contains("system"));
|
||||
assert!(json.contains("user"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_deserializes() {
|
||||
let json = r#"{"choices":[{"message":{"content":"Hello from Venice!"}}]}"#;
|
||||
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.choices[0].message.content, "Hello from Venice!");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_empty_choices() {
|
||||
let json = r#"{"choices":[]}"#;
|
||||
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.choices.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn x_api_key_auth_style() {
|
||||
let p = OpenAiCompatibleProvider::new(
|
||||
"moonshot", "https://api.moonshot.cn", Some("ms-key"), AuthStyle::XApiKey,
|
||||
);
|
||||
assert!(matches!(p.auth_header, AuthStyle::XApiKey));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_auth_style() {
|
||||
let p = OpenAiCompatibleProvider::new(
|
||||
"custom", "https://api.example.com", Some("key"), AuthStyle::Custom("X-Custom-Key".into()),
|
||||
);
|
||||
assert!(matches!(p.auth_header, AuthStyle::Custom(_)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn all_compatible_providers_fail_without_key() {
|
||||
let providers = vec![
|
||||
make_provider("Venice", "https://api.venice.ai", None),
|
||||
make_provider("Moonshot", "https://api.moonshot.cn", None),
|
||||
make_provider("GLM", "https://open.bigmodel.cn", None),
|
||||
make_provider("MiniMax", "https://api.minimax.chat", None),
|
||||
make_provider("Groq", "https://api.groq.com/openai", None),
|
||||
make_provider("Mistral", "https://api.mistral.ai", None),
|
||||
make_provider("xAI", "https://api.x.ai", None),
|
||||
];
|
||||
|
||||
for p in providers {
|
||||
let result = p.chat_with_system(None, "test", "model", 0.7).await;
|
||||
assert!(result.is_err(), "{} should fail without key", p.name);
|
||||
assert!(
|
||||
result.unwrap_err().to_string().contains("API key not set"),
|
||||
"{} error should mention key", p.name
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
266
src/providers/mod.rs
Normal file
266
src/providers/mod.rs
Normal file
|
|
@ -0,0 +1,266 @@
|
|||
pub mod anthropic;
|
||||
pub mod compatible;
|
||||
pub mod ollama;
|
||||
pub mod openai;
|
||||
pub mod openrouter;
|
||||
pub mod traits;
|
||||
|
||||
pub use traits::Provider;
|
||||
|
||||
use compatible::{AuthStyle, OpenAiCompatibleProvider};
|
||||
|
||||
/// Factory: create the right provider from config
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<dyn Provider>> {
|
||||
match name {
|
||||
// ── Primary providers (custom implementations) ───────
|
||||
"openrouter" => Ok(Box::new(openrouter::OpenRouterProvider::new(api_key))),
|
||||
"anthropic" => Ok(Box::new(anthropic::AnthropicProvider::new(api_key))),
|
||||
"openai" => Ok(Box::new(openai::OpenAiProvider::new(api_key))),
|
||||
"ollama" => Ok(Box::new(ollama::OllamaProvider::new(
|
||||
api_key.filter(|k| !k.is_empty()),
|
||||
))),
|
||||
|
||||
// ── OpenAI-compatible providers ──────────────────────
|
||||
"venice" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"Venice", "https://api.venice.ai", api_key, AuthStyle::Bearer,
|
||||
))),
|
||||
"vercel" | "vercel-ai" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"Vercel AI Gateway", "https://api.vercel.ai", api_key, AuthStyle::Bearer,
|
||||
))),
|
||||
"cloudflare" | "cloudflare-ai" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"Cloudflare AI Gateway",
|
||||
"https://gateway.ai.cloudflare.com/v1",
|
||||
api_key,
|
||||
AuthStyle::Bearer,
|
||||
))),
|
||||
"moonshot" | "kimi" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"Moonshot", "https://api.moonshot.cn", api_key, AuthStyle::Bearer,
|
||||
))),
|
||||
"synthetic" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"Synthetic", "https://api.synthetic.com", api_key, AuthStyle::Bearer,
|
||||
))),
|
||||
"opencode" | "opencode-zen" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"OpenCode Zen", "https://api.opencode.ai", api_key, AuthStyle::Bearer,
|
||||
))),
|
||||
"zai" | "z.ai" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"Z.AI", "https://api.z.ai", api_key, AuthStyle::Bearer,
|
||||
))),
|
||||
"glm" | "zhipu" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"GLM", "https://open.bigmodel.cn/api/paas", api_key, AuthStyle::Bearer,
|
||||
))),
|
||||
"minimax" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"MiniMax", "https://api.minimax.chat", api_key, AuthStyle::Bearer,
|
||||
))),
|
||||
"bedrock" | "aws-bedrock" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"Amazon Bedrock",
|
||||
"https://bedrock-runtime.us-east-1.amazonaws.com",
|
||||
api_key,
|
||||
AuthStyle::Bearer,
|
||||
))),
|
||||
"qianfan" | "baidu" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"Qianfan", "https://aip.baidubce.com", api_key, AuthStyle::Bearer,
|
||||
))),
|
||||
|
||||
// ── Extended ecosystem (community favorites) ─────────
|
||||
"groq" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"Groq", "https://api.groq.com/openai", api_key, AuthStyle::Bearer,
|
||||
))),
|
||||
"mistral" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"Mistral", "https://api.mistral.ai", api_key, AuthStyle::Bearer,
|
||||
))),
|
||||
"xai" | "grok" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"xAI", "https://api.x.ai", api_key, AuthStyle::Bearer,
|
||||
))),
|
||||
"deepseek" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"DeepSeek", "https://api.deepseek.com", api_key, AuthStyle::Bearer,
|
||||
))),
|
||||
"together" | "together-ai" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"Together AI", "https://api.together.xyz", api_key, AuthStyle::Bearer,
|
||||
))),
|
||||
"fireworks" | "fireworks-ai" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"Fireworks AI", "https://api.fireworks.ai/inference", api_key, AuthStyle::Bearer,
|
||||
))),
|
||||
"perplexity" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"Perplexity", "https://api.perplexity.ai", api_key, AuthStyle::Bearer,
|
||||
))),
|
||||
"cohere" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"Cohere", "https://api.cohere.com/compatibility", api_key, AuthStyle::Bearer,
|
||||
))),
|
||||
|
||||
_ => anyhow::bail!(
|
||||
"Unknown provider: {name}. Run `zeroclaw integrations list -c ai` to see all available providers."
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── Primary providers ────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn factory_openrouter() {
|
||||
assert!(create_provider("openrouter", Some("sk-test")).is_ok());
|
||||
assert!(create_provider("openrouter", None).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_anthropic() {
|
||||
assert!(create_provider("anthropic", Some("sk-test")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_openai() {
|
||||
assert!(create_provider("openai", Some("sk-test")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_ollama() {
|
||||
assert!(create_provider("ollama", None).is_ok());
|
||||
}
|
||||
|
||||
// ── OpenAI-compatible providers ──────────────────────────
|
||||
|
||||
#[test]
|
||||
fn factory_venice() {
|
||||
assert!(create_provider("venice", Some("vn-key")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_vercel() {
|
||||
assert!(create_provider("vercel", Some("key")).is_ok());
|
||||
assert!(create_provider("vercel-ai", Some("key")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_cloudflare() {
|
||||
assert!(create_provider("cloudflare", Some("key")).is_ok());
|
||||
assert!(create_provider("cloudflare-ai", Some("key")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_moonshot() {
|
||||
assert!(create_provider("moonshot", Some("key")).is_ok());
|
||||
assert!(create_provider("kimi", Some("key")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_synthetic() {
|
||||
assert!(create_provider("synthetic", Some("key")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_opencode() {
|
||||
assert!(create_provider("opencode", Some("key")).is_ok());
|
||||
assert!(create_provider("opencode-zen", Some("key")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_zai() {
|
||||
assert!(create_provider("zai", Some("key")).is_ok());
|
||||
assert!(create_provider("z.ai", Some("key")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_glm() {
|
||||
assert!(create_provider("glm", Some("key")).is_ok());
|
||||
assert!(create_provider("zhipu", Some("key")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_minimax() {
|
||||
assert!(create_provider("minimax", Some("key")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_bedrock() {
|
||||
assert!(create_provider("bedrock", Some("key")).is_ok());
|
||||
assert!(create_provider("aws-bedrock", Some("key")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_qianfan() {
|
||||
assert!(create_provider("qianfan", Some("key")).is_ok());
|
||||
assert!(create_provider("baidu", Some("key")).is_ok());
|
||||
}
|
||||
|
||||
// ── Extended ecosystem ───────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn factory_groq() {
|
||||
assert!(create_provider("groq", Some("key")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_mistral() {
|
||||
assert!(create_provider("mistral", Some("key")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_xai() {
|
||||
assert!(create_provider("xai", Some("key")).is_ok());
|
||||
assert!(create_provider("grok", Some("key")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_deepseek() {
|
||||
assert!(create_provider("deepseek", Some("key")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_together() {
|
||||
assert!(create_provider("together", Some("key")).is_ok());
|
||||
assert!(create_provider("together-ai", Some("key")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_fireworks() {
|
||||
assert!(create_provider("fireworks", Some("key")).is_ok());
|
||||
assert!(create_provider("fireworks-ai", Some("key")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_perplexity() {
|
||||
assert!(create_provider("perplexity", Some("key")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_cohere() {
|
||||
assert!(create_provider("cohere", Some("key")).is_ok());
|
||||
}
|
||||
|
||||
// ── Error cases ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn factory_unknown_provider_errors() {
|
||||
let p = create_provider("nonexistent", None);
|
||||
assert!(p.is_err());
|
||||
let msg = p.err().unwrap().to_string();
|
||||
assert!(msg.contains("Unknown provider"));
|
||||
assert!(msg.contains("nonexistent"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_empty_name_errors() {
|
||||
assert!(create_provider("", None).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_all_providers_create_successfully() {
|
||||
let providers = [
|
||||
"openrouter", "anthropic", "openai", "ollama",
|
||||
"venice", "vercel", "cloudflare", "moonshot", "synthetic",
|
||||
"opencode", "zai", "glm", "minimax", "bedrock", "qianfan",
|
||||
"groq", "mistral", "xai", "deepseek", "together",
|
||||
"fireworks", "perplexity", "cohere",
|
||||
];
|
||||
for name in providers {
|
||||
assert!(
|
||||
create_provider(name, Some("test-key")).is_ok(),
|
||||
"Provider '{name}' should create successfully"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
177
src/providers/ollama.rs
Normal file
177
src/providers/ollama.rs
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
use crate::providers::traits::Provider;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub struct OllamaProvider {
|
||||
base_url: String,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ChatRequest {
|
||||
model: String,
|
||||
messages: Vec<Message>,
|
||||
stream: bool,
|
||||
options: Options,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct Message {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct Options {
|
||||
temperature: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ChatResponse {
|
||||
message: ResponseMessage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseMessage {
|
||||
content: String,
|
||||
}
|
||||
|
||||
impl OllamaProvider {
|
||||
pub fn new(base_url: Option<&str>) -> Self {
|
||||
Self {
|
||||
base_url: base_url
|
||||
.unwrap_or("http://localhost:11434")
|
||||
.trim_end_matches('/')
|
||||
.to_string(),
|
||||
client: Client::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for OllamaProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let mut messages = Vec::new();
|
||||
|
||||
if let Some(sys) = system_prompt {
|
||||
messages.push(Message {
|
||||
role: "system".to_string(),
|
||||
content: sys.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
messages.push(Message {
|
||||
role: "user".to_string(),
|
||||
content: message.to_string(),
|
||||
});
|
||||
|
||||
let request = ChatRequest {
|
||||
model: model.to_string(),
|
||||
messages,
|
||||
stream: false,
|
||||
options: Options { temperature },
|
||||
};
|
||||
|
||||
let url = format!("{}/api/chat", self.base_url);
|
||||
|
||||
let response = self.client.post(&url).json(&request).send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let error = response.text().await?;
|
||||
anyhow::bail!("Ollama error: {error}. Is Ollama running? (brew install ollama && ollama serve)");
|
||||
}
|
||||
|
||||
let chat_response: ChatResponse = response.json().await?;
|
||||
Ok(chat_response.message.content)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn default_url() {
|
||||
let p = OllamaProvider::new(None);
|
||||
assert_eq!(p.base_url, "http://localhost:11434");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_url_trailing_slash() {
|
||||
let p = OllamaProvider::new(Some("http://192.168.1.100:11434/"));
|
||||
assert_eq!(p.base_url, "http://192.168.1.100:11434");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_url_no_trailing_slash() {
|
||||
let p = OllamaProvider::new(Some("http://myserver:11434"));
|
||||
assert_eq!(p.base_url, "http://myserver:11434");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_url_uses_empty() {
|
||||
let p = OllamaProvider::new(Some(""));
|
||||
assert_eq!(p.base_url, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_serializes_with_system() {
|
||||
let req = ChatRequest {
|
||||
model: "llama3".to_string(),
|
||||
messages: vec![
|
||||
Message { role: "system".to_string(), content: "You are ZeroClaw".to_string() },
|
||||
Message { role: "user".to_string(), content: "hello".to_string() },
|
||||
],
|
||||
stream: false,
|
||||
options: Options { temperature: 0.7 },
|
||||
};
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(json.contains("\"stream\":false"));
|
||||
assert!(json.contains("llama3"));
|
||||
assert!(json.contains("system"));
|
||||
assert!(json.contains("\"temperature\":0.7"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_serializes_without_system() {
|
||||
let req = ChatRequest {
|
||||
model: "mistral".to_string(),
|
||||
messages: vec![
|
||||
Message { role: "user".to_string(), content: "test".to_string() },
|
||||
],
|
||||
stream: false,
|
||||
options: Options { temperature: 0.0 },
|
||||
};
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(!json.contains("\"role\":\"system\""));
|
||||
assert!(json.contains("mistral"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_deserializes() {
|
||||
let json = r#"{"message":{"role":"assistant","content":"Hello from Ollama!"}}"#;
|
||||
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.message.content, "Hello from Ollama!");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_with_empty_content() {
|
||||
let json = r#"{"message":{"role":"assistant","content":""}}"#;
|
||||
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.message.content.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_with_multiline() {
|
||||
let json = r#"{"message":{"role":"assistant","content":"line1\nline2\nline3"}}"#;
|
||||
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.message.content.contains("line1"));
|
||||
}
|
||||
}
|
||||
211
src/providers/openai.rs
Normal file
211
src/providers/openai.rs
Normal file
|
|
@ -0,0 +1,211 @@
|
|||
use crate::providers::traits::Provider;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub struct OpenAiProvider {
|
||||
api_key: Option<String>,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ChatRequest {
|
||||
model: String,
|
||||
messages: Vec<Message>,
|
||||
temperature: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct Message {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ChatResponse {
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Choice {
|
||||
message: ResponseMessage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseMessage {
|
||||
content: String,
|
||||
}
|
||||
|
||||
impl OpenAiProvider {
|
||||
pub fn new(api_key: Option<&str>) -> Self {
|
||||
Self {
|
||||
api_key: api_key.map(ToString::to_string),
|
||||
client: Client::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for OpenAiProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
|
||||
})?;
|
||||
|
||||
let mut messages = Vec::new();
|
||||
|
||||
if let Some(sys) = system_prompt {
|
||||
messages.push(Message {
|
||||
role: "system".to_string(),
|
||||
content: sys.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
messages.push(Message {
|
||||
role: "user".to_string(),
|
||||
content: message.to_string(),
|
||||
});
|
||||
|
||||
let request = ChatRequest {
|
||||
model: model.to_string(),
|
||||
messages,
|
||||
temperature,
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post("https://api.openai.com/v1/chat/completions")
|
||||
.header("Authorization", format!("Bearer {api_key}"))
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let error = response.text().await?;
|
||||
anyhow::bail!("OpenAI API error: {error}");
|
||||
}
|
||||
|
||||
let chat_response: ChatResponse = response.json().await?;
|
||||
|
||||
chat_response
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| c.message.content)
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn creates_with_key() {
|
||||
let p = OpenAiProvider::new(Some("sk-proj-abc123"));
|
||||
assert_eq!(p.api_key.as_deref(), Some("sk-proj-abc123"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_without_key() {
|
||||
let p = OpenAiProvider::new(None);
|
||||
assert!(p.api_key.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_with_empty_key() {
|
||||
let p = OpenAiProvider::new(Some(""));
|
||||
assert_eq!(p.api_key.as_deref(), Some(""));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_fails_without_key() {
|
||||
let p = OpenAiProvider::new(None);
|
||||
let result = p.chat_with_system(None, "hello", "gpt-4o", 0.7).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("API key not set"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_with_system_fails_without_key() {
|
||||
let p = OpenAiProvider::new(None);
|
||||
let result = p
|
||||
.chat_with_system(Some("You are ZeroClaw"), "test", "gpt-4o", 0.5)
|
||||
.await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_serializes_with_system_message() {
|
||||
let req = ChatRequest {
|
||||
model: "gpt-4o".to_string(),
|
||||
messages: vec![
|
||||
Message { role: "system".to_string(), content: "You are ZeroClaw".to_string() },
|
||||
Message { role: "user".to_string(), content: "hello".to_string() },
|
||||
],
|
||||
temperature: 0.7,
|
||||
};
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(json.contains("\"role\":\"system\""));
|
||||
assert!(json.contains("\"role\":\"user\""));
|
||||
assert!(json.contains("gpt-4o"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_serializes_without_system() {
|
||||
let req = ChatRequest {
|
||||
model: "gpt-4o".to_string(),
|
||||
messages: vec![
|
||||
Message { role: "user".to_string(), content: "hello".to_string() },
|
||||
],
|
||||
temperature: 0.0,
|
||||
};
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(!json.contains("system"));
|
||||
assert!(json.contains("\"temperature\":0.0"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_deserializes_single_choice() {
|
||||
let json = r#"{"choices":[{"message":{"content":"Hi!"}}]}"#;
|
||||
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.choices.len(), 1);
|
||||
assert_eq!(resp.choices[0].message.content, "Hi!");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_deserializes_empty_choices() {
|
||||
let json = r#"{"choices":[]}"#;
|
||||
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.choices.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_deserializes_multiple_choices() {
|
||||
let json = r#"{"choices":[{"message":{"content":"A"}},{"message":{"content":"B"}}]}"#;
|
||||
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.choices.len(), 2);
|
||||
assert_eq!(resp.choices[0].message.content, "A");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_with_unicode() {
|
||||
let json = r#"{"choices":[{"message":{"content":"こんにちは 🦀"}}]}"#;
|
||||
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.choices[0].message.content, "こんにちは 🦀");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_with_long_content() {
|
||||
let long = "x".repeat(100_000);
|
||||
let json = format!(r#"{{"choices":[{{"message":{{"content":"{long}"}}}}]}}"#);
|
||||
let resp: ChatResponse = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(resp.choices[0].message.content.len(), 100_000);
|
||||
}
|
||||
}
|
||||
107
src/providers/openrouter.rs
Normal file
107
src/providers/openrouter.rs
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
use crate::providers::traits::Provider;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub struct OpenRouterProvider {
|
||||
api_key: Option<String>,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ChatRequest {
|
||||
model: String,
|
||||
messages: Vec<Message>,
|
||||
temperature: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct Message {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ChatResponse {
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Choice {
|
||||
message: ResponseMessage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseMessage {
|
||||
content: String,
|
||||
}
|
||||
|
||||
impl OpenRouterProvider {
|
||||
pub fn new(api_key: Option<&str>) -> Self {
|
||||
Self {
|
||||
api_key: api_key.map(ToString::to_string),
|
||||
client: Client::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for OpenRouterProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let api_key = self.api_key.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
|
||||
|
||||
let mut messages = Vec::new();
|
||||
|
||||
if let Some(sys) = system_prompt {
|
||||
messages.push(Message {
|
||||
role: "system".to_string(),
|
||||
content: sys.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
messages.push(Message {
|
||||
role: "user".to_string(),
|
||||
content: message.to_string(),
|
||||
});
|
||||
|
||||
let request = ChatRequest {
|
||||
model: model.to_string(),
|
||||
messages,
|
||||
temperature,
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post("https://openrouter.ai/api/v1/chat/completions")
|
||||
.header("Authorization", format!("Bearer {api_key}"))
|
||||
.header(
|
||||
"HTTP-Referer",
|
||||
"https://github.com/theonlyhennygod/zeroclaw",
|
||||
)
|
||||
.header("X-Title", "ZeroClaw")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let error = response.text().await?;
|
||||
anyhow::bail!("OpenRouter API error: {error}");
|
||||
}
|
||||
|
||||
let chat_response: ChatResponse = response.json().await?;
|
||||
|
||||
chat_response
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| c.message.content)
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))
|
||||
}
|
||||
}
|
||||
22
src/providers/traits.rs
Normal file
22
src/providers/traits.rs
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
use async_trait::async_trait;
|
||||
|
||||
#[async_trait]
|
||||
pub trait Provider: Send + Sync {
|
||||
async fn chat(
|
||||
&self,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
self.chat_with_system(None, message, model, temperature)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String>;
|
||||
}
|
||||
71
src/runtime/mod.rs
Normal file
71
src/runtime/mod.rs
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
pub mod native;
|
||||
pub mod traits;
|
||||
|
||||
pub use native::NativeRuntime;
|
||||
pub use traits::RuntimeAdapter;
|
||||
|
||||
use crate::config::RuntimeConfig;
|
||||
|
||||
/// Factory: create the right runtime from config
|
||||
pub fn create_runtime(config: &RuntimeConfig) -> Box<dyn RuntimeAdapter> {
|
||||
match config.kind.as_str() {
|
||||
"native" | "docker" => Box::new(NativeRuntime::new()),
|
||||
"cloudflare" => {
|
||||
tracing::warn!("Cloudflare runtime not yet implemented, falling back to native");
|
||||
Box::new(NativeRuntime::new())
|
||||
}
|
||||
_ => {
|
||||
tracing::warn!("Unknown runtime '{}', falling back to native", config.kind);
|
||||
Box::new(NativeRuntime::new())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn factory_native() {
|
||||
let cfg = RuntimeConfig {
|
||||
kind: "native".into(),
|
||||
};
|
||||
let rt = create_runtime(&cfg);
|
||||
assert_eq!(rt.name(), "native");
|
||||
assert!(rt.has_shell_access());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_docker_returns_native() {
|
||||
let cfg = RuntimeConfig {
|
||||
kind: "docker".into(),
|
||||
};
|
||||
let rt = create_runtime(&cfg);
|
||||
assert_eq!(rt.name(), "native");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_cloudflare_falls_back() {
|
||||
let cfg = RuntimeConfig {
|
||||
kind: "cloudflare".into(),
|
||||
};
|
||||
let rt = create_runtime(&cfg);
|
||||
assert_eq!(rt.name(), "native");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_unknown_falls_back() {
|
||||
let cfg = RuntimeConfig {
|
||||
kind: "wasm-edge-unknown".into(),
|
||||
};
|
||||
let rt = create_runtime(&cfg);
|
||||
assert_eq!(rt.name(), "native");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_empty_falls_back() {
|
||||
let cfg = RuntimeConfig { kind: "".into() };
|
||||
let rt = create_runtime(&cfg);
|
||||
assert_eq!(rt.name(), "native");
|
||||
}
|
||||
}
|
||||
72
src/runtime/native.rs
Normal file
72
src/runtime/native.rs
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
use super::traits::RuntimeAdapter;
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Native runtime — full access, runs on Mac/Linux/Docker/Raspberry Pi
|
||||
pub struct NativeRuntime;
|
||||
|
||||
impl NativeRuntime {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl RuntimeAdapter for NativeRuntime {
|
||||
fn name(&self) -> &str {
|
||||
"native"
|
||||
}
|
||||
|
||||
fn has_shell_access(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn has_filesystem_access(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn storage_path(&self) -> PathBuf {
|
||||
directories::UserDirs::new().map_or_else(
|
||||
|| PathBuf::from(".zeroclaw"),
|
||||
|u| u.home_dir().join(".zeroclaw"),
|
||||
)
|
||||
}
|
||||
|
||||
fn supports_long_running(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn native_name() {
|
||||
assert_eq!(NativeRuntime::new().name(), "native");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn native_has_shell_access() {
|
||||
assert!(NativeRuntime::new().has_shell_access());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn native_has_filesystem_access() {
|
||||
assert!(NativeRuntime::new().has_filesystem_access());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn native_supports_long_running() {
|
||||
assert!(NativeRuntime::new().supports_long_running());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn native_memory_budget_unlimited() {
|
||||
assert_eq!(NativeRuntime::new().memory_budget(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn native_storage_path_contains_zeroclaw() {
|
||||
let path = NativeRuntime::new().storage_path();
|
||||
assert!(path.to_string_lossy().contains("zeroclaw"));
|
||||
}
|
||||
}
|
||||
25
src/runtime/traits.rs
Normal file
25
src/runtime/traits.rs
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
use std::path::PathBuf;
|
||||
|
||||
/// Runtime adapter — abstracts platform differences so the same agent
|
||||
/// code runs on native, Docker, Cloudflare Workers, Raspberry Pi, etc.
|
||||
pub trait RuntimeAdapter: Send + Sync {
|
||||
/// Human-readable runtime name
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Whether this runtime supports shell access
|
||||
fn has_shell_access(&self) -> bool;
|
||||
|
||||
/// Whether this runtime supports filesystem access
|
||||
fn has_filesystem_access(&self) -> bool;
|
||||
|
||||
/// Base storage path for this runtime
|
||||
fn storage_path(&self) -> PathBuf;
|
||||
|
||||
/// Whether long-running processes (gateway, heartbeat) are supported
|
||||
fn supports_long_running(&self) -> bool;
|
||||
|
||||
/// Maximum memory budget in bytes (0 = unlimited)
|
||||
fn memory_budget(&self) -> u64 {
|
||||
0
|
||||
}
|
||||
}
|
||||
3
src/security/mod.rs
Normal file
3
src/security/mod.rs
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
pub mod policy;
|
||||
|
||||
pub use policy::{AutonomyLevel, SecurityPolicy};
|
||||
365
src/security/policy.rs
Normal file
365
src/security/policy.rs
Normal file
|
|
@ -0,0 +1,365 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
/// How much autonomy the agent has
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum AutonomyLevel {
|
||||
/// Read-only: can observe but not act
|
||||
ReadOnly,
|
||||
/// Supervised: acts but requires approval for risky operations
|
||||
Supervised,
|
||||
/// Full: autonomous execution within policy bounds
|
||||
Full,
|
||||
}
|
||||
|
||||
impl Default for AutonomyLevel {
|
||||
fn default() -> Self {
|
||||
Self::Supervised
|
||||
}
|
||||
}
|
||||
|
||||
/// Security policy enforced on all tool executions
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SecurityPolicy {
|
||||
pub autonomy: AutonomyLevel,
|
||||
pub workspace_dir: PathBuf,
|
||||
pub workspace_only: bool,
|
||||
pub allowed_commands: Vec<String>,
|
||||
pub forbidden_paths: Vec<String>,
|
||||
pub max_actions_per_hour: u32,
|
||||
pub max_cost_per_day_cents: u32,
|
||||
}
|
||||
|
||||
impl Default for SecurityPolicy {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
autonomy: AutonomyLevel::Supervised,
|
||||
workspace_dir: PathBuf::from("."),
|
||||
workspace_only: true,
|
||||
allowed_commands: vec![
|
||||
"git".into(),
|
||||
"npm".into(),
|
||||
"cargo".into(),
|
||||
"ls".into(),
|
||||
"cat".into(),
|
||||
"grep".into(),
|
||||
"find".into(),
|
||||
"echo".into(),
|
||||
"pwd".into(),
|
||||
"wc".into(),
|
||||
"head".into(),
|
||||
"tail".into(),
|
||||
],
|
||||
forbidden_paths: vec![
|
||||
"/etc".into(),
|
||||
"/root".into(),
|
||||
"~/.ssh".into(),
|
||||
"~/.gnupg".into(),
|
||||
"/var/run".into(),
|
||||
],
|
||||
max_actions_per_hour: 20,
|
||||
max_cost_per_day_cents: 500,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SecurityPolicy {
|
||||
/// Check if a shell command is allowed
|
||||
pub fn is_command_allowed(&self, command: &str) -> bool {
|
||||
if self.autonomy == AutonomyLevel::ReadOnly {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Extract the base command (first word)
|
||||
let base_cmd = command
|
||||
.split_whitespace()
|
||||
.next()
|
||||
.unwrap_or("")
|
||||
.rsplit('/')
|
||||
.next()
|
||||
.unwrap_or("");
|
||||
|
||||
self.allowed_commands
|
||||
.iter()
|
||||
.any(|allowed| allowed == base_cmd)
|
||||
}
|
||||
|
||||
/// Check if a file path is allowed (no path traversal, within workspace)
|
||||
pub fn is_path_allowed(&self, path: &str) -> bool {
|
||||
// Block obvious traversal attempts
|
||||
if path.contains("..") {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Block absolute paths when workspace_only is set
|
||||
if self.workspace_only && Path::new(path).is_absolute() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Block forbidden paths
|
||||
for forbidden in &self.forbidden_paths {
|
||||
if path.starts_with(forbidden.as_str()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Check if autonomy level permits any action at all
|
||||
pub fn can_act(&self) -> bool {
|
||||
self.autonomy != AutonomyLevel::ReadOnly
|
||||
}
|
||||
|
||||
/// Build from config sections
|
||||
pub fn from_config(
|
||||
autonomy_config: &crate::config::AutonomyConfig,
|
||||
workspace_dir: &Path,
|
||||
) -> Self {
|
||||
Self {
|
||||
autonomy: autonomy_config.level,
|
||||
workspace_dir: workspace_dir.to_path_buf(),
|
||||
workspace_only: autonomy_config.workspace_only,
|
||||
allowed_commands: autonomy_config.allowed_commands.clone(),
|
||||
forbidden_paths: autonomy_config.forbidden_paths.clone(),
|
||||
max_actions_per_hour: autonomy_config.max_actions_per_hour,
|
||||
max_cost_per_day_cents: autonomy_config.max_cost_per_day_cents,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn default_policy() -> SecurityPolicy {
|
||||
SecurityPolicy::default()
|
||||
}
|
||||
|
||||
fn readonly_policy() -> SecurityPolicy {
|
||||
SecurityPolicy {
|
||||
autonomy: AutonomyLevel::ReadOnly,
|
||||
..SecurityPolicy::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn full_policy() -> SecurityPolicy {
|
||||
SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Full,
|
||||
..SecurityPolicy::default()
|
||||
}
|
||||
}
|
||||
|
||||
// ── AutonomyLevel ────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn autonomy_default_is_supervised() {
|
||||
assert_eq!(AutonomyLevel::default(), AutonomyLevel::Supervised);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn autonomy_serde_roundtrip() {
|
||||
let json = serde_json::to_string(&AutonomyLevel::Full).unwrap();
|
||||
assert_eq!(json, "\"full\"");
|
||||
let parsed: AutonomyLevel = serde_json::from_str("\"readonly\"").unwrap();
|
||||
assert_eq!(parsed, AutonomyLevel::ReadOnly);
|
||||
let parsed2: AutonomyLevel = serde_json::from_str("\"supervised\"").unwrap();
|
||||
assert_eq!(parsed2, AutonomyLevel::Supervised);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_act_readonly_false() {
|
||||
assert!(!readonly_policy().can_act());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_act_supervised_true() {
|
||||
assert!(default_policy().can_act());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_act_full_true() {
|
||||
assert!(full_policy().can_act());
|
||||
}
|
||||
|
||||
// ── is_command_allowed ───────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn allowed_commands_basic() {
|
||||
let p = default_policy();
|
||||
assert!(p.is_command_allowed("ls"));
|
||||
assert!(p.is_command_allowed("git status"));
|
||||
assert!(p.is_command_allowed("cargo build --release"));
|
||||
assert!(p.is_command_allowed("cat file.txt"));
|
||||
assert!(p.is_command_allowed("grep -r pattern ."));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn blocked_commands_basic() {
|
||||
let p = default_policy();
|
||||
assert!(!p.is_command_allowed("rm -rf /"));
|
||||
assert!(!p.is_command_allowed("sudo apt install"));
|
||||
assert!(!p.is_command_allowed("curl http://evil.com"));
|
||||
assert!(!p.is_command_allowed("wget http://evil.com"));
|
||||
assert!(!p.is_command_allowed("python3 exploit.py"));
|
||||
assert!(!p.is_command_allowed("node malicious.js"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn readonly_blocks_all_commands() {
|
||||
let p = readonly_policy();
|
||||
assert!(!p.is_command_allowed("ls"));
|
||||
assert!(!p.is_command_allowed("cat file.txt"));
|
||||
assert!(!p.is_command_allowed("echo hello"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn full_autonomy_still_uses_allowlist() {
|
||||
let p = full_policy();
|
||||
assert!(p.is_command_allowed("ls"));
|
||||
assert!(!p.is_command_allowed("rm -rf /"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn command_with_absolute_path_extracts_basename() {
|
||||
let p = default_policy();
|
||||
assert!(p.is_command_allowed("/usr/bin/git status"));
|
||||
assert!(p.is_command_allowed("/bin/ls -la"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_command_blocked() {
|
||||
let p = default_policy();
|
||||
assert!(!p.is_command_allowed(""));
|
||||
assert!(!p.is_command_allowed(" "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn command_with_pipes_uses_first_word() {
|
||||
let p = default_policy();
|
||||
assert!(p.is_command_allowed("ls | grep foo"));
|
||||
assert!(p.is_command_allowed("cat file.txt | wc -l"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_allowlist() {
|
||||
let p = SecurityPolicy {
|
||||
allowed_commands: vec!["docker".into(), "kubectl".into()],
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
assert!(p.is_command_allowed("docker ps"));
|
||||
assert!(p.is_command_allowed("kubectl get pods"));
|
||||
assert!(!p.is_command_allowed("ls"));
|
||||
assert!(!p.is_command_allowed("git status"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_allowlist_blocks_everything() {
|
||||
let p = SecurityPolicy {
|
||||
allowed_commands: vec![],
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
assert!(!p.is_command_allowed("ls"));
|
||||
assert!(!p.is_command_allowed("echo hello"));
|
||||
}
|
||||
|
||||
// ── is_path_allowed ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn relative_paths_allowed() {
|
||||
let p = default_policy();
|
||||
assert!(p.is_path_allowed("file.txt"));
|
||||
assert!(p.is_path_allowed("src/main.rs"));
|
||||
assert!(p.is_path_allowed("deep/nested/dir/file.txt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn path_traversal_blocked() {
|
||||
let p = default_policy();
|
||||
assert!(!p.is_path_allowed("../etc/passwd"));
|
||||
assert!(!p.is_path_allowed("../../root/.ssh/id_rsa"));
|
||||
assert!(!p.is_path_allowed("foo/../../../etc/shadow"));
|
||||
assert!(!p.is_path_allowed(".."));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn absolute_paths_blocked_when_workspace_only() {
|
||||
let p = default_policy();
|
||||
assert!(!p.is_path_allowed("/etc/passwd"));
|
||||
assert!(!p.is_path_allowed("/root/.ssh/id_rsa"));
|
||||
assert!(!p.is_path_allowed("/tmp/file.txt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn absolute_paths_allowed_when_not_workspace_only() {
|
||||
let p = SecurityPolicy {
|
||||
workspace_only: false,
|
||||
forbidden_paths: vec![],
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
assert!(p.is_path_allowed("/tmp/file.txt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forbidden_paths_blocked() {
|
||||
let p = SecurityPolicy {
|
||||
workspace_only: false,
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
assert!(!p.is_path_allowed("/etc/passwd"));
|
||||
assert!(!p.is_path_allowed("/root/.bashrc"));
|
||||
assert!(!p.is_path_allowed("~/.ssh/id_rsa"));
|
||||
assert!(!p.is_path_allowed("~/.gnupg/pubring.kbx"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_path_allowed() {
|
||||
let p = default_policy();
|
||||
assert!(p.is_path_allowed(""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dotfile_in_workspace_allowed() {
|
||||
let p = default_policy();
|
||||
assert!(p.is_path_allowed(".gitignore"));
|
||||
assert!(p.is_path_allowed(".env"));
|
||||
}
|
||||
|
||||
// ── from_config ─────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn from_config_maps_all_fields() {
|
||||
let autonomy_config = crate::config::AutonomyConfig {
|
||||
level: AutonomyLevel::Full,
|
||||
workspace_only: false,
|
||||
allowed_commands: vec!["docker".into()],
|
||||
forbidden_paths: vec!["/secret".into()],
|
||||
max_actions_per_hour: 100,
|
||||
max_cost_per_day_cents: 1000,
|
||||
};
|
||||
let workspace = PathBuf::from("/tmp/test-workspace");
|
||||
let policy = SecurityPolicy::from_config(&autonomy_config, &workspace);
|
||||
|
||||
assert_eq!(policy.autonomy, AutonomyLevel::Full);
|
||||
assert!(!policy.workspace_only);
|
||||
assert_eq!(policy.allowed_commands, vec!["docker"]);
|
||||
assert_eq!(policy.forbidden_paths, vec!["/secret"]);
|
||||
assert_eq!(policy.max_actions_per_hour, 100);
|
||||
assert_eq!(policy.max_cost_per_day_cents, 1000);
|
||||
assert_eq!(policy.workspace_dir, PathBuf::from("/tmp/test-workspace"));
|
||||
}
|
||||
|
||||
// ── Default policy ──────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn default_policy_has_sane_values() {
|
||||
let p = SecurityPolicy::default();
|
||||
assert_eq!(p.autonomy, AutonomyLevel::Supervised);
|
||||
assert!(p.workspace_only);
|
||||
assert!(!p.allowed_commands.is_empty());
|
||||
assert!(!p.forbidden_paths.is_empty());
|
||||
assert!(p.max_actions_per_hour > 0);
|
||||
assert!(p.max_cost_per_day_cents > 0);
|
||||
}
|
||||
}
|
||||
615
src/skills/mod.rs
Normal file
615
src/skills/mod.rs
Normal file
|
|
@ -0,0 +1,615 @@
|
|||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
/// A skill is a user-defined or community-built capability.
|
||||
/// Skills live in `~/.zeroclaw/workspace/skills/<name>/SKILL.md`
|
||||
/// and can include tool definitions, prompts, and automation scripts.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Skill {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub version: String,
|
||||
#[serde(default)]
|
||||
pub author: Option<String>,
|
||||
#[serde(default)]
|
||||
pub tags: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub tools: Vec<SkillTool>,
|
||||
#[serde(default)]
|
||||
pub prompts: Vec<String>,
|
||||
}
|
||||
|
||||
/// A tool defined by a skill (shell command, HTTP call, etc.)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SkillTool {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
/// "shell", "http", "script"
|
||||
pub kind: String,
|
||||
/// The command/URL/script to execute
|
||||
pub command: String,
|
||||
#[serde(default)]
|
||||
pub args: HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// Skill manifest parsed from SKILL.toml
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct SkillManifest {
|
||||
skill: SkillMeta,
|
||||
#[serde(default)]
|
||||
tools: Vec<SkillTool>,
|
||||
#[serde(default)]
|
||||
prompts: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct SkillMeta {
|
||||
name: String,
|
||||
description: String,
|
||||
#[serde(default = "default_version")]
|
||||
version: String,
|
||||
#[serde(default)]
|
||||
author: Option<String>,
|
||||
#[serde(default)]
|
||||
tags: Vec<String>,
|
||||
}
|
||||
|
||||
fn default_version() -> String {
|
||||
"0.1.0".to_string()
|
||||
}
|
||||
|
||||
/// Load all skills from the workspace skills directory
|
||||
pub fn load_skills(workspace_dir: &Path) -> Vec<Skill> {
|
||||
let skills_dir = workspace_dir.join("skills");
|
||||
if !skills_dir.exists() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut skills = Vec::new();
|
||||
|
||||
let Ok(entries) = std::fs::read_dir(&skills_dir) else {
|
||||
return skills;
|
||||
};
|
||||
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
if !path.is_dir() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Try SKILL.toml first, then SKILL.md
|
||||
let manifest_path = path.join("SKILL.toml");
|
||||
let md_path = path.join("SKILL.md");
|
||||
|
||||
if manifest_path.exists() {
|
||||
if let Ok(skill) = load_skill_toml(&manifest_path) {
|
||||
skills.push(skill);
|
||||
}
|
||||
} else if md_path.exists() {
|
||||
if let Ok(skill) = load_skill_md(&md_path, &path) {
|
||||
skills.push(skill);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
skills
|
||||
}
|
||||
|
||||
/// Load a skill from a SKILL.toml manifest
|
||||
fn load_skill_toml(path: &Path) -> Result<Skill> {
|
||||
let content = std::fs::read_to_string(path)?;
|
||||
let manifest: SkillManifest = toml::from_str(&content)?;
|
||||
|
||||
Ok(Skill {
|
||||
name: manifest.skill.name,
|
||||
description: manifest.skill.description,
|
||||
version: manifest.skill.version,
|
||||
author: manifest.skill.author,
|
||||
tags: manifest.skill.tags,
|
||||
tools: manifest.tools,
|
||||
prompts: manifest.prompts,
|
||||
})
|
||||
}
|
||||
|
||||
/// Load a skill from a SKILL.md file (simpler format)
|
||||
fn load_skill_md(path: &Path, dir: &Path) -> Result<Skill> {
|
||||
let content = std::fs::read_to_string(path)?;
|
||||
let name = dir
|
||||
.file_name()
|
||||
.and_then(|n| n.to_str())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
|
||||
// Extract description from first non-heading line
|
||||
let description = content
|
||||
.lines()
|
||||
.find(|l| !l.starts_with('#') && !l.trim().is_empty())
|
||||
.unwrap_or("No description")
|
||||
.trim()
|
||||
.to_string();
|
||||
|
||||
Ok(Skill {
|
||||
name,
|
||||
description,
|
||||
version: "0.1.0".to_string(),
|
||||
author: None,
|
||||
tags: Vec::new(),
|
||||
tools: Vec::new(),
|
||||
prompts: vec![content],
|
||||
})
|
||||
}
|
||||
|
||||
/// Build a system prompt addition from all loaded skills
|
||||
pub fn skills_to_prompt(skills: &[Skill]) -> String {
|
||||
use std::fmt::Write;
|
||||
|
||||
if skills.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
let mut prompt = String::from("\n## Active Skills\n\n");
|
||||
|
||||
for skill in skills {
|
||||
let _ = writeln!(prompt, "### {} (v{})", skill.name, skill.version);
|
||||
let _ = writeln!(prompt, "{}", skill.description);
|
||||
|
||||
if !skill.tools.is_empty() {
|
||||
prompt.push_str("Tools:\n");
|
||||
for tool in &skill.tools {
|
||||
let _ = writeln!(prompt, "- **{}**: {} ({})", tool.name, tool.description, tool.kind);
|
||||
}
|
||||
}
|
||||
|
||||
for p in &skill.prompts {
|
||||
prompt.push_str(p);
|
||||
prompt.push('\n');
|
||||
}
|
||||
|
||||
prompt.push('\n');
|
||||
}
|
||||
|
||||
prompt
|
||||
}
|
||||
|
||||
/// Get the skills directory path
|
||||
pub fn skills_dir(workspace_dir: &Path) -> PathBuf {
|
||||
workspace_dir.join("skills")
|
||||
}
|
||||
|
||||
/// Initialize the skills directory with a README
|
||||
pub fn init_skills_dir(workspace_dir: &Path) -> Result<()> {
|
||||
let dir = skills_dir(workspace_dir);
|
||||
std::fs::create_dir_all(&dir)?;
|
||||
|
||||
let readme = dir.join("README.md");
|
||||
if !readme.exists() {
|
||||
std::fs::write(
|
||||
&readme,
|
||||
"# ZeroClaw Skills\n\n\
|
||||
Each subdirectory is a skill. Create a `SKILL.toml` or `SKILL.md` file inside.\n\n\
|
||||
## SKILL.toml format\n\n\
|
||||
```toml\n\
|
||||
[skill]\n\
|
||||
name = \"my-skill\"\n\
|
||||
description = \"What this skill does\"\n\
|
||||
version = \"0.1.0\"\n\
|
||||
author = \"your-name\"\n\
|
||||
tags = [\"productivity\", \"automation\"]\n\n\
|
||||
[[tools]]\n\
|
||||
name = \"my_tool\"\n\
|
||||
description = \"What this tool does\"\n\
|
||||
kind = \"shell\"\n\
|
||||
command = \"echo hello\"\n\
|
||||
```\n\n\
|
||||
## SKILL.md format (simpler)\n\n\
|
||||
Just write a markdown file with instructions for the agent.\n\
|
||||
The agent will read it and follow the instructions.\n\n\
|
||||
## Installing community skills\n\n\
|
||||
```bash\n\
|
||||
zeroclaw skills install <github-url>\n\
|
||||
zeroclaw skills list\n\
|
||||
```\n",
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle the `skills` CLI command
|
||||
pub fn handle_command(command: super::SkillCommands, workspace_dir: &Path) -> Result<()> {
|
||||
match command {
|
||||
super::SkillCommands::List => {
|
||||
let skills = load_skills(workspace_dir);
|
||||
if skills.is_empty() {
|
||||
println!("No skills installed.");
|
||||
println!();
|
||||
println!(" Create one: mkdir -p ~/.zeroclaw/workspace/skills/my-skill");
|
||||
println!(" echo '# My Skill' > ~/.zeroclaw/workspace/skills/my-skill/SKILL.md");
|
||||
println!();
|
||||
println!(" Or install: zeroclaw skills install <github-url>");
|
||||
} else {
|
||||
println!("Installed skills ({}):", skills.len());
|
||||
println!();
|
||||
for skill in &skills {
|
||||
println!(
|
||||
" {} {} — {}",
|
||||
console::style(&skill.name).white().bold(),
|
||||
console::style(format!("v{}", skill.version)).dim(),
|
||||
skill.description
|
||||
);
|
||||
if !skill.tools.is_empty() {
|
||||
println!(
|
||||
" Tools: {}",
|
||||
skill.tools.iter().map(|t| t.name.as_str()).collect::<Vec<_>>().join(", ")
|
||||
);
|
||||
}
|
||||
if !skill.tags.is_empty() {
|
||||
println!(
|
||||
" Tags: {}",
|
||||
skill.tags.join(", ")
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
println!();
|
||||
Ok(())
|
||||
}
|
||||
super::SkillCommands::Install { source } => {
|
||||
println!("Installing skill from: {source}");
|
||||
|
||||
let skills_path = skills_dir(workspace_dir);
|
||||
std::fs::create_dir_all(&skills_path)?;
|
||||
|
||||
if source.starts_with("http") || source.contains("github.com") {
|
||||
// Git clone
|
||||
let output = std::process::Command::new("git")
|
||||
.args(["clone", "--depth", "1", &source])
|
||||
.current_dir(&skills_path)
|
||||
.output()?;
|
||||
|
||||
if output.status.success() {
|
||||
println!(" {} Skill installed successfully!", console::style("✓").green().bold());
|
||||
println!(" Restart `zeroclaw channel start` to activate.");
|
||||
} else {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
anyhow::bail!("Git clone failed: {stderr}");
|
||||
}
|
||||
} else {
|
||||
// Local path — symlink or copy
|
||||
let src = PathBuf::from(&source);
|
||||
if !src.exists() {
|
||||
anyhow::bail!("Source path does not exist: {source}");
|
||||
}
|
||||
let name = src.file_name().unwrap_or_default();
|
||||
let dest = skills_path.join(name);
|
||||
|
||||
#[cfg(unix)]
|
||||
std::os::unix::fs::symlink(&src, &dest)?;
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
// On non-unix, copy the directory
|
||||
anyhow::bail!("Symlink not supported on this platform. Copy the skill directory manually.");
|
||||
}
|
||||
|
||||
println!(" {} Skill linked: {}", console::style("✓").green().bold(), dest.display());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
super::SkillCommands::Remove { name } => {
|
||||
let skill_path = skills_dir(workspace_dir).join(&name);
|
||||
if !skill_path.exists() {
|
||||
anyhow::bail!("Skill not found: {name}");
|
||||
}
|
||||
|
||||
std::fs::remove_dir_all(&skill_path)?;
|
||||
println!(" {} Skill '{}' removed.", console::style("✓").green().bold(), name);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::fs;
|
||||
|
||||
#[test]
|
||||
fn load_empty_skills_dir() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let skills = load_skills(dir.path());
|
||||
assert!(skills.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_skill_from_toml() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let skills_dir = dir.path().join("skills");
|
||||
let skill_dir = skills_dir.join("test-skill");
|
||||
fs::create_dir_all(&skill_dir).unwrap();
|
||||
|
||||
fs::write(
|
||||
skill_dir.join("SKILL.toml"),
|
||||
r#"
|
||||
[skill]
|
||||
name = "test-skill"
|
||||
description = "A test skill"
|
||||
version = "1.0.0"
|
||||
tags = ["test"]
|
||||
|
||||
[[tools]]
|
||||
name = "hello"
|
||||
description = "Says hello"
|
||||
kind = "shell"
|
||||
command = "echo hello"
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let skills = load_skills(dir.path());
|
||||
assert_eq!(skills.len(), 1);
|
||||
assert_eq!(skills[0].name, "test-skill");
|
||||
assert_eq!(skills[0].tools.len(), 1);
|
||||
assert_eq!(skills[0].tools[0].name, "hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_skill_from_md() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let skills_dir = dir.path().join("skills");
|
||||
let skill_dir = skills_dir.join("md-skill");
|
||||
fs::create_dir_all(&skill_dir).unwrap();
|
||||
|
||||
fs::write(
|
||||
skill_dir.join("SKILL.md"),
|
||||
"# My Skill\nThis skill does cool things.\n",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let skills = load_skills(dir.path());
|
||||
assert_eq!(skills.len(), 1);
|
||||
assert_eq!(skills[0].name, "md-skill");
|
||||
assert!(skills[0].description.contains("cool things"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skills_to_prompt_empty() {
|
||||
let prompt = skills_to_prompt(&[]);
|
||||
assert!(prompt.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skills_to_prompt_with_skills() {
|
||||
let skills = vec![Skill {
|
||||
name: "test".to_string(),
|
||||
description: "A test".to_string(),
|
||||
version: "1.0.0".to_string(),
|
||||
author: None,
|
||||
tags: vec![],
|
||||
tools: vec![],
|
||||
prompts: vec!["Do the thing.".to_string()],
|
||||
}];
|
||||
let prompt = skills_to_prompt(&skills);
|
||||
assert!(prompt.contains("test"));
|
||||
assert!(prompt.contains("Do the thing"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn init_skills_creates_readme() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
init_skills_dir(dir.path()).unwrap();
|
||||
assert!(dir.path().join("skills").join("README.md").exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn init_skills_idempotent() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
init_skills_dir(dir.path()).unwrap();
|
||||
init_skills_dir(dir.path()).unwrap(); // second call should not fail
|
||||
assert!(dir.path().join("skills").join("README.md").exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_nonexistent_dir() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let fake = dir.path().join("nonexistent");
|
||||
let skills = load_skills(&fake);
|
||||
assert!(skills.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_ignores_files_in_skills_dir() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let skills_dir = dir.path().join("skills");
|
||||
fs::create_dir_all(&skills_dir).unwrap();
|
||||
// A file, not a directory — should be ignored
|
||||
fs::write(skills_dir.join("not-a-skill.txt"), "hello").unwrap();
|
||||
let skills = load_skills(dir.path());
|
||||
assert!(skills.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_ignores_dir_without_manifest() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let skills_dir = dir.path().join("skills");
|
||||
let empty_skill = skills_dir.join("empty-skill");
|
||||
fs::create_dir_all(&empty_skill).unwrap();
|
||||
// Directory exists but no SKILL.toml or SKILL.md
|
||||
let skills = load_skills(dir.path());
|
||||
assert!(skills.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_multiple_skills() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let skills_dir = dir.path().join("skills");
|
||||
|
||||
for name in ["alpha", "beta", "gamma"] {
|
||||
let skill_dir = skills_dir.join(name);
|
||||
fs::create_dir_all(&skill_dir).unwrap();
|
||||
fs::write(
|
||||
skill_dir.join("SKILL.md"),
|
||||
format!("# {name}\nSkill {name} description.\n"),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let skills = load_skills(dir.path());
|
||||
assert_eq!(skills.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn toml_skill_with_multiple_tools() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let skills_dir = dir.path().join("skills");
|
||||
let skill_dir = skills_dir.join("multi-tool");
|
||||
fs::create_dir_all(&skill_dir).unwrap();
|
||||
|
||||
fs::write(
|
||||
skill_dir.join("SKILL.toml"),
|
||||
r#"
|
||||
[skill]
|
||||
name = "multi-tool"
|
||||
description = "Has many tools"
|
||||
version = "2.0.0"
|
||||
author = "tester"
|
||||
tags = ["automation", "devops"]
|
||||
|
||||
[[tools]]
|
||||
name = "build"
|
||||
description = "Build the project"
|
||||
kind = "shell"
|
||||
command = "cargo build"
|
||||
|
||||
[[tools]]
|
||||
name = "test"
|
||||
description = "Run tests"
|
||||
kind = "shell"
|
||||
command = "cargo test"
|
||||
|
||||
[[tools]]
|
||||
name = "deploy"
|
||||
description = "Deploy via HTTP"
|
||||
kind = "http"
|
||||
command = "https://api.example.com/deploy"
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let skills = load_skills(dir.path());
|
||||
assert_eq!(skills.len(), 1);
|
||||
let s = &skills[0];
|
||||
assert_eq!(s.name, "multi-tool");
|
||||
assert_eq!(s.version, "2.0.0");
|
||||
assert_eq!(s.author.as_deref(), Some("tester"));
|
||||
assert_eq!(s.tags, vec!["automation", "devops"]);
|
||||
assert_eq!(s.tools.len(), 3);
|
||||
assert_eq!(s.tools[0].name, "build");
|
||||
assert_eq!(s.tools[1].kind, "shell");
|
||||
assert_eq!(s.tools[2].kind, "http");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn toml_skill_minimal() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let skills_dir = dir.path().join("skills");
|
||||
let skill_dir = skills_dir.join("minimal");
|
||||
fs::create_dir_all(&skill_dir).unwrap();
|
||||
|
||||
fs::write(
|
||||
skill_dir.join("SKILL.toml"),
|
||||
r#"
|
||||
[skill]
|
||||
name = "minimal"
|
||||
description = "Bare minimum"
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let skills = load_skills(dir.path());
|
||||
assert_eq!(skills.len(), 1);
|
||||
assert_eq!(skills[0].version, "0.1.0"); // default version
|
||||
assert!(skills[0].author.is_none());
|
||||
assert!(skills[0].tags.is_empty());
|
||||
assert!(skills[0].tools.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn toml_skill_invalid_syntax_skipped() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let skills_dir = dir.path().join("skills");
|
||||
let skill_dir = skills_dir.join("broken");
|
||||
fs::create_dir_all(&skill_dir).unwrap();
|
||||
|
||||
fs::write(skill_dir.join("SKILL.toml"), "this is not valid toml {{{{").unwrap();
|
||||
|
||||
let skills = load_skills(dir.path());
|
||||
assert!(skills.is_empty()); // broken skill is skipped
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn md_skill_heading_only() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let skills_dir = dir.path().join("skills");
|
||||
let skill_dir = skills_dir.join("heading-only");
|
||||
fs::create_dir_all(&skill_dir).unwrap();
|
||||
|
||||
fs::write(skill_dir.join("SKILL.md"), "# Just a Heading\n").unwrap();
|
||||
|
||||
let skills = load_skills(dir.path());
|
||||
assert_eq!(skills.len(), 1);
|
||||
assert_eq!(skills[0].description, "No description");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skills_to_prompt_includes_tools() {
|
||||
let skills = vec![Skill {
|
||||
name: "weather".to_string(),
|
||||
description: "Get weather".to_string(),
|
||||
version: "1.0.0".to_string(),
|
||||
author: None,
|
||||
tags: vec![],
|
||||
tools: vec![SkillTool {
|
||||
name: "get_weather".to_string(),
|
||||
description: "Fetch forecast".to_string(),
|
||||
kind: "shell".to_string(),
|
||||
command: "curl wttr.in".to_string(),
|
||||
args: HashMap::new(),
|
||||
}],
|
||||
prompts: vec![],
|
||||
}];
|
||||
let prompt = skills_to_prompt(&skills);
|
||||
assert!(prompt.contains("weather"));
|
||||
assert!(prompt.contains("get_weather"));
|
||||
assert!(prompt.contains("Fetch forecast"));
|
||||
assert!(prompt.contains("shell"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skills_dir_path() {
|
||||
let base = std::path::Path::new("/home/user/.zeroclaw");
|
||||
let dir = skills_dir(base);
|
||||
assert_eq!(dir, PathBuf::from("/home/user/.zeroclaw/skills"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn toml_prefers_over_md() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let skills_dir = dir.path().join("skills");
|
||||
let skill_dir = skills_dir.join("dual");
|
||||
fs::create_dir_all(&skill_dir).unwrap();
|
||||
|
||||
fs::write(
|
||||
skill_dir.join("SKILL.toml"),
|
||||
"[skill]\nname = \"from-toml\"\ndescription = \"TOML wins\"\n",
|
||||
)
|
||||
.unwrap();
|
||||
fs::write(skill_dir.join("SKILL.md"), "# From MD\nMD description\n").unwrap();
|
||||
|
||||
let skills = load_skills(dir.path());
|
||||
assert_eq!(skills.len(), 1);
|
||||
assert_eq!(skills[0].name, "from-toml"); // TOML takes priority
|
||||
}
|
||||
}
|
||||
203
src/tools/file_read.rs
Normal file
203
src/tools/file_read.rs
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
use super::traits::{Tool, ToolResult};
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Read file contents with path sandboxing
|
||||
pub struct FileReadTool {
|
||||
security: Arc<SecurityPolicy>,
|
||||
}
|
||||
|
||||
impl FileReadTool {
|
||||
pub fn new(security: Arc<SecurityPolicy>) -> Self {
|
||||
Self { security }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for FileReadTool {
|
||||
fn name(&self) -> &str {
|
||||
"file_read"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Read the contents of a file in the workspace"
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Relative path to the file within the workspace"
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let path = args
|
||||
.get("path")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'path' parameter"))?;
|
||||
|
||||
// Security check: validate path is within workspace
|
||||
if !self.security.is_path_allowed(path) {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Path not allowed by security policy: {path}")),
|
||||
});
|
||||
}
|
||||
|
||||
let full_path = self.security.workspace_dir.join(path);
|
||||
|
||||
match tokio::fs::read_to_string(&full_path).await {
|
||||
Ok(contents) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: contents,
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Failed to read file: {e}")),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::security::{AutonomyLevel, SecurityPolicy};
|
||||
|
||||
fn test_security(workspace: std::path::PathBuf) -> Arc<SecurityPolicy> {
|
||||
Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Supervised,
|
||||
workspace_dir: workspace,
|
||||
..SecurityPolicy::default()
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn file_read_name() {
|
||||
let tool = FileReadTool::new(test_security(std::env::temp_dir()));
|
||||
assert_eq!(tool.name(), "file_read");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn file_read_schema_has_path() {
|
||||
let tool = FileReadTool::new(test_security(std::env::temp_dir()));
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"]["path"].is_object());
|
||||
assert!(schema["required"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.contains(&json!("path")));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn file_read_existing_file() {
|
||||
let dir = std::env::temp_dir().join("zeroclaw_test_file_read");
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
tokio::fs::create_dir_all(&dir).await.unwrap();
|
||||
tokio::fs::write(dir.join("test.txt"), "hello world")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let tool = FileReadTool::new(test_security(dir.clone()));
|
||||
let result = tool.execute(json!({"path": "test.txt"})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output, "hello world");
|
||||
assert!(result.error.is_none());
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn file_read_nonexistent_file() {
|
||||
let dir = std::env::temp_dir().join("zeroclaw_test_file_read_missing");
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
tokio::fs::create_dir_all(&dir).await.unwrap();
|
||||
|
||||
let tool = FileReadTool::new(test_security(dir.clone()));
|
||||
let result = tool.execute(json!({"path": "nope.txt"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("Failed to read"));
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn file_read_blocks_path_traversal() {
|
||||
let dir = std::env::temp_dir().join("zeroclaw_test_file_read_traversal");
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
tokio::fs::create_dir_all(&dir).await.unwrap();
|
||||
|
||||
let tool = FileReadTool::new(test_security(dir.clone()));
|
||||
let result = tool
|
||||
.execute(json!({"path": "../../../etc/passwd"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("not allowed"));
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn file_read_blocks_absolute_path() {
|
||||
let tool = FileReadTool::new(test_security(std::env::temp_dir()));
|
||||
let result = tool.execute(json!({"path": "/etc/passwd"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("not allowed"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn file_read_missing_path_param() {
|
||||
let tool = FileReadTool::new(test_security(std::env::temp_dir()));
|
||||
let result = tool.execute(json!({})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn file_read_empty_file() {
|
||||
let dir = std::env::temp_dir().join("zeroclaw_test_file_read_empty");
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
tokio::fs::create_dir_all(&dir).await.unwrap();
|
||||
tokio::fs::write(dir.join("empty.txt"), "").await.unwrap();
|
||||
|
||||
let tool = FileReadTool::new(test_security(dir.clone()));
|
||||
let result = tool.execute(json!({"path": "empty.txt"})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output, "");
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn file_read_nested_path() {
|
||||
let dir = std::env::temp_dir().join("zeroclaw_test_file_read_nested");
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
tokio::fs::create_dir_all(dir.join("sub/dir"))
|
||||
.await
|
||||
.unwrap();
|
||||
tokio::fs::write(dir.join("sub/dir/deep.txt"), "deep content")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let tool = FileReadTool::new(test_security(dir.clone()));
|
||||
let result = tool
|
||||
.execute(json!({"path": "sub/dir/deep.txt"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output, "deep content");
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
}
|
||||
242
src/tools/file_write.rs
Normal file
242
src/tools/file_write.rs
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
use super::traits::{Tool, ToolResult};
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Write file contents with path sandboxing
|
||||
pub struct FileWriteTool {
|
||||
security: Arc<SecurityPolicy>,
|
||||
}
|
||||
|
||||
impl FileWriteTool {
|
||||
pub fn new(security: Arc<SecurityPolicy>) -> Self {
|
||||
Self { security }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for FileWriteTool {
|
||||
fn name(&self) -> &str {
|
||||
"file_write"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Write contents to a file in the workspace"
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Relative path to the file within the workspace"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Content to write to the file"
|
||||
}
|
||||
},
|
||||
"required": ["path", "content"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let path = args
|
||||
.get("path")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'path' parameter"))?;
|
||||
|
||||
let content = args
|
||||
.get("content")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'content' parameter"))?;
|
||||
|
||||
// Security check: validate path is within workspace
|
||||
if !self.security.is_path_allowed(path) {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Path not allowed by security policy: {path}")),
|
||||
});
|
||||
}
|
||||
|
||||
let full_path = self.security.workspace_dir.join(path);
|
||||
|
||||
// Ensure parent directory exists
|
||||
if let Some(parent) = full_path.parent() {
|
||||
tokio::fs::create_dir_all(parent).await?;
|
||||
}
|
||||
|
||||
match tokio::fs::write(&full_path, content).await {
|
||||
Ok(()) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Written {} bytes to {path}", content.len()),
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Failed to write file: {e}")),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::security::{AutonomyLevel, SecurityPolicy};
|
||||
|
||||
fn test_security(workspace: std::path::PathBuf) -> Arc<SecurityPolicy> {
|
||||
Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Supervised,
|
||||
workspace_dir: workspace,
|
||||
..SecurityPolicy::default()
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn file_write_name() {
|
||||
let tool = FileWriteTool::new(test_security(std::env::temp_dir()));
|
||||
assert_eq!(tool.name(), "file_write");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn file_write_schema_has_path_and_content() {
|
||||
let tool = FileWriteTool::new(test_security(std::env::temp_dir()));
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"]["path"].is_object());
|
||||
assert!(schema["properties"]["content"].is_object());
|
||||
let required = schema["required"].as_array().unwrap();
|
||||
assert!(required.contains(&json!("path")));
|
||||
assert!(required.contains(&json!("content")));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn file_write_creates_file() {
|
||||
let dir = std::env::temp_dir().join("zeroclaw_test_file_write");
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
tokio::fs::create_dir_all(&dir).await.unwrap();
|
||||
|
||||
let tool = FileWriteTool::new(test_security(dir.clone()));
|
||||
let result = tool
|
||||
.execute(json!({"path": "out.txt", "content": "written!"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("8 bytes"));
|
||||
|
||||
let content = tokio::fs::read_to_string(dir.join("out.txt"))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(content, "written!");
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn file_write_creates_parent_dirs() {
|
||||
let dir = std::env::temp_dir().join("zeroclaw_test_file_write_nested");
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
tokio::fs::create_dir_all(&dir).await.unwrap();
|
||||
|
||||
let tool = FileWriteTool::new(test_security(dir.clone()));
|
||||
let result = tool
|
||||
.execute(json!({"path": "a/b/c/deep.txt", "content": "deep"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
|
||||
let content = tokio::fs::read_to_string(dir.join("a/b/c/deep.txt"))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(content, "deep");
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn file_write_overwrites_existing() {
|
||||
let dir = std::env::temp_dir().join("zeroclaw_test_file_write_overwrite");
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
tokio::fs::create_dir_all(&dir).await.unwrap();
|
||||
tokio::fs::write(dir.join("exist.txt"), "old")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let tool = FileWriteTool::new(test_security(dir.clone()));
|
||||
let result = tool
|
||||
.execute(json!({"path": "exist.txt", "content": "new"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
|
||||
let content = tokio::fs::read_to_string(dir.join("exist.txt"))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(content, "new");
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn file_write_blocks_path_traversal() {
|
||||
let dir = std::env::temp_dir().join("zeroclaw_test_file_write_traversal");
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
tokio::fs::create_dir_all(&dir).await.unwrap();
|
||||
|
||||
let tool = FileWriteTool::new(test_security(dir.clone()));
|
||||
let result = tool
|
||||
.execute(json!({"path": "../../etc/evil", "content": "bad"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("not allowed"));
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn file_write_blocks_absolute_path() {
|
||||
let tool = FileWriteTool::new(test_security(std::env::temp_dir()));
|
||||
let result = tool
|
||||
.execute(json!({"path": "/etc/evil", "content": "bad"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("not allowed"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn file_write_missing_path_param() {
|
||||
let tool = FileWriteTool::new(test_security(std::env::temp_dir()));
|
||||
let result = tool.execute(json!({"content": "data"})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn file_write_missing_content_param() {
|
||||
let tool = FileWriteTool::new(test_security(std::env::temp_dir()));
|
||||
let result = tool.execute(json!({"path": "file.txt"})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn file_write_empty_content() {
|
||||
let dir = std::env::temp_dir().join("zeroclaw_test_file_write_empty");
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
tokio::fs::create_dir_all(&dir).await.unwrap();
|
||||
|
||||
let tool = FileWriteTool::new(test_security(dir.clone()));
|
||||
let result = tool
|
||||
.execute(json!({"path": "empty.txt", "content": ""}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("0 bytes"));
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
}
|
||||
118
src/tools/memory_forget.rs
Normal file
118
src/tools/memory_forget.rs
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
use super::traits::{Tool, ToolResult};
|
||||
use crate::memory::Memory;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Let the agent forget/delete a memory entry
|
||||
pub struct MemoryForgetTool {
|
||||
memory: Arc<dyn Memory>,
|
||||
}
|
||||
|
||||
impl MemoryForgetTool {
|
||||
pub fn new(memory: Arc<dyn Memory>) -> Self {
|
||||
Self { memory }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for MemoryForgetTool {
|
||||
fn name(&self) -> &str {
|
||||
"memory_forget"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Remove a memory by key. Use to delete outdated facts or sensitive data. Returns whether the memory was found and removed."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": "The key of the memory to forget"
|
||||
}
|
||||
},
|
||||
"required": ["key"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let key = args
|
||||
.get("key")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'key' parameter"))?;
|
||||
|
||||
match self.memory.forget(key).await {
|
||||
Ok(true) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Forgot memory: {key}"),
|
||||
error: None,
|
||||
}),
|
||||
Ok(false) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("No memory found with key: {key}"),
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Failed to forget memory: {e}")),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::{MemoryCategory, SqliteMemory};
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn test_mem() -> (TempDir, Arc<dyn Memory>) {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
(tmp, Arc::new(mem))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn name_and_schema() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
let tool = MemoryForgetTool::new(mem);
|
||||
assert_eq!(tool.name(), "memory_forget");
|
||||
assert!(tool.parameters_schema()["properties"]["key"].is_object());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn forget_existing() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
mem.store("temp", "temporary", MemoryCategory::Conversation)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let tool = MemoryForgetTool::new(mem.clone());
|
||||
let result = tool.execute(json!({"key": "temp"})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Forgot"));
|
||||
|
||||
assert!(mem.get("temp").await.unwrap().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn forget_nonexistent() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
let tool = MemoryForgetTool::new(mem);
|
||||
let result = tool.execute(json!({"key": "nope"})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("No memory found"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn forget_missing_key() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
let tool = MemoryForgetTool::new(mem);
|
||||
let result = tool.execute(json!({})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
163
src/tools/memory_recall.rs
Normal file
163
src/tools/memory_recall.rs
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
use super::traits::{Tool, ToolResult};
|
||||
use crate::memory::Memory;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::fmt::Write;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Let the agent search its own memory
|
||||
pub struct MemoryRecallTool {
|
||||
memory: Arc<dyn Memory>,
|
||||
}
|
||||
|
||||
impl MemoryRecallTool {
|
||||
pub fn new(memory: Arc<dyn Memory>) -> Self {
|
||||
Self { memory }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for MemoryRecallTool {
|
||||
fn name(&self) -> &str {
|
||||
"memory_recall"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Search long-term memory for relevant facts, preferences, or context. Returns scored results ranked by relevance."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Keywords or phrase to search for in memory"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max results to return (default: 5)"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let query = args
|
||||
.get("query")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'query' parameter"))?;
|
||||
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let limit = args
|
||||
.get("limit")
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.map_or(5, |v| v as usize);
|
||||
|
||||
match self.memory.recall(query, limit).await {
|
||||
Ok(entries) if entries.is_empty() => Ok(ToolResult {
|
||||
success: true,
|
||||
output: "No memories found matching that query.".into(),
|
||||
error: None,
|
||||
}),
|
||||
Ok(entries) => {
|
||||
let mut output = format!("Found {} memories:\n", entries.len());
|
||||
for entry in &entries {
|
||||
let score = entry.score.map_or_else(String::new, |s| format!(" [{s:.0}%]"));
|
||||
let _ = writeln!(
|
||||
output,
|
||||
"- [{}] {}: {}{score}",
|
||||
entry.category, entry.key, entry.content
|
||||
);
|
||||
}
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Memory recall failed: {e}")),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::{MemoryCategory, SqliteMemory};
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn seeded_mem() -> (TempDir, Arc<dyn Memory>) {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
(tmp, Arc::new(mem))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_empty() {
|
||||
let (_tmp, mem) = seeded_mem();
|
||||
let tool = MemoryRecallTool::new(mem);
|
||||
let result = tool
|
||||
.execute(json!({"query": "anything"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("No memories found"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_finds_match() {
|
||||
let (_tmp, mem) = seeded_mem();
|
||||
mem.store("lang", "User prefers Rust", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("tz", "Timezone is EST", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let tool = MemoryRecallTool::new(mem);
|
||||
let result = tool.execute(json!({"query": "Rust"})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Rust"));
|
||||
assert!(result.output.contains("Found 1"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_respects_limit() {
|
||||
let (_tmp, mem) = seeded_mem();
|
||||
for i in 0..10 {
|
||||
mem.store(&format!("k{i}"), &format!("Rust fact {i}"), MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let tool = MemoryRecallTool::new(mem);
|
||||
let result = tool
|
||||
.execute(json!({"query": "Rust", "limit": 3}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Found 3"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_missing_query() {
|
||||
let (_tmp, mem) = seeded_mem();
|
||||
let tool = MemoryRecallTool::new(mem);
|
||||
let result = tool.execute(json!({})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn name_and_schema() {
|
||||
let (_tmp, mem) = seeded_mem();
|
||||
let tool = MemoryRecallTool::new(mem);
|
||||
assert_eq!(tool.name(), "memory_recall");
|
||||
assert!(tool.parameters_schema()["properties"]["query"].is_object());
|
||||
}
|
||||
}
|
||||
146
src/tools/memory_store.rs
Normal file
146
src/tools/memory_store.rs
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
use super::traits::{Tool, ToolResult};
|
||||
use crate::memory::{Memory, MemoryCategory};
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Let the agent store memories — its own brain writes
|
||||
pub struct MemoryStoreTool {
|
||||
memory: Arc<dyn Memory>,
|
||||
}
|
||||
|
||||
impl MemoryStoreTool {
|
||||
pub fn new(memory: Arc<dyn Memory>) -> Self {
|
||||
Self { memory }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for MemoryStoreTool {
|
||||
fn name(&self) -> &str {
|
||||
"memory_store"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Store a fact, preference, or note in long-term memory. Use category 'core' for permanent facts, 'daily' for session notes, 'conversation' for chat context."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": "Unique key for this memory (e.g. 'user_lang', 'project_stack')"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The information to remember"
|
||||
},
|
||||
"category": {
|
||||
"type": "string",
|
||||
"enum": ["core", "daily", "conversation"],
|
||||
"description": "Memory category: core (permanent), daily (session), conversation (chat)"
|
||||
}
|
||||
},
|
||||
"required": ["key", "content"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let key = args
|
||||
.get("key")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'key' parameter"))?;
|
||||
|
||||
let content = args
|
||||
.get("content")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'content' parameter"))?;
|
||||
|
||||
let category = match args.get("category").and_then(|v| v.as_str()) {
|
||||
Some("daily") => MemoryCategory::Daily,
|
||||
Some("conversation") => MemoryCategory::Conversation,
|
||||
_ => MemoryCategory::Core,
|
||||
};
|
||||
|
||||
match self.memory.store(key, content, category).await {
|
||||
Ok(()) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Stored memory: {key}"),
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Failed to store memory: {e}")),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::SqliteMemory;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn test_mem() -> (TempDir, Arc<dyn Memory>) {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
(tmp, Arc::new(mem))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn name_and_schema() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
let tool = MemoryStoreTool::new(mem);
|
||||
assert_eq!(tool.name(), "memory_store");
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"]["key"].is_object());
|
||||
assert!(schema["properties"]["content"].is_object());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn store_core() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
let tool = MemoryStoreTool::new(mem.clone());
|
||||
let result = tool
|
||||
.execute(json!({"key": "lang", "content": "Prefers Rust"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("lang"));
|
||||
|
||||
let entry = mem.get("lang").await.unwrap();
|
||||
assert!(entry.is_some());
|
||||
assert_eq!(entry.unwrap().content, "Prefers Rust");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn store_with_category() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
let tool = MemoryStoreTool::new(mem.clone());
|
||||
let result = tool
|
||||
.execute(json!({"key": "note", "content": "Fixed bug", "category": "daily"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn store_missing_key() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
let tool = MemoryStoreTool::new(mem);
|
||||
let result = tool.execute(json!({"content": "no key"})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn store_missing_content() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
let tool = MemoryStoreTool::new(mem);
|
||||
let result = tool.execute(json!({"key": "no_content"})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
189
src/tools/mod.rs
Normal file
189
src/tools/mod.rs
Normal file
|
|
@ -0,0 +1,189 @@
|
|||
pub mod file_read;
|
||||
pub mod file_write;
|
||||
pub mod memory_forget;
|
||||
pub mod memory_recall;
|
||||
pub mod memory_store;
|
||||
pub mod shell;
|
||||
pub mod traits;
|
||||
|
||||
pub use file_read::FileReadTool;
|
||||
pub use file_write::FileWriteTool;
|
||||
pub use memory_forget::MemoryForgetTool;
|
||||
pub use memory_recall::MemoryRecallTool;
|
||||
pub use memory_store::MemoryStoreTool;
|
||||
pub use shell::ShellTool;
|
||||
pub use traits::Tool;
|
||||
#[allow(unused_imports)]
|
||||
pub use traits::{ToolResult, ToolSpec};
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::memory::Memory;
|
||||
use crate::security::SecurityPolicy;
|
||||
use anyhow::Result;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Create the default tool registry
|
||||
pub fn default_tools(security: Arc<SecurityPolicy>) -> Vec<Box<dyn Tool>> {
|
||||
vec![
|
||||
Box::new(ShellTool::new(security.clone())),
|
||||
Box::new(FileReadTool::new(security.clone())),
|
||||
Box::new(FileWriteTool::new(security)),
|
||||
]
|
||||
}
|
||||
|
||||
/// Create full tool registry including memory tools
|
||||
pub fn all_tools(
|
||||
security: Arc<SecurityPolicy>,
|
||||
memory: Arc<dyn Memory>,
|
||||
) -> Vec<Box<dyn Tool>> {
|
||||
vec![
|
||||
Box::new(ShellTool::new(security.clone())),
|
||||
Box::new(FileReadTool::new(security.clone())),
|
||||
Box::new(FileWriteTool::new(security)),
|
||||
Box::new(MemoryStoreTool::new(memory.clone())),
|
||||
Box::new(MemoryRecallTool::new(memory.clone())),
|
||||
Box::new(MemoryForgetTool::new(memory)),
|
||||
]
|
||||
}
|
||||
|
||||
pub async fn handle_command(command: super::ToolCommands, config: Config) -> Result<()> {
|
||||
let security = Arc::new(SecurityPolicy {
|
||||
workspace_dir: config.workspace_dir.clone(),
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let mem: Arc<dyn Memory> =
|
||||
Arc::from(crate::memory::create_memory(&config.memory, &config.workspace_dir)?);
|
||||
let tools_list = all_tools(security, mem);
|
||||
|
||||
match command {
|
||||
super::ToolCommands::List => {
|
||||
println!("Available tools ({}):", tools_list.len());
|
||||
for tool in &tools_list {
|
||||
println!(" - {}: {}", tool.name(), tool.description());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
super::ToolCommands::Test { tool, args } => {
|
||||
let matched = tools_list.iter().find(|t| t.name() == tool);
|
||||
match matched {
|
||||
Some(t) => {
|
||||
let parsed: serde_json::Value = serde_json::from_str(&args)?;
|
||||
let result = t.execute(parsed).await?;
|
||||
println!("Success: {}", result.success);
|
||||
println!("Output: {}", result.output);
|
||||
if let Some(err) = result.error {
|
||||
println!("Error: {err}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
None => anyhow::bail!("Unknown tool: {tool}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn default_tools_has_three() {
|
||||
let security = Arc::new(SecurityPolicy::default());
|
||||
let tools = default_tools(security);
|
||||
assert_eq!(tools.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_tools_names() {
|
||||
let security = Arc::new(SecurityPolicy::default());
|
||||
let tools = default_tools(security);
|
||||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||
assert!(names.contains(&"shell"));
|
||||
assert!(names.contains(&"file_read"));
|
||||
assert!(names.contains(&"file_write"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_tools_all_have_descriptions() {
|
||||
let security = Arc::new(SecurityPolicy::default());
|
||||
let tools = default_tools(security);
|
||||
for tool in &tools {
|
||||
assert!(
|
||||
!tool.description().is_empty(),
|
||||
"Tool {} has empty description",
|
||||
tool.name()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_tools_all_have_schemas() {
|
||||
let security = Arc::new(SecurityPolicy::default());
|
||||
let tools = default_tools(security);
|
||||
for tool in &tools {
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(
|
||||
schema.is_object(),
|
||||
"Tool {} schema is not an object",
|
||||
tool.name()
|
||||
);
|
||||
assert!(
|
||||
schema["properties"].is_object(),
|
||||
"Tool {} schema has no properties",
|
||||
tool.name()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_spec_generation() {
|
||||
let security = Arc::new(SecurityPolicy::default());
|
||||
let tools = default_tools(security);
|
||||
for tool in &tools {
|
||||
let spec = tool.spec();
|
||||
assert_eq!(spec.name, tool.name());
|
||||
assert_eq!(spec.description, tool.description());
|
||||
assert!(spec.parameters.is_object());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_result_serde() {
|
||||
let result = ToolResult {
|
||||
success: true,
|
||||
output: "hello".into(),
|
||||
error: None,
|
||||
};
|
||||
let json = serde_json::to_string(&result).unwrap();
|
||||
let parsed: ToolResult = serde_json::from_str(&json).unwrap();
|
||||
assert!(parsed.success);
|
||||
assert_eq!(parsed.output, "hello");
|
||||
assert!(parsed.error.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_result_with_error_serde() {
|
||||
let result = ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("boom".into()),
|
||||
};
|
||||
let json = serde_json::to_string(&result).unwrap();
|
||||
let parsed: ToolResult = serde_json::from_str(&json).unwrap();
|
||||
assert!(!parsed.success);
|
||||
assert_eq!(parsed.error.as_deref(), Some("boom"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_spec_serde() {
|
||||
let spec = ToolSpec {
|
||||
name: "test".into(),
|
||||
description: "A test tool".into(),
|
||||
parameters: serde_json::json!({"type": "object"}),
|
||||
};
|
||||
let json = serde_json::to_string(&spec).unwrap();
|
||||
let parsed: ToolSpec = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.name, "test");
|
||||
assert_eq!(parsed.description, "A test tool");
|
||||
}
|
||||
}
|
||||
166
src/tools/shell.rs
Normal file
166
src/tools/shell.rs
Normal file
|
|
@ -0,0 +1,166 @@
|
|||
use super::traits::{Tool, ToolResult};
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Shell command execution tool with sandboxing
|
||||
pub struct ShellTool {
|
||||
security: Arc<SecurityPolicy>,
|
||||
}
|
||||
|
||||
impl ShellTool {
|
||||
pub fn new(security: Arc<SecurityPolicy>) -> Self {
|
||||
Self { security }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for ShellTool {
|
||||
fn name(&self) -> &str {
|
||||
"shell"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Execute a shell command in the workspace directory"
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The shell command to execute"
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let command = args
|
||||
.get("command")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'command' parameter"))?;
|
||||
|
||||
// Security check: validate command against allowlist
|
||||
if !self.security.is_command_allowed(command) {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Command not allowed by security policy: {command}")),
|
||||
});
|
||||
}
|
||||
|
||||
let output = tokio::process::Command::new("sh")
|
||||
.arg("-c")
|
||||
.arg(command)
|
||||
.current_dir(&self.security.workspace_dir)
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
|
||||
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
|
||||
|
||||
Ok(ToolResult {
|
||||
success: output.status.success(),
|
||||
output: stdout,
|
||||
error: if stderr.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(stderr)
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::security::{AutonomyLevel, SecurityPolicy};
|
||||
|
||||
fn test_security(autonomy: AutonomyLevel) -> Arc<SecurityPolicy> {
|
||||
Arc::new(SecurityPolicy {
|
||||
autonomy,
|
||||
workspace_dir: std::env::temp_dir(),
|
||||
..SecurityPolicy::default()
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn shell_tool_name() {
|
||||
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised));
|
||||
assert_eq!(tool.name(), "shell");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn shell_tool_description() {
|
||||
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised));
|
||||
assert!(!tool.description().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn shell_tool_schema_has_command() {
|
||||
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised));
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"]["command"].is_object());
|
||||
assert!(schema["required"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.contains(&json!("command")));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shell_executes_allowed_command() {
|
||||
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised));
|
||||
let result = tool
|
||||
.execute(json!({"command": "echo hello"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.trim().contains("hello"));
|
||||
assert!(result.error.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shell_blocks_disallowed_command() {
|
||||
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised));
|
||||
let result = tool.execute(json!({"command": "rm -rf /"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("not allowed"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shell_blocks_readonly() {
|
||||
let tool = ShellTool::new(test_security(AutonomyLevel::ReadOnly));
|
||||
let result = tool.execute(json!({"command": "ls"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("not allowed"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shell_missing_command_param() {
|
||||
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised));
|
||||
let result = tool.execute(json!({})).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("command"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shell_wrong_type_param() {
|
||||
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised));
|
||||
let result = tool.execute(json!({"command": 123})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn shell_captures_exit_code() {
|
||||
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised));
|
||||
let result = tool
|
||||
.execute(json!({"command": "ls /nonexistent_dir_xyz"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
}
|
||||
}
|
||||
43
src/tools/traits.rs
Normal file
43
src/tools/traits.rs
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Result of a tool execution
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolResult {
|
||||
pub success: bool,
|
||||
pub output: String,
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// Description of a tool for the LLM
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolSpec {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
/// Core tool trait — implement for any capability
|
||||
#[async_trait]
|
||||
pub trait Tool: Send + Sync {
|
||||
/// Tool name (used in LLM function calling)
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Human-readable description
|
||||
fn description(&self) -> &str;
|
||||
|
||||
/// JSON schema for parameters
|
||||
fn parameters_schema(&self) -> serde_json::Value;
|
||||
|
||||
/// Execute the tool with given arguments
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult>;
|
||||
|
||||
/// Get the full spec for LLM registration
|
||||
fn spec(&self) -> ToolSpec {
|
||||
ToolSpec {
|
||||
name: self.name().to_string(),
|
||||
description: self.description().to_string(),
|
||||
parameters: self.parameters_schema(),
|
||||
}
|
||||
}
|
||||
}
|
||||
369
tests/memory_comparison.rs
Normal file
369
tests/memory_comparison.rs
Normal file
|
|
@ -0,0 +1,369 @@
|
|||
//! Head-to-head comparison: SQLite vs Markdown memory backends
|
||||
//!
|
||||
//! Run with: cargo test --test memory_comparison -- --nocapture
|
||||
|
||||
use std::time::Instant;
|
||||
use tempfile::TempDir;
|
||||
|
||||
// We test both backends through the public memory module
|
||||
use zeroclaw::memory::{
|
||||
markdown::MarkdownMemory, sqlite::SqliteMemory, Memory, MemoryCategory,
|
||||
};
|
||||
|
||||
// ── Helpers ────────────────────────────────────────────────────
|
||||
|
||||
fn sqlite_backend(dir: &std::path::Path) -> SqliteMemory {
|
||||
SqliteMemory::new(dir).expect("SQLite init failed")
|
||||
}
|
||||
|
||||
fn markdown_backend(dir: &std::path::Path) -> MarkdownMemory {
|
||||
MarkdownMemory::new(dir)
|
||||
}
|
||||
|
||||
// ── Test 1: Store performance ──────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn compare_store_speed() {
|
||||
let tmp_sq = TempDir::new().unwrap();
|
||||
let tmp_md = TempDir::new().unwrap();
|
||||
let sq = sqlite_backend(tmp_sq.path());
|
||||
let md = markdown_backend(tmp_md.path());
|
||||
|
||||
let n = 100;
|
||||
|
||||
// SQLite: 100 stores
|
||||
let start = Instant::now();
|
||||
for i in 0..n {
|
||||
sq.store(
|
||||
&format!("key_{i}"),
|
||||
&format!("Memory entry number {i} about Rust programming"),
|
||||
MemoryCategory::Core,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
let sq_dur = start.elapsed();
|
||||
|
||||
// Markdown: 100 stores
|
||||
let start = Instant::now();
|
||||
for i in 0..n {
|
||||
md.store(
|
||||
&format!("key_{i}"),
|
||||
&format!("Memory entry number {i} about Rust programming"),
|
||||
MemoryCategory::Core,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
let md_dur = start.elapsed();
|
||||
|
||||
println!("\n============================================================");
|
||||
println!("STORE {n} entries:");
|
||||
println!(" SQLite: {:?}", sq_dur);
|
||||
println!(" Markdown: {:?}", md_dur);
|
||||
|
||||
// Both should succeed
|
||||
assert_eq!(sq.count().await.unwrap(), n);
|
||||
// Markdown count parses lines, may differ slightly from n
|
||||
let md_count = md.count().await.unwrap();
|
||||
assert!(md_count >= n, "Markdown stored {md_count}, expected >= {n}");
|
||||
}
|
||||
|
||||
// ── Test 2: Recall / search quality ────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn compare_recall_quality() {
|
||||
let tmp_sq = TempDir::new().unwrap();
|
||||
let tmp_md = TempDir::new().unwrap();
|
||||
let sq = sqlite_backend(tmp_sq.path());
|
||||
let md = markdown_backend(tmp_md.path());
|
||||
|
||||
// Seed both with identical data
|
||||
let entries = vec![
|
||||
("lang_pref", "User prefers Rust over Python", MemoryCategory::Core),
|
||||
("editor", "Uses VS Code with rust-analyzer", MemoryCategory::Core),
|
||||
("tz", "Timezone is EST, works 9-5", MemoryCategory::Core),
|
||||
("proj1", "Working on ZeroClaw AI assistant", MemoryCategory::Daily),
|
||||
("proj2", "Previous project was a web scraper in Python", MemoryCategory::Daily),
|
||||
("deploy", "Deploys to Hetzner VPS via Docker", MemoryCategory::Core),
|
||||
("model", "Prefers Claude Sonnet for coding tasks", MemoryCategory::Core),
|
||||
("style", "Likes concise responses, no fluff", MemoryCategory::Core),
|
||||
("rust_note", "Rust's ownership model prevents memory bugs", MemoryCategory::Daily),
|
||||
("perf", "Cares about binary size and startup time", MemoryCategory::Core),
|
||||
];
|
||||
|
||||
for (key, content, cat) in &entries {
|
||||
sq.store(key, content, cat.clone()).await.unwrap();
|
||||
md.store(key, content, cat.clone()).await.unwrap();
|
||||
}
|
||||
|
||||
// Test queries and compare results
|
||||
let queries = vec![
|
||||
("Rust", "Should find Rust-related entries"),
|
||||
("Python", "Should find Python references"),
|
||||
("deploy Docker", "Multi-keyword search"),
|
||||
("Claude", "Specific tool reference"),
|
||||
("javascript", "No matches expected"),
|
||||
("binary size startup", "Multi-keyword partial match"),
|
||||
];
|
||||
|
||||
println!("\n============================================================");
|
||||
println!("RECALL QUALITY (10 entries seeded):\n");
|
||||
|
||||
for (query, desc) in &queries {
|
||||
let sq_results = sq.recall(query, 10).await.unwrap();
|
||||
let md_results = md.recall(query, 10).await.unwrap();
|
||||
|
||||
println!(" Query: \"{query}\" — {desc}");
|
||||
println!(" SQLite: {} results", sq_results.len());
|
||||
for r in &sq_results {
|
||||
println!(
|
||||
" [{:.2}] {}: {}",
|
||||
r.score.unwrap_or(0.0),
|
||||
r.key,
|
||||
&r.content[..r.content.len().min(50)]
|
||||
);
|
||||
}
|
||||
println!(" Markdown: {} results", md_results.len());
|
||||
for r in &md_results {
|
||||
println!(
|
||||
" [{:.2}] {}: {}",
|
||||
r.score.unwrap_or(0.0),
|
||||
r.key,
|
||||
&r.content[..r.content.len().min(50)]
|
||||
);
|
||||
}
|
||||
println!();
|
||||
}
|
||||
}
|
||||
|
||||
// ── Test 3: Recall speed at scale ──────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn compare_recall_speed() {
|
||||
let tmp_sq = TempDir::new().unwrap();
|
||||
let tmp_md = TempDir::new().unwrap();
|
||||
let sq = sqlite_backend(tmp_sq.path());
|
||||
let md = markdown_backend(tmp_md.path());
|
||||
|
||||
// Seed 200 entries
|
||||
let n = 200;
|
||||
for i in 0..n {
|
||||
let content = if i % 3 == 0 {
|
||||
format!("Rust is great for systems programming, entry {i}")
|
||||
} else if i % 3 == 1 {
|
||||
format!("Python is popular for data science, entry {i}")
|
||||
} else {
|
||||
format!("TypeScript powers modern web apps, entry {i}")
|
||||
};
|
||||
sq.store(&format!("e{i}"), &content, MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
md.store(&format!("e{i}"), &content, MemoryCategory::Daily)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Benchmark recall
|
||||
let start = Instant::now();
|
||||
let sq_results = sq.recall("Rust systems", 10).await.unwrap();
|
||||
let sq_dur = start.elapsed();
|
||||
|
||||
let start = Instant::now();
|
||||
let md_results = md.recall("Rust systems", 10).await.unwrap();
|
||||
let md_dur = start.elapsed();
|
||||
|
||||
println!("\n============================================================");
|
||||
println!("RECALL from {n} entries (query: \"Rust systems\", limit 10):");
|
||||
println!(" SQLite: {:?} → {} results", sq_dur, sq_results.len());
|
||||
println!(" Markdown: {:?} → {} results", md_dur, md_results.len());
|
||||
|
||||
// Both should find results
|
||||
assert!(!sq_results.is_empty());
|
||||
assert!(!md_results.is_empty());
|
||||
}
|
||||
|
||||
// ── Test 4: Persistence (SQLite wins by design) ────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn compare_persistence() {
|
||||
let tmp_sq = TempDir::new().unwrap();
|
||||
let tmp_md = TempDir::new().unwrap();
|
||||
|
||||
// Store in both, then drop and re-open
|
||||
{
|
||||
let sq = sqlite_backend(tmp_sq.path());
|
||||
sq.store("persist_test", "I should survive", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
{
|
||||
let md = markdown_backend(tmp_md.path());
|
||||
md.store("persist_test", "I should survive", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Re-open
|
||||
let sq2 = sqlite_backend(tmp_sq.path());
|
||||
let md2 = markdown_backend(tmp_md.path());
|
||||
|
||||
let sq_entry = sq2.get("persist_test").await.unwrap();
|
||||
let md_entry = md2.get("persist_test").await.unwrap();
|
||||
|
||||
println!("\n============================================================");
|
||||
println!("PERSISTENCE (store → drop → re-open → get):");
|
||||
println!(
|
||||
" SQLite: {}",
|
||||
if sq_entry.is_some() {
|
||||
"✅ Survived"
|
||||
} else {
|
||||
"❌ Lost"
|
||||
}
|
||||
);
|
||||
println!(
|
||||
" Markdown: {}",
|
||||
if md_entry.is_some() {
|
||||
"✅ Survived"
|
||||
} else {
|
||||
"❌ Lost"
|
||||
}
|
||||
);
|
||||
|
||||
// SQLite should always persist by key
|
||||
assert!(sq_entry.is_some());
|
||||
assert_eq!(sq_entry.unwrap().content, "I should survive");
|
||||
|
||||
// Markdown persists content to files (get uses content search)
|
||||
assert!(md_entry.is_some());
|
||||
}
|
||||
|
||||
// ── Test 5: Upsert / update behavior ──────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn compare_upsert() {
|
||||
let tmp_sq = TempDir::new().unwrap();
|
||||
let tmp_md = TempDir::new().unwrap();
|
||||
let sq = sqlite_backend(tmp_sq.path());
|
||||
let md = markdown_backend(tmp_md.path());
|
||||
|
||||
// Store twice with same key, different content
|
||||
sq.store("pref", "likes Rust", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
sq.store("pref", "loves Rust", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
md.store("pref", "likes Rust", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
md.store("pref", "loves Rust", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let sq_count = sq.count().await.unwrap();
|
||||
let md_count = md.count().await.unwrap();
|
||||
|
||||
let sq_entry = sq.get("pref").await.unwrap();
|
||||
let md_results = md.recall("loves Rust", 5).await.unwrap();
|
||||
|
||||
println!("\n============================================================");
|
||||
println!("UPSERT (store same key twice):");
|
||||
println!(" SQLite: count={sq_count}, latest=\"{}\"",
|
||||
sq_entry.as_ref().map_or("none", |e| &e.content));
|
||||
println!(" Markdown: count={md_count} (append-only, both entries kept)");
|
||||
println!(" Can still find latest: {}", !md_results.is_empty());
|
||||
|
||||
// SQLite: upsert replaces, count stays at 1
|
||||
assert_eq!(sq_count, 1);
|
||||
assert_eq!(sq_entry.unwrap().content, "loves Rust");
|
||||
|
||||
// Markdown: append-only, count increases
|
||||
assert!(md_count >= 2, "Markdown should keep both entries");
|
||||
}
|
||||
|
||||
// ── Test 6: Forget / delete capability ─────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn compare_forget() {
|
||||
let tmp_sq = TempDir::new().unwrap();
|
||||
let tmp_md = TempDir::new().unwrap();
|
||||
let sq = sqlite_backend(tmp_sq.path());
|
||||
let md = markdown_backend(tmp_md.path());
|
||||
|
||||
sq.store("secret", "API key: sk-1234", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
md.store("secret", "API key: sk-1234", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let sq_forgot = sq.forget("secret").await.unwrap();
|
||||
let md_forgot = md.forget("secret").await.unwrap();
|
||||
|
||||
println!("\n============================================================");
|
||||
println!("FORGET (delete sensitive data):");
|
||||
println!(
|
||||
" SQLite: {} (count={})",
|
||||
if sq_forgot { "✅ Deleted" } else { "❌ Kept" },
|
||||
sq.count().await.unwrap()
|
||||
);
|
||||
println!(
|
||||
" Markdown: {} (append-only by design)",
|
||||
if md_forgot { "✅ Deleted" } else { "⚠️ Cannot delete (audit trail)" },
|
||||
);
|
||||
|
||||
// SQLite can delete
|
||||
assert!(sq_forgot);
|
||||
assert_eq!(sq.count().await.unwrap(), 0);
|
||||
|
||||
// Markdown cannot delete (by design)
|
||||
assert!(!md_forgot);
|
||||
}
|
||||
|
||||
// ── Test 7: Category filtering ─────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn compare_category_filter() {
|
||||
let tmp_sq = TempDir::new().unwrap();
|
||||
let tmp_md = TempDir::new().unwrap();
|
||||
let sq = sqlite_backend(tmp_sq.path());
|
||||
let md = markdown_backend(tmp_md.path());
|
||||
|
||||
// Mix of categories
|
||||
sq.store("a", "core fact 1", MemoryCategory::Core).await.unwrap();
|
||||
sq.store("b", "core fact 2", MemoryCategory::Core).await.unwrap();
|
||||
sq.store("c", "daily note", MemoryCategory::Daily).await.unwrap();
|
||||
sq.store("d", "convo msg", MemoryCategory::Conversation).await.unwrap();
|
||||
|
||||
md.store("a", "core fact 1", MemoryCategory::Core).await.unwrap();
|
||||
md.store("b", "core fact 2", MemoryCategory::Core).await.unwrap();
|
||||
md.store("c", "daily note", MemoryCategory::Daily).await.unwrap();
|
||||
|
||||
let sq_core = sq.list(Some(&MemoryCategory::Core)).await.unwrap();
|
||||
let sq_daily = sq.list(Some(&MemoryCategory::Daily)).await.unwrap();
|
||||
let sq_conv = sq.list(Some(&MemoryCategory::Conversation)).await.unwrap();
|
||||
let sq_all = sq.list(None).await.unwrap();
|
||||
|
||||
let md_core = md.list(Some(&MemoryCategory::Core)).await.unwrap();
|
||||
let md_daily = md.list(Some(&MemoryCategory::Daily)).await.unwrap();
|
||||
let md_all = md.list(None).await.unwrap();
|
||||
|
||||
println!("\n============================================================");
|
||||
println!("CATEGORY FILTERING:");
|
||||
println!(" SQLite: core={}, daily={}, conv={}, all={}",
|
||||
sq_core.len(), sq_daily.len(), sq_conv.len(), sq_all.len());
|
||||
println!(" Markdown: core={}, daily={}, all={}",
|
||||
md_core.len(), md_daily.len(), md_all.len());
|
||||
|
||||
// SQLite: precise category filtering via SQL WHERE
|
||||
assert_eq!(sq_core.len(), 2);
|
||||
assert_eq!(sq_daily.len(), 1);
|
||||
assert_eq!(sq_conv.len(), 1);
|
||||
assert_eq!(sq_all.len(), 4);
|
||||
|
||||
// Markdown: categories determined by file location
|
||||
assert!(!md_core.is_empty());
|
||||
assert!(!md_all.is_empty());
|
||||
}
|
||||
BIN
zeroclaw.png
Normal file
BIN
zeroclaw.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.1 MiB |
Loading…
Add table
Add a link
Reference in a new issue