zeroclaw/src/providers/reliable.rs
Argenis 8694c2e2d2
fix(providers): skip retries on non-retryable HTTP errors (4xx)
Skip retries on non-retryable HTTP client errors (4xx) to avoid wasting time on requests that will never succeed.

- Added is_non_retryable() function to detect non-retryable errors
- 4xx client errors (400, 401, 403, 404) are now non-retryable
- Exceptions: 429 (rate limiting) and 408 (timeout) remain retryable
- 5xx server errors remain retryable
- Fallback logic now skips retries for non-retryable errors

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 10:11:32 -05:00

335 lines
11 KiB
Rust

use super::Provider;
use async_trait::async_trait;
use std::time::Duration;
/// Check if an error is non-retryable (client errors that won't resolve with retries).
fn is_non_retryable(err: &anyhow::Error) -> bool {
// Check for reqwest status errors (returned by .error_for_status())
if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>() {
if let Some(status) = reqwest_err.status() {
let code = status.as_u16();
// 4xx client errors are non-retryable, except:
// - 429 Too Many Requests (rate limiting, transient)
// - 408 Request Timeout (transient)
return status.is_client_error() && code != 429 && code != 408;
}
}
// String fallback: scan for any 4xx status code in error message
let msg = err.to_string();
for word in msg.split(|c: char| !c.is_ascii_digit()) {
if let Ok(code) = word.parse::<u16>() {
if (400..500).contains(&code) {
return code != 429 && code != 408;
}
}
}
false
}
/// Provider wrapper with retry + fallback behavior.
pub struct ReliableProvider {
providers: Vec<(String, Box<dyn Provider>)>,
max_retries: u32,
base_backoff_ms: u64,
}
impl ReliableProvider {
pub fn new(
providers: Vec<(String, Box<dyn Provider>)>,
max_retries: u32,
base_backoff_ms: u64,
) -> Self {
Self {
providers,
max_retries,
base_backoff_ms: base_backoff_ms.max(50),
}
}
}
#[async_trait]
impl Provider for ReliableProvider {
async fn warmup(&self) -> anyhow::Result<()> {
for (name, provider) in &self.providers {
tracing::info!(provider = name, "Warming up provider connection pool");
if let Err(e) = provider.warmup().await {
tracing::warn!(provider = name, "Warmup failed (non-fatal): {e}");
}
}
Ok(())
}
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let mut failures = Vec::new();
for (provider_name, provider) in &self.providers {
let mut backoff_ms = self.base_backoff_ms;
for attempt in 0..=self.max_retries {
match provider
.chat_with_system(system_prompt, message, model, temperature)
.await
{
Ok(resp) => {
if attempt > 0 {
tracing::info!(
provider = provider_name,
attempt,
"Provider recovered after retries"
);
}
return Ok(resp);
}
Err(e) => {
let non_retryable = is_non_retryable(&e);
failures.push(format!(
"{provider_name} attempt {}/{}: {e}",
attempt + 1,
self.max_retries + 1
));
if non_retryable {
tracing::warn!(
provider = provider_name,
"Non-retryable error, switching provider"
);
break;
}
if attempt < self.max_retries {
tracing::warn!(
provider = provider_name,
attempt = attempt + 1,
max_retries = self.max_retries,
"Provider call failed, retrying"
);
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
}
}
}
}
tracing::warn!(provider = provider_name, "Switching to fallback provider");
}
anyhow::bail!("All providers failed. Attempts:\n{}", failures.join("\n"))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
struct MockProvider {
calls: Arc<AtomicUsize>,
fail_until_attempt: usize,
response: &'static str,
error: &'static str,
}
#[async_trait]
impl Provider for MockProvider {
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
if attempt <= self.fail_until_attempt {
anyhow::bail!(self.error);
}
Ok(self.response.to_string())
}
}
#[tokio::test]
async fn succeeds_without_retry() {
let calls = Arc::new(AtomicUsize::new(0));
let provider = ReliableProvider::new(
vec![(
"primary".into(),
Box::new(MockProvider {
calls: Arc::clone(&calls),
fail_until_attempt: 0,
response: "ok",
error: "boom",
}),
)],
2,
1,
);
let result = provider.chat("hello", "test", 0.0).await.unwrap();
assert_eq!(result, "ok");
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn retries_then_recovers() {
let calls = Arc::new(AtomicUsize::new(0));
let provider = ReliableProvider::new(
vec![(
"primary".into(),
Box::new(MockProvider {
calls: Arc::clone(&calls),
fail_until_attempt: 1,
response: "recovered",
error: "temporary",
}),
)],
2,
1,
);
let result = provider.chat("hello", "test", 0.0).await.unwrap();
assert_eq!(result, "recovered");
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn falls_back_after_retries_exhausted() {
let primary_calls = Arc::new(AtomicUsize::new(0));
let fallback_calls = Arc::new(AtomicUsize::new(0));
let provider = ReliableProvider::new(
vec![
(
"primary".into(),
Box::new(MockProvider {
calls: Arc::clone(&primary_calls),
fail_until_attempt: usize::MAX,
response: "never",
error: "primary down",
}),
),
(
"fallback".into(),
Box::new(MockProvider {
calls: Arc::clone(&fallback_calls),
fail_until_attempt: 0,
response: "from fallback",
error: "fallback down",
}),
),
],
1,
1,
);
let result = provider.chat("hello", "test", 0.0).await.unwrap();
assert_eq!(result, "from fallback");
assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn returns_aggregated_error_when_all_providers_fail() {
let provider = ReliableProvider::new(
vec![
(
"p1".into(),
Box::new(MockProvider {
calls: Arc::new(AtomicUsize::new(0)),
fail_until_attempt: usize::MAX,
response: "never",
error: "p1 error",
}),
),
(
"p2".into(),
Box::new(MockProvider {
calls: Arc::new(AtomicUsize::new(0)),
fail_until_attempt: usize::MAX,
response: "never",
error: "p2 error",
}),
),
],
0,
1,
);
let err = provider
.chat("hello", "test", 0.0)
.await
.expect_err("all providers should fail");
let msg = err.to_string();
assert!(msg.contains("All providers failed"));
assert!(msg.contains("p1 attempt 1/1"));
assert!(msg.contains("p2 attempt 1/1"));
}
#[test]
fn non_retryable_detects_common_patterns() {
// Non-retryable 4xx errors
assert!(is_non_retryable(&anyhow::anyhow!("400 Bad Request")));
assert!(is_non_retryable(&anyhow::anyhow!("401 Unauthorized")));
assert!(is_non_retryable(&anyhow::anyhow!("403 Forbidden")));
assert!(is_non_retryable(&anyhow::anyhow!("404 Not Found")));
assert!(is_non_retryable(&anyhow::anyhow!(
"API error with 400 Bad Request"
)));
// Retryable: 429 Too Many Requests
assert!(!is_non_retryable(&anyhow::anyhow!(
"429 Too Many Requests"
)));
// Retryable: 408 Request Timeout
assert!(!is_non_retryable(&anyhow::anyhow!("408 Request Timeout")));
// Retryable: 5xx server errors
assert!(!is_non_retryable(&anyhow::anyhow!(
"500 Internal Server Error"
)));
assert!(!is_non_retryable(&anyhow::anyhow!("502 Bad Gateway")));
// Retryable: transient errors
assert!(!is_non_retryable(&anyhow::anyhow!("timeout")));
assert!(!is_non_retryable(&anyhow::anyhow!("connection reset")));
}
#[tokio::test]
async fn skips_retries_on_non_retryable_error() {
let primary_calls = Arc::new(AtomicUsize::new(0));
let fallback_calls = Arc::new(AtomicUsize::new(0));
let provider = ReliableProvider::new(
vec![
(
"primary".into(),
Box::new(MockProvider {
calls: Arc::clone(&primary_calls),
fail_until_attempt: usize::MAX,
response: "never",
error: "401 Unauthorized",
}),
),
(
"fallback".into(),
Box::new(MockProvider {
calls: Arc::clone(&fallback_calls),
fail_until_attempt: 0,
response: "from fallback",
error: "fallback err",
}),
),
],
3, // 3 retries allowed, but should skip them
1,
);
let result = provider.chat("hello", "test", 0.0).await.unwrap();
assert_eq!(result, "from fallback");
// Primary should have been called only once (no retries)
assert_eq!(primary_calls.load(Ordering::SeqCst), 1);
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
}
}