Merge pull request #535 from zeroclaw-labs/pr-484-clean
fix: Discord channel replies and parking_lot::Mutex migration
This commit is contained in:
commit
acfdc34be2
18 changed files with 130 additions and 203 deletions
|
|
@ -7,7 +7,7 @@ name = "zeroclaw"
|
|||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
authors = ["theonlyhennygod"]
|
||||
license = "MIT"
|
||||
license = "Apache-2.0"
|
||||
description = "Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant."
|
||||
repository = "https://github.com/zeroclaw-labs/zeroclaw"
|
||||
readme = "README.md"
|
||||
|
|
|
|||
29
Dockerfile
29
Dockerfile
|
|
@ -1,12 +1,14 @@
|
|||
# syntax=docker/dockerfile:1
|
||||
# syntax=docker/dockerfile:1.7
|
||||
|
||||
# ── Stage 1: Build ────────────────────────────────────────────
|
||||
FROM rust:1.93-slim-trixie@sha256:9663b80a1621253d30b146454f903de48f0af925c967be48c84745537cd35d8b AS builder
|
||||
FROM rust:1.92-slim@sha256:bf3368a992915f128293ac76917ab6e561e4dda883273c8f5c9f6f8ea37a378e AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install build dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||
apt-get update && apt-get install -y \
|
||||
pkg-config \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
|
@ -14,19 +16,20 @@ RUN apt-get update && apt-get install -y \
|
|||
COPY Cargo.toml Cargo.lock ./
|
||||
# Create dummy main.rs to build dependencies
|
||||
RUN mkdir src && echo "fn main() {}" > src/main.rs
|
||||
RUN --mount=type=cache,target=/usr/local/cargo/registry \
|
||||
--mount=type=cache,target=/usr/local/cargo/git \
|
||||
RUN --mount=type=cache,id=zeroclaw-cargo-registry,target=/usr/local/cargo/registry,sharing=locked \
|
||||
--mount=type=cache,id=zeroclaw-cargo-git,target=/usr/local/cargo/git,sharing=locked \
|
||||
--mount=type=cache,id=zeroclaw-target,target=/app/target,sharing=locked \
|
||||
cargo build --release --locked
|
||||
RUN rm -rf src
|
||||
|
||||
# 2. Copy source code
|
||||
COPY . .
|
||||
# Touch main.rs to force rebuild
|
||||
RUN touch src/main.rs
|
||||
RUN --mount=type=cache,target=/usr/local/cargo/registry \
|
||||
--mount=type=cache,target=/usr/local/cargo/git \
|
||||
RUN --mount=type=cache,id=zeroclaw-cargo-registry,target=/usr/local/cargo/registry,sharing=locked \
|
||||
--mount=type=cache,id=zeroclaw-cargo-git,target=/usr/local/cargo/git,sharing=locked \
|
||||
--mount=type=cache,id=zeroclaw-target,target=/app/target,sharing=locked \
|
||||
cargo build --release --locked && \
|
||||
strip target/release/zeroclaw
|
||||
cp target/release/zeroclaw /app/zeroclaw && \
|
||||
strip /app/zeroclaw
|
||||
|
||||
# ── Stage 2: Permissions & Config Prep ───────────────────────
|
||||
FROM busybox:1.37@sha256:b3255e7dfbcd10cb367af0d409747d511aeb66dfac98cf30e97e87e4207dd76f AS permissions
|
||||
|
|
@ -35,7 +38,7 @@ RUN mkdir -p /zeroclaw-data/.zeroclaw /zeroclaw-data/workspace
|
|||
|
||||
# Create minimal config for PRODUCTION (allows binding to public interfaces)
|
||||
# NOTE: Provider configuration must be done via environment variables at runtime
|
||||
RUN cat > /zeroclaw-data/.zeroclaw/config.toml << 'EOF'
|
||||
RUN cat > /zeroclaw-data/.zeroclaw/config.toml <<EOF
|
||||
workspace_dir = "/zeroclaw-data/workspace"
|
||||
config_path = "/zeroclaw-data/.zeroclaw/config.toml"
|
||||
api_key = ""
|
||||
|
|
@ -65,7 +68,7 @@ RUN apt-get update && apt-get install -y \
|
|||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY --from=permissions /zeroclaw-data /zeroclaw-data
|
||||
COPY --from=builder /app/target/release/zeroclaw /usr/local/bin/zeroclaw
|
||||
COPY --from=builder /app/zeroclaw /usr/local/bin/zeroclaw
|
||||
|
||||
# Overwrite minimal config with DEV template (Ollama defaults)
|
||||
COPY dev/config.template.toml /zeroclaw-data/.zeroclaw/config.toml
|
||||
|
|
@ -92,7 +95,7 @@ CMD ["gateway", "--port", "3000", "--host", "[::]"]
|
|||
# ── Stage 4: Production Runtime (Distroless) ─────────────────
|
||||
FROM gcr.io/distroless/cc-debian13:nonroot@sha256:84fcd3c223b144b0cb6edc5ecc75641819842a9679a3a58fd6294bec47532bf7 AS release
|
||||
|
||||
COPY --from=builder /app/target/release/zeroclaw /usr/local/bin/zeroclaw
|
||||
COPY --from=builder /app/zeroclaw /usr/local/bin/zeroclaw
|
||||
COPY --from=permissions /zeroclaw-data /zeroclaw-data
|
||||
|
||||
# Environment setup
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@
|
|||
<p align="center">
|
||||
<a href="LICENSE"><img src="https://img.shields.io/badge/license-MIT-blue.svg" alt="License: MIT" /></a>
|
||||
<a href="NOTICE"><img src="https://img.shields.io/badge/contributors-27+-green.svg" alt="Contributors" /></a>
|
||||
<a href="https://buymeacoffee.com/argenistherose"><img src="https://img.shields.io/badge/Buy%20Me%20a%20Coffee-Donate-yellow.svg?style=flat&logo=buy-me-a-coffee" alt="Buy Me a Coffee" /></a>
|
||||
</p>
|
||||
|
||||
Fast, small, and fully autonomous AI assistant infrastructure — deploy anywhere, swap anything.
|
||||
|
|
@ -616,12 +615,6 @@ For high-throughput collaboration and consistent reviews:
|
|||
- CI ownership and triage map: [docs/ci-map.md](docs/ci-map.md)
|
||||
- Security disclosure policy: [SECURITY.md](SECURITY.md)
|
||||
|
||||
## Support
|
||||
|
||||
ZeroClaw is an open-source project maintained with passion. If you find it useful and would like to support its continued development, hardware for testing, and coffee for the maintainer, you can support me here:
|
||||
|
||||
<a href="https://buymeacoffee.com/argenistherose"><img src="https://img.shields.io/badge/Buy%20Me%20a%20Coffee-Donate-yellow.svg?style=for-the-badge&logo=buy-me-a-coffee" alt="Buy Me a Coffee" /></a>
|
||||
|
||||
### 🙏 Special Thanks
|
||||
|
||||
A heartfelt thank you to the communities and institutions that inspire and fuel this open-source work:
|
||||
|
|
|
|||
|
|
@ -163,5 +163,7 @@ Note: local `deny` focuses on license/source policy; advisory scanning is handle
|
|||
### Build cache notes
|
||||
|
||||
- Both `Dockerfile` and `dev/ci/Dockerfile` use BuildKit cache mounts for Cargo registry/git data.
|
||||
- The root `Dockerfile` also caches Rust `target/` (`id=zeroclaw-target`) to speed repeat local image builds.
|
||||
- Local CI reuses named Docker volumes for Cargo registry/git and target outputs.
|
||||
- `./dev/ci.sh docker-smoke` and `./dev/ci.sh all` now use `docker buildx` local cache at `.cache/buildx-smoke` when available.
|
||||
- The CI image keeps Rust toolchain defaults from `rust:1.92-slim` and installs pinned toolchain `1.92.0` (no custom `CARGO_HOME`/`RUSTUP_HOME` overrides), preventing repeated toolchain bootstrapping on each run.
|
||||
|
|
|
|||
24
dev/ci.sh
24
dev/ci.sh
|
|
@ -11,12 +11,32 @@ else
|
|||
fi
|
||||
|
||||
compose_cmd=(docker compose -f "$COMPOSE_FILE")
|
||||
SMOKE_CACHE_DIR="${SMOKE_CACHE_DIR:-.cache/buildx-smoke}"
|
||||
|
||||
run_in_ci() {
|
||||
local cmd="$1"
|
||||
"${compose_cmd[@]}" run --rm local-ci bash -c "$cmd"
|
||||
}
|
||||
|
||||
build_smoke_image() {
|
||||
if docker buildx version >/dev/null 2>&1; then
|
||||
mkdir -p "$SMOKE_CACHE_DIR"
|
||||
local build_args=(
|
||||
--load
|
||||
--target dev
|
||||
--cache-to "type=local,dest=$SMOKE_CACHE_DIR,mode=max"
|
||||
-t zeroclaw-local-smoke:latest
|
||||
.
|
||||
)
|
||||
if [ -f "$SMOKE_CACHE_DIR/index.json" ]; then
|
||||
build_args=(--cache-from "type=local,src=$SMOKE_CACHE_DIR" "${build_args[@]}")
|
||||
fi
|
||||
docker buildx build "${build_args[@]}"
|
||||
else
|
||||
DOCKER_BUILDKIT=1 docker build --target dev -t zeroclaw-local-smoke:latest .
|
||||
fi
|
||||
}
|
||||
|
||||
print_help() {
|
||||
cat <<'EOF'
|
||||
ZeroClaw Local CI in Docker
|
||||
|
|
@ -88,7 +108,7 @@ case "$1" in
|
|||
;;
|
||||
|
||||
docker-smoke)
|
||||
docker build --target dev -t zeroclaw-local-smoke:latest .
|
||||
build_smoke_image
|
||||
docker run --rm zeroclaw-local-smoke:latest --version
|
||||
;;
|
||||
|
||||
|
|
@ -98,7 +118,7 @@ case "$1" in
|
|||
run_in_ci "cargo build --release --locked --verbose"
|
||||
run_in_ci "cargo deny check licenses sources"
|
||||
run_in_ci "cargo audit"
|
||||
docker build --target dev -t zeroclaw-local-smoke:latest .
|
||||
build_smoke_image
|
||||
docker run --rm zeroclaw-local-smoke:latest --version
|
||||
;;
|
||||
|
||||
|
|
|
|||
|
|
@ -568,7 +568,7 @@ pub async fn run(
|
|||
mod tests {
|
||||
use super::*;
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Mutex;
|
||||
use parking_lot::Mutex;
|
||||
|
||||
struct MockProvider {
|
||||
responses: Mutex<Vec<crate::providers::ChatResponse>>,
|
||||
|
|
@ -592,7 +592,7 @@ mod tests {
|
|||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> Result<crate::providers::ChatResponse> {
|
||||
let mut guard = self.responses.lock().unwrap();
|
||||
let mut guard = self.responses.lock();
|
||||
if guard.is_empty() {
|
||||
return Ok(crate::providers::ChatResponse {
|
||||
text: Some("done".into()),
|
||||
|
|
|
|||
|
|
@ -363,11 +363,7 @@ impl Channel for DiscordChannel {
|
|||
};
|
||||
|
||||
let message_id = d.get("id").and_then(|i| i.as_str()).unwrap_or("");
|
||||
let channel_id = d
|
||||
.get("channel_id")
|
||||
.and_then(|c| c.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let channel_id = d.get("channel_id").and_then(|c| c.as_str()).unwrap_or("").to_string();
|
||||
|
||||
let channel_msg = ChannelMessage {
|
||||
id: if message_id.is_empty() {
|
||||
|
|
@ -379,10 +375,10 @@ impl Channel for DiscordChannel {
|
|||
reply_target: if channel_id.is_empty() {
|
||||
author_id.to_string()
|
||||
} else {
|
||||
channel_id
|
||||
channel_id.clone()
|
||||
},
|
||||
content: content.to_string(),
|
||||
channel: "discord".to_string(),
|
||||
content: clean_content,
|
||||
channel: channel_id,
|
||||
timestamp: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
|
|
|
|||
|
|
@ -14,11 +14,11 @@ use lettre::message::SinglePart;
|
|||
use lettre::transport::smtp::authentication::Credentials;
|
||||
use lettre::{Message, SmtpTransport, Transport};
|
||||
use mail_parser::{MessageParser, MimeHeaders};
|
||||
use parking_lot::Mutex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashSet;
|
||||
use std::io::Write as IoWrite;
|
||||
use std::net::TcpStream;
|
||||
use std::sync::Mutex;
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::{interval, sleep};
|
||||
|
|
@ -413,10 +413,7 @@ impl Channel for EmailChannel {
|
|||
Ok(Ok(messages)) => {
|
||||
for (id, sender, content, ts) in messages {
|
||||
{
|
||||
let mut seen = self
|
||||
.seen_messages
|
||||
.lock()
|
||||
.expect("seen_messages mutex should not be poisoned");
|
||||
let mut seen = self.seen_messages.lock();
|
||||
if seen.contains(&id) {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -488,20 +485,14 @@ mod tests {
|
|||
#[test]
|
||||
fn seen_messages_starts_empty() {
|
||||
let channel = EmailChannel::new(EmailConfig::default());
|
||||
let seen = channel
|
||||
.seen_messages
|
||||
.lock()
|
||||
.expect("seen_messages mutex should not be poisoned");
|
||||
let seen = channel.seen_messages.lock();
|
||||
assert!(seen.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn seen_messages_tracks_unique_ids() {
|
||||
let channel = EmailChannel::new(EmailConfig::default());
|
||||
let mut seen = channel
|
||||
.seen_messages
|
||||
.lock()
|
||||
.expect("seen_messages mutex should not be poisoned");
|
||||
let mut seen = channel.seen_messages.lock();
|
||||
|
||||
assert!(seen.insert("first-id".to_string()));
|
||||
assert!(!seen.insert("first-id".to_string()));
|
||||
|
|
@ -576,10 +567,7 @@ mod tests {
|
|||
let channel = EmailChannel::new(config.clone());
|
||||
assert_eq!(channel.config.imap_host, config.imap_host);
|
||||
|
||||
let seen_guard = channel
|
||||
.seen_messages
|
||||
.lock()
|
||||
.expect("seen_messages mutex should not be poisoned");
|
||||
let seen_guard = channel.seen_messages.lock();
|
||||
assert_eq!(seen_guard.len(), 0);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -25,9 +25,10 @@ use axum::{
|
|||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use parking_lot::Mutex;
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tower_http::limit::RequestBodyLimitLayer;
|
||||
use tower_http::timeout::TimeoutLayer;
|
||||
|
|
@ -82,10 +83,7 @@ impl SlidingWindowRateLimiter {
|
|||
let now = Instant::now();
|
||||
let cutoff = now.checked_sub(self.window).unwrap_or_else(Instant::now);
|
||||
|
||||
let mut guard = self
|
||||
.requests
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let mut guard = self.requests.lock();
|
||||
let (requests, last_sweep) = &mut *guard;
|
||||
|
||||
// Periodic sweep: remove IPs with no recent requests
|
||||
|
|
@ -150,10 +148,7 @@ impl IdempotencyStore {
|
|||
/// Returns true if this key is new and is now recorded.
|
||||
fn record_if_new(&self, key: &str) -> bool {
|
||||
let now = Instant::now();
|
||||
let mut keys = self
|
||||
.keys
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let mut keys = self.keys.lock();
|
||||
|
||||
keys.retain(|_, seen_at| now.duration_since(*seen_at) < self.ttl);
|
||||
|
||||
|
|
@ -739,8 +734,8 @@ mod tests {
|
|||
use axum::http::HeaderValue;
|
||||
use axum::response::IntoResponse;
|
||||
use http_body_util::BodyExt;
|
||||
use parking_lot::Mutex;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Mutex;
|
||||
|
||||
#[test]
|
||||
fn security_body_limit_is_64kb() {
|
||||
|
|
@ -797,19 +792,13 @@ mod tests {
|
|||
assert!(limiter.allow("ip-3"));
|
||||
|
||||
{
|
||||
let guard = limiter
|
||||
.requests
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let guard = limiter.requests.lock();
|
||||
assert_eq!(guard.0.len(), 3);
|
||||
}
|
||||
|
||||
// Force a sweep by backdating last_sweep
|
||||
{
|
||||
let mut guard = limiter
|
||||
.requests
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let mut guard = limiter.requests.lock();
|
||||
guard.1 = Instant::now()
|
||||
.checked_sub(Duration::from_secs(RATE_LIMITER_SWEEP_INTERVAL_SECS + 1))
|
||||
.unwrap();
|
||||
|
|
@ -822,10 +811,7 @@ mod tests {
|
|||
assert!(limiter.allow("ip-1"));
|
||||
|
||||
{
|
||||
let guard = limiter
|
||||
.requests
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let guard = limiter.requests.lock();
|
||||
assert_eq!(guard.0.len(), 1, "Stale entries should have been swept");
|
||||
assert!(guard.0.contains_key("ip-1"));
|
||||
}
|
||||
|
|
@ -962,10 +948,7 @@ mod tests {
|
|||
_category: MemoryCategory,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.keys
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.push(key.to_string());
|
||||
self.keys.lock().push(key.to_string());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
@ -995,11 +978,7 @@ mod tests {
|
|||
}
|
||||
|
||||
async fn count(&self) -> anyhow::Result<usize> {
|
||||
let size = self
|
||||
.keys
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.len();
|
||||
let size = self.keys.lock().len();
|
||||
Ok(size)
|
||||
}
|
||||
|
||||
|
|
@ -1094,11 +1073,7 @@ mod tests {
|
|||
.into_response();
|
||||
assert_eq!(second.status(), StatusCode::OK);
|
||||
|
||||
let keys = tracking_impl
|
||||
.keys
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.clone();
|
||||
let keys = tracking_impl.keys.lock().clone();
|
||||
assert_eq!(keys.len(), 2);
|
||||
assert_ne!(keys[0], keys[1]);
|
||||
assert!(keys[0].starts_with("webhook_msg_"));
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@ use super::sqlite::SqliteMemory;
|
|||
use super::traits::{Memory, MemoryCategory, MemoryEntry};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Local;
|
||||
use parking_lot::Mutex;
|
||||
use std::collections::HashSet;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Mutex;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::process::Command;
|
||||
use tokio::time::timeout;
|
||||
|
|
@ -116,9 +116,7 @@ impl LucidMemory {
|
|||
}
|
||||
|
||||
fn in_failure_cooldown(&self) -> bool {
|
||||
let Ok(guard) = self.last_failure_at.lock() else {
|
||||
return false;
|
||||
};
|
||||
let guard = self.last_failure_at.lock();
|
||||
|
||||
guard
|
||||
.as_ref()
|
||||
|
|
@ -126,15 +124,11 @@ impl LucidMemory {
|
|||
}
|
||||
|
||||
fn mark_failure_now(&self) {
|
||||
if let Ok(mut guard) = self.last_failure_at.lock() {
|
||||
*guard = Some(Instant::now());
|
||||
}
|
||||
*self.last_failure_at.lock() = Some(Instant::now());
|
||||
}
|
||||
|
||||
fn clear_failure(&self) {
|
||||
if let Ok(mut guard) = self.last_failure_at.lock() {
|
||||
*guard = None;
|
||||
}
|
||||
*self.last_failure_at.lock() = None;
|
||||
}
|
||||
|
||||
fn to_lucid_type(category: &MemoryCategory) -> &'static str {
|
||||
|
|
@ -565,11 +559,12 @@ exit 1
|
|||
"local_note",
|
||||
"Local sqlite auth fallback note",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let entries = memory.recall("auth", 5).await.unwrap();
|
||||
let entries = memory.recall("auth", 5, None).await.unwrap();
|
||||
|
||||
assert!(entries
|
||||
.iter()
|
||||
|
|
|
|||
|
|
@ -7,10 +7,10 @@
|
|||
|
||||
use anyhow::Result;
|
||||
use chrono::{Duration, Local};
|
||||
use parking_lot::Mutex;
|
||||
use rusqlite::{params, Connection};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Mutex;
|
||||
|
||||
/// Response cache backed by a dedicated SQLite database.
|
||||
///
|
||||
|
|
@ -77,10 +77,7 @@ impl ResponseCache {
|
|||
|
||||
/// Look up a cached response. Returns `None` on miss or expired entry.
|
||||
pub fn get(&self, key: &str) -> Result<Option<String>> {
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
let conn = self.conn.lock();
|
||||
|
||||
let now = Local::now();
|
||||
let cutoff = (now - Duration::minutes(self.ttl_minutes)).to_rfc3339();
|
||||
|
|
@ -108,10 +105,7 @@ impl ResponseCache {
|
|||
|
||||
/// Store a response in the cache.
|
||||
pub fn put(&self, key: &str, model: &str, response: &str, token_count: u32) -> Result<()> {
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
let conn = self.conn.lock();
|
||||
|
||||
let now = Local::now().to_rfc3339();
|
||||
|
||||
|
|
@ -146,10 +140,7 @@ impl ResponseCache {
|
|||
|
||||
/// Return cache statistics: (total_entries, total_hits, total_tokens_saved).
|
||||
pub fn stats(&self) -> Result<(usize, u64, u64)> {
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
let conn = self.conn.lock();
|
||||
|
||||
let count: i64 =
|
||||
conn.query_row("SELECT COUNT(*) FROM response_cache", [], |row| row.get(0))?;
|
||||
|
|
@ -172,10 +163,7 @@ impl ResponseCache {
|
|||
|
||||
/// Wipe the entire cache (useful for `zeroclaw cache clear`).
|
||||
pub fn clear(&self) -> Result<usize> {
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
let conn = self.conn.lock();
|
||||
|
||||
let affected = conn.execute("DELETE FROM response_cache", [])?;
|
||||
Ok(affected)
|
||||
|
|
|
|||
|
|
@ -3,9 +3,10 @@ use super::traits::{Memory, MemoryCategory, MemoryEntry};
|
|||
use super::vector;
|
||||
use async_trait::async_trait;
|
||||
use chrono::Local;
|
||||
use parking_lot::Mutex;
|
||||
use rusqlite::{params, Connection};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// SQLite-backed persistent memory — the brain
|
||||
|
|
@ -185,10 +186,7 @@ impl SqliteMemory {
|
|||
|
||||
// Check cache
|
||||
{
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
let conn = self.conn.lock();
|
||||
|
||||
let mut stmt =
|
||||
conn.prepare("SELECT embedding FROM embedding_cache WHERE content_hash = ?1")?;
|
||||
|
|
@ -210,10 +208,7 @@ impl SqliteMemory {
|
|||
|
||||
// Store in cache + LRU eviction
|
||||
{
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
let conn = self.conn.lock();
|
||||
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO embedding_cache (content_hash, embedding, created_at, accessed_at)
|
||||
|
|
@ -315,10 +310,7 @@ impl SqliteMemory {
|
|||
pub async fn reindex(&self) -> anyhow::Result<usize> {
|
||||
// Step 1: Rebuild FTS5
|
||||
{
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
let conn = self.conn.lock();
|
||||
|
||||
conn.execute_batch("INSERT INTO memories_fts(memories_fts) VALUES('rebuild');")?;
|
||||
}
|
||||
|
|
@ -329,10 +321,7 @@ impl SqliteMemory {
|
|||
}
|
||||
|
||||
let entries: Vec<(String, String)> = {
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
let conn = self.conn.lock();
|
||||
|
||||
let mut stmt =
|
||||
conn.prepare("SELECT id, content FROM memories WHERE embedding IS NULL")?;
|
||||
|
|
@ -346,10 +335,7 @@ impl SqliteMemory {
|
|||
for (id, content) in &entries {
|
||||
if let Ok(Some(emb)) = self.get_or_compute_embedding(content).await {
|
||||
let bytes = vector::vec_to_bytes(&emb);
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
let conn = self.conn.lock();
|
||||
conn.execute(
|
||||
"UPDATE memories SET embedding = ?1 WHERE id = ?2",
|
||||
params![bytes, id],
|
||||
|
|
@ -381,10 +367,7 @@ impl Memory for SqliteMemory {
|
|||
.await?
|
||||
.map(|emb| vector::vec_to_bytes(&emb));
|
||||
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
let conn = self.conn.lock();
|
||||
let now = Local::now().to_rfc3339();
|
||||
let cat = Self::category_to_str(&category);
|
||||
let id = Uuid::new_v4().to_string();
|
||||
|
|
@ -417,10 +400,7 @@ impl Memory for SqliteMemory {
|
|||
// Compute query embedding (async, before lock)
|
||||
let query_embedding = self.get_or_compute_embedding(query).await?;
|
||||
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
let conn = self.conn.lock();
|
||||
|
||||
// FTS5 BM25 keyword search
|
||||
let keyword_results = Self::fts5_search(&conn, query, limit * 2).unwrap_or_default();
|
||||
|
|
@ -539,10 +519,7 @@ impl Memory for SqliteMemory {
|
|||
}
|
||||
|
||||
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
let conn = self.conn.lock();
|
||||
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE key = ?1",
|
||||
|
|
@ -571,10 +548,7 @@ impl Memory for SqliteMemory {
|
|||
category: Option<&MemoryCategory>,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
let conn = self.conn.lock();
|
||||
|
||||
let mut results = Vec::new();
|
||||
|
||||
|
|
@ -627,29 +601,20 @@ impl Memory for SqliteMemory {
|
|||
}
|
||||
|
||||
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
let conn = self.conn.lock();
|
||||
let affected = conn.execute("DELETE FROM memories WHERE key = ?1", params![key])?;
|
||||
Ok(affected > 0)
|
||||
}
|
||||
|
||||
async fn count(&self) -> anyhow::Result<usize> {
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
let conn = self.conn.lock();
|
||||
let count: i64 = conn.query_row("SELECT COUNT(*) FROM memories", [], |row| row.get(0))?;
|
||||
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||
Ok(count as usize)
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
self.conn
|
||||
.lock()
|
||||
.map(|c| c.execute_batch("SELECT 1").is_ok())
|
||||
.unwrap_or(false)
|
||||
self.conn.lock().execute_batch("SELECT 1").is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -968,7 +933,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn schema_has_fts5_table() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
let conn = mem.conn.lock().unwrap();
|
||||
let conn = mem.conn.lock();
|
||||
// FTS5 table should exist
|
||||
let count: i64 = conn
|
||||
.query_row(
|
||||
|
|
@ -983,7 +948,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn schema_has_embedding_cache() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
let conn = mem.conn.lock().unwrap();
|
||||
let conn = mem.conn.lock();
|
||||
let count: i64 = conn
|
||||
.query_row(
|
||||
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='embedding_cache'",
|
||||
|
|
@ -997,7 +962,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn schema_memories_has_embedding_column() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
let conn = mem.conn.lock().unwrap();
|
||||
let conn = mem.conn.lock();
|
||||
// Check that embedding column exists by querying it
|
||||
let result = conn.execute_batch("SELECT embedding FROM memories LIMIT 0");
|
||||
assert!(result.is_ok());
|
||||
|
|
@ -1017,7 +982,7 @@ mod tests {
|
|||
.await
|
||||
.unwrap();
|
||||
|
||||
let conn = mem.conn.lock().unwrap();
|
||||
let conn = mem.conn.lock();
|
||||
let count: i64 = conn
|
||||
.query_row(
|
||||
"SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"unique_searchterm_xyz\"'",
|
||||
|
|
@ -1041,7 +1006,7 @@ mod tests {
|
|||
.unwrap();
|
||||
mem.forget("del_key").await.unwrap();
|
||||
|
||||
let conn = mem.conn.lock().unwrap();
|
||||
let conn = mem.conn.lock();
|
||||
let count: i64 = conn
|
||||
.query_row(
|
||||
"SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"deletable_content_abc\"'",
|
||||
|
|
@ -1067,7 +1032,7 @@ mod tests {
|
|||
.await
|
||||
.unwrap();
|
||||
|
||||
let conn = mem.conn.lock().unwrap();
|
||||
let conn = mem.conn.lock();
|
||||
// Old content should not be findable
|
||||
let old: i64 = conn
|
||||
.query_row(
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ pub trait Observer: Send + Sync + 'static {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::Mutex;
|
||||
use parking_lot::Mutex;
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Default)]
|
||||
|
|
@ -97,12 +97,12 @@ mod tests {
|
|||
|
||||
impl Observer for DummyObserver {
|
||||
fn record_event(&self, _event: &ObserverEvent) {
|
||||
let mut guard = self.events.lock().unwrap();
|
||||
let mut guard = self.events.lock();
|
||||
*guard += 1;
|
||||
}
|
||||
|
||||
fn record_metric(&self, _metric: &ObserverMetric) {
|
||||
let mut guard = self.metrics.lock().unwrap();
|
||||
let mut guard = self.metrics.lock();
|
||||
*guard += 1;
|
||||
}
|
||||
|
||||
|
|
@ -122,8 +122,8 @@ mod tests {
|
|||
});
|
||||
observer.record_metric(&ObserverMetric::TokensUsed(42));
|
||||
|
||||
assert_eq!(*observer.events.lock().unwrap(), 2);
|
||||
assert_eq!(*observer.metrics.lock().unwrap(), 1);
|
||||
assert_eq!(*observer.events.lock(), 2);
|
||||
assert_eq!(*observer.metrics.lock(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -683,7 +683,7 @@ impl Provider for OpenAiCompatibleProvider {
|
|||
options: StreamOptions,
|
||||
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
|
||||
let credential = match self.credential.as_ref() {
|
||||
Some(key) => key.clone(),
|
||||
Some(value) => value.clone(),
|
||||
None => {
|
||||
let provider_name = self.name.clone();
|
||||
return stream::once(async move {
|
||||
|
|
|
|||
|
|
@ -475,7 +475,7 @@ mod tests {
|
|||
/// Mock that records which model was used for each call.
|
||||
struct ModelAwareMock {
|
||||
calls: Arc<AtomicUsize>,
|
||||
models_seen: std::sync::Mutex<Vec<String>>,
|
||||
models_seen: parking_lot::Mutex<Vec<String>>,
|
||||
fail_models: Vec<&'static str>,
|
||||
response: &'static str,
|
||||
}
|
||||
|
|
@ -490,7 +490,7 @@ mod tests {
|
|||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
self.calls.fetch_add(1, Ordering::SeqCst);
|
||||
self.models_seen.lock().unwrap().push(model.to_string());
|
||||
self.models_seen.lock().push(model.to_string());
|
||||
if self.fail_models.contains(&model) {
|
||||
anyhow::bail!("500 model {} unavailable", model);
|
||||
}
|
||||
|
|
@ -743,7 +743,7 @@ mod tests {
|
|||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let mock = Arc::new(ModelAwareMock {
|
||||
calls: Arc::clone(&calls),
|
||||
models_seen: std::sync::Mutex::new(Vec::new()),
|
||||
models_seen: parking_lot::Mutex::new(Vec::new()),
|
||||
fail_models: vec!["claude-opus"],
|
||||
response: "ok from sonnet",
|
||||
});
|
||||
|
|
@ -767,7 +767,7 @@ mod tests {
|
|||
.unwrap();
|
||||
assert_eq!(result, "ok from sonnet");
|
||||
|
||||
let seen = mock.models_seen.lock().unwrap();
|
||||
let seen = mock.models_seen.lock();
|
||||
assert_eq!(seen.len(), 2);
|
||||
assert_eq!(seen[0], "claude-opus");
|
||||
assert_eq!(seen[1], "claude-sonnet");
|
||||
|
|
@ -778,7 +778,7 @@ mod tests {
|
|||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let mock = Arc::new(ModelAwareMock {
|
||||
calls: Arc::clone(&calls),
|
||||
models_seen: std::sync::Mutex::new(Vec::new()),
|
||||
models_seen: parking_lot::Mutex::new(Vec::new()),
|
||||
fail_models: vec!["model-a", "model-b", "model-c"],
|
||||
response: "never",
|
||||
});
|
||||
|
|
@ -802,7 +802,7 @@ mod tests {
|
|||
.expect_err("all models should fail");
|
||||
assert!(err.to_string().contains("All providers/models failed"));
|
||||
|
||||
let seen = mock.models_seen.lock().unwrap();
|
||||
let seen = mock.models_seen.lock();
|
||||
assert_eq!(seen.len(), 3);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -164,7 +164,7 @@ mod tests {
|
|||
struct MockProvider {
|
||||
calls: Arc<AtomicUsize>,
|
||||
response: &'static str,
|
||||
last_model: std::sync::Mutex<String>,
|
||||
last_model: parking_lot::Mutex<String>,
|
||||
}
|
||||
|
||||
impl MockProvider {
|
||||
|
|
@ -172,7 +172,7 @@ mod tests {
|
|||
Self {
|
||||
calls: Arc::new(AtomicUsize::new(0)),
|
||||
response,
|
||||
last_model: std::sync::Mutex::new(String::new()),
|
||||
last_model: parking_lot::Mutex::new(String::new()),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -181,7 +181,7 @@ mod tests {
|
|||
}
|
||||
|
||||
fn last_model(&self) -> String {
|
||||
self.last_model.lock().unwrap().clone()
|
||||
self.last_model.lock().clone()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -195,7 +195,7 @@ mod tests {
|
|||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
self.calls.fetch_add(1, Ordering::SeqCst);
|
||||
*self.last_model.lock().unwrap() = model.to_string();
|
||||
*self.last_model.lock() = model.to_string();
|
||||
Ok(self.response.to_string())
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -76,6 +76,13 @@ pub struct ChatRequest<'a> {
|
|||
pub tools: Option<&'a [ToolSpec]>,
|
||||
}
|
||||
|
||||
/// Declares optional provider features.
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
|
||||
pub struct ProviderCapabilities {
|
||||
/// Provider can perform native tool calling without prompt-level emulation.
|
||||
pub native_tool_calling: bool,
|
||||
}
|
||||
|
||||
/// A tool result to feed back to the LLM.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolResultMessage {
|
||||
|
|
@ -191,21 +198,6 @@ pub enum StreamError {
|
|||
Io(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
/// Provider capabilities declaration.
|
||||
///
|
||||
/// Describes what features a provider supports, enabling intelligent
|
||||
/// adaptation of tool calling modes and request formatting.
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||
pub struct ProviderCapabilities {
|
||||
/// Whether the provider supports native tool calling via API primitives.
|
||||
///
|
||||
/// When `true`, the provider can convert tool definitions to API-native
|
||||
/// formats (e.g., Gemini's functionDeclarations, Anthropic's input_schema).
|
||||
///
|
||||
/// When `false`, tools must be injected via system prompt as text.
|
||||
pub native_tool_calling: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Provider: Send + Sync {
|
||||
/// Query provider capabilities.
|
||||
|
|
@ -329,11 +321,21 @@ pub trait Provider: Send + Sync {
|
|||
/// Default implementation falls back to stream_chat_with_system with last user message.
|
||||
fn stream_chat_with_history(
|
||||
&self,
|
||||
_messages: &[ChatMessage],
|
||||
messages: &[ChatMessage],
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
_options: StreamOptions,
|
||||
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
|
||||
let _system = messages
|
||||
.iter()
|
||||
.find(|m| m.role == "system")
|
||||
.map(|m| m.content.clone());
|
||||
let _last_user = messages
|
||||
.iter()
|
||||
.rfind(|m| m.role == "user")
|
||||
.map(|m| m.content.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
// For default implementation, we need to convert to owned strings
|
||||
// This is a limitation of the default implementation
|
||||
let provider_name = "unknown".to_string();
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@
|
|||
use crate::config::AuditConfig;
|
||||
use anyhow::Result;
|
||||
use chrono::{DateTime, Utc};
|
||||
use parking_lot::Mutex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fs::OpenOptions;
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Audit event types
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue