Fixes #221 - SQLite Memory Override bug. This PR resolves memory overwrite behavior in autosave paths by replacing fixed memory keys with unique keys, and improves short-horizon recall quality in channel runtime. **Root Cause** SQLite memory uses a unique constraint on `memories.key` and writes with `ON CONFLICT(key) DO UPDATE`. Several autosave paths reused fixed keys (or sender-stable keys), so newer messages overwrote earlier conversation entries. **Changes** - Channel runtime: autosave key changed from `channel_sender` to `channel_sender_messageId` - Added memory-context injection before provider calls (aligned with agent loop behavior) - Agent loop: autosave keys changed from fixed `user_msg`/`assistant_resp` to UUID-suffixed keys - Gateway: Webhook/WhatsApp autosave keys changed to UUID-suffixed keys All CI checks passing. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
475 lines
16 KiB
Rust
475 lines
16 KiB
Rust
use super::traits::ChatMessage;
|
|
use super::Provider;
|
|
use async_trait::async_trait;
|
|
use std::time::Duration;
|
|
|
|
/// Check if an error is non-retryable (client errors that won't resolve with retries).
|
|
fn is_non_retryable(err: &anyhow::Error) -> bool {
|
|
// Check for reqwest status errors (returned by .error_for_status())
|
|
if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>() {
|
|
if let Some(status) = reqwest_err.status() {
|
|
let code = status.as_u16();
|
|
// 4xx client errors are non-retryable, except:
|
|
// - 429 Too Many Requests (rate limiting, transient)
|
|
// - 408 Request Timeout (transient)
|
|
return status.is_client_error() && code != 429 && code != 408;
|
|
}
|
|
}
|
|
// String fallback: scan for any 4xx status code in error message
|
|
let msg = err.to_string();
|
|
for word in msg.split(|c: char| !c.is_ascii_digit()) {
|
|
if let Ok(code) = word.parse::<u16>() {
|
|
if (400..500).contains(&code) {
|
|
return code != 429 && code != 408;
|
|
}
|
|
}
|
|
}
|
|
false
|
|
}
|
|
|
|
/// Provider wrapper with retry + fallback behavior.
|
|
pub struct ReliableProvider {
|
|
providers: Vec<(String, Box<dyn Provider>)>,
|
|
max_retries: u32,
|
|
base_backoff_ms: u64,
|
|
}
|
|
|
|
impl ReliableProvider {
|
|
pub fn new(
|
|
providers: Vec<(String, Box<dyn Provider>)>,
|
|
max_retries: u32,
|
|
base_backoff_ms: u64,
|
|
) -> Self {
|
|
Self {
|
|
providers,
|
|
max_retries,
|
|
base_backoff_ms: base_backoff_ms.max(50),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Provider for ReliableProvider {
|
|
async fn warmup(&self) -> anyhow::Result<()> {
|
|
for (name, provider) in &self.providers {
|
|
tracing::info!(provider = name, "Warming up provider connection pool");
|
|
if let Err(e) = provider.warmup().await {
|
|
tracing::warn!(provider = name, "Warmup failed (non-fatal): {e}");
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
async fn chat_with_system(
|
|
&self,
|
|
system_prompt: Option<&str>,
|
|
message: &str,
|
|
model: &str,
|
|
temperature: f64,
|
|
) -> anyhow::Result<String> {
|
|
let mut failures = Vec::new();
|
|
|
|
for (provider_name, provider) in &self.providers {
|
|
let mut backoff_ms = self.base_backoff_ms;
|
|
|
|
for attempt in 0..=self.max_retries {
|
|
match provider
|
|
.chat_with_system(system_prompt, message, model, temperature)
|
|
.await
|
|
{
|
|
Ok(resp) => {
|
|
if attempt > 0 {
|
|
tracing::info!(
|
|
provider = provider_name,
|
|
attempt,
|
|
"Provider recovered after retries"
|
|
);
|
|
}
|
|
return Ok(resp);
|
|
}
|
|
Err(e) => {
|
|
let non_retryable = is_non_retryable(&e);
|
|
failures.push(format!(
|
|
"{provider_name} attempt {}/{}: {e}",
|
|
attempt + 1,
|
|
self.max_retries + 1
|
|
));
|
|
|
|
if non_retryable {
|
|
tracing::warn!(
|
|
provider = provider_name,
|
|
"Non-retryable error, switching provider"
|
|
);
|
|
break;
|
|
}
|
|
|
|
if attempt < self.max_retries {
|
|
tracing::warn!(
|
|
provider = provider_name,
|
|
attempt = attempt + 1,
|
|
max_retries = self.max_retries,
|
|
"Provider call failed, retrying"
|
|
);
|
|
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
|
|
backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
tracing::warn!(provider = provider_name, "Switching to fallback provider");
|
|
}
|
|
|
|
anyhow::bail!("All providers failed. Attempts:\n{}", failures.join("\n"))
|
|
}
|
|
|
|
async fn chat_with_history(
|
|
&self,
|
|
messages: &[ChatMessage],
|
|
model: &str,
|
|
temperature: f64,
|
|
) -> anyhow::Result<String> {
|
|
let mut failures = Vec::new();
|
|
|
|
for (provider_name, provider) in &self.providers {
|
|
let mut backoff_ms = self.base_backoff_ms;
|
|
|
|
for attempt in 0..=self.max_retries {
|
|
match provider
|
|
.chat_with_history(messages, model, temperature)
|
|
.await
|
|
{
|
|
Ok(resp) => {
|
|
if attempt > 0 {
|
|
tracing::info!(
|
|
provider = provider_name,
|
|
attempt,
|
|
"Provider recovered after retries"
|
|
);
|
|
}
|
|
return Ok(resp);
|
|
}
|
|
Err(e) => {
|
|
let non_retryable = is_non_retryable(&e);
|
|
failures.push(format!(
|
|
"{provider_name} attempt {}/{}: {e}",
|
|
attempt + 1,
|
|
self.max_retries + 1
|
|
));
|
|
|
|
if non_retryable {
|
|
tracing::warn!(
|
|
provider = provider_name,
|
|
"Non-retryable error, switching provider"
|
|
);
|
|
break;
|
|
}
|
|
|
|
if attempt < self.max_retries {
|
|
tracing::warn!(
|
|
provider = provider_name,
|
|
attempt = attempt + 1,
|
|
max_retries = self.max_retries,
|
|
"Provider call failed, retrying"
|
|
);
|
|
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
|
|
backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
tracing::warn!(provider = provider_name, "Switching to fallback provider");
|
|
}
|
|
|
|
anyhow::bail!("All providers failed. Attempts:\n{}", failures.join("\n"))
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
use std::sync::Arc;
|
|
|
|
struct MockProvider {
|
|
calls: Arc<AtomicUsize>,
|
|
fail_until_attempt: usize,
|
|
response: &'static str,
|
|
error: &'static str,
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Provider for MockProvider {
|
|
async fn chat_with_system(
|
|
&self,
|
|
_system_prompt: Option<&str>,
|
|
_message: &str,
|
|
_model: &str,
|
|
_temperature: f64,
|
|
) -> anyhow::Result<String> {
|
|
let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
|
|
if attempt <= self.fail_until_attempt {
|
|
anyhow::bail!(self.error);
|
|
}
|
|
Ok(self.response.to_string())
|
|
}
|
|
|
|
async fn chat_with_history(
|
|
&self,
|
|
_messages: &[ChatMessage],
|
|
_model: &str,
|
|
_temperature: f64,
|
|
) -> anyhow::Result<String> {
|
|
let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
|
|
if attempt <= self.fail_until_attempt {
|
|
anyhow::bail!(self.error);
|
|
}
|
|
Ok(self.response.to_string())
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn succeeds_without_retry() {
|
|
let calls = Arc::new(AtomicUsize::new(0));
|
|
let provider = ReliableProvider::new(
|
|
vec![(
|
|
"primary".into(),
|
|
Box::new(MockProvider {
|
|
calls: Arc::clone(&calls),
|
|
fail_until_attempt: 0,
|
|
response: "ok",
|
|
error: "boom",
|
|
}),
|
|
)],
|
|
2,
|
|
1,
|
|
);
|
|
|
|
let result = provider.chat("hello", "test", 0.0).await.unwrap();
|
|
assert_eq!(result, "ok");
|
|
assert_eq!(calls.load(Ordering::SeqCst), 1);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn retries_then_recovers() {
|
|
let calls = Arc::new(AtomicUsize::new(0));
|
|
let provider = ReliableProvider::new(
|
|
vec![(
|
|
"primary".into(),
|
|
Box::new(MockProvider {
|
|
calls: Arc::clone(&calls),
|
|
fail_until_attempt: 1,
|
|
response: "recovered",
|
|
error: "temporary",
|
|
}),
|
|
)],
|
|
2,
|
|
1,
|
|
);
|
|
|
|
let result = provider.chat("hello", "test", 0.0).await.unwrap();
|
|
assert_eq!(result, "recovered");
|
|
assert_eq!(calls.load(Ordering::SeqCst), 2);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn falls_back_after_retries_exhausted() {
|
|
let primary_calls = Arc::new(AtomicUsize::new(0));
|
|
let fallback_calls = Arc::new(AtomicUsize::new(0));
|
|
|
|
let provider = ReliableProvider::new(
|
|
vec![
|
|
(
|
|
"primary".into(),
|
|
Box::new(MockProvider {
|
|
calls: Arc::clone(&primary_calls),
|
|
fail_until_attempt: usize::MAX,
|
|
response: "never",
|
|
error: "primary down",
|
|
}),
|
|
),
|
|
(
|
|
"fallback".into(),
|
|
Box::new(MockProvider {
|
|
calls: Arc::clone(&fallback_calls),
|
|
fail_until_attempt: 0,
|
|
response: "from fallback",
|
|
error: "fallback down",
|
|
}),
|
|
),
|
|
],
|
|
1,
|
|
1,
|
|
);
|
|
|
|
let result = provider.chat("hello", "test", 0.0).await.unwrap();
|
|
assert_eq!(result, "from fallback");
|
|
assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
|
|
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn returns_aggregated_error_when_all_providers_fail() {
|
|
let provider = ReliableProvider::new(
|
|
vec![
|
|
(
|
|
"p1".into(),
|
|
Box::new(MockProvider {
|
|
calls: Arc::new(AtomicUsize::new(0)),
|
|
fail_until_attempt: usize::MAX,
|
|
response: "never",
|
|
error: "p1 error",
|
|
}),
|
|
),
|
|
(
|
|
"p2".into(),
|
|
Box::new(MockProvider {
|
|
calls: Arc::new(AtomicUsize::new(0)),
|
|
fail_until_attempt: usize::MAX,
|
|
response: "never",
|
|
error: "p2 error",
|
|
}),
|
|
),
|
|
],
|
|
0,
|
|
1,
|
|
);
|
|
|
|
let err = provider
|
|
.chat("hello", "test", 0.0)
|
|
.await
|
|
.expect_err("all providers should fail");
|
|
let msg = err.to_string();
|
|
assert!(msg.contains("All providers failed"));
|
|
assert!(msg.contains("p1 attempt 1/1"));
|
|
assert!(msg.contains("p2 attempt 1/1"));
|
|
}
|
|
|
|
#[test]
|
|
fn non_retryable_detects_common_patterns() {
|
|
// Non-retryable 4xx errors
|
|
assert!(is_non_retryable(&anyhow::anyhow!("400 Bad Request")));
|
|
assert!(is_non_retryable(&anyhow::anyhow!("401 Unauthorized")));
|
|
assert!(is_non_retryable(&anyhow::anyhow!("403 Forbidden")));
|
|
assert!(is_non_retryable(&anyhow::anyhow!("404 Not Found")));
|
|
assert!(is_non_retryable(&anyhow::anyhow!(
|
|
"API error with 400 Bad Request"
|
|
)));
|
|
// Retryable: 429 Too Many Requests
|
|
assert!(!is_non_retryable(&anyhow::anyhow!("429 Too Many Requests")));
|
|
// Retryable: 408 Request Timeout
|
|
assert!(!is_non_retryable(&anyhow::anyhow!("408 Request Timeout")));
|
|
// Retryable: 5xx server errors
|
|
assert!(!is_non_retryable(&anyhow::anyhow!(
|
|
"500 Internal Server Error"
|
|
)));
|
|
assert!(!is_non_retryable(&anyhow::anyhow!("502 Bad Gateway")));
|
|
// Retryable: transient errors
|
|
assert!(!is_non_retryable(&anyhow::anyhow!("timeout")));
|
|
assert!(!is_non_retryable(&anyhow::anyhow!("connection reset")));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn skips_retries_on_non_retryable_error() {
|
|
let primary_calls = Arc::new(AtomicUsize::new(0));
|
|
let fallback_calls = Arc::new(AtomicUsize::new(0));
|
|
|
|
let provider = ReliableProvider::new(
|
|
vec![
|
|
(
|
|
"primary".into(),
|
|
Box::new(MockProvider {
|
|
calls: Arc::clone(&primary_calls),
|
|
fail_until_attempt: usize::MAX,
|
|
response: "never",
|
|
error: "401 Unauthorized",
|
|
}),
|
|
),
|
|
(
|
|
"fallback".into(),
|
|
Box::new(MockProvider {
|
|
calls: Arc::clone(&fallback_calls),
|
|
fail_until_attempt: 0,
|
|
response: "from fallback",
|
|
error: "fallback err",
|
|
}),
|
|
),
|
|
],
|
|
3, // 3 retries allowed, but should skip them
|
|
1,
|
|
);
|
|
|
|
let result = provider.chat("hello", "test", 0.0).await.unwrap();
|
|
assert_eq!(result, "from fallback");
|
|
// Primary should have been called only once (no retries)
|
|
assert_eq!(primary_calls.load(Ordering::SeqCst), 1);
|
|
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn chat_with_history_retries_then_recovers() {
|
|
let calls = Arc::new(AtomicUsize::new(0));
|
|
let provider = ReliableProvider::new(
|
|
vec![(
|
|
"primary".into(),
|
|
Box::new(MockProvider {
|
|
calls: Arc::clone(&calls),
|
|
fail_until_attempt: 1,
|
|
response: "history ok",
|
|
error: "temporary",
|
|
}),
|
|
)],
|
|
2,
|
|
1,
|
|
);
|
|
|
|
let messages = vec![ChatMessage::system("system"), ChatMessage::user("hello")];
|
|
let result = provider
|
|
.chat_with_history(&messages, "test", 0.0)
|
|
.await
|
|
.unwrap();
|
|
assert_eq!(result, "history ok");
|
|
assert_eq!(calls.load(Ordering::SeqCst), 2);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn chat_with_history_falls_back() {
|
|
let primary_calls = Arc::new(AtomicUsize::new(0));
|
|
let fallback_calls = Arc::new(AtomicUsize::new(0));
|
|
|
|
let provider = ReliableProvider::new(
|
|
vec![
|
|
(
|
|
"primary".into(),
|
|
Box::new(MockProvider {
|
|
calls: Arc::clone(&primary_calls),
|
|
fail_until_attempt: usize::MAX,
|
|
response: "never",
|
|
error: "primary down",
|
|
}),
|
|
),
|
|
(
|
|
"fallback".into(),
|
|
Box::new(MockProvider {
|
|
calls: Arc::clone(&fallback_calls),
|
|
fail_until_attempt: 0,
|
|
response: "fallback ok",
|
|
error: "fallback err",
|
|
}),
|
|
),
|
|
],
|
|
1,
|
|
1,
|
|
);
|
|
|
|
let messages = vec![ChatMessage::user("hello")];
|
|
let result = provider
|
|
.chat_with_history(&messages, "test", 0.0)
|
|
.await
|
|
.unwrap();
|
|
assert_eq!(result, "fallback ok");
|
|
assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
|
|
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
|
|
}
|
|
}
|