fix(ci): repair parking_lot migration regressions in PR #535
This commit is contained in:
parent
ee05d62ce4
commit
9e0958dee5
10 changed files with 51 additions and 115 deletions
|
|
@ -375,9 +375,9 @@ impl Channel for DiscordChannel {
|
||||||
reply_target: if channel_id.is_empty() {
|
reply_target: if channel_id.is_empty() {
|
||||||
author_id.to_string()
|
author_id.to_string()
|
||||||
} else {
|
} else {
|
||||||
channel_id
|
channel_id.clone()
|
||||||
},
|
},
|
||||||
content: content.to_string(),
|
content: clean_content,
|
||||||
channel: channel_id,
|
channel: channel_id,
|
||||||
timestamp: std::time::SystemTime::now()
|
timestamp: std::time::SystemTime::now()
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
|
|
||||||
|
|
@ -14,11 +14,11 @@ use lettre::message::SinglePart;
|
||||||
use lettre::transport::smtp::authentication::Credentials;
|
use lettre::transport::smtp::authentication::Credentials;
|
||||||
use lettre::{Message, SmtpTransport, Transport};
|
use lettre::{Message, SmtpTransport, Transport};
|
||||||
use mail_parser::{MessageParser, MimeHeaders};
|
use mail_parser::{MessageParser, MimeHeaders};
|
||||||
|
use parking_lot::Mutex;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use std::io::Write as IoWrite;
|
use std::io::Write as IoWrite;
|
||||||
use std::net::TcpStream;
|
use std::net::TcpStream;
|
||||||
use parking_lot::Mutex;
|
|
||||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
use tokio::time::{interval, sleep};
|
use tokio::time::{interval, sleep};
|
||||||
|
|
@ -413,10 +413,7 @@ impl Channel for EmailChannel {
|
||||||
Ok(Ok(messages)) => {
|
Ok(Ok(messages)) => {
|
||||||
for (id, sender, content, ts) in messages {
|
for (id, sender, content, ts) in messages {
|
||||||
{
|
{
|
||||||
let mut seen = self
|
let mut seen = self.seen_messages.lock();
|
||||||
.seen_messages
|
|
||||||
.lock()
|
|
||||||
;
|
|
||||||
if seen.contains(&id) {
|
if seen.contains(&id) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
@ -488,20 +485,14 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn seen_messages_starts_empty() {
|
fn seen_messages_starts_empty() {
|
||||||
let channel = EmailChannel::new(EmailConfig::default());
|
let channel = EmailChannel::new(EmailConfig::default());
|
||||||
let seen = channel
|
let seen = channel.seen_messages.lock();
|
||||||
.seen_messages
|
|
||||||
.lock()
|
|
||||||
.expect("seen_messages mutex should not be poisoned");
|
|
||||||
assert!(seen.is_empty());
|
assert!(seen.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn seen_messages_tracks_unique_ids() {
|
fn seen_messages_tracks_unique_ids() {
|
||||||
let channel = EmailChannel::new(EmailConfig::default());
|
let channel = EmailChannel::new(EmailConfig::default());
|
||||||
let mut seen = channel
|
let mut seen = channel.seen_messages.lock();
|
||||||
.seen_messages
|
|
||||||
.lock()
|
|
||||||
.expect("seen_messages mutex should not be poisoned");
|
|
||||||
|
|
||||||
assert!(seen.insert("first-id".to_string()));
|
assert!(seen.insert("first-id".to_string()));
|
||||||
assert!(!seen.insert("first-id".to_string()));
|
assert!(!seen.insert("first-id".to_string()));
|
||||||
|
|
@ -576,10 +567,7 @@ mod tests {
|
||||||
let channel = EmailChannel::new(config.clone());
|
let channel = EmailChannel::new(config.clone());
|
||||||
assert_eq!(channel.config.imap_host, config.imap_host);
|
assert_eq!(channel.config.imap_host, config.imap_host);
|
||||||
|
|
||||||
let seen_guard = channel
|
let seen_guard = channel.seen_messages.lock();
|
||||||
.seen_messages
|
|
||||||
.lock()
|
|
||||||
.expect("seen_messages mutex should not be poisoned");
|
|
||||||
assert_eq!(seen_guard.len(), 0);
|
assert_eq!(seen_guard.len(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,9 +25,9 @@ use axum::{
|
||||||
routing::{get, post},
|
routing::{get, post},
|
||||||
Router,
|
Router,
|
||||||
};
|
};
|
||||||
|
use parking_lot::Mutex;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use parking_lot::Mutex;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tower_http::limit::RequestBodyLimitLayer;
|
use tower_http::limit::RequestBodyLimitLayer;
|
||||||
|
|
@ -83,9 +83,7 @@ impl SlidingWindowRateLimiter {
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
let cutoff = now.checked_sub(self.window).unwrap_or_else(Instant::now);
|
let cutoff = now.checked_sub(self.window).unwrap_or_else(Instant::now);
|
||||||
|
|
||||||
let mut guard = self
|
let mut guard = self.requests.lock();
|
||||||
.requests
|
|
||||||
.lock();
|
|
||||||
let (requests, last_sweep) = &mut *guard;
|
let (requests, last_sweep) = &mut *guard;
|
||||||
|
|
||||||
// Periodic sweep: remove IPs with no recent requests
|
// Periodic sweep: remove IPs with no recent requests
|
||||||
|
|
@ -150,9 +148,7 @@ impl IdempotencyStore {
|
||||||
/// Returns true if this key is new and is now recorded.
|
/// Returns true if this key is new and is now recorded.
|
||||||
fn record_if_new(&self, key: &str) -> bool {
|
fn record_if_new(&self, key: &str) -> bool {
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
let mut keys = self
|
let mut keys = self.keys.lock();
|
||||||
.keys
|
|
||||||
.lock();
|
|
||||||
|
|
||||||
keys.retain(|_, seen_at| now.duration_since(*seen_at) < self.ttl);
|
keys.retain(|_, seen_at| now.duration_since(*seen_at) < self.ttl);
|
||||||
|
|
||||||
|
|
@ -738,8 +734,8 @@ mod tests {
|
||||||
use axum::http::HeaderValue;
|
use axum::http::HeaderValue;
|
||||||
use axum::response::IntoResponse;
|
use axum::response::IntoResponse;
|
||||||
use http_body_util::BodyExt;
|
use http_body_util::BodyExt;
|
||||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn security_body_limit_is_64kb() {
|
fn security_body_limit_is_64kb() {
|
||||||
|
|
@ -796,19 +792,13 @@ mod tests {
|
||||||
assert!(limiter.allow("ip-3"));
|
assert!(limiter.allow("ip-3"));
|
||||||
|
|
||||||
{
|
{
|
||||||
let guard = limiter
|
let guard = limiter.requests.lock();
|
||||||
.requests
|
|
||||||
.lock()
|
|
||||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
|
||||||
assert_eq!(guard.0.len(), 3);
|
assert_eq!(guard.0.len(), 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Force a sweep by backdating last_sweep
|
// Force a sweep by backdating last_sweep
|
||||||
{
|
{
|
||||||
let mut guard = limiter
|
let mut guard = limiter.requests.lock();
|
||||||
.requests
|
|
||||||
.lock()
|
|
||||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
|
||||||
guard.1 = Instant::now()
|
guard.1 = Instant::now()
|
||||||
.checked_sub(Duration::from_secs(RATE_LIMITER_SWEEP_INTERVAL_SECS + 1))
|
.checked_sub(Duration::from_secs(RATE_LIMITER_SWEEP_INTERVAL_SECS + 1))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
@ -821,10 +811,7 @@ mod tests {
|
||||||
assert!(limiter.allow("ip-1"));
|
assert!(limiter.allow("ip-1"));
|
||||||
|
|
||||||
{
|
{
|
||||||
let guard = limiter
|
let guard = limiter.requests.lock();
|
||||||
.requests
|
|
||||||
.lock()
|
|
||||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
|
||||||
assert_eq!(guard.0.len(), 1, "Stale entries should have been swept");
|
assert_eq!(guard.0.len(), 1, "Stale entries should have been swept");
|
||||||
assert!(guard.0.contains_key("ip-1"));
|
assert!(guard.0.contains_key("ip-1"));
|
||||||
}
|
}
|
||||||
|
|
@ -961,10 +948,7 @@ mod tests {
|
||||||
_category: MemoryCategory,
|
_category: MemoryCategory,
|
||||||
_session_id: Option<&str>,
|
_session_id: Option<&str>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
self.keys
|
self.keys.lock().push(key.to_string());
|
||||||
.lock()
|
|
||||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
|
||||||
.push(key.to_string());
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -994,11 +978,7 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn count(&self) -> anyhow::Result<usize> {
|
async fn count(&self) -> anyhow::Result<usize> {
|
||||||
let size = self
|
let size = self.keys.lock().len();
|
||||||
.keys
|
|
||||||
.lock()
|
|
||||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
|
||||||
.len();
|
|
||||||
Ok(size)
|
Ok(size)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1093,11 +1073,7 @@ mod tests {
|
||||||
.into_response();
|
.into_response();
|
||||||
assert_eq!(second.status(), StatusCode::OK);
|
assert_eq!(second.status(), StatusCode::OK);
|
||||||
|
|
||||||
let keys = tracking_impl
|
let keys = tracking_impl.keys.lock().clone();
|
||||||
.keys
|
|
||||||
.lock()
|
|
||||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
|
||||||
.clone();
|
|
||||||
assert_eq!(keys.len(), 2);
|
assert_eq!(keys.len(), 2);
|
||||||
assert_ne!(keys[0], keys[1]);
|
assert_ne!(keys[0], keys[1]);
|
||||||
assert!(keys[0].starts_with("webhook_msg_"));
|
assert!(keys[0].starts_with("webhook_msg_"));
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,9 @@ use super::sqlite::SqliteMemory;
|
||||||
use super::traits::{Memory, MemoryCategory, MemoryEntry};
|
use super::traits::{Memory, MemoryCategory, MemoryEntry};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use chrono::Local;
|
use chrono::Local;
|
||||||
|
use parking_lot::Mutex;
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use parking_lot::Mutex;
|
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tokio::process::Command;
|
use tokio::process::Command;
|
||||||
use tokio::time::timeout;
|
use tokio::time::timeout;
|
||||||
|
|
@ -559,11 +559,12 @@ exit 1
|
||||||
"local_note",
|
"local_note",
|
||||||
"Local sqlite auth fallback note",
|
"Local sqlite auth fallback note",
|
||||||
MemoryCategory::Core,
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let entries = memory.recall("auth", 5).await.unwrap();
|
let entries = memory.recall("auth", 5, None).await.unwrap();
|
||||||
|
|
||||||
assert!(entries
|
assert!(entries
|
||||||
.iter()
|
.iter()
|
||||||
|
|
|
||||||
|
|
@ -7,10 +7,10 @@
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use chrono::{Duration, Local};
|
use chrono::{Duration, Local};
|
||||||
|
use parking_lot::Mutex;
|
||||||
use rusqlite::{params, Connection};
|
use rusqlite::{params, Connection};
|
||||||
use sha2::{Digest, Sha256};
|
use sha2::{Digest, Sha256};
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use parking_lot::Mutex;
|
|
||||||
|
|
||||||
/// Response cache backed by a dedicated SQLite database.
|
/// Response cache backed by a dedicated SQLite database.
|
||||||
///
|
///
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,9 @@ use super::traits::{Memory, MemoryCategory, MemoryEntry};
|
||||||
use super::vector;
|
use super::vector;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use chrono::Local;
|
use chrono::Local;
|
||||||
|
use parking_lot::Mutex;
|
||||||
use rusqlite::{params, Connection};
|
use rusqlite::{params, Connection};
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use parking_lot::Mutex;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
|
@ -186,10 +186,7 @@ impl SqliteMemory {
|
||||||
|
|
||||||
// Check cache
|
// Check cache
|
||||||
{
|
{
|
||||||
let conn = self
|
let conn = self.conn.lock();
|
||||||
.conn
|
|
||||||
.lock()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
|
||||||
|
|
||||||
let mut stmt =
|
let mut stmt =
|
||||||
conn.prepare("SELECT embedding FROM embedding_cache WHERE content_hash = ?1")?;
|
conn.prepare("SELECT embedding FROM embedding_cache WHERE content_hash = ?1")?;
|
||||||
|
|
@ -211,10 +208,7 @@ impl SqliteMemory {
|
||||||
|
|
||||||
// Store in cache + LRU eviction
|
// Store in cache + LRU eviction
|
||||||
{
|
{
|
||||||
let conn = self
|
let conn = self.conn.lock();
|
||||||
.conn
|
|
||||||
.lock()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
|
||||||
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT OR REPLACE INTO embedding_cache (content_hash, embedding, created_at, accessed_at)
|
"INSERT OR REPLACE INTO embedding_cache (content_hash, embedding, created_at, accessed_at)
|
||||||
|
|
@ -316,10 +310,7 @@ impl SqliteMemory {
|
||||||
pub async fn reindex(&self) -> anyhow::Result<usize> {
|
pub async fn reindex(&self) -> anyhow::Result<usize> {
|
||||||
// Step 1: Rebuild FTS5
|
// Step 1: Rebuild FTS5
|
||||||
{
|
{
|
||||||
let conn = self
|
let conn = self.conn.lock();
|
||||||
.conn
|
|
||||||
.lock()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
|
||||||
|
|
||||||
conn.execute_batch("INSERT INTO memories_fts(memories_fts) VALUES('rebuild');")?;
|
conn.execute_batch("INSERT INTO memories_fts(memories_fts) VALUES('rebuild');")?;
|
||||||
}
|
}
|
||||||
|
|
@ -330,10 +321,7 @@ impl SqliteMemory {
|
||||||
}
|
}
|
||||||
|
|
||||||
let entries: Vec<(String, String)> = {
|
let entries: Vec<(String, String)> = {
|
||||||
let conn = self
|
let conn = self.conn.lock();
|
||||||
.conn
|
|
||||||
.lock()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
|
||||||
|
|
||||||
let mut stmt =
|
let mut stmt =
|
||||||
conn.prepare("SELECT id, content FROM memories WHERE embedding IS NULL")?;
|
conn.prepare("SELECT id, content FROM memories WHERE embedding IS NULL")?;
|
||||||
|
|
@ -347,10 +335,7 @@ impl SqliteMemory {
|
||||||
for (id, content) in &entries {
|
for (id, content) in &entries {
|
||||||
if let Ok(Some(emb)) = self.get_or_compute_embedding(content).await {
|
if let Ok(Some(emb)) = self.get_or_compute_embedding(content).await {
|
||||||
let bytes = vector::vec_to_bytes(&emb);
|
let bytes = vector::vec_to_bytes(&emb);
|
||||||
let conn = self
|
let conn = self.conn.lock();
|
||||||
.conn
|
|
||||||
.lock()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"UPDATE memories SET embedding = ?1 WHERE id = ?2",
|
"UPDATE memories SET embedding = ?1 WHERE id = ?2",
|
||||||
params![bytes, id],
|
params![bytes, id],
|
||||||
|
|
@ -382,10 +367,7 @@ impl Memory for SqliteMemory {
|
||||||
.await?
|
.await?
|
||||||
.map(|emb| vector::vec_to_bytes(&emb));
|
.map(|emb| vector::vec_to_bytes(&emb));
|
||||||
|
|
||||||
let conn = self
|
let conn = self.conn.lock();
|
||||||
.conn
|
|
||||||
.lock()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
|
||||||
let now = Local::now().to_rfc3339();
|
let now = Local::now().to_rfc3339();
|
||||||
let cat = Self::category_to_str(&category);
|
let cat = Self::category_to_str(&category);
|
||||||
let id = Uuid::new_v4().to_string();
|
let id = Uuid::new_v4().to_string();
|
||||||
|
|
@ -418,10 +400,7 @@ impl Memory for SqliteMemory {
|
||||||
// Compute query embedding (async, before lock)
|
// Compute query embedding (async, before lock)
|
||||||
let query_embedding = self.get_or_compute_embedding(query).await?;
|
let query_embedding = self.get_or_compute_embedding(query).await?;
|
||||||
|
|
||||||
let conn = self
|
let conn = self.conn.lock();
|
||||||
.conn
|
|
||||||
.lock()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
|
||||||
|
|
||||||
// FTS5 BM25 keyword search
|
// FTS5 BM25 keyword search
|
||||||
let keyword_results = Self::fts5_search(&conn, query, limit * 2).unwrap_or_default();
|
let keyword_results = Self::fts5_search(&conn, query, limit * 2).unwrap_or_default();
|
||||||
|
|
@ -540,10 +519,7 @@ impl Memory for SqliteMemory {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
|
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
|
||||||
let conn = self
|
let conn = self.conn.lock();
|
||||||
.conn
|
|
||||||
.lock()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
|
||||||
|
|
||||||
let mut stmt = conn.prepare(
|
let mut stmt = conn.prepare(
|
||||||
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE key = ?1",
|
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE key = ?1",
|
||||||
|
|
@ -572,10 +548,7 @@ impl Memory for SqliteMemory {
|
||||||
category: Option<&MemoryCategory>,
|
category: Option<&MemoryCategory>,
|
||||||
session_id: Option<&str>,
|
session_id: Option<&str>,
|
||||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
let conn = self
|
let conn = self.conn.lock();
|
||||||
.conn
|
|
||||||
.lock()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
|
||||||
|
|
||||||
let mut results = Vec::new();
|
let mut results = Vec::new();
|
||||||
|
|
||||||
|
|
@ -628,29 +601,20 @@ impl Memory for SqliteMemory {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
|
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
|
||||||
let conn = self
|
let conn = self.conn.lock();
|
||||||
.conn
|
|
||||||
.lock()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
|
||||||
let affected = conn.execute("DELETE FROM memories WHERE key = ?1", params![key])?;
|
let affected = conn.execute("DELETE FROM memories WHERE key = ?1", params![key])?;
|
||||||
Ok(affected > 0)
|
Ok(affected > 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn count(&self) -> anyhow::Result<usize> {
|
async fn count(&self) -> anyhow::Result<usize> {
|
||||||
let conn = self
|
let conn = self.conn.lock();
|
||||||
.conn
|
|
||||||
.lock()
|
|
||||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
|
||||||
let count: i64 = conn.query_row("SELECT COUNT(*) FROM memories", [], |row| row.get(0))?;
|
let count: i64 = conn.query_row("SELECT COUNT(*) FROM memories", [], |row| row.get(0))?;
|
||||||
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||||
Ok(count as usize)
|
Ok(count as usize)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn health_check(&self) -> bool {
|
async fn health_check(&self) -> bool {
|
||||||
self.conn
|
self.conn.lock().execute_batch("SELECT 1").is_ok()
|
||||||
.lock()
|
|
||||||
.map(|c| c.execute_batch("SELECT 1").is_ok())
|
|
||||||
.unwrap_or(false)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -688,8 +688,8 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
options: StreamOptions,
|
options: StreamOptions,
|
||||||
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
|
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
|
||||||
let api_key = match self.api_key.as_ref() {
|
let credential = match self.credential.as_ref() {
|
||||||
Some(key) => key.clone(),
|
Some(value) => value.clone(),
|
||||||
None => {
|
None => {
|
||||||
let provider_name = self.name.clone();
|
let provider_name = self.name.clone();
|
||||||
return stream::once(async move {
|
return stream::once(async move {
|
||||||
|
|
@ -735,10 +735,10 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
// Apply auth header
|
// Apply auth header
|
||||||
req_builder = match &auth_header {
|
req_builder = match &auth_header {
|
||||||
AuthStyle::Bearer => {
|
AuthStyle::Bearer => {
|
||||||
req_builder.header("Authorization", format!("Bearer {}", api_key))
|
req_builder.header("Authorization", format!("Bearer {}", credential))
|
||||||
}
|
}
|
||||||
AuthStyle::XApiKey => req_builder.header("x-api-key", &api_key),
|
AuthStyle::XApiKey => req_builder.header("x-api-key", &credential),
|
||||||
AuthStyle::Custom(header) => req_builder.header(header, &api_key),
|
AuthStyle::Custom(header) => req_builder.header(header, &credential),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Set accept header for streaming
|
// Set accept header for streaming
|
||||||
|
|
|
||||||
|
|
@ -767,7 +767,7 @@ mod tests {
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(result, "ok from sonnet");
|
assert_eq!(result, "ok from sonnet");
|
||||||
|
|
||||||
let seen = mock.models_seen.lock().unwrap();
|
let seen = mock.models_seen.lock();
|
||||||
assert_eq!(seen.len(), 2);
|
assert_eq!(seen.len(), 2);
|
||||||
assert_eq!(seen[0], "claude-opus");
|
assert_eq!(seen[0], "claude-opus");
|
||||||
assert_eq!(seen[1], "claude-sonnet");
|
assert_eq!(seen[1], "claude-sonnet");
|
||||||
|
|
@ -802,7 +802,7 @@ mod tests {
|
||||||
.expect_err("all models should fail");
|
.expect_err("all models should fail");
|
||||||
assert!(err.to_string().contains("All providers/models failed"));
|
assert!(err.to_string().contains("All providers/models failed"));
|
||||||
|
|
||||||
let seen = mock.models_seen.lock().unwrap();
|
let seen = mock.models_seen.lock();
|
||||||
assert_eq!(seen.len(), 3);
|
assert_eq!(seen.len(), 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -76,6 +76,13 @@ pub struct ChatRequest<'a> {
|
||||||
pub tools: Option<&'a [ToolSpec]>,
|
pub tools: Option<&'a [ToolSpec]>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Declares optional provider features.
|
||||||
|
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
|
||||||
|
pub struct ProviderCapabilities {
|
||||||
|
/// Provider can perform native tool calling without prompt-level emulation.
|
||||||
|
pub native_tool_calling: bool,
|
||||||
|
}
|
||||||
|
|
||||||
/// A tool result to feed back to the LLM.
|
/// A tool result to feed back to the LLM.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct ToolResultMessage {
|
pub struct ToolResultMessage {
|
||||||
|
|
@ -319,11 +326,11 @@ pub trait Provider: Send + Sync {
|
||||||
_temperature: f64,
|
_temperature: f64,
|
||||||
_options: StreamOptions,
|
_options: StreamOptions,
|
||||||
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
|
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
|
||||||
let system = messages
|
let _system = messages
|
||||||
.iter()
|
.iter()
|
||||||
.find(|m| m.role == "system")
|
.find(|m| m.role == "system")
|
||||||
.map(|m| m.content.clone());
|
.map(|m| m.content.clone());
|
||||||
let last_user = messages
|
let _last_user = messages
|
||||||
.iter()
|
.iter()
|
||||||
.rfind(|m| m.role == "user")
|
.rfind(|m| m.role == "user")
|
||||||
.map(|m| m.content.clone())
|
.map(|m| m.content.clone())
|
||||||
|
|
|
||||||
|
|
@ -3,11 +3,11 @@
|
||||||
use crate::config::AuditConfig;
|
use crate::config::AuditConfig;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
|
use parking_lot::Mutex;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::fs::OpenOptions;
|
use std::fs::OpenOptions;
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use parking_lot::Mutex;
|
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
/// Audit event types
|
/// Audit event types
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue