diff --git a/src/providers/reliable.rs b/src/providers/reliable.rs index 5c20c52..791f13d 100644 --- a/src/providers/reliable.rs +++ b/src/providers/reliable.rs @@ -2,6 +2,30 @@ use super::Provider; use async_trait::async_trait; use std::time::Duration; +/// Check if an error is non-retryable (client errors that won't resolve with retries). +fn is_non_retryable(err: &anyhow::Error) -> bool { + // Check for reqwest status errors (returned by .error_for_status()) + if let Some(reqwest_err) = err.downcast_ref::() { + if let Some(status) = reqwest_err.status() { + let code = status.as_u16(); + // 4xx client errors are non-retryable, except: + // - 429 Too Many Requests (rate limiting, transient) + // - 408 Request Timeout (transient) + return status.is_client_error() && code != 429 && code != 408; + } + } + // String fallback: scan for any 4xx status code in error message + let msg = err.to_string(); + for word in msg.split(|c: char| !c.is_ascii_digit()) { + if let Ok(code) = word.parse::() { + if (400..500).contains(&code) { + return code != 429 && code != 408; + } + } + } + false +} + /// Provider wrapper with retry + fallback behavior. pub struct ReliableProvider { providers: Vec<(String, Box)>, @@ -63,12 +87,21 @@ impl Provider for ReliableProvider { return Ok(resp); } Err(e) => { + let non_retryable = is_non_retryable(&e); failures.push(format!( "{provider_name} attempt {}/{}: {e}", attempt + 1, self.max_retries + 1 )); + if non_retryable { + tracing::warn!( + provider = provider_name, + "Non-retryable error, switching provider" + ); + break; + } + if attempt < self.max_retries { tracing::warn!( provider = provider_name, @@ -236,4 +269,67 @@ mod tests { assert!(msg.contains("p1 attempt 1/1")); assert!(msg.contains("p2 attempt 1/1")); } + + #[test] + fn non_retryable_detects_common_patterns() { + // Non-retryable 4xx errors + assert!(is_non_retryable(&anyhow::anyhow!("400 Bad Request"))); + assert!(is_non_retryable(&anyhow::anyhow!("401 Unauthorized"))); + assert!(is_non_retryable(&anyhow::anyhow!("403 Forbidden"))); + assert!(is_non_retryable(&anyhow::anyhow!("404 Not Found"))); + assert!(is_non_retryable(&anyhow::anyhow!( + "API error with 400 Bad Request" + ))); + // Retryable: 429 Too Many Requests + assert!(!is_non_retryable(&anyhow::anyhow!( + "429 Too Many Requests" + ))); + // Retryable: 408 Request Timeout + assert!(!is_non_retryable(&anyhow::anyhow!("408 Request Timeout"))); + // Retryable: 5xx server errors + assert!(!is_non_retryable(&anyhow::anyhow!( + "500 Internal Server Error" + ))); + assert!(!is_non_retryable(&anyhow::anyhow!("502 Bad Gateway"))); + // Retryable: transient errors + assert!(!is_non_retryable(&anyhow::anyhow!("timeout"))); + assert!(!is_non_retryable(&anyhow::anyhow!("connection reset"))); + } + + #[tokio::test] + async fn skips_retries_on_non_retryable_error() { + let primary_calls = Arc::new(AtomicUsize::new(0)); + let fallback_calls = Arc::new(AtomicUsize::new(0)); + + let provider = ReliableProvider::new( + vec![ + ( + "primary".into(), + Box::new(MockProvider { + calls: Arc::clone(&primary_calls), + fail_until_attempt: usize::MAX, + response: "never", + error: "401 Unauthorized", + }), + ), + ( + "fallback".into(), + Box::new(MockProvider { + calls: Arc::clone(&fallback_calls), + fail_until_attempt: 0, + response: "from fallback", + error: "fallback err", + }), + ), + ], + 3, // 3 retries allowed, but should skip them + 1, + ); + + let result = provider.chat("hello", "test", 0.0).await.unwrap(); + assert_eq!(result, "from fallback"); + // Primary should have been called only once (no retries) + assert_eq!(primary_calls.load(Ordering::SeqCst), 1); + assert_eq!(fallback_calls.load(Ordering::SeqCst), 1); + } }