Merge pull request #535 from zeroclaw-labs/pr-484-clean

fix: Discord channel replies and parking_lot::Mutex migration
This commit is contained in:
Will Sarg 2026-02-17 09:29:28 -05:00 committed by GitHub
commit acfdc34be2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 130 additions and 203 deletions

View file

@ -7,7 +7,7 @@ name = "zeroclaw"
version = "0.1.0"
edition = "2021"
authors = ["theonlyhennygod"]
license = "MIT"
license = "Apache-2.0"
description = "Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant."
repository = "https://github.com/zeroclaw-labs/zeroclaw"
readme = "README.md"

View file

@ -1,12 +1,14 @@
# syntax=docker/dockerfile:1
# syntax=docker/dockerfile:1.7
# ── Stage 1: Build ────────────────────────────────────────────
FROM rust:1.93-slim-trixie@sha256:9663b80a1621253d30b146454f903de48f0af925c967be48c84745537cd35d8b AS builder
FROM rust:1.92-slim@sha256:bf3368a992915f128293ac76917ab6e561e4dda883273c8f5c9f6f8ea37a378e AS builder
WORKDIR /app
# Install build dependencies
RUN apt-get update && apt-get install -y \
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
apt-get update && apt-get install -y \
pkg-config \
&& rm -rf /var/lib/apt/lists/*
@ -14,19 +16,20 @@ RUN apt-get update && apt-get install -y \
COPY Cargo.toml Cargo.lock ./
# Create dummy main.rs to build dependencies
RUN mkdir src && echo "fn main() {}" > src/main.rs
RUN --mount=type=cache,target=/usr/local/cargo/registry \
--mount=type=cache,target=/usr/local/cargo/git \
RUN --mount=type=cache,id=zeroclaw-cargo-registry,target=/usr/local/cargo/registry,sharing=locked \
--mount=type=cache,id=zeroclaw-cargo-git,target=/usr/local/cargo/git,sharing=locked \
--mount=type=cache,id=zeroclaw-target,target=/app/target,sharing=locked \
cargo build --release --locked
RUN rm -rf src
# 2. Copy source code
COPY . .
# Touch main.rs to force rebuild
RUN touch src/main.rs
RUN --mount=type=cache,target=/usr/local/cargo/registry \
--mount=type=cache,target=/usr/local/cargo/git \
RUN --mount=type=cache,id=zeroclaw-cargo-registry,target=/usr/local/cargo/registry,sharing=locked \
--mount=type=cache,id=zeroclaw-cargo-git,target=/usr/local/cargo/git,sharing=locked \
--mount=type=cache,id=zeroclaw-target,target=/app/target,sharing=locked \
cargo build --release --locked && \
strip target/release/zeroclaw
cp target/release/zeroclaw /app/zeroclaw && \
strip /app/zeroclaw
# ── Stage 2: Permissions & Config Prep ───────────────────────
FROM busybox:1.37@sha256:b3255e7dfbcd10cb367af0d409747d511aeb66dfac98cf30e97e87e4207dd76f AS permissions
@ -35,7 +38,7 @@ RUN mkdir -p /zeroclaw-data/.zeroclaw /zeroclaw-data/workspace
# Create minimal config for PRODUCTION (allows binding to public interfaces)
# NOTE: Provider configuration must be done via environment variables at runtime
RUN cat > /zeroclaw-data/.zeroclaw/config.toml << 'EOF'
RUN cat > /zeroclaw-data/.zeroclaw/config.toml <<EOF
workspace_dir = "/zeroclaw-data/workspace"
config_path = "/zeroclaw-data/.zeroclaw/config.toml"
api_key = ""
@ -65,7 +68,7 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/*
COPY --from=permissions /zeroclaw-data /zeroclaw-data
COPY --from=builder /app/target/release/zeroclaw /usr/local/bin/zeroclaw
COPY --from=builder /app/zeroclaw /usr/local/bin/zeroclaw
# Overwrite minimal config with DEV template (Ollama defaults)
COPY dev/config.template.toml /zeroclaw-data/.zeroclaw/config.toml
@ -92,7 +95,7 @@ CMD ["gateway", "--port", "3000", "--host", "[::]"]
# ── Stage 4: Production Runtime (Distroless) ─────────────────
FROM gcr.io/distroless/cc-debian13:nonroot@sha256:84fcd3c223b144b0cb6edc5ecc75641819842a9679a3a58fd6294bec47532bf7 AS release
COPY --from=builder /app/target/release/zeroclaw /usr/local/bin/zeroclaw
COPY --from=builder /app/zeroclaw /usr/local/bin/zeroclaw
COPY --from=permissions /zeroclaw-data /zeroclaw-data
# Environment setup

View file

@ -12,7 +12,6 @@
<p align="center">
<a href="LICENSE"><img src="https://img.shields.io/badge/license-MIT-blue.svg" alt="License: MIT" /></a>
<a href="NOTICE"><img src="https://img.shields.io/badge/contributors-27+-green.svg" alt="Contributors" /></a>
<a href="https://buymeacoffee.com/argenistherose"><img src="https://img.shields.io/badge/Buy%20Me%20a%20Coffee-Donate-yellow.svg?style=flat&logo=buy-me-a-coffee" alt="Buy Me a Coffee" /></a>
</p>
Fast, small, and fully autonomous AI assistant infrastructure — deploy anywhere, swap anything.
@ -616,12 +615,6 @@ For high-throughput collaboration and consistent reviews:
- CI ownership and triage map: [docs/ci-map.md](docs/ci-map.md)
- Security disclosure policy: [SECURITY.md](SECURITY.md)
## Support
ZeroClaw is an open-source project maintained with passion. If you find it useful and would like to support its continued development, hardware for testing, and coffee for the maintainer, you can support me here:
<a href="https://buymeacoffee.com/argenistherose"><img src="https://img.shields.io/badge/Buy%20Me%20a%20Coffee-Donate-yellow.svg?style=for-the-badge&logo=buy-me-a-coffee" alt="Buy Me a Coffee" /></a>
### 🙏 Special Thanks
A heartfelt thank you to the communities and institutions that inspire and fuel this open-source work:

View file

@ -163,5 +163,7 @@ Note: local `deny` focuses on license/source policy; advisory scanning is handle
### Build cache notes
- Both `Dockerfile` and `dev/ci/Dockerfile` use BuildKit cache mounts for Cargo registry/git data.
- The root `Dockerfile` also caches Rust `target/` (`id=zeroclaw-target`) to speed repeat local image builds.
- Local CI reuses named Docker volumes for Cargo registry/git and target outputs.
- `./dev/ci.sh docker-smoke` and `./dev/ci.sh all` now use `docker buildx` local cache at `.cache/buildx-smoke` when available.
- The CI image keeps Rust toolchain defaults from `rust:1.92-slim` and installs pinned toolchain `1.92.0` (no custom `CARGO_HOME`/`RUSTUP_HOME` overrides), preventing repeated toolchain bootstrapping on each run.

View file

@ -11,12 +11,32 @@ else
fi
compose_cmd=(docker compose -f "$COMPOSE_FILE")
SMOKE_CACHE_DIR="${SMOKE_CACHE_DIR:-.cache/buildx-smoke}"
run_in_ci() {
local cmd="$1"
"${compose_cmd[@]}" run --rm local-ci bash -c "$cmd"
}
build_smoke_image() {
if docker buildx version >/dev/null 2>&1; then
mkdir -p "$SMOKE_CACHE_DIR"
local build_args=(
--load
--target dev
--cache-to "type=local,dest=$SMOKE_CACHE_DIR,mode=max"
-t zeroclaw-local-smoke:latest
.
)
if [ -f "$SMOKE_CACHE_DIR/index.json" ]; then
build_args=(--cache-from "type=local,src=$SMOKE_CACHE_DIR" "${build_args[@]}")
fi
docker buildx build "${build_args[@]}"
else
DOCKER_BUILDKIT=1 docker build --target dev -t zeroclaw-local-smoke:latest .
fi
}
print_help() {
cat <<'EOF'
ZeroClaw Local CI in Docker
@ -88,7 +108,7 @@ case "$1" in
;;
docker-smoke)
docker build --target dev -t zeroclaw-local-smoke:latest .
build_smoke_image
docker run --rm zeroclaw-local-smoke:latest --version
;;
@ -98,7 +118,7 @@ case "$1" in
run_in_ci "cargo build --release --locked --verbose"
run_in_ci "cargo deny check licenses sources"
run_in_ci "cargo audit"
docker build --target dev -t zeroclaw-local-smoke:latest .
build_smoke_image
docker run --rm zeroclaw-local-smoke:latest --version
;;

View file

@ -568,7 +568,7 @@ pub async fn run(
mod tests {
use super::*;
use async_trait::async_trait;
use std::sync::Mutex;
use parking_lot::Mutex;
struct MockProvider {
responses: Mutex<Vec<crate::providers::ChatResponse>>,
@ -592,7 +592,7 @@ mod tests {
_model: &str,
_temperature: f64,
) -> Result<crate::providers::ChatResponse> {
let mut guard = self.responses.lock().unwrap();
let mut guard = self.responses.lock();
if guard.is_empty() {
return Ok(crate::providers::ChatResponse {
text: Some("done".into()),

View file

@ -363,11 +363,7 @@ impl Channel for DiscordChannel {
};
let message_id = d.get("id").and_then(|i| i.as_str()).unwrap_or("");
let channel_id = d
.get("channel_id")
.and_then(|c| c.as_str())
.unwrap_or("")
.to_string();
let channel_id = d.get("channel_id").and_then(|c| c.as_str()).unwrap_or("").to_string();
let channel_msg = ChannelMessage {
id: if message_id.is_empty() {
@ -379,10 +375,10 @@ impl Channel for DiscordChannel {
reply_target: if channel_id.is_empty() {
author_id.to_string()
} else {
channel_id
channel_id.clone()
},
content: content.to_string(),
channel: "discord".to_string(),
content: clean_content,
channel: channel_id,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()

View file

@ -14,11 +14,11 @@ use lettre::message::SinglePart;
use lettre::transport::smtp::authentication::Credentials;
use lettre::{Message, SmtpTransport, Transport};
use mail_parser::{MessageParser, MimeHeaders};
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::io::Write as IoWrite;
use std::net::TcpStream;
use std::sync::Mutex;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::mpsc;
use tokio::time::{interval, sleep};
@ -413,10 +413,7 @@ impl Channel for EmailChannel {
Ok(Ok(messages)) => {
for (id, sender, content, ts) in messages {
{
let mut seen = self
.seen_messages
.lock()
.expect("seen_messages mutex should not be poisoned");
let mut seen = self.seen_messages.lock();
if seen.contains(&id) {
continue;
}
@ -488,20 +485,14 @@ mod tests {
#[test]
fn seen_messages_starts_empty() {
let channel = EmailChannel::new(EmailConfig::default());
let seen = channel
.seen_messages
.lock()
.expect("seen_messages mutex should not be poisoned");
let seen = channel.seen_messages.lock();
assert!(seen.is_empty());
}
#[test]
fn seen_messages_tracks_unique_ids() {
let channel = EmailChannel::new(EmailConfig::default());
let mut seen = channel
.seen_messages
.lock()
.expect("seen_messages mutex should not be poisoned");
let mut seen = channel.seen_messages.lock();
assert!(seen.insert("first-id".to_string()));
assert!(!seen.insert("first-id".to_string()));
@ -576,10 +567,7 @@ mod tests {
let channel = EmailChannel::new(config.clone());
assert_eq!(channel.config.imap_host, config.imap_host);
let seen_guard = channel
.seen_messages
.lock()
.expect("seen_messages mutex should not be poisoned");
let seen_guard = channel.seen_messages.lock();
assert_eq!(seen_guard.len(), 0);
}

View file

@ -25,9 +25,10 @@ use axum::{
routing::{get, post},
Router,
};
use parking_lot::Mutex;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tower_http::limit::RequestBodyLimitLayer;
use tower_http::timeout::TimeoutLayer;
@ -82,10 +83,7 @@ impl SlidingWindowRateLimiter {
let now = Instant::now();
let cutoff = now.checked_sub(self.window).unwrap_or_else(Instant::now);
let mut guard = self
.requests
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let mut guard = self.requests.lock();
let (requests, last_sweep) = &mut *guard;
// Periodic sweep: remove IPs with no recent requests
@ -150,10 +148,7 @@ impl IdempotencyStore {
/// Returns true if this key is new and is now recorded.
fn record_if_new(&self, key: &str) -> bool {
let now = Instant::now();
let mut keys = self
.keys
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let mut keys = self.keys.lock();
keys.retain(|_, seen_at| now.duration_since(*seen_at) < self.ttl);
@ -739,8 +734,8 @@ mod tests {
use axum::http::HeaderValue;
use axum::response::IntoResponse;
use http_body_util::BodyExt;
use parking_lot::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Mutex;
#[test]
fn security_body_limit_is_64kb() {
@ -797,19 +792,13 @@ mod tests {
assert!(limiter.allow("ip-3"));
{
let guard = limiter
.requests
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let guard = limiter.requests.lock();
assert_eq!(guard.0.len(), 3);
}
// Force a sweep by backdating last_sweep
{
let mut guard = limiter
.requests
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let mut guard = limiter.requests.lock();
guard.1 = Instant::now()
.checked_sub(Duration::from_secs(RATE_LIMITER_SWEEP_INTERVAL_SECS + 1))
.unwrap();
@ -822,10 +811,7 @@ mod tests {
assert!(limiter.allow("ip-1"));
{
let guard = limiter
.requests
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let guard = limiter.requests.lock();
assert_eq!(guard.0.len(), 1, "Stale entries should have been swept");
assert!(guard.0.contains_key("ip-1"));
}
@ -962,10 +948,7 @@ mod tests {
_category: MemoryCategory,
_session_id: Option<&str>,
) -> anyhow::Result<()> {
self.keys
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.push(key.to_string());
self.keys.lock().push(key.to_string());
Ok(())
}
@ -995,11 +978,7 @@ mod tests {
}
async fn count(&self) -> anyhow::Result<usize> {
let size = self
.keys
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.len();
let size = self.keys.lock().len();
Ok(size)
}
@ -1094,11 +1073,7 @@ mod tests {
.into_response();
assert_eq!(second.status(), StatusCode::OK);
let keys = tracking_impl
.keys
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.clone();
let keys = tracking_impl.keys.lock().clone();
assert_eq!(keys.len(), 2);
assert_ne!(keys[0], keys[1]);
assert!(keys[0].starts_with("webhook_msg_"));

View file

@ -2,9 +2,9 @@ use super::sqlite::SqliteMemory;
use super::traits::{Memory, MemoryCategory, MemoryEntry};
use async_trait::async_trait;
use chrono::Local;
use parking_lot::Mutex;
use std::collections::HashSet;
use std::path::{Path, PathBuf};
use std::sync::Mutex;
use std::time::{Duration, Instant};
use tokio::process::Command;
use tokio::time::timeout;
@ -116,9 +116,7 @@ impl LucidMemory {
}
fn in_failure_cooldown(&self) -> bool {
let Ok(guard) = self.last_failure_at.lock() else {
return false;
};
let guard = self.last_failure_at.lock();
guard
.as_ref()
@ -126,15 +124,11 @@ impl LucidMemory {
}
fn mark_failure_now(&self) {
if let Ok(mut guard) = self.last_failure_at.lock() {
*guard = Some(Instant::now());
}
*self.last_failure_at.lock() = Some(Instant::now());
}
fn clear_failure(&self) {
if let Ok(mut guard) = self.last_failure_at.lock() {
*guard = None;
}
*self.last_failure_at.lock() = None;
}
fn to_lucid_type(category: &MemoryCategory) -> &'static str {
@ -565,11 +559,12 @@ exit 1
"local_note",
"Local sqlite auth fallback note",
MemoryCategory::Core,
None,
)
.await
.unwrap();
let entries = memory.recall("auth", 5).await.unwrap();
let entries = memory.recall("auth", 5, None).await.unwrap();
assert!(entries
.iter()

View file

@ -7,10 +7,10 @@
use anyhow::Result;
use chrono::{Duration, Local};
use parking_lot::Mutex;
use rusqlite::{params, Connection};
use sha2::{Digest, Sha256};
use std::path::{Path, PathBuf};
use std::sync::Mutex;
/// Response cache backed by a dedicated SQLite database.
///
@ -77,10 +77,7 @@ impl ResponseCache {
/// Look up a cached response. Returns `None` on miss or expired entry.
pub fn get(&self, key: &str) -> Result<Option<String>> {
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let conn = self.conn.lock();
let now = Local::now();
let cutoff = (now - Duration::minutes(self.ttl_minutes)).to_rfc3339();
@ -108,10 +105,7 @@ impl ResponseCache {
/// Store a response in the cache.
pub fn put(&self, key: &str, model: &str, response: &str, token_count: u32) -> Result<()> {
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let conn = self.conn.lock();
let now = Local::now().to_rfc3339();
@ -146,10 +140,7 @@ impl ResponseCache {
/// Return cache statistics: (total_entries, total_hits, total_tokens_saved).
pub fn stats(&self) -> Result<(usize, u64, u64)> {
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let conn = self.conn.lock();
let count: i64 =
conn.query_row("SELECT COUNT(*) FROM response_cache", [], |row| row.get(0))?;
@ -172,10 +163,7 @@ impl ResponseCache {
/// Wipe the entire cache (useful for `zeroclaw cache clear`).
pub fn clear(&self) -> Result<usize> {
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let conn = self.conn.lock();
let affected = conn.execute("DELETE FROM response_cache", [])?;
Ok(affected)

View file

@ -3,9 +3,10 @@ use super::traits::{Memory, MemoryCategory, MemoryEntry};
use super::vector;
use async_trait::async_trait;
use chrono::Local;
use parking_lot::Mutex;
use rusqlite::{params, Connection};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use std::sync::Arc;
use uuid::Uuid;
/// SQLite-backed persistent memory — the brain
@ -185,10 +186,7 @@ impl SqliteMemory {
// Check cache
{
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let conn = self.conn.lock();
let mut stmt =
conn.prepare("SELECT embedding FROM embedding_cache WHERE content_hash = ?1")?;
@ -210,10 +208,7 @@ impl SqliteMemory {
// Store in cache + LRU eviction
{
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let conn = self.conn.lock();
conn.execute(
"INSERT OR REPLACE INTO embedding_cache (content_hash, embedding, created_at, accessed_at)
@ -315,10 +310,7 @@ impl SqliteMemory {
pub async fn reindex(&self) -> anyhow::Result<usize> {
// Step 1: Rebuild FTS5
{
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let conn = self.conn.lock();
conn.execute_batch("INSERT INTO memories_fts(memories_fts) VALUES('rebuild');")?;
}
@ -329,10 +321,7 @@ impl SqliteMemory {
}
let entries: Vec<(String, String)> = {
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let conn = self.conn.lock();
let mut stmt =
conn.prepare("SELECT id, content FROM memories WHERE embedding IS NULL")?;
@ -346,10 +335,7 @@ impl SqliteMemory {
for (id, content) in &entries {
if let Ok(Some(emb)) = self.get_or_compute_embedding(content).await {
let bytes = vector::vec_to_bytes(&emb);
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let conn = self.conn.lock();
conn.execute(
"UPDATE memories SET embedding = ?1 WHERE id = ?2",
params![bytes, id],
@ -381,10 +367,7 @@ impl Memory for SqliteMemory {
.await?
.map(|emb| vector::vec_to_bytes(&emb));
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let conn = self.conn.lock();
let now = Local::now().to_rfc3339();
let cat = Self::category_to_str(&category);
let id = Uuid::new_v4().to_string();
@ -417,10 +400,7 @@ impl Memory for SqliteMemory {
// Compute query embedding (async, before lock)
let query_embedding = self.get_or_compute_embedding(query).await?;
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let conn = self.conn.lock();
// FTS5 BM25 keyword search
let keyword_results = Self::fts5_search(&conn, query, limit * 2).unwrap_or_default();
@ -539,10 +519,7 @@ impl Memory for SqliteMemory {
}
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let conn = self.conn.lock();
let mut stmt = conn.prepare(
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE key = ?1",
@ -571,10 +548,7 @@ impl Memory for SqliteMemory {
category: Option<&MemoryCategory>,
session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let conn = self.conn.lock();
let mut results = Vec::new();
@ -627,29 +601,20 @@ impl Memory for SqliteMemory {
}
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let conn = self.conn.lock();
let affected = conn.execute("DELETE FROM memories WHERE key = ?1", params![key])?;
Ok(affected > 0)
}
async fn count(&self) -> anyhow::Result<usize> {
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let conn = self.conn.lock();
let count: i64 = conn.query_row("SELECT COUNT(*) FROM memories", [], |row| row.get(0))?;
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
Ok(count as usize)
}
async fn health_check(&self) -> bool {
self.conn
.lock()
.map(|c| c.execute_batch("SELECT 1").is_ok())
.unwrap_or(false)
self.conn.lock().execute_batch("SELECT 1").is_ok()
}
}
@ -968,7 +933,7 @@ mod tests {
#[tokio::test]
async fn schema_has_fts5_table() {
let (_tmp, mem) = temp_sqlite();
let conn = mem.conn.lock().unwrap();
let conn = mem.conn.lock();
// FTS5 table should exist
let count: i64 = conn
.query_row(
@ -983,7 +948,7 @@ mod tests {
#[tokio::test]
async fn schema_has_embedding_cache() {
let (_tmp, mem) = temp_sqlite();
let conn = mem.conn.lock().unwrap();
let conn = mem.conn.lock();
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='embedding_cache'",
@ -997,7 +962,7 @@ mod tests {
#[tokio::test]
async fn schema_memories_has_embedding_column() {
let (_tmp, mem) = temp_sqlite();
let conn = mem.conn.lock().unwrap();
let conn = mem.conn.lock();
// Check that embedding column exists by querying it
let result = conn.execute_batch("SELECT embedding FROM memories LIMIT 0");
assert!(result.is_ok());
@ -1017,7 +982,7 @@ mod tests {
.await
.unwrap();
let conn = mem.conn.lock().unwrap();
let conn = mem.conn.lock();
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"unique_searchterm_xyz\"'",
@ -1041,7 +1006,7 @@ mod tests {
.unwrap();
mem.forget("del_key").await.unwrap();
let conn = mem.conn.lock().unwrap();
let conn = mem.conn.lock();
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"deletable_content_abc\"'",
@ -1067,7 +1032,7 @@ mod tests {
.await
.unwrap();
let conn = mem.conn.lock().unwrap();
let conn = mem.conn.lock();
// Old content should not be findable
let old: i64 = conn
.query_row(

View file

@ -86,7 +86,7 @@ pub trait Observer: Send + Sync + 'static {
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
use parking_lot::Mutex;
use std::time::Duration;
#[derive(Default)]
@ -97,12 +97,12 @@ mod tests {
impl Observer for DummyObserver {
fn record_event(&self, _event: &ObserverEvent) {
let mut guard = self.events.lock().unwrap();
let mut guard = self.events.lock();
*guard += 1;
}
fn record_metric(&self, _metric: &ObserverMetric) {
let mut guard = self.metrics.lock().unwrap();
let mut guard = self.metrics.lock();
*guard += 1;
}
@ -122,8 +122,8 @@ mod tests {
});
observer.record_metric(&ObserverMetric::TokensUsed(42));
assert_eq!(*observer.events.lock().unwrap(), 2);
assert_eq!(*observer.metrics.lock().unwrap(), 1);
assert_eq!(*observer.events.lock(), 2);
assert_eq!(*observer.metrics.lock(), 1);
}
#[test]

View file

@ -683,7 +683,7 @@ impl Provider for OpenAiCompatibleProvider {
options: StreamOptions,
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
let credential = match self.credential.as_ref() {
Some(key) => key.clone(),
Some(value) => value.clone(),
None => {
let provider_name = self.name.clone();
return stream::once(async move {

View file

@ -475,7 +475,7 @@ mod tests {
/// Mock that records which model was used for each call.
struct ModelAwareMock {
calls: Arc<AtomicUsize>,
models_seen: std::sync::Mutex<Vec<String>>,
models_seen: parking_lot::Mutex<Vec<String>>,
fail_models: Vec<&'static str>,
response: &'static str,
}
@ -490,7 +490,7 @@ mod tests {
_temperature: f64,
) -> anyhow::Result<String> {
self.calls.fetch_add(1, Ordering::SeqCst);
self.models_seen.lock().unwrap().push(model.to_string());
self.models_seen.lock().push(model.to_string());
if self.fail_models.contains(&model) {
anyhow::bail!("500 model {} unavailable", model);
}
@ -743,7 +743,7 @@ mod tests {
let calls = Arc::new(AtomicUsize::new(0));
let mock = Arc::new(ModelAwareMock {
calls: Arc::clone(&calls),
models_seen: std::sync::Mutex::new(Vec::new()),
models_seen: parking_lot::Mutex::new(Vec::new()),
fail_models: vec!["claude-opus"],
response: "ok from sonnet",
});
@ -767,7 +767,7 @@ mod tests {
.unwrap();
assert_eq!(result, "ok from sonnet");
let seen = mock.models_seen.lock().unwrap();
let seen = mock.models_seen.lock();
assert_eq!(seen.len(), 2);
assert_eq!(seen[0], "claude-opus");
assert_eq!(seen[1], "claude-sonnet");
@ -778,7 +778,7 @@ mod tests {
let calls = Arc::new(AtomicUsize::new(0));
let mock = Arc::new(ModelAwareMock {
calls: Arc::clone(&calls),
models_seen: std::sync::Mutex::new(Vec::new()),
models_seen: parking_lot::Mutex::new(Vec::new()),
fail_models: vec!["model-a", "model-b", "model-c"],
response: "never",
});
@ -802,7 +802,7 @@ mod tests {
.expect_err("all models should fail");
assert!(err.to_string().contains("All providers/models failed"));
let seen = mock.models_seen.lock().unwrap();
let seen = mock.models_seen.lock();
assert_eq!(seen.len(), 3);
}

View file

@ -164,7 +164,7 @@ mod tests {
struct MockProvider {
calls: Arc<AtomicUsize>,
response: &'static str,
last_model: std::sync::Mutex<String>,
last_model: parking_lot::Mutex<String>,
}
impl MockProvider {
@ -172,7 +172,7 @@ mod tests {
Self {
calls: Arc::new(AtomicUsize::new(0)),
response,
last_model: std::sync::Mutex::new(String::new()),
last_model: parking_lot::Mutex::new(String::new()),
}
}
@ -181,7 +181,7 @@ mod tests {
}
fn last_model(&self) -> String {
self.last_model.lock().unwrap().clone()
self.last_model.lock().clone()
}
}
@ -195,7 +195,7 @@ mod tests {
_temperature: f64,
) -> anyhow::Result<String> {
self.calls.fetch_add(1, Ordering::SeqCst);
*self.last_model.lock().unwrap() = model.to_string();
*self.last_model.lock() = model.to_string();
Ok(self.response.to_string())
}
}

View file

@ -76,6 +76,13 @@ pub struct ChatRequest<'a> {
pub tools: Option<&'a [ToolSpec]>,
}
/// Declares optional provider features.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct ProviderCapabilities {
/// Provider can perform native tool calling without prompt-level emulation.
pub native_tool_calling: bool,
}
/// A tool result to feed back to the LLM.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResultMessage {
@ -191,21 +198,6 @@ pub enum StreamError {
Io(#[from] std::io::Error),
}
/// Provider capabilities declaration.
///
/// Describes what features a provider supports, enabling intelligent
/// adaptation of tool calling modes and request formatting.
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct ProviderCapabilities {
/// Whether the provider supports native tool calling via API primitives.
///
/// When `true`, the provider can convert tool definitions to API-native
/// formats (e.g., Gemini's functionDeclarations, Anthropic's input_schema).
///
/// When `false`, tools must be injected via system prompt as text.
pub native_tool_calling: bool,
}
#[async_trait]
pub trait Provider: Send + Sync {
/// Query provider capabilities.
@ -329,11 +321,21 @@ pub trait Provider: Send + Sync {
/// Default implementation falls back to stream_chat_with_system with last user message.
fn stream_chat_with_history(
&self,
_messages: &[ChatMessage],
messages: &[ChatMessage],
_model: &str,
_temperature: f64,
_options: StreamOptions,
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
let _system = messages
.iter()
.find(|m| m.role == "system")
.map(|m| m.content.clone());
let _last_user = messages
.iter()
.rfind(|m| m.role == "user")
.map(|m| m.content.clone())
.unwrap_or_default();
// For default implementation, we need to convert to owned strings
// This is a limitation of the default implementation
let provider_name = "unknown".to_string();

View file

@ -3,11 +3,11 @@
use crate::config::AuditConfig;
use anyhow::Result;
use chrono::{DateTime, Utc};
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use std::fs::OpenOptions;
use std::io::Write;
use std::path::PathBuf;
use std::sync::Mutex;
use uuid::Uuid;
/// Audit event types