Merge pull request #535 from zeroclaw-labs/pr-484-clean
fix: Discord channel replies and parking_lot::Mutex migration
This commit is contained in:
commit
acfdc34be2
18 changed files with 130 additions and 203 deletions
|
|
@ -7,7 +7,7 @@ name = "zeroclaw"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
authors = ["theonlyhennygod"]
|
authors = ["theonlyhennygod"]
|
||||||
license = "MIT"
|
license = "Apache-2.0"
|
||||||
description = "Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant."
|
description = "Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant."
|
||||||
repository = "https://github.com/zeroclaw-labs/zeroclaw"
|
repository = "https://github.com/zeroclaw-labs/zeroclaw"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
|
||||||
31
Dockerfile
31
Dockerfile
|
|
@ -1,32 +1,35 @@
|
||||||
# syntax=docker/dockerfile:1
|
# syntax=docker/dockerfile:1.7
|
||||||
|
|
||||||
# ── Stage 1: Build ────────────────────────────────────────────
|
# ── Stage 1: Build ────────────────────────────────────────────
|
||||||
FROM rust:1.93-slim-trixie@sha256:9663b80a1621253d30b146454f903de48f0af925c967be48c84745537cd35d8b AS builder
|
FROM rust:1.92-slim@sha256:bf3368a992915f128293ac76917ab6e561e4dda883273c8f5c9f6f8ea37a378e AS builder
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# Install build dependencies
|
# Install build dependencies
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||||
pkg-config \
|
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||||
|
apt-get update && apt-get install -y \
|
||||||
|
pkg-config \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# 1. Copy manifests to cache dependencies
|
# 1. Copy manifests to cache dependencies
|
||||||
COPY Cargo.toml Cargo.lock ./
|
COPY Cargo.toml Cargo.lock ./
|
||||||
# Create dummy main.rs to build dependencies
|
# Create dummy main.rs to build dependencies
|
||||||
RUN mkdir src && echo "fn main() {}" > src/main.rs
|
RUN mkdir src && echo "fn main() {}" > src/main.rs
|
||||||
RUN --mount=type=cache,target=/usr/local/cargo/registry \
|
RUN --mount=type=cache,id=zeroclaw-cargo-registry,target=/usr/local/cargo/registry,sharing=locked \
|
||||||
--mount=type=cache,target=/usr/local/cargo/git \
|
--mount=type=cache,id=zeroclaw-cargo-git,target=/usr/local/cargo/git,sharing=locked \
|
||||||
|
--mount=type=cache,id=zeroclaw-target,target=/app/target,sharing=locked \
|
||||||
cargo build --release --locked
|
cargo build --release --locked
|
||||||
RUN rm -rf src
|
RUN rm -rf src
|
||||||
|
|
||||||
# 2. Copy source code
|
# 2. Copy source code
|
||||||
COPY . .
|
COPY . .
|
||||||
# Touch main.rs to force rebuild
|
RUN --mount=type=cache,id=zeroclaw-cargo-registry,target=/usr/local/cargo/registry,sharing=locked \
|
||||||
RUN touch src/main.rs
|
--mount=type=cache,id=zeroclaw-cargo-git,target=/usr/local/cargo/git,sharing=locked \
|
||||||
RUN --mount=type=cache,target=/usr/local/cargo/registry \
|
--mount=type=cache,id=zeroclaw-target,target=/app/target,sharing=locked \
|
||||||
--mount=type=cache,target=/usr/local/cargo/git \
|
|
||||||
cargo build --release --locked && \
|
cargo build --release --locked && \
|
||||||
strip target/release/zeroclaw
|
cp target/release/zeroclaw /app/zeroclaw && \
|
||||||
|
strip /app/zeroclaw
|
||||||
|
|
||||||
# ── Stage 2: Permissions & Config Prep ───────────────────────
|
# ── Stage 2: Permissions & Config Prep ───────────────────────
|
||||||
FROM busybox:1.37@sha256:b3255e7dfbcd10cb367af0d409747d511aeb66dfac98cf30e97e87e4207dd76f AS permissions
|
FROM busybox:1.37@sha256:b3255e7dfbcd10cb367af0d409747d511aeb66dfac98cf30e97e87e4207dd76f AS permissions
|
||||||
|
|
@ -35,7 +38,7 @@ RUN mkdir -p /zeroclaw-data/.zeroclaw /zeroclaw-data/workspace
|
||||||
|
|
||||||
# Create minimal config for PRODUCTION (allows binding to public interfaces)
|
# Create minimal config for PRODUCTION (allows binding to public interfaces)
|
||||||
# NOTE: Provider configuration must be done via environment variables at runtime
|
# NOTE: Provider configuration must be done via environment variables at runtime
|
||||||
RUN cat > /zeroclaw-data/.zeroclaw/config.toml << 'EOF'
|
RUN cat > /zeroclaw-data/.zeroclaw/config.toml <<EOF
|
||||||
workspace_dir = "/zeroclaw-data/workspace"
|
workspace_dir = "/zeroclaw-data/workspace"
|
||||||
config_path = "/zeroclaw-data/.zeroclaw/config.toml"
|
config_path = "/zeroclaw-data/.zeroclaw/config.toml"
|
||||||
api_key = ""
|
api_key = ""
|
||||||
|
|
@ -65,7 +68,7 @@ RUN apt-get update && apt-get install -y \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
COPY --from=permissions /zeroclaw-data /zeroclaw-data
|
COPY --from=permissions /zeroclaw-data /zeroclaw-data
|
||||||
COPY --from=builder /app/target/release/zeroclaw /usr/local/bin/zeroclaw
|
COPY --from=builder /app/zeroclaw /usr/local/bin/zeroclaw
|
||||||
|
|
||||||
# Overwrite minimal config with DEV template (Ollama defaults)
|
# Overwrite minimal config with DEV template (Ollama defaults)
|
||||||
COPY dev/config.template.toml /zeroclaw-data/.zeroclaw/config.toml
|
COPY dev/config.template.toml /zeroclaw-data/.zeroclaw/config.toml
|
||||||
|
|
@ -92,7 +95,7 @@ CMD ["gateway", "--port", "3000", "--host", "[::]"]
|
||||||
# ── Stage 4: Production Runtime (Distroless) ─────────────────
|
# ── Stage 4: Production Runtime (Distroless) ─────────────────
|
||||||
FROM gcr.io/distroless/cc-debian13:nonroot@sha256:84fcd3c223b144b0cb6edc5ecc75641819842a9679a3a58fd6294bec47532bf7 AS release
|
FROM gcr.io/distroless/cc-debian13:nonroot@sha256:84fcd3c223b144b0cb6edc5ecc75641819842a9679a3a58fd6294bec47532bf7 AS release
|
||||||
|
|
||||||
COPY --from=builder /app/target/release/zeroclaw /usr/local/bin/zeroclaw
|
COPY --from=builder /app/zeroclaw /usr/local/bin/zeroclaw
|
||||||
COPY --from=permissions /zeroclaw-data /zeroclaw-data
|
COPY --from=permissions /zeroclaw-data /zeroclaw-data
|
||||||
|
|
||||||
# Environment setup
|
# Environment setup
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,6 @@
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<a href="LICENSE"><img src="https://img.shields.io/badge/license-MIT-blue.svg" alt="License: MIT" /></a>
|
<a href="LICENSE"><img src="https://img.shields.io/badge/license-MIT-blue.svg" alt="License: MIT" /></a>
|
||||||
<a href="NOTICE"><img src="https://img.shields.io/badge/contributors-27+-green.svg" alt="Contributors" /></a>
|
<a href="NOTICE"><img src="https://img.shields.io/badge/contributors-27+-green.svg" alt="Contributors" /></a>
|
||||||
<a href="https://buymeacoffee.com/argenistherose"><img src="https://img.shields.io/badge/Buy%20Me%20a%20Coffee-Donate-yellow.svg?style=flat&logo=buy-me-a-coffee" alt="Buy Me a Coffee" /></a>
|
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
Fast, small, and fully autonomous AI assistant infrastructure — deploy anywhere, swap anything.
|
Fast, small, and fully autonomous AI assistant infrastructure — deploy anywhere, swap anything.
|
||||||
|
|
@ -616,12 +615,6 @@ For high-throughput collaboration and consistent reviews:
|
||||||
- CI ownership and triage map: [docs/ci-map.md](docs/ci-map.md)
|
- CI ownership and triage map: [docs/ci-map.md](docs/ci-map.md)
|
||||||
- Security disclosure policy: [SECURITY.md](SECURITY.md)
|
- Security disclosure policy: [SECURITY.md](SECURITY.md)
|
||||||
|
|
||||||
## Support
|
|
||||||
|
|
||||||
ZeroClaw is an open-source project maintained with passion. If you find it useful and would like to support its continued development, hardware for testing, and coffee for the maintainer, you can support me here:
|
|
||||||
|
|
||||||
<a href="https://buymeacoffee.com/argenistherose"><img src="https://img.shields.io/badge/Buy%20Me%20a%20Coffee-Donate-yellow.svg?style=for-the-badge&logo=buy-me-a-coffee" alt="Buy Me a Coffee" /></a>
|
|
||||||
|
|
||||||
### 🙏 Special Thanks
|
### 🙏 Special Thanks
|
||||||
|
|
||||||
A heartfelt thank you to the communities and institutions that inspire and fuel this open-source work:
|
A heartfelt thank you to the communities and institutions that inspire and fuel this open-source work:
|
||||||
|
|
|
||||||
|
|
@ -163,5 +163,7 @@ Note: local `deny` focuses on license/source policy; advisory scanning is handle
|
||||||
### Build cache notes
|
### Build cache notes
|
||||||
|
|
||||||
- Both `Dockerfile` and `dev/ci/Dockerfile` use BuildKit cache mounts for Cargo registry/git data.
|
- Both `Dockerfile` and `dev/ci/Dockerfile` use BuildKit cache mounts for Cargo registry/git data.
|
||||||
|
- The root `Dockerfile` also caches Rust `target/` (`id=zeroclaw-target`) to speed repeat local image builds.
|
||||||
- Local CI reuses named Docker volumes for Cargo registry/git and target outputs.
|
- Local CI reuses named Docker volumes for Cargo registry/git and target outputs.
|
||||||
|
- `./dev/ci.sh docker-smoke` and `./dev/ci.sh all` now use `docker buildx` local cache at `.cache/buildx-smoke` when available.
|
||||||
- The CI image keeps Rust toolchain defaults from `rust:1.92-slim` and installs pinned toolchain `1.92.0` (no custom `CARGO_HOME`/`RUSTUP_HOME` overrides), preventing repeated toolchain bootstrapping on each run.
|
- The CI image keeps Rust toolchain defaults from `rust:1.92-slim` and installs pinned toolchain `1.92.0` (no custom `CARGO_HOME`/`RUSTUP_HOME` overrides), preventing repeated toolchain bootstrapping on each run.
|
||||||
|
|
|
||||||
24
dev/ci.sh
24
dev/ci.sh
|
|
@ -11,12 +11,32 @@ else
|
||||||
fi
|
fi
|
||||||
|
|
||||||
compose_cmd=(docker compose -f "$COMPOSE_FILE")
|
compose_cmd=(docker compose -f "$COMPOSE_FILE")
|
||||||
|
SMOKE_CACHE_DIR="${SMOKE_CACHE_DIR:-.cache/buildx-smoke}"
|
||||||
|
|
||||||
run_in_ci() {
|
run_in_ci() {
|
||||||
local cmd="$1"
|
local cmd="$1"
|
||||||
"${compose_cmd[@]}" run --rm local-ci bash -c "$cmd"
|
"${compose_cmd[@]}" run --rm local-ci bash -c "$cmd"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
build_smoke_image() {
|
||||||
|
if docker buildx version >/dev/null 2>&1; then
|
||||||
|
mkdir -p "$SMOKE_CACHE_DIR"
|
||||||
|
local build_args=(
|
||||||
|
--load
|
||||||
|
--target dev
|
||||||
|
--cache-to "type=local,dest=$SMOKE_CACHE_DIR,mode=max"
|
||||||
|
-t zeroclaw-local-smoke:latest
|
||||||
|
.
|
||||||
|
)
|
||||||
|
if [ -f "$SMOKE_CACHE_DIR/index.json" ]; then
|
||||||
|
build_args=(--cache-from "type=local,src=$SMOKE_CACHE_DIR" "${build_args[@]}")
|
||||||
|
fi
|
||||||
|
docker buildx build "${build_args[@]}"
|
||||||
|
else
|
||||||
|
DOCKER_BUILDKIT=1 docker build --target dev -t zeroclaw-local-smoke:latest .
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
print_help() {
|
print_help() {
|
||||||
cat <<'EOF'
|
cat <<'EOF'
|
||||||
ZeroClaw Local CI in Docker
|
ZeroClaw Local CI in Docker
|
||||||
|
|
@ -88,7 +108,7 @@ case "$1" in
|
||||||
;;
|
;;
|
||||||
|
|
||||||
docker-smoke)
|
docker-smoke)
|
||||||
docker build --target dev -t zeroclaw-local-smoke:latest .
|
build_smoke_image
|
||||||
docker run --rm zeroclaw-local-smoke:latest --version
|
docker run --rm zeroclaw-local-smoke:latest --version
|
||||||
;;
|
;;
|
||||||
|
|
||||||
|
|
@ -98,7 +118,7 @@ case "$1" in
|
||||||
run_in_ci "cargo build --release --locked --verbose"
|
run_in_ci "cargo build --release --locked --verbose"
|
||||||
run_in_ci "cargo deny check licenses sources"
|
run_in_ci "cargo deny check licenses sources"
|
||||||
run_in_ci "cargo audit"
|
run_in_ci "cargo audit"
|
||||||
docker build --target dev -t zeroclaw-local-smoke:latest .
|
build_smoke_image
|
||||||
docker run --rm zeroclaw-local-smoke:latest --version
|
docker run --rm zeroclaw-local-smoke:latest --version
|
||||||
;;
|
;;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -568,7 +568,7 @@ pub async fn run(
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use std::sync::Mutex;
|
use parking_lot::Mutex;
|
||||||
|
|
||||||
struct MockProvider {
|
struct MockProvider {
|
||||||
responses: Mutex<Vec<crate::providers::ChatResponse>>,
|
responses: Mutex<Vec<crate::providers::ChatResponse>>,
|
||||||
|
|
@ -592,7 +592,7 @@ mod tests {
|
||||||
_model: &str,
|
_model: &str,
|
||||||
_temperature: f64,
|
_temperature: f64,
|
||||||
) -> Result<crate::providers::ChatResponse> {
|
) -> Result<crate::providers::ChatResponse> {
|
||||||
let mut guard = self.responses.lock().unwrap();
|
let mut guard = self.responses.lock();
|
||||||
if guard.is_empty() {
|
if guard.is_empty() {
|
||||||
return Ok(crate::providers::ChatResponse {
|
return Ok(crate::providers::ChatResponse {
|
||||||
text: Some("done".into()),
|
text: Some("done".into()),
|
||||||
|
|
|
||||||
|
|
@ -363,11 +363,7 @@ impl Channel for DiscordChannel {
|
||||||
};
|
};
|
||||||
|
|
||||||
let message_id = d.get("id").and_then(|i| i.as_str()).unwrap_or("");
|
let message_id = d.get("id").and_then(|i| i.as_str()).unwrap_or("");
|
||||||
let channel_id = d
|
let channel_id = d.get("channel_id").and_then(|c| c.as_str()).unwrap_or("").to_string();
|
||||||
.get("channel_id")
|
|
||||||
.and_then(|c| c.as_str())
|
|
||||||
.unwrap_or("")
|
|
||||||
.to_string();
|
|
||||||
|
|
||||||
let channel_msg = ChannelMessage {
|
let channel_msg = ChannelMessage {
|
||||||
id: if message_id.is_empty() {
|
id: if message_id.is_empty() {
|
||||||
|
|
@ -379,10 +375,10 @@ impl Channel for DiscordChannel {
|
||||||
reply_target: if channel_id.is_empty() {
|
reply_target: if channel_id.is_empty() {
|
||||||
author_id.to_string()
|
author_id.to_string()
|
||||||
} else {
|
} else {
|
||||||
channel_id
|
channel_id.clone()
|
||||||
},
|
},
|
||||||
content: content.to_string(),
|
content: clean_content,
|
||||||
channel: "discord".to_string(),
|
channel: channel_id,
|
||||||
timestamp: std::time::SystemTime::now()
|
timestamp: std::time::SystemTime::now()
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
.unwrap_or_default()
|
.unwrap_or_default()
|
||||||
|
|
|
||||||
|
|
@ -14,11 +14,11 @@ use lettre::message::SinglePart;
|
||||||
use lettre::transport::smtp::authentication::Credentials;
|
use lettre::transport::smtp::authentication::Credentials;
|
||||||
use lettre::{Message, SmtpTransport, Transport};
|
use lettre::{Message, SmtpTransport, Transport};
|
||||||
use mail_parser::{MessageParser, MimeHeaders};
|
use mail_parser::{MessageParser, MimeHeaders};
|
||||||
|
use parking_lot::Mutex;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use std::io::Write as IoWrite;
|
use std::io::Write as IoWrite;
|
||||||
use std::net::TcpStream;
|
use std::net::TcpStream;
|
||||||
use std::sync::Mutex;
|
|
||||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
use tokio::time::{interval, sleep};
|
use tokio::time::{interval, sleep};
|
||||||
|
|
@ -413,10 +413,7 @@ impl Channel for EmailChannel {
|
||||||
Ok(Ok(messages)) => {
|
Ok(Ok(messages)) => {
|
||||||
for (id, sender, content, ts) in messages {
|
for (id, sender, content, ts) in messages {
|
||||||
{
|
{
|
||||||
let mut seen = self
|
let mut seen = self.seen_messages.lock();
|
||||||
.seen_messages
|
|
||||||
.lock()
|
|
||||||
.expect("seen_messages mutex should not be poisoned");
|
|
||||||
if seen.contains(&id) {
|
if seen.contains(&id) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
@ -488,20 +485,14 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn seen_messages_starts_empty() {
|
fn seen_messages_starts_empty() {
|
||||||
let channel = EmailChannel::new(EmailConfig::default());
|
let channel = EmailChannel::new(EmailConfig::default());
|
||||||
let seen = channel
|
let seen = channel.seen_messages.lock();
|
||||||
.seen_messages
|
|
||||||
.lock()
|
|
||||||
.expect("seen_messages mutex should not be poisoned");
|
|
||||||
assert!(seen.is_empty());
|
assert!(seen.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn seen_messages_tracks_unique_ids() {
|
fn seen_messages_tracks_unique_ids() {
|
||||||
let channel = EmailChannel::new(EmailConfig::default());
|
let channel = EmailChannel::new(EmailConfig::default());
|
||||||
let mut seen = channel
|
let mut seen = channel.seen_messages.lock();
|
||||||
.seen_messages
|
|
||||||
.lock()
|
|
||||||
.expect("seen_messages mutex should not be poisoned");
|
|
||||||
|
|
||||||
assert!(seen.insert("first-id".to_string()));
|
assert!(seen.insert("first-id".to_string()));
|
||||||
assert!(!seen.insert("first-id".to_string()));
|
assert!(!seen.insert("first-id".to_string()));
|
||||||
|
|
@ -576,10 +567,7 @@ mod tests {
|
||||||
let channel = EmailChannel::new(config.clone());
|
let channel = EmailChannel::new(config.clone());
|
||||||
assert_eq!(channel.config.imap_host, config.imap_host);
|
assert_eq!(channel.config.imap_host, config.imap_host);
|
||||||
|
|
||||||
let seen_guard = channel
|
let seen_guard = channel.seen_messages.lock();
|
||||||
.seen_messages
|
|
||||||
.lock()
|
|
||||||
.expect("seen_messages mutex should not be poisoned");
|
|
||||||
assert_eq!(seen_guard.len(), 0);
|
assert_eq!(seen_guard.len(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,9 +25,10 @@ use axum::{
|
||||||
routing::{get, post},
|
routing::{get, post},
|
||||||
Router,
|
Router,
|
||||||
};
|
};
|
||||||
|
use parking_lot::Mutex;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tower_http::limit::RequestBodyLimitLayer;
|
use tower_http::limit::RequestBodyLimitLayer;
|
||||||
use tower_http::timeout::TimeoutLayer;
|
use tower_http::timeout::TimeoutLayer;
|
||||||
|
|
@ -82,10 +83,7 @@ impl SlidingWindowRateLimiter {
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
let cutoff = now.checked_sub(self.window).unwrap_or_else(Instant::now);
|
let cutoff = now.checked_sub(self.window).unwrap_or_else(Instant::now);
|
||||||
|
|
||||||
let mut guard = self
|
let mut guard = self.requests.lock();
|
||||||
.requests
|
|
||||||
.lock()
|
|
||||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
|
||||||
let (requests, last_sweep) = &mut *guard;
|
let (requests, last_sweep) = &mut *guard;
|
||||||
|
|
||||||
// Periodic sweep: remove IPs with no recent requests
|
// Periodic sweep: remove IPs with no recent requests
|
||||||
|
|
@ -150,10 +148,7 @@ impl IdempotencyStore {
|
||||||
/// Returns true if this key is new and is now recorded.
|
/// Returns true if this key is new and is now recorded.
|
||||||
fn record_if_new(&self, key: &str) -> bool {
|
fn record_if_new(&self, key: &str) -> bool {
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
let mut keys = self
|
let mut keys = self.keys.lock();
|
||||||
.keys
|
|
||||||
.lock()
|
|
||||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
|
||||||
|
|
||||||
keys.retain(|_, seen_at| now.duration_since(*seen_at) < self.ttl);
|
keys.retain(|_, seen_at| now.duration_since(*seen_at) < self.ttl);
|
||||||
|
|
||||||
|
|
@ -739,8 +734,8 @@ mod tests {
|
||||||
use axum::http::HeaderValue;
|
use axum::http::HeaderValue;
|
||||||
use axum::response::IntoResponse;
|
use axum::response::IntoResponse;
|
||||||
use http_body_util::BodyExt;
|
use http_body_util::BodyExt;
|
||||||
|
use parking_lot::Mutex;
|
||||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
use std::sync::Mutex;
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn security_body_limit_is_64kb() {
|
fn security_body_limit_is_64kb() {
|
||||||
|
|
@ -797,19 +792,13 @@ mod tests {
|
||||||
assert!(limiter.allow("ip-3"));
|
assert!(limiter.allow("ip-3"));
|
||||||
|
|
||||||
{
|
{
|
||||||
let guard = limiter
|
let guard = limiter.requests.lock();
|
||||||
.requests
|
|
||||||
.lock()
|
|
||||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
|
||||||
assert_eq!(guard.0.len(), 3);
|
assert_eq!(guard.0.len(), 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Force a sweep by backdating last_sweep
|
// Force a sweep by backdating last_sweep
|
||||||
{
|
{
|
||||||
let mut guard = limiter
|
let mut guard = limiter.requests.lock();
|
||||||
.requests
|
|
||||||
.lock()
|
|
||||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
|
||||||
guard.1 = Instant::now()
|
guard.1 = Instant::now()
|
||||||
.checked_sub(Duration::from_secs(RATE_LIMITER_SWEEP_INTERVAL_SECS + 1))
|
.checked_sub(Duration::from_secs(RATE_LIMITER_SWEEP_INTERVAL_SECS + 1))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
@ -822,10 +811,7 @@ mod tests {
|
||||||
assert!(limiter.allow("ip-1"));
|
assert!(limiter.allow("ip-1"));
|
||||||
|
|
||||||
{
|
{
|
||||||
let guard = limiter
|
let guard = limiter.requests.lock();
|
||||||
.requests
|
|
||||||
.lock()
|
|
||||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
|
||||||
assert_eq!(guard.0.len(), 1, "Stale entries should have been swept");
|
assert_eq!(guard.0.len(), 1, "Stale entries should have been swept");
|
||||||
assert!(guard.0.contains_key("ip-1"));
|
assert!(guard.0.contains_key("ip-1"));
|
||||||
}
|
}
|
||||||
|
|
@ -962,10 +948,7 @@ mod tests {
|
||||||
_category: MemoryCategory,
|
_category: MemoryCategory,
|
||||||
_session_id: Option<&str>,
|
_session_id: Option<&str>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
self.keys
|
self.keys.lock().push(key.to_string());
|
||||||
.lock()
|
|
||||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
|
||||||
.push(key.to_string());
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -995,11 +978,7 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn count(&self) -> anyhow::Result<usize> {
|
async fn count(&self) -> anyhow::Result<usize> {
|
||||||
let size = self
|
let size = self.keys.lock().len();
|
||||||
.keys
|
|
||||||
.lock()
|
|
||||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
|
||||||
.len();
|
|
||||||
Ok(size)
|
Ok(size)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1094,11 +1073,7 @@ mod tests {
|
||||||
.into_response();
|
.into_response();
|
||||||
assert_eq!(second.status(), StatusCode::OK);
|
assert_eq!(second.status(), StatusCode::OK);
|
||||||
|
|
||||||
let keys = tracking_impl
|
let keys = tracking_impl.keys.lock().clone();
|
||||||
.keys
|
|
||||||
.lock()
|
|
||||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
|
||||||
.clone();
|
|
||||||
assert_eq!(keys.len(), 2);
|
assert_eq!(keys.len(), 2);
|
||||||
assert_ne!(keys[0], keys[1]);
|
assert_ne!(keys[0], keys[1]);
|
||||||
assert!(keys[0].starts_with("webhook_msg_"));
|
assert!(keys[0].starts_with("webhook_msg_"));
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,9 @@ use super::sqlite::SqliteMemory;
|
||||||
use super::traits::{Memory, MemoryCategory, MemoryEntry};
|
use super::traits::{Memory, MemoryCategory, MemoryEntry};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use chrono::Local;
|
use chrono::Local;
|
||||||
|
use parking_lot::Mutex;
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use std::sync::Mutex;
|
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tokio::process::Command;
|
use tokio::process::Command;
|
||||||
use tokio::time::timeout;
|
use tokio::time::timeout;
|
||||||
|
|
@ -116,9 +116,7 @@ impl LucidMemory {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn in_failure_cooldown(&self) -> bool {
|
fn in_failure_cooldown(&self) -> bool {
|
||||||
let Ok(guard) = self.last_failure_at.lock() else {
|
let guard = self.last_failure_at.lock();
|
||||||
return false;
|
|
||||||
};
|
|
||||||
|
|
||||||
guard
|
guard
|
||||||
.as_ref()
|
.as_ref()
|
||||||
|
|
@ -126,15 +124,11 @@ impl LucidMemory {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mark_failure_now(&self) {
|
fn mark_failure_now(&self) {
|
||||||
if let Ok(mut guard) = self.last_failure_at.lock() {
|
*self.last_failure_at.lock() = Some(Instant::now());
|
||||||
*guard = Some(Instant::now());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn clear_failure(&self) {
|
fn clear_failure(&self) {
|
||||||
if let Ok(mut guard) = self.last_failure_at.lock() {
|
*self.last_failure_at.lock() = None;
|
||||||
*guard = None;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn to_lucid_type(category: &MemoryCategory) -> &'static str {
|
fn to_lucid_type(category: &MemoryCategory) -> &'static str {
|
||||||
|
|
@ -565,11 +559,12 @@ exit 1
|
||||||
"local_note",
|
"local_note",
|
||||||
"Local sqlite auth fallback note",
|
"Local sqlite auth fallback note",
|
||||||
MemoryCategory::Core,
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let entries = memory.recall("auth", 5).await.unwrap();
|
let entries = memory.recall("auth", 5, None).await.unwrap();
|
||||||
|
|
||||||
assert!(entries
|
assert!(entries
|
||||||
.iter()
|
.iter()
|
||||||
|
|
|
||||||
|
|
@ -7,10 +7,10 @@
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use chrono::{Duration, Local};
|
use chrono::{Duration, Local};
|
||||||
|
use parking_lot::Mutex;
|
||||||
use rusqlite::{params, Connection};
|
use rusqlite::{params, Connection};
|
||||||
use sha2::{Digest, Sha256};
|
use sha2::{Digest, Sha256};
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use std::sync::Mutex;
|
|
||||||
|
|
||||||
/// Response cache backed by a dedicated SQLite database.
|
/// Response cache backed by a dedicated SQLite database.
|
||||||
///
|
///
|
||||||
|
|
@ -77,10 +77,7 @@ impl ResponseCache {
|
||||||
|
|
||||||
/// Look up a cached response. Returns `None` on miss or expired entry.
|
/// Look up a cached response. Returns `None` on miss or expired entry.
|
||||||
pub fn get(&self, key: &str) -> Result<Option<String>> {
|
pub fn get(&self, key: &str) -> Result<Option<String>> {
|
||||||
let conn = self
|
let conn = self.conn.lock();
|
||||||
.conn
|
|
||||||
.lock()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
|
||||||
|
|
||||||
let now = Local::now();
|
let now = Local::now();
|
||||||
let cutoff = (now - Duration::minutes(self.ttl_minutes)).to_rfc3339();
|
let cutoff = (now - Duration::minutes(self.ttl_minutes)).to_rfc3339();
|
||||||
|
|
@ -108,10 +105,7 @@ impl ResponseCache {
|
||||||
|
|
||||||
/// Store a response in the cache.
|
/// Store a response in the cache.
|
||||||
pub fn put(&self, key: &str, model: &str, response: &str, token_count: u32) -> Result<()> {
|
pub fn put(&self, key: &str, model: &str, response: &str, token_count: u32) -> Result<()> {
|
||||||
let conn = self
|
let conn = self.conn.lock();
|
||||||
.conn
|
|
||||||
.lock()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
|
||||||
|
|
||||||
let now = Local::now().to_rfc3339();
|
let now = Local::now().to_rfc3339();
|
||||||
|
|
||||||
|
|
@ -146,10 +140,7 @@ impl ResponseCache {
|
||||||
|
|
||||||
/// Return cache statistics: (total_entries, total_hits, total_tokens_saved).
|
/// Return cache statistics: (total_entries, total_hits, total_tokens_saved).
|
||||||
pub fn stats(&self) -> Result<(usize, u64, u64)> {
|
pub fn stats(&self) -> Result<(usize, u64, u64)> {
|
||||||
let conn = self
|
let conn = self.conn.lock();
|
||||||
.conn
|
|
||||||
.lock()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
|
||||||
|
|
||||||
let count: i64 =
|
let count: i64 =
|
||||||
conn.query_row("SELECT COUNT(*) FROM response_cache", [], |row| row.get(0))?;
|
conn.query_row("SELECT COUNT(*) FROM response_cache", [], |row| row.get(0))?;
|
||||||
|
|
@ -172,10 +163,7 @@ impl ResponseCache {
|
||||||
|
|
||||||
/// Wipe the entire cache (useful for `zeroclaw cache clear`).
|
/// Wipe the entire cache (useful for `zeroclaw cache clear`).
|
||||||
pub fn clear(&self) -> Result<usize> {
|
pub fn clear(&self) -> Result<usize> {
|
||||||
let conn = self
|
let conn = self.conn.lock();
|
||||||
.conn
|
|
||||||
.lock()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
|
||||||
|
|
||||||
let affected = conn.execute("DELETE FROM response_cache", [])?;
|
let affected = conn.execute("DELETE FROM response_cache", [])?;
|
||||||
Ok(affected)
|
Ok(affected)
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,10 @@ use super::traits::{Memory, MemoryCategory, MemoryEntry};
|
||||||
use super::vector;
|
use super::vector;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use chrono::Local;
|
use chrono::Local;
|
||||||
|
use parking_lot::Mutex;
|
||||||
use rusqlite::{params, Connection};
|
use rusqlite::{params, Connection};
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::Arc;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
/// SQLite-backed persistent memory — the brain
|
/// SQLite-backed persistent memory — the brain
|
||||||
|
|
@ -185,10 +186,7 @@ impl SqliteMemory {
|
||||||
|
|
||||||
// Check cache
|
// Check cache
|
||||||
{
|
{
|
||||||
let conn = self
|
let conn = self.conn.lock();
|
||||||
.conn
|
|
||||||
.lock()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
|
||||||
|
|
||||||
let mut stmt =
|
let mut stmt =
|
||||||
conn.prepare("SELECT embedding FROM embedding_cache WHERE content_hash = ?1")?;
|
conn.prepare("SELECT embedding FROM embedding_cache WHERE content_hash = ?1")?;
|
||||||
|
|
@ -210,10 +208,7 @@ impl SqliteMemory {
|
||||||
|
|
||||||
// Store in cache + LRU eviction
|
// Store in cache + LRU eviction
|
||||||
{
|
{
|
||||||
let conn = self
|
let conn = self.conn.lock();
|
||||||
.conn
|
|
||||||
.lock()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
|
||||||
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT OR REPLACE INTO embedding_cache (content_hash, embedding, created_at, accessed_at)
|
"INSERT OR REPLACE INTO embedding_cache (content_hash, embedding, created_at, accessed_at)
|
||||||
|
|
@ -315,10 +310,7 @@ impl SqliteMemory {
|
||||||
pub async fn reindex(&self) -> anyhow::Result<usize> {
|
pub async fn reindex(&self) -> anyhow::Result<usize> {
|
||||||
// Step 1: Rebuild FTS5
|
// Step 1: Rebuild FTS5
|
||||||
{
|
{
|
||||||
let conn = self
|
let conn = self.conn.lock();
|
||||||
.conn
|
|
||||||
.lock()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
|
||||||
|
|
||||||
conn.execute_batch("INSERT INTO memories_fts(memories_fts) VALUES('rebuild');")?;
|
conn.execute_batch("INSERT INTO memories_fts(memories_fts) VALUES('rebuild');")?;
|
||||||
}
|
}
|
||||||
|
|
@ -329,10 +321,7 @@ impl SqliteMemory {
|
||||||
}
|
}
|
||||||
|
|
||||||
let entries: Vec<(String, String)> = {
|
let entries: Vec<(String, String)> = {
|
||||||
let conn = self
|
let conn = self.conn.lock();
|
||||||
.conn
|
|
||||||
.lock()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
|
||||||
|
|
||||||
let mut stmt =
|
let mut stmt =
|
||||||
conn.prepare("SELECT id, content FROM memories WHERE embedding IS NULL")?;
|
conn.prepare("SELECT id, content FROM memories WHERE embedding IS NULL")?;
|
||||||
|
|
@ -346,10 +335,7 @@ impl SqliteMemory {
|
||||||
for (id, content) in &entries {
|
for (id, content) in &entries {
|
||||||
if let Ok(Some(emb)) = self.get_or_compute_embedding(content).await {
|
if let Ok(Some(emb)) = self.get_or_compute_embedding(content).await {
|
||||||
let bytes = vector::vec_to_bytes(&emb);
|
let bytes = vector::vec_to_bytes(&emb);
|
||||||
let conn = self
|
let conn = self.conn.lock();
|
||||||
.conn
|
|
||||||
.lock()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"UPDATE memories SET embedding = ?1 WHERE id = ?2",
|
"UPDATE memories SET embedding = ?1 WHERE id = ?2",
|
||||||
params![bytes, id],
|
params![bytes, id],
|
||||||
|
|
@ -381,10 +367,7 @@ impl Memory for SqliteMemory {
|
||||||
.await?
|
.await?
|
||||||
.map(|emb| vector::vec_to_bytes(&emb));
|
.map(|emb| vector::vec_to_bytes(&emb));
|
||||||
|
|
||||||
let conn = self
|
let conn = self.conn.lock();
|
||||||
.conn
|
|
||||||
.lock()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
|
||||||
let now = Local::now().to_rfc3339();
|
let now = Local::now().to_rfc3339();
|
||||||
let cat = Self::category_to_str(&category);
|
let cat = Self::category_to_str(&category);
|
||||||
let id = Uuid::new_v4().to_string();
|
let id = Uuid::new_v4().to_string();
|
||||||
|
|
@ -417,10 +400,7 @@ impl Memory for SqliteMemory {
|
||||||
// Compute query embedding (async, before lock)
|
// Compute query embedding (async, before lock)
|
||||||
let query_embedding = self.get_or_compute_embedding(query).await?;
|
let query_embedding = self.get_or_compute_embedding(query).await?;
|
||||||
|
|
||||||
let conn = self
|
let conn = self.conn.lock();
|
||||||
.conn
|
|
||||||
.lock()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
|
||||||
|
|
||||||
// FTS5 BM25 keyword search
|
// FTS5 BM25 keyword search
|
||||||
let keyword_results = Self::fts5_search(&conn, query, limit * 2).unwrap_or_default();
|
let keyword_results = Self::fts5_search(&conn, query, limit * 2).unwrap_or_default();
|
||||||
|
|
@ -539,10 +519,7 @@ impl Memory for SqliteMemory {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
|
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
|
||||||
let conn = self
|
let conn = self.conn.lock();
|
||||||
.conn
|
|
||||||
.lock()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
|
||||||
|
|
||||||
let mut stmt = conn.prepare(
|
let mut stmt = conn.prepare(
|
||||||
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE key = ?1",
|
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE key = ?1",
|
||||||
|
|
@ -571,10 +548,7 @@ impl Memory for SqliteMemory {
|
||||||
category: Option<&MemoryCategory>,
|
category: Option<&MemoryCategory>,
|
||||||
session_id: Option<&str>,
|
session_id: Option<&str>,
|
||||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
let conn = self
|
let conn = self.conn.lock();
|
||||||
.conn
|
|
||||||
.lock()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
|
||||||
|
|
||||||
let mut results = Vec::new();
|
let mut results = Vec::new();
|
||||||
|
|
||||||
|
|
@ -627,29 +601,20 @@ impl Memory for SqliteMemory {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
|
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
|
||||||
let conn = self
|
let conn = self.conn.lock();
|
||||||
.conn
|
|
||||||
.lock()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
|
||||||
let affected = conn.execute("DELETE FROM memories WHERE key = ?1", params![key])?;
|
let affected = conn.execute("DELETE FROM memories WHERE key = ?1", params![key])?;
|
||||||
Ok(affected > 0)
|
Ok(affected > 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn count(&self) -> anyhow::Result<usize> {
|
async fn count(&self) -> anyhow::Result<usize> {
|
||||||
let conn = self
|
let conn = self.conn.lock();
|
||||||
.conn
|
|
||||||
.lock()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
|
||||||
let count: i64 = conn.query_row("SELECT COUNT(*) FROM memories", [], |row| row.get(0))?;
|
let count: i64 = conn.query_row("SELECT COUNT(*) FROM memories", [], |row| row.get(0))?;
|
||||||
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||||
Ok(count as usize)
|
Ok(count as usize)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn health_check(&self) -> bool {
|
async fn health_check(&self) -> bool {
|
||||||
self.conn
|
self.conn.lock().execute_batch("SELECT 1").is_ok()
|
||||||
.lock()
|
|
||||||
.map(|c| c.execute_batch("SELECT 1").is_ok())
|
|
||||||
.unwrap_or(false)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -968,7 +933,7 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn schema_has_fts5_table() {
|
async fn schema_has_fts5_table() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
let conn = mem.conn.lock().unwrap();
|
let conn = mem.conn.lock();
|
||||||
// FTS5 table should exist
|
// FTS5 table should exist
|
||||||
let count: i64 = conn
|
let count: i64 = conn
|
||||||
.query_row(
|
.query_row(
|
||||||
|
|
@ -983,7 +948,7 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn schema_has_embedding_cache() {
|
async fn schema_has_embedding_cache() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
let conn = mem.conn.lock().unwrap();
|
let conn = mem.conn.lock();
|
||||||
let count: i64 = conn
|
let count: i64 = conn
|
||||||
.query_row(
|
.query_row(
|
||||||
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='embedding_cache'",
|
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='embedding_cache'",
|
||||||
|
|
@ -997,7 +962,7 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn schema_memories_has_embedding_column() {
|
async fn schema_memories_has_embedding_column() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
let conn = mem.conn.lock().unwrap();
|
let conn = mem.conn.lock();
|
||||||
// Check that embedding column exists by querying it
|
// Check that embedding column exists by querying it
|
||||||
let result = conn.execute_batch("SELECT embedding FROM memories LIMIT 0");
|
let result = conn.execute_batch("SELECT embedding FROM memories LIMIT 0");
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
@ -1017,7 +982,7 @@ mod tests {
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let conn = mem.conn.lock().unwrap();
|
let conn = mem.conn.lock();
|
||||||
let count: i64 = conn
|
let count: i64 = conn
|
||||||
.query_row(
|
.query_row(
|
||||||
"SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"unique_searchterm_xyz\"'",
|
"SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"unique_searchterm_xyz\"'",
|
||||||
|
|
@ -1041,7 +1006,7 @@ mod tests {
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.forget("del_key").await.unwrap();
|
mem.forget("del_key").await.unwrap();
|
||||||
|
|
||||||
let conn = mem.conn.lock().unwrap();
|
let conn = mem.conn.lock();
|
||||||
let count: i64 = conn
|
let count: i64 = conn
|
||||||
.query_row(
|
.query_row(
|
||||||
"SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"deletable_content_abc\"'",
|
"SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"deletable_content_abc\"'",
|
||||||
|
|
@ -1067,7 +1032,7 @@ mod tests {
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let conn = mem.conn.lock().unwrap();
|
let conn = mem.conn.lock();
|
||||||
// Old content should not be findable
|
// Old content should not be findable
|
||||||
let old: i64 = conn
|
let old: i64 = conn
|
||||||
.query_row(
|
.query_row(
|
||||||
|
|
|
||||||
|
|
@ -86,7 +86,7 @@ pub trait Observer: Send + Sync + 'static {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use std::sync::Mutex;
|
use parking_lot::Mutex;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
|
|
@ -97,12 +97,12 @@ mod tests {
|
||||||
|
|
||||||
impl Observer for DummyObserver {
|
impl Observer for DummyObserver {
|
||||||
fn record_event(&self, _event: &ObserverEvent) {
|
fn record_event(&self, _event: &ObserverEvent) {
|
||||||
let mut guard = self.events.lock().unwrap();
|
let mut guard = self.events.lock();
|
||||||
*guard += 1;
|
*guard += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn record_metric(&self, _metric: &ObserverMetric) {
|
fn record_metric(&self, _metric: &ObserverMetric) {
|
||||||
let mut guard = self.metrics.lock().unwrap();
|
let mut guard = self.metrics.lock();
|
||||||
*guard += 1;
|
*guard += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -122,8 +122,8 @@ mod tests {
|
||||||
});
|
});
|
||||||
observer.record_metric(&ObserverMetric::TokensUsed(42));
|
observer.record_metric(&ObserverMetric::TokensUsed(42));
|
||||||
|
|
||||||
assert_eq!(*observer.events.lock().unwrap(), 2);
|
assert_eq!(*observer.events.lock(), 2);
|
||||||
assert_eq!(*observer.metrics.lock().unwrap(), 1);
|
assert_eq!(*observer.metrics.lock(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
|
|
@ -683,7 +683,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
options: StreamOptions,
|
options: StreamOptions,
|
||||||
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
|
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
|
||||||
let credential = match self.credential.as_ref() {
|
let credential = match self.credential.as_ref() {
|
||||||
Some(key) => key.clone(),
|
Some(value) => value.clone(),
|
||||||
None => {
|
None => {
|
||||||
let provider_name = self.name.clone();
|
let provider_name = self.name.clone();
|
||||||
return stream::once(async move {
|
return stream::once(async move {
|
||||||
|
|
|
||||||
|
|
@ -475,7 +475,7 @@ mod tests {
|
||||||
/// Mock that records which model was used for each call.
|
/// Mock that records which model was used for each call.
|
||||||
struct ModelAwareMock {
|
struct ModelAwareMock {
|
||||||
calls: Arc<AtomicUsize>,
|
calls: Arc<AtomicUsize>,
|
||||||
models_seen: std::sync::Mutex<Vec<String>>,
|
models_seen: parking_lot::Mutex<Vec<String>>,
|
||||||
fail_models: Vec<&'static str>,
|
fail_models: Vec<&'static str>,
|
||||||
response: &'static str,
|
response: &'static str,
|
||||||
}
|
}
|
||||||
|
|
@ -490,7 +490,7 @@ mod tests {
|
||||||
_temperature: f64,
|
_temperature: f64,
|
||||||
) -> anyhow::Result<String> {
|
) -> anyhow::Result<String> {
|
||||||
self.calls.fetch_add(1, Ordering::SeqCst);
|
self.calls.fetch_add(1, Ordering::SeqCst);
|
||||||
self.models_seen.lock().unwrap().push(model.to_string());
|
self.models_seen.lock().push(model.to_string());
|
||||||
if self.fail_models.contains(&model) {
|
if self.fail_models.contains(&model) {
|
||||||
anyhow::bail!("500 model {} unavailable", model);
|
anyhow::bail!("500 model {} unavailable", model);
|
||||||
}
|
}
|
||||||
|
|
@ -743,7 +743,7 @@ mod tests {
|
||||||
let calls = Arc::new(AtomicUsize::new(0));
|
let calls = Arc::new(AtomicUsize::new(0));
|
||||||
let mock = Arc::new(ModelAwareMock {
|
let mock = Arc::new(ModelAwareMock {
|
||||||
calls: Arc::clone(&calls),
|
calls: Arc::clone(&calls),
|
||||||
models_seen: std::sync::Mutex::new(Vec::new()),
|
models_seen: parking_lot::Mutex::new(Vec::new()),
|
||||||
fail_models: vec!["claude-opus"],
|
fail_models: vec!["claude-opus"],
|
||||||
response: "ok from sonnet",
|
response: "ok from sonnet",
|
||||||
});
|
});
|
||||||
|
|
@ -767,7 +767,7 @@ mod tests {
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(result, "ok from sonnet");
|
assert_eq!(result, "ok from sonnet");
|
||||||
|
|
||||||
let seen = mock.models_seen.lock().unwrap();
|
let seen = mock.models_seen.lock();
|
||||||
assert_eq!(seen.len(), 2);
|
assert_eq!(seen.len(), 2);
|
||||||
assert_eq!(seen[0], "claude-opus");
|
assert_eq!(seen[0], "claude-opus");
|
||||||
assert_eq!(seen[1], "claude-sonnet");
|
assert_eq!(seen[1], "claude-sonnet");
|
||||||
|
|
@ -778,7 +778,7 @@ mod tests {
|
||||||
let calls = Arc::new(AtomicUsize::new(0));
|
let calls = Arc::new(AtomicUsize::new(0));
|
||||||
let mock = Arc::new(ModelAwareMock {
|
let mock = Arc::new(ModelAwareMock {
|
||||||
calls: Arc::clone(&calls),
|
calls: Arc::clone(&calls),
|
||||||
models_seen: std::sync::Mutex::new(Vec::new()),
|
models_seen: parking_lot::Mutex::new(Vec::new()),
|
||||||
fail_models: vec!["model-a", "model-b", "model-c"],
|
fail_models: vec!["model-a", "model-b", "model-c"],
|
||||||
response: "never",
|
response: "never",
|
||||||
});
|
});
|
||||||
|
|
@ -802,7 +802,7 @@ mod tests {
|
||||||
.expect_err("all models should fail");
|
.expect_err("all models should fail");
|
||||||
assert!(err.to_string().contains("All providers/models failed"));
|
assert!(err.to_string().contains("All providers/models failed"));
|
||||||
|
|
||||||
let seen = mock.models_seen.lock().unwrap();
|
let seen = mock.models_seen.lock();
|
||||||
assert_eq!(seen.len(), 3);
|
assert_eq!(seen.len(), 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -164,7 +164,7 @@ mod tests {
|
||||||
struct MockProvider {
|
struct MockProvider {
|
||||||
calls: Arc<AtomicUsize>,
|
calls: Arc<AtomicUsize>,
|
||||||
response: &'static str,
|
response: &'static str,
|
||||||
last_model: std::sync::Mutex<String>,
|
last_model: parking_lot::Mutex<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MockProvider {
|
impl MockProvider {
|
||||||
|
|
@ -172,7 +172,7 @@ mod tests {
|
||||||
Self {
|
Self {
|
||||||
calls: Arc::new(AtomicUsize::new(0)),
|
calls: Arc::new(AtomicUsize::new(0)),
|
||||||
response,
|
response,
|
||||||
last_model: std::sync::Mutex::new(String::new()),
|
last_model: parking_lot::Mutex::new(String::new()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -181,7 +181,7 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn last_model(&self) -> String {
|
fn last_model(&self) -> String {
|
||||||
self.last_model.lock().unwrap().clone()
|
self.last_model.lock().clone()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -195,7 +195,7 @@ mod tests {
|
||||||
_temperature: f64,
|
_temperature: f64,
|
||||||
) -> anyhow::Result<String> {
|
) -> anyhow::Result<String> {
|
||||||
self.calls.fetch_add(1, Ordering::SeqCst);
|
self.calls.fetch_add(1, Ordering::SeqCst);
|
||||||
*self.last_model.lock().unwrap() = model.to_string();
|
*self.last_model.lock() = model.to_string();
|
||||||
Ok(self.response.to_string())
|
Ok(self.response.to_string())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -76,6 +76,13 @@ pub struct ChatRequest<'a> {
|
||||||
pub tools: Option<&'a [ToolSpec]>,
|
pub tools: Option<&'a [ToolSpec]>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Declares optional provider features.
|
||||||
|
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
|
||||||
|
pub struct ProviderCapabilities {
|
||||||
|
/// Provider can perform native tool calling without prompt-level emulation.
|
||||||
|
pub native_tool_calling: bool,
|
||||||
|
}
|
||||||
|
|
||||||
/// A tool result to feed back to the LLM.
|
/// A tool result to feed back to the LLM.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct ToolResultMessage {
|
pub struct ToolResultMessage {
|
||||||
|
|
@ -191,21 +198,6 @@ pub enum StreamError {
|
||||||
Io(#[from] std::io::Error),
|
Io(#[from] std::io::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Provider capabilities declaration.
|
|
||||||
///
|
|
||||||
/// Describes what features a provider supports, enabling intelligent
|
|
||||||
/// adaptation of tool calling modes and request formatting.
|
|
||||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
|
||||||
pub struct ProviderCapabilities {
|
|
||||||
/// Whether the provider supports native tool calling via API primitives.
|
|
||||||
///
|
|
||||||
/// When `true`, the provider can convert tool definitions to API-native
|
|
||||||
/// formats (e.g., Gemini's functionDeclarations, Anthropic's input_schema).
|
|
||||||
///
|
|
||||||
/// When `false`, tools must be injected via system prompt as text.
|
|
||||||
pub native_tool_calling: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait Provider: Send + Sync {
|
pub trait Provider: Send + Sync {
|
||||||
/// Query provider capabilities.
|
/// Query provider capabilities.
|
||||||
|
|
@ -329,11 +321,21 @@ pub trait Provider: Send + Sync {
|
||||||
/// Default implementation falls back to stream_chat_with_system with last user message.
|
/// Default implementation falls back to stream_chat_with_system with last user message.
|
||||||
fn stream_chat_with_history(
|
fn stream_chat_with_history(
|
||||||
&self,
|
&self,
|
||||||
_messages: &[ChatMessage],
|
messages: &[ChatMessage],
|
||||||
_model: &str,
|
_model: &str,
|
||||||
_temperature: f64,
|
_temperature: f64,
|
||||||
_options: StreamOptions,
|
_options: StreamOptions,
|
||||||
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
|
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
|
||||||
|
let _system = messages
|
||||||
|
.iter()
|
||||||
|
.find(|m| m.role == "system")
|
||||||
|
.map(|m| m.content.clone());
|
||||||
|
let _last_user = messages
|
||||||
|
.iter()
|
||||||
|
.rfind(|m| m.role == "user")
|
||||||
|
.map(|m| m.content.clone())
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
// For default implementation, we need to convert to owned strings
|
// For default implementation, we need to convert to owned strings
|
||||||
// This is a limitation of the default implementation
|
// This is a limitation of the default implementation
|
||||||
let provider_name = "unknown".to_string();
|
let provider_name = "unknown".to_string();
|
||||||
|
|
|
||||||
|
|
@ -3,11 +3,11 @@
|
||||||
use crate::config::AuditConfig;
|
use crate::config::AuditConfig;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
|
use parking_lot::Mutex;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::fs::OpenOptions;
|
use std::fs::OpenOptions;
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::sync::Mutex;
|
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
/// Audit event types
|
/// Audit event types
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue