fix(providers): skip retries on non-retryable HTTP errors (4xx)
Skip retries on non-retryable HTTP client errors (4xx) to avoid wasting time on requests that will never succeed. - Added is_non_retryable() function to detect non-retryable errors - 4xx client errors (400, 401, 403, 404) are now non-retryable - Exceptions: 429 (rate limiting) and 408 (timeout) remain retryable - 5xx server errors remain retryable - Fallback logic now skips retries for non-retryable errors Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
be135e07cf
commit
8694c2e2d2
1 changed files with 96 additions and 0 deletions
|
|
@ -2,6 +2,30 @@ use super::Provider;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
|
/// Check if an error is non-retryable (client errors that won't resolve with retries).
|
||||||
|
fn is_non_retryable(err: &anyhow::Error) -> bool {
|
||||||
|
// Check for reqwest status errors (returned by .error_for_status())
|
||||||
|
if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>() {
|
||||||
|
if let Some(status) = reqwest_err.status() {
|
||||||
|
let code = status.as_u16();
|
||||||
|
// 4xx client errors are non-retryable, except:
|
||||||
|
// - 429 Too Many Requests (rate limiting, transient)
|
||||||
|
// - 408 Request Timeout (transient)
|
||||||
|
return status.is_client_error() && code != 429 && code != 408;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// String fallback: scan for any 4xx status code in error message
|
||||||
|
let msg = err.to_string();
|
||||||
|
for word in msg.split(|c: char| !c.is_ascii_digit()) {
|
||||||
|
if let Ok(code) = word.parse::<u16>() {
|
||||||
|
if (400..500).contains(&code) {
|
||||||
|
return code != 429 && code != 408;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
/// Provider wrapper with retry + fallback behavior.
|
/// Provider wrapper with retry + fallback behavior.
|
||||||
pub struct ReliableProvider {
|
pub struct ReliableProvider {
|
||||||
providers: Vec<(String, Box<dyn Provider>)>,
|
providers: Vec<(String, Box<dyn Provider>)>,
|
||||||
|
|
@ -63,12 +87,21 @@ impl Provider for ReliableProvider {
|
||||||
return Ok(resp);
|
return Ok(resp);
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
let non_retryable = is_non_retryable(&e);
|
||||||
failures.push(format!(
|
failures.push(format!(
|
||||||
"{provider_name} attempt {}/{}: {e}",
|
"{provider_name} attempt {}/{}: {e}",
|
||||||
attempt + 1,
|
attempt + 1,
|
||||||
self.max_retries + 1
|
self.max_retries + 1
|
||||||
));
|
));
|
||||||
|
|
||||||
|
if non_retryable {
|
||||||
|
tracing::warn!(
|
||||||
|
provider = provider_name,
|
||||||
|
"Non-retryable error, switching provider"
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
if attempt < self.max_retries {
|
if attempt < self.max_retries {
|
||||||
tracing::warn!(
|
tracing::warn!(
|
||||||
provider = provider_name,
|
provider = provider_name,
|
||||||
|
|
@ -236,4 +269,67 @@ mod tests {
|
||||||
assert!(msg.contains("p1 attempt 1/1"));
|
assert!(msg.contains("p1 attempt 1/1"));
|
||||||
assert!(msg.contains("p2 attempt 1/1"));
|
assert!(msg.contains("p2 attempt 1/1"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn non_retryable_detects_common_patterns() {
|
||||||
|
// Non-retryable 4xx errors
|
||||||
|
assert!(is_non_retryable(&anyhow::anyhow!("400 Bad Request")));
|
||||||
|
assert!(is_non_retryable(&anyhow::anyhow!("401 Unauthorized")));
|
||||||
|
assert!(is_non_retryable(&anyhow::anyhow!("403 Forbidden")));
|
||||||
|
assert!(is_non_retryable(&anyhow::anyhow!("404 Not Found")));
|
||||||
|
assert!(is_non_retryable(&anyhow::anyhow!(
|
||||||
|
"API error with 400 Bad Request"
|
||||||
|
)));
|
||||||
|
// Retryable: 429 Too Many Requests
|
||||||
|
assert!(!is_non_retryable(&anyhow::anyhow!(
|
||||||
|
"429 Too Many Requests"
|
||||||
|
)));
|
||||||
|
// Retryable: 408 Request Timeout
|
||||||
|
assert!(!is_non_retryable(&anyhow::anyhow!("408 Request Timeout")));
|
||||||
|
// Retryable: 5xx server errors
|
||||||
|
assert!(!is_non_retryable(&anyhow::anyhow!(
|
||||||
|
"500 Internal Server Error"
|
||||||
|
)));
|
||||||
|
assert!(!is_non_retryable(&anyhow::anyhow!("502 Bad Gateway")));
|
||||||
|
// Retryable: transient errors
|
||||||
|
assert!(!is_non_retryable(&anyhow::anyhow!("timeout")));
|
||||||
|
assert!(!is_non_retryable(&anyhow::anyhow!("connection reset")));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn skips_retries_on_non_retryable_error() {
|
||||||
|
let primary_calls = Arc::new(AtomicUsize::new(0));
|
||||||
|
let fallback_calls = Arc::new(AtomicUsize::new(0));
|
||||||
|
|
||||||
|
let provider = ReliableProvider::new(
|
||||||
|
vec![
|
||||||
|
(
|
||||||
|
"primary".into(),
|
||||||
|
Box::new(MockProvider {
|
||||||
|
calls: Arc::clone(&primary_calls),
|
||||||
|
fail_until_attempt: usize::MAX,
|
||||||
|
response: "never",
|
||||||
|
error: "401 Unauthorized",
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"fallback".into(),
|
||||||
|
Box::new(MockProvider {
|
||||||
|
calls: Arc::clone(&fallback_calls),
|
||||||
|
fail_until_attempt: 0,
|
||||||
|
response: "from fallback",
|
||||||
|
error: "fallback err",
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
3, // 3 retries allowed, but should skip them
|
||||||
|
1,
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = provider.chat("hello", "test", 0.0).await.unwrap();
|
||||||
|
assert_eq!(result, "from fallback");
|
||||||
|
// Primary should have been called only once (no retries)
|
||||||
|
assert_eq!(primary_calls.load(Ordering::SeqCst), 1);
|
||||||
|
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue