diff --git a/Cargo.toml b/Cargo.toml index d1ba9ed..15d4665 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ name = "zeroclaw" version = "0.1.0" edition = "2021" authors = ["theonlyhennygod"] -license = "MIT" +license = "Apache-2.0" description = "Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant." repository = "https://github.com/zeroclaw-labs/zeroclaw" readme = "README.md" diff --git a/Dockerfile b/Dockerfile index e79f2d9..693e4de 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,32 +1,35 @@ -# syntax=docker/dockerfile:1 +# syntax=docker/dockerfile:1.7 # ── Stage 1: Build ──────────────────────────────────────────── -FROM rust:1.93-slim-trixie@sha256:9663b80a1621253d30b146454f903de48f0af925c967be48c84745537cd35d8b AS builder +FROM rust:1.92-slim@sha256:bf3368a992915f128293ac76917ab6e561e4dda883273c8f5c9f6f8ea37a378e AS builder WORKDIR /app # Install build dependencies -RUN apt-get update && apt-get install -y \ - pkg-config \ +RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ + --mount=type=cache,target=/var/lib/apt,sharing=locked \ + apt-get update && apt-get install -y \ + pkg-config \ && rm -rf /var/lib/apt/lists/* # 1. Copy manifests to cache dependencies COPY Cargo.toml Cargo.lock ./ # Create dummy main.rs to build dependencies RUN mkdir src && echo "fn main() {}" > src/main.rs -RUN --mount=type=cache,target=/usr/local/cargo/registry \ - --mount=type=cache,target=/usr/local/cargo/git \ +RUN --mount=type=cache,id=zeroclaw-cargo-registry,target=/usr/local/cargo/registry,sharing=locked \ + --mount=type=cache,id=zeroclaw-cargo-git,target=/usr/local/cargo/git,sharing=locked \ + --mount=type=cache,id=zeroclaw-target,target=/app/target,sharing=locked \ cargo build --release --locked RUN rm -rf src # 2. Copy source code COPY . . -# Touch main.rs to force rebuild -RUN touch src/main.rs -RUN --mount=type=cache,target=/usr/local/cargo/registry \ - --mount=type=cache,target=/usr/local/cargo/git \ +RUN --mount=type=cache,id=zeroclaw-cargo-registry,target=/usr/local/cargo/registry,sharing=locked \ + --mount=type=cache,id=zeroclaw-cargo-git,target=/usr/local/cargo/git,sharing=locked \ + --mount=type=cache,id=zeroclaw-target,target=/app/target,sharing=locked \ cargo build --release --locked && \ - strip target/release/zeroclaw + cp target/release/zeroclaw /app/zeroclaw && \ + strip /app/zeroclaw # ── Stage 2: Permissions & Config Prep ─────────────────────── FROM busybox:1.37@sha256:b3255e7dfbcd10cb367af0d409747d511aeb66dfac98cf30e97e87e4207dd76f AS permissions @@ -35,7 +38,7 @@ RUN mkdir -p /zeroclaw-data/.zeroclaw /zeroclaw-data/workspace # Create minimal config for PRODUCTION (allows binding to public interfaces) # NOTE: Provider configuration must be done via environment variables at runtime -RUN cat > /zeroclaw-data/.zeroclaw/config.toml << 'EOF' +RUN cat > /zeroclaw-data/.zeroclaw/config.toml < License: MIT Contributors - Buy Me a Coffee

Fast, small, and fully autonomous AI assistant infrastructure — deploy anywhere, swap anything. @@ -616,12 +615,6 @@ For high-throughput collaboration and consistent reviews: - CI ownership and triage map: [docs/ci-map.md](docs/ci-map.md) - Security disclosure policy: [SECURITY.md](SECURITY.md) -## Support - -ZeroClaw is an open-source project maintained with passion. If you find it useful and would like to support its continued development, hardware for testing, and coffee for the maintainer, you can support me here: - -Buy Me a Coffee - ### 🙏 Special Thanks A heartfelt thank you to the communities and institutions that inspire and fuel this open-source work: diff --git a/dev/README.md b/dev/README.md index 12fcb4b..427b566 100644 --- a/dev/README.md +++ b/dev/README.md @@ -163,5 +163,7 @@ Note: local `deny` focuses on license/source policy; advisory scanning is handle ### Build cache notes - Both `Dockerfile` and `dev/ci/Dockerfile` use BuildKit cache mounts for Cargo registry/git data. +- The root `Dockerfile` also caches Rust `target/` (`id=zeroclaw-target`) to speed repeat local image builds. - Local CI reuses named Docker volumes for Cargo registry/git and target outputs. +- `./dev/ci.sh docker-smoke` and `./dev/ci.sh all` now use `docker buildx` local cache at `.cache/buildx-smoke` when available. - The CI image keeps Rust toolchain defaults from `rust:1.92-slim` and installs pinned toolchain `1.92.0` (no custom `CARGO_HOME`/`RUSTUP_HOME` overrides), preventing repeated toolchain bootstrapping on each run. diff --git a/dev/ci.sh b/dev/ci.sh index 61bf73b..a348a19 100755 --- a/dev/ci.sh +++ b/dev/ci.sh @@ -11,12 +11,32 @@ else fi compose_cmd=(docker compose -f "$COMPOSE_FILE") +SMOKE_CACHE_DIR="${SMOKE_CACHE_DIR:-.cache/buildx-smoke}" run_in_ci() { local cmd="$1" "${compose_cmd[@]}" run --rm local-ci bash -c "$cmd" } +build_smoke_image() { + if docker buildx version >/dev/null 2>&1; then + mkdir -p "$SMOKE_CACHE_DIR" + local build_args=( + --load + --target dev + --cache-to "type=local,dest=$SMOKE_CACHE_DIR,mode=max" + -t zeroclaw-local-smoke:latest + . + ) + if [ -f "$SMOKE_CACHE_DIR/index.json" ]; then + build_args=(--cache-from "type=local,src=$SMOKE_CACHE_DIR" "${build_args[@]}") + fi + docker buildx build "${build_args[@]}" + else + DOCKER_BUILDKIT=1 docker build --target dev -t zeroclaw-local-smoke:latest . + fi +} + print_help() { cat <<'EOF' ZeroClaw Local CI in Docker @@ -88,7 +108,7 @@ case "$1" in ;; docker-smoke) - docker build --target dev -t zeroclaw-local-smoke:latest . + build_smoke_image docker run --rm zeroclaw-local-smoke:latest --version ;; @@ -98,7 +118,7 @@ case "$1" in run_in_ci "cargo build --release --locked --verbose" run_in_ci "cargo deny check licenses sources" run_in_ci "cargo audit" - docker build --target dev -t zeroclaw-local-smoke:latest . + build_smoke_image docker run --rm zeroclaw-local-smoke:latest --version ;; diff --git a/src/agent/agent.rs b/src/agent/agent.rs index 4495736..3e5693e 100644 --- a/src/agent/agent.rs +++ b/src/agent/agent.rs @@ -568,7 +568,7 @@ pub async fn run( mod tests { use super::*; use async_trait::async_trait; - use std::sync::Mutex; + use parking_lot::Mutex; struct MockProvider { responses: Mutex>, @@ -592,7 +592,7 @@ mod tests { _model: &str, _temperature: f64, ) -> Result { - let mut guard = self.responses.lock().unwrap(); + let mut guard = self.responses.lock(); if guard.is_empty() { return Ok(crate::providers::ChatResponse { text: Some("done".into()), diff --git a/src/channels/discord.rs b/src/channels/discord.rs index 10578d2..9f7d429 100644 --- a/src/channels/discord.rs +++ b/src/channels/discord.rs @@ -363,11 +363,7 @@ impl Channel for DiscordChannel { }; let message_id = d.get("id").and_then(|i| i.as_str()).unwrap_or(""); - let channel_id = d - .get("channel_id") - .and_then(|c| c.as_str()) - .unwrap_or("") - .to_string(); + let channel_id = d.get("channel_id").and_then(|c| c.as_str()).unwrap_or("").to_string(); let channel_msg = ChannelMessage { id: if message_id.is_empty() { @@ -379,10 +375,10 @@ impl Channel for DiscordChannel { reply_target: if channel_id.is_empty() { author_id.to_string() } else { - channel_id + channel_id.clone() }, - content: content.to_string(), - channel: "discord".to_string(), + content: clean_content, + channel: channel_id, timestamp: std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() diff --git a/src/channels/email_channel.rs b/src/channels/email_channel.rs index 709ba18..e59e0ac 100644 --- a/src/channels/email_channel.rs +++ b/src/channels/email_channel.rs @@ -14,11 +14,11 @@ use lettre::message::SinglePart; use lettre::transport::smtp::authentication::Credentials; use lettre::{Message, SmtpTransport, Transport}; use mail_parser::{MessageParser, MimeHeaders}; +use parking_lot::Mutex; use serde::{Deserialize, Serialize}; use std::collections::HashSet; use std::io::Write as IoWrite; use std::net::TcpStream; -use std::sync::Mutex; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use tokio::sync::mpsc; use tokio::time::{interval, sleep}; @@ -413,10 +413,7 @@ impl Channel for EmailChannel { Ok(Ok(messages)) => { for (id, sender, content, ts) in messages { { - let mut seen = self - .seen_messages - .lock() - .expect("seen_messages mutex should not be poisoned"); + let mut seen = self.seen_messages.lock(); if seen.contains(&id) { continue; } @@ -488,20 +485,14 @@ mod tests { #[test] fn seen_messages_starts_empty() { let channel = EmailChannel::new(EmailConfig::default()); - let seen = channel - .seen_messages - .lock() - .expect("seen_messages mutex should not be poisoned"); + let seen = channel.seen_messages.lock(); assert!(seen.is_empty()); } #[test] fn seen_messages_tracks_unique_ids() { let channel = EmailChannel::new(EmailConfig::default()); - let mut seen = channel - .seen_messages - .lock() - .expect("seen_messages mutex should not be poisoned"); + let mut seen = channel.seen_messages.lock(); assert!(seen.insert("first-id".to_string())); assert!(!seen.insert("first-id".to_string())); @@ -576,10 +567,7 @@ mod tests { let channel = EmailChannel::new(config.clone()); assert_eq!(channel.config.imap_host, config.imap_host); - let seen_guard = channel - .seen_messages - .lock() - .expect("seen_messages mutex should not be poisoned"); + let seen_guard = channel.seen_messages.lock(); assert_eq!(seen_guard.len(), 0); } diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 001fc35..7c618ed 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -25,9 +25,10 @@ use axum::{ routing::{get, post}, Router, }; +use parking_lot::Mutex; use std::collections::HashMap; use std::net::SocketAddr; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use std::time::{Duration, Instant}; use tower_http::limit::RequestBodyLimitLayer; use tower_http::timeout::TimeoutLayer; @@ -82,10 +83,7 @@ impl SlidingWindowRateLimiter { let now = Instant::now(); let cutoff = now.checked_sub(self.window).unwrap_or_else(Instant::now); - let mut guard = self - .requests - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); + let mut guard = self.requests.lock(); let (requests, last_sweep) = &mut *guard; // Periodic sweep: remove IPs with no recent requests @@ -150,10 +148,7 @@ impl IdempotencyStore { /// Returns true if this key is new and is now recorded. fn record_if_new(&self, key: &str) -> bool { let now = Instant::now(); - let mut keys = self - .keys - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); + let mut keys = self.keys.lock(); keys.retain(|_, seen_at| now.duration_since(*seen_at) < self.ttl); @@ -739,8 +734,8 @@ mod tests { use axum::http::HeaderValue; use axum::response::IntoResponse; use http_body_util::BodyExt; + use parking_lot::Mutex; use std::sync::atomic::{AtomicUsize, Ordering}; - use std::sync::Mutex; #[test] fn security_body_limit_is_64kb() { @@ -797,19 +792,13 @@ mod tests { assert!(limiter.allow("ip-3")); { - let guard = limiter - .requests - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); + let guard = limiter.requests.lock(); assert_eq!(guard.0.len(), 3); } // Force a sweep by backdating last_sweep { - let mut guard = limiter - .requests - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); + let mut guard = limiter.requests.lock(); guard.1 = Instant::now() .checked_sub(Duration::from_secs(RATE_LIMITER_SWEEP_INTERVAL_SECS + 1)) .unwrap(); @@ -822,10 +811,7 @@ mod tests { assert!(limiter.allow("ip-1")); { - let guard = limiter - .requests - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); + let guard = limiter.requests.lock(); assert_eq!(guard.0.len(), 1, "Stale entries should have been swept"); assert!(guard.0.contains_key("ip-1")); } @@ -962,10 +948,7 @@ mod tests { _category: MemoryCategory, _session_id: Option<&str>, ) -> anyhow::Result<()> { - self.keys - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .push(key.to_string()); + self.keys.lock().push(key.to_string()); Ok(()) } @@ -995,11 +978,7 @@ mod tests { } async fn count(&self) -> anyhow::Result { - let size = self - .keys - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .len(); + let size = self.keys.lock().len(); Ok(size) } @@ -1094,11 +1073,7 @@ mod tests { .into_response(); assert_eq!(second.status(), StatusCode::OK); - let keys = tracking_impl - .keys - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .clone(); + let keys = tracking_impl.keys.lock().clone(); assert_eq!(keys.len(), 2); assert_ne!(keys[0], keys[1]); assert!(keys[0].starts_with("webhook_msg_")); diff --git a/src/memory/lucid.rs b/src/memory/lucid.rs index 454d0dc..7ea75a0 100644 --- a/src/memory/lucid.rs +++ b/src/memory/lucid.rs @@ -2,9 +2,9 @@ use super::sqlite::SqliteMemory; use super::traits::{Memory, MemoryCategory, MemoryEntry}; use async_trait::async_trait; use chrono::Local; +use parking_lot::Mutex; use std::collections::HashSet; use std::path::{Path, PathBuf}; -use std::sync::Mutex; use std::time::{Duration, Instant}; use tokio::process::Command; use tokio::time::timeout; @@ -116,9 +116,7 @@ impl LucidMemory { } fn in_failure_cooldown(&self) -> bool { - let Ok(guard) = self.last_failure_at.lock() else { - return false; - }; + let guard = self.last_failure_at.lock(); guard .as_ref() @@ -126,15 +124,11 @@ impl LucidMemory { } fn mark_failure_now(&self) { - if let Ok(mut guard) = self.last_failure_at.lock() { - *guard = Some(Instant::now()); - } + *self.last_failure_at.lock() = Some(Instant::now()); } fn clear_failure(&self) { - if let Ok(mut guard) = self.last_failure_at.lock() { - *guard = None; - } + *self.last_failure_at.lock() = None; } fn to_lucid_type(category: &MemoryCategory) -> &'static str { @@ -565,11 +559,12 @@ exit 1 "local_note", "Local sqlite auth fallback note", MemoryCategory::Core, + None, ) .await .unwrap(); - let entries = memory.recall("auth", 5).await.unwrap(); + let entries = memory.recall("auth", 5, None).await.unwrap(); assert!(entries .iter() diff --git a/src/memory/response_cache.rs b/src/memory/response_cache.rs index e7fb3f2..62fae6c 100644 --- a/src/memory/response_cache.rs +++ b/src/memory/response_cache.rs @@ -7,10 +7,10 @@ use anyhow::Result; use chrono::{Duration, Local}; +use parking_lot::Mutex; use rusqlite::{params, Connection}; use sha2::{Digest, Sha256}; use std::path::{Path, PathBuf}; -use std::sync::Mutex; /// Response cache backed by a dedicated SQLite database. /// @@ -77,10 +77,7 @@ impl ResponseCache { /// Look up a cached response. Returns `None` on miss or expired entry. pub fn get(&self, key: &str) -> Result> { - let conn = self - .conn - .lock() - .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + let conn = self.conn.lock(); let now = Local::now(); let cutoff = (now - Duration::minutes(self.ttl_minutes)).to_rfc3339(); @@ -108,10 +105,7 @@ impl ResponseCache { /// Store a response in the cache. pub fn put(&self, key: &str, model: &str, response: &str, token_count: u32) -> Result<()> { - let conn = self - .conn - .lock() - .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + let conn = self.conn.lock(); let now = Local::now().to_rfc3339(); @@ -146,10 +140,7 @@ impl ResponseCache { /// Return cache statistics: (total_entries, total_hits, total_tokens_saved). pub fn stats(&self) -> Result<(usize, u64, u64)> { - let conn = self - .conn - .lock() - .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + let conn = self.conn.lock(); let count: i64 = conn.query_row("SELECT COUNT(*) FROM response_cache", [], |row| row.get(0))?; @@ -172,10 +163,7 @@ impl ResponseCache { /// Wipe the entire cache (useful for `zeroclaw cache clear`). pub fn clear(&self) -> Result { - let conn = self - .conn - .lock() - .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + let conn = self.conn.lock(); let affected = conn.execute("DELETE FROM response_cache", [])?; Ok(affected) diff --git a/src/memory/sqlite.rs b/src/memory/sqlite.rs index f5df9a3..b0addeb 100644 --- a/src/memory/sqlite.rs +++ b/src/memory/sqlite.rs @@ -3,9 +3,10 @@ use super::traits::{Memory, MemoryCategory, MemoryEntry}; use super::vector; use async_trait::async_trait; use chrono::Local; +use parking_lot::Mutex; use rusqlite::{params, Connection}; use std::path::{Path, PathBuf}; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use uuid::Uuid; /// SQLite-backed persistent memory — the brain @@ -185,10 +186,7 @@ impl SqliteMemory { // Check cache { - let conn = self - .conn - .lock() - .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + let conn = self.conn.lock(); let mut stmt = conn.prepare("SELECT embedding FROM embedding_cache WHERE content_hash = ?1")?; @@ -210,10 +208,7 @@ impl SqliteMemory { // Store in cache + LRU eviction { - let conn = self - .conn - .lock() - .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + let conn = self.conn.lock(); conn.execute( "INSERT OR REPLACE INTO embedding_cache (content_hash, embedding, created_at, accessed_at) @@ -315,10 +310,7 @@ impl SqliteMemory { pub async fn reindex(&self) -> anyhow::Result { // Step 1: Rebuild FTS5 { - let conn = self - .conn - .lock() - .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + let conn = self.conn.lock(); conn.execute_batch("INSERT INTO memories_fts(memories_fts) VALUES('rebuild');")?; } @@ -329,10 +321,7 @@ impl SqliteMemory { } let entries: Vec<(String, String)> = { - let conn = self - .conn - .lock() - .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + let conn = self.conn.lock(); let mut stmt = conn.prepare("SELECT id, content FROM memories WHERE embedding IS NULL")?; @@ -346,10 +335,7 @@ impl SqliteMemory { for (id, content) in &entries { if let Ok(Some(emb)) = self.get_or_compute_embedding(content).await { let bytes = vector::vec_to_bytes(&emb); - let conn = self - .conn - .lock() - .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + let conn = self.conn.lock(); conn.execute( "UPDATE memories SET embedding = ?1 WHERE id = ?2", params![bytes, id], @@ -381,10 +367,7 @@ impl Memory for SqliteMemory { .await? .map(|emb| vector::vec_to_bytes(&emb)); - let conn = self - .conn - .lock() - .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + let conn = self.conn.lock(); let now = Local::now().to_rfc3339(); let cat = Self::category_to_str(&category); let id = Uuid::new_v4().to_string(); @@ -417,10 +400,7 @@ impl Memory for SqliteMemory { // Compute query embedding (async, before lock) let query_embedding = self.get_or_compute_embedding(query).await?; - let conn = self - .conn - .lock() - .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + let conn = self.conn.lock(); // FTS5 BM25 keyword search let keyword_results = Self::fts5_search(&conn, query, limit * 2).unwrap_or_default(); @@ -539,10 +519,7 @@ impl Memory for SqliteMemory { } async fn get(&self, key: &str) -> anyhow::Result> { - let conn = self - .conn - .lock() - .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + let conn = self.conn.lock(); let mut stmt = conn.prepare( "SELECT id, key, content, category, created_at, session_id FROM memories WHERE key = ?1", @@ -571,10 +548,7 @@ impl Memory for SqliteMemory { category: Option<&MemoryCategory>, session_id: Option<&str>, ) -> anyhow::Result> { - let conn = self - .conn - .lock() - .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + let conn = self.conn.lock(); let mut results = Vec::new(); @@ -627,29 +601,20 @@ impl Memory for SqliteMemory { } async fn forget(&self, key: &str) -> anyhow::Result { - let conn = self - .conn - .lock() - .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + let conn = self.conn.lock(); let affected = conn.execute("DELETE FROM memories WHERE key = ?1", params![key])?; Ok(affected > 0) } async fn count(&self) -> anyhow::Result { - let conn = self - .conn - .lock() - .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + let conn = self.conn.lock(); let count: i64 = conn.query_row("SELECT COUNT(*) FROM memories", [], |row| row.get(0))?; #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)] Ok(count as usize) } async fn health_check(&self) -> bool { - self.conn - .lock() - .map(|c| c.execute_batch("SELECT 1").is_ok()) - .unwrap_or(false) + self.conn.lock().execute_batch("SELECT 1").is_ok() } } @@ -968,7 +933,7 @@ mod tests { #[tokio::test] async fn schema_has_fts5_table() { let (_tmp, mem) = temp_sqlite(); - let conn = mem.conn.lock().unwrap(); + let conn = mem.conn.lock(); // FTS5 table should exist let count: i64 = conn .query_row( @@ -983,7 +948,7 @@ mod tests { #[tokio::test] async fn schema_has_embedding_cache() { let (_tmp, mem) = temp_sqlite(); - let conn = mem.conn.lock().unwrap(); + let conn = mem.conn.lock(); let count: i64 = conn .query_row( "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='embedding_cache'", @@ -997,7 +962,7 @@ mod tests { #[tokio::test] async fn schema_memories_has_embedding_column() { let (_tmp, mem) = temp_sqlite(); - let conn = mem.conn.lock().unwrap(); + let conn = mem.conn.lock(); // Check that embedding column exists by querying it let result = conn.execute_batch("SELECT embedding FROM memories LIMIT 0"); assert!(result.is_ok()); @@ -1017,7 +982,7 @@ mod tests { .await .unwrap(); - let conn = mem.conn.lock().unwrap(); + let conn = mem.conn.lock(); let count: i64 = conn .query_row( "SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"unique_searchterm_xyz\"'", @@ -1041,7 +1006,7 @@ mod tests { .unwrap(); mem.forget("del_key").await.unwrap(); - let conn = mem.conn.lock().unwrap(); + let conn = mem.conn.lock(); let count: i64 = conn .query_row( "SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"deletable_content_abc\"'", @@ -1067,7 +1032,7 @@ mod tests { .await .unwrap(); - let conn = mem.conn.lock().unwrap(); + let conn = mem.conn.lock(); // Old content should not be findable let old: i64 = conn .query_row( diff --git a/src/observability/traits.rs b/src/observability/traits.rs index 6fb114f..d978304 100644 --- a/src/observability/traits.rs +++ b/src/observability/traits.rs @@ -86,7 +86,7 @@ pub trait Observer: Send + Sync + 'static { #[cfg(test)] mod tests { use super::*; - use std::sync::Mutex; + use parking_lot::Mutex; use std::time::Duration; #[derive(Default)] @@ -97,12 +97,12 @@ mod tests { impl Observer for DummyObserver { fn record_event(&self, _event: &ObserverEvent) { - let mut guard = self.events.lock().unwrap(); + let mut guard = self.events.lock(); *guard += 1; } fn record_metric(&self, _metric: &ObserverMetric) { - let mut guard = self.metrics.lock().unwrap(); + let mut guard = self.metrics.lock(); *guard += 1; } @@ -122,8 +122,8 @@ mod tests { }); observer.record_metric(&ObserverMetric::TokensUsed(42)); - assert_eq!(*observer.events.lock().unwrap(), 2); - assert_eq!(*observer.metrics.lock().unwrap(), 1); + assert_eq!(*observer.events.lock(), 2); + assert_eq!(*observer.metrics.lock(), 1); } #[test] diff --git a/src/providers/compatible.rs b/src/providers/compatible.rs index cdb0f0e..047c335 100644 --- a/src/providers/compatible.rs +++ b/src/providers/compatible.rs @@ -683,7 +683,7 @@ impl Provider for OpenAiCompatibleProvider { options: StreamOptions, ) -> stream::BoxStream<'static, StreamResult> { let credential = match self.credential.as_ref() { - Some(key) => key.clone(), + Some(value) => value.clone(), None => { let provider_name = self.name.clone(); return stream::once(async move { diff --git a/src/providers/reliable.rs b/src/providers/reliable.rs index ba7ae9a..be4818c 100644 --- a/src/providers/reliable.rs +++ b/src/providers/reliable.rs @@ -475,7 +475,7 @@ mod tests { /// Mock that records which model was used for each call. struct ModelAwareMock { calls: Arc, - models_seen: std::sync::Mutex>, + models_seen: parking_lot::Mutex>, fail_models: Vec<&'static str>, response: &'static str, } @@ -490,7 +490,7 @@ mod tests { _temperature: f64, ) -> anyhow::Result { self.calls.fetch_add(1, Ordering::SeqCst); - self.models_seen.lock().unwrap().push(model.to_string()); + self.models_seen.lock().push(model.to_string()); if self.fail_models.contains(&model) { anyhow::bail!("500 model {} unavailable", model); } @@ -743,7 +743,7 @@ mod tests { let calls = Arc::new(AtomicUsize::new(0)); let mock = Arc::new(ModelAwareMock { calls: Arc::clone(&calls), - models_seen: std::sync::Mutex::new(Vec::new()), + models_seen: parking_lot::Mutex::new(Vec::new()), fail_models: vec!["claude-opus"], response: "ok from sonnet", }); @@ -767,7 +767,7 @@ mod tests { .unwrap(); assert_eq!(result, "ok from sonnet"); - let seen = mock.models_seen.lock().unwrap(); + let seen = mock.models_seen.lock(); assert_eq!(seen.len(), 2); assert_eq!(seen[0], "claude-opus"); assert_eq!(seen[1], "claude-sonnet"); @@ -778,7 +778,7 @@ mod tests { let calls = Arc::new(AtomicUsize::new(0)); let mock = Arc::new(ModelAwareMock { calls: Arc::clone(&calls), - models_seen: std::sync::Mutex::new(Vec::new()), + models_seen: parking_lot::Mutex::new(Vec::new()), fail_models: vec!["model-a", "model-b", "model-c"], response: "never", }); @@ -802,7 +802,7 @@ mod tests { .expect_err("all models should fail"); assert!(err.to_string().contains("All providers/models failed")); - let seen = mock.models_seen.lock().unwrap(); + let seen = mock.models_seen.lock(); assert_eq!(seen.len(), 3); } diff --git a/src/providers/router.rs b/src/providers/router.rs index ccbdffb..78edde0 100644 --- a/src/providers/router.rs +++ b/src/providers/router.rs @@ -164,7 +164,7 @@ mod tests { struct MockProvider { calls: Arc, response: &'static str, - last_model: std::sync::Mutex, + last_model: parking_lot::Mutex, } impl MockProvider { @@ -172,7 +172,7 @@ mod tests { Self { calls: Arc::new(AtomicUsize::new(0)), response, - last_model: std::sync::Mutex::new(String::new()), + last_model: parking_lot::Mutex::new(String::new()), } } @@ -181,7 +181,7 @@ mod tests { } fn last_model(&self) -> String { - self.last_model.lock().unwrap().clone() + self.last_model.lock().clone() } } @@ -195,7 +195,7 @@ mod tests { _temperature: f64, ) -> anyhow::Result { self.calls.fetch_add(1, Ordering::SeqCst); - *self.last_model.lock().unwrap() = model.to_string(); + *self.last_model.lock() = model.to_string(); Ok(self.response.to_string()) } } diff --git a/src/providers/traits.rs b/src/providers/traits.rs index a6253e4..1bb296b 100644 --- a/src/providers/traits.rs +++ b/src/providers/traits.rs @@ -76,6 +76,13 @@ pub struct ChatRequest<'a> { pub tools: Option<&'a [ToolSpec]>, } +/// Declares optional provider features. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub struct ProviderCapabilities { + /// Provider can perform native tool calling without prompt-level emulation. + pub native_tool_calling: bool, +} + /// A tool result to feed back to the LLM. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolResultMessage { @@ -191,21 +198,6 @@ pub enum StreamError { Io(#[from] std::io::Error), } -/// Provider capabilities declaration. -/// -/// Describes what features a provider supports, enabling intelligent -/// adaptation of tool calling modes and request formatting. -#[derive(Debug, Clone, Default, PartialEq, Eq)] -pub struct ProviderCapabilities { - /// Whether the provider supports native tool calling via API primitives. - /// - /// When `true`, the provider can convert tool definitions to API-native - /// formats (e.g., Gemini's functionDeclarations, Anthropic's input_schema). - /// - /// When `false`, tools must be injected via system prompt as text. - pub native_tool_calling: bool, -} - #[async_trait] pub trait Provider: Send + Sync { /// Query provider capabilities. @@ -329,11 +321,21 @@ pub trait Provider: Send + Sync { /// Default implementation falls back to stream_chat_with_system with last user message. fn stream_chat_with_history( &self, - _messages: &[ChatMessage], + messages: &[ChatMessage], _model: &str, _temperature: f64, _options: StreamOptions, ) -> stream::BoxStream<'static, StreamResult> { + let _system = messages + .iter() + .find(|m| m.role == "system") + .map(|m| m.content.clone()); + let _last_user = messages + .iter() + .rfind(|m| m.role == "user") + .map(|m| m.content.clone()) + .unwrap_or_default(); + // For default implementation, we need to convert to owned strings // This is a limitation of the default implementation let provider_name = "unknown".to_string(); diff --git a/src/security/audit.rs b/src/security/audit.rs index f18208f..5eb2b42 100644 --- a/src/security/audit.rs +++ b/src/security/audit.rs @@ -3,11 +3,11 @@ use crate::config::AuditConfig; use anyhow::Result; use chrono::{DateTime, Utc}; +use parking_lot::Mutex; use serde::{Deserialize, Serialize}; use std::fs::OpenOptions; use std::io::Write; use std::path::PathBuf; -use std::sync::Mutex; use uuid::Uuid; /// Audit event types