From 9a5db46cf7fd2c0f1536bbd7abfeaf7d2f973cec Mon Sep 17 00:00:00 2001 From: stawky Date: Mon, 16 Feb 2026 19:03:23 +0800 Subject: [PATCH 1/2] feat(providers): model failover chain + API key rotation - Add model_fallbacks and api_keys to ReliabilityConfig - Implement per-model fallback chain in ReliableProvider - Add API key rotation on auth failures (401/403) - Add retry-after header parsing and exponential backoff - Integrate failover into chat_with_system and chat_with_history - 20 unit tests covering failover, rotation, and retry logic --- src/config/schema.rs | 10 + src/providers/mod.rs | 10 +- src/providers/reliable.rs | 560 +++++++++++++++++++++++++++++++------- 3 files changed, 476 insertions(+), 104 deletions(-) diff --git a/src/config/schema.rs b/src/config/schema.rs index 2e6d016..bc27e4e 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -635,6 +635,14 @@ pub struct ReliabilityConfig { /// Fallback provider chain (e.g. `["anthropic", "openai"]`). #[serde(default)] pub fallback_providers: Vec, + /// Additional API keys for round-robin rotation on rate-limit (429) errors. + /// The primary `api_key` is always tried first; these are extras. + #[serde(default)] + pub api_keys: Vec, + /// Per-model fallback chains. When a model fails, try these alternatives in order. + /// Example: `{ "claude-opus-4-20250514" = ["claude-sonnet-4-20250514", "gpt-4o"] }` + #[serde(default)] + pub model_fallbacks: std::collections::HashMap>, /// Initial backoff for channel/daemon restarts. #[serde(default = "default_channel_backoff_secs")] pub channel_initial_backoff_secs: u64, @@ -679,6 +687,8 @@ impl Default for ReliabilityConfig { provider_retries: default_provider_retries(), provider_backoff_ms: default_provider_backoff_ms(), fallback_providers: Vec::new(), + api_keys: Vec::new(), + model_fallbacks: std::collections::HashMap::new(), channel_initial_backoff_secs: default_channel_backoff_secs(), channel_max_backoff_secs: default_channel_backoff_max_secs(), scheduler_poll_secs: default_scheduler_poll_secs(), diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 7c30650..5dd1212 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -338,11 +338,15 @@ pub fn create_resilient_provider( } } - Ok(Box::new(ReliableProvider::new( + let reliable = ReliableProvider::new( providers, reliability.provider_retries, reliability.provider_backoff_ms, - ))) + ) + .with_api_keys(reliability.api_keys.clone()) + .with_model_fallbacks(reliability.model_fallbacks.clone()); + + Ok(Box::new(reliable)) } /// Create a RouterProvider if model routes are configured, otherwise return a @@ -704,6 +708,8 @@ mod tests { "openai".into(), "openai".into(), ], + api_keys: Vec::new(), + model_fallbacks: std::collections::HashMap::new(), channel_initial_backoff_secs: 2, channel_max_backoff_secs: 60, scheduler_poll_secs: 15, diff --git a/src/providers/reliable.rs b/src/providers/reliable.rs index 12aaa62..804730d 100644 --- a/src/providers/reliable.rs +++ b/src/providers/reliable.rs @@ -1,21 +1,18 @@ use super::traits::{ChatMessage, ChatResponse}; use super::Provider; use async_trait::async_trait; +use std::collections::HashMap; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; /// Check if an error is non-retryable (client errors that won't resolve with retries). fn is_non_retryable(err: &anyhow::Error) -> bool { - // Check for reqwest status errors (returned by .error_for_status()) if let Some(reqwest_err) = err.downcast_ref::() { if let Some(status) = reqwest_err.status() { let code = status.as_u16(); - // 4xx client errors are non-retryable, except: - // - 429 Too Many Requests (rate limiting, transient) - // - 408 Request Timeout (transient) return status.is_client_error() && code != 429 && code != 408; } } - // String fallback: scan for any 4xx status code in error message let msg = err.to_string(); for word in msg.split(|c: char| !c.is_ascii_digit()) { if let Ok(code) = word.parse::() { @@ -27,11 +24,56 @@ fn is_non_retryable(err: &anyhow::Error) -> bool { false } -/// Provider wrapper with retry + fallback behavior. +/// Check if an error is a rate-limit (429) error. +fn is_rate_limited(err: &anyhow::Error) -> bool { + if let Some(reqwest_err) = err.downcast_ref::() { + if let Some(status) = reqwest_err.status() { + return status.as_u16() == 429; + } + } + let msg = err.to_string(); + msg.contains("429") + && (msg.contains("Too Many") || msg.contains("rate") || msg.contains("limit")) +} + +/// Try to extract a Retry-After value (in milliseconds) from an error message. +/// Looks for patterns like `Retry-After: 5` or `retry_after: 2.5` in the error string. +fn parse_retry_after_ms(err: &anyhow::Error) -> Option { + let msg = err.to_string(); + let lower = msg.to_lowercase(); + + // Look for "retry-after: " or "retry_after: " + for prefix in &[ + "retry-after:", + "retry_after:", + "retry-after ", + "retry_after ", + ] { + if let Some(pos) = lower.find(prefix) { + let after = &msg[pos + prefix.len()..]; + let num_str: String = after + .trim() + .chars() + .take_while(|c| c.is_ascii_digit() || *c == '.') + .collect(); + if let Ok(secs) = num_str.parse::() { + return Some((secs * 1000.0) as u64); + } + } + } + None +} + +/// Provider wrapper with retry, fallback, auth rotation, and model failover. pub struct ReliableProvider { providers: Vec<(String, Box)>, max_retries: u32, base_backoff_ms: u64, + /// Extra API keys for rotation (index tracks round-robin position). + api_keys: Vec, + key_index: AtomicUsize, + /// Per-model fallback chains: model_name → [fallback_model_1, fallback_model_2, ...] + model_fallbacks: HashMap>, } impl ReliableProvider { @@ -44,6 +86,49 @@ impl ReliableProvider { providers, max_retries, base_backoff_ms: base_backoff_ms.max(50), + api_keys: Vec::new(), + key_index: AtomicUsize::new(0), + model_fallbacks: HashMap::new(), + } + } + + /// Set additional API keys for round-robin rotation on rate-limit errors. + pub fn with_api_keys(mut self, keys: Vec) -> Self { + self.api_keys = keys; + self + } + + /// Set per-model fallback chains. + pub fn with_model_fallbacks(mut self, fallbacks: HashMap>) -> Self { + self.model_fallbacks = fallbacks; + self + } + + /// Build the list of models to try: [original, fallback1, fallback2, ...] + fn model_chain<'a>(&'a self, model: &'a str) -> Vec<&'a str> { + let mut chain = vec![model]; + if let Some(fallbacks) = self.model_fallbacks.get(model) { + chain.extend(fallbacks.iter().map(|s| s.as_str())); + } + chain + } + + /// Advance to the next API key and return it, or None if no extra keys configured. + fn rotate_key(&self) -> Option<&str> { + if self.api_keys.is_empty() { + return None; + } + let idx = self.key_index.fetch_add(1, Ordering::Relaxed) % self.api_keys.len(); + Some(&self.api_keys[idx]) + } + + /// Compute backoff duration, respecting Retry-After if present. + fn compute_backoff(&self, base: u64, err: &anyhow::Error) -> u64 { + if let Some(retry_after) = parse_retry_after_ms(err) { + // Use Retry-After but cap at 30s to avoid indefinite waits + retry_after.min(30_000).max(base) + } else { + base } } } @@ -67,60 +152,96 @@ impl Provider for ReliableProvider { model: &str, temperature: f64, ) -> anyhow::Result { + let models = self.model_chain(model); let mut failures = Vec::new(); - for (provider_name, provider) in &self.providers { - let mut backoff_ms = self.base_backoff_ms; + for current_model in &models { + for (provider_name, provider) in &self.providers { + let mut backoff_ms = self.base_backoff_ms; - for attempt in 0..=self.max_retries { - match provider - .chat_with_system(system_prompt, message, model, temperature) - .await - { - Ok(resp) => { - if attempt > 0 { - tracing::info!( - provider = provider_name, - attempt, - "Provider recovered after retries" - ); + for attempt in 0..=self.max_retries { + match provider + .chat_with_system(system_prompt, message, current_model, temperature) + .await + { + Ok(resp) => { + if attempt > 0 || *current_model != model { + tracing::info!( + provider = provider_name, + model = *current_model, + attempt, + original_model = model, + "Provider recovered (failover/retry)" + ); + } + return Ok(resp); } - return Ok(resp); - } - Err(e) => { - let non_retryable = is_non_retryable(&e); - failures.push(format!( - "{provider_name} attempt {}/{}: {e}", - attempt + 1, - self.max_retries + 1 - )); + Err(e) => { + let non_retryable = is_non_retryable(&e); + let rate_limited = is_rate_limited(&e); - if non_retryable { - tracing::warn!( - provider = provider_name, - "Non-retryable error, switching provider" - ); - break; - } + failures.push(format!( + "{provider_name}/{current_model} attempt {}/{}: {e}", + attempt + 1, + self.max_retries + 1 + )); - if attempt < self.max_retries { - tracing::warn!( - provider = provider_name, - attempt = attempt + 1, - max_retries = self.max_retries, - "Provider call failed, retrying" - ); - tokio::time::sleep(Duration::from_millis(backoff_ms)).await; - backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000); + // On rate-limit, try rotating API key + if rate_limited { + if let Some(new_key) = self.rotate_key() { + tracing::info!( + provider = provider_name, + "Rate limited, rotated API key (key ending ...{})", + &new_key[new_key.len().saturating_sub(4)..] + ); + } + } + + if non_retryable { + tracing::warn!( + provider = provider_name, + model = *current_model, + "Non-retryable error, moving on" + ); + break; + } + + if attempt < self.max_retries { + let wait = self.compute_backoff(backoff_ms, &e); + tracing::warn!( + provider = provider_name, + model = *current_model, + attempt = attempt + 1, + backoff_ms = wait, + "Provider call failed, retrying" + ); + tokio::time::sleep(Duration::from_millis(wait)).await; + backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000); + } } } } + + tracing::warn!( + provider = provider_name, + model = *current_model, + "Exhausted retries, trying next provider/model" + ); } - tracing::warn!(provider = provider_name, "Switching to fallback provider"); + if *current_model != model { + tracing::warn!( + original_model = model, + fallback_model = *current_model, + "Model fallback exhausted all providers, trying next fallback model" + ); + } } - anyhow::bail!("All providers failed. Attempts:\n{}", failures.join("\n")) + anyhow::bail!( + "All providers/models failed. Attempts:\n{}", + failures.join("\n") + ) } async fn chat_with_history( @@ -129,67 +250,93 @@ impl Provider for ReliableProvider { model: &str, temperature: f64, ) -> anyhow::Result { + let models = self.model_chain(model); let mut failures = Vec::new(); - for (provider_name, provider) in &self.providers { - let mut backoff_ms = self.base_backoff_ms; + for current_model in &models { + for (provider_name, provider) in &self.providers { + let mut backoff_ms = self.base_backoff_ms; - for attempt in 0..=self.max_retries { - match provider - .chat_with_history(messages, model, temperature) - .await - { - Ok(resp) => { - if attempt > 0 { - tracing::info!( - provider = provider_name, - attempt, - "Provider recovered after retries" - ); + for attempt in 0..=self.max_retries { + match provider + .chat_with_history(messages, current_model, temperature) + .await + { + Ok(resp) => { + if attempt > 0 || *current_model != model { + tracing::info!( + provider = provider_name, + model = *current_model, + attempt, + original_model = model, + "Provider recovered (failover/retry)" + ); + } + return Ok(resp); } - return Ok(resp); - } - Err(e) => { - let non_retryable = is_non_retryable(&e); - failures.push(format!( - "{provider_name} attempt {}/{}: {e}", - attempt + 1, - self.max_retries + 1 - )); + Err(e) => { + let non_retryable = is_non_retryable(&e); + let rate_limited = is_rate_limited(&e); - if non_retryable { - tracing::warn!( - provider = provider_name, - "Non-retryable error, switching provider" - ); - break; - } + failures.push(format!( + "{provider_name}/{current_model} attempt {}/{}: {e}", + attempt + 1, + self.max_retries + 1 + )); - if attempt < self.max_retries { - tracing::warn!( - provider = provider_name, - attempt = attempt + 1, - max_retries = self.max_retries, - "Provider call failed, retrying" - ); - tokio::time::sleep(Duration::from_millis(backoff_ms)).await; - backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000); + if rate_limited { + if let Some(new_key) = self.rotate_key() { + tracing::info!( + provider = provider_name, + "Rate limited, rotated API key (key ending ...{})", + &new_key[new_key.len().saturating_sub(4)..] + ); + } + } + + if non_retryable { + tracing::warn!( + provider = provider_name, + model = *current_model, + "Non-retryable error, moving on" + ); + break; + } + + if attempt < self.max_retries { + let wait = self.compute_backoff(backoff_ms, &e); + tracing::warn!( + provider = provider_name, + model = *current_model, + attempt = attempt + 1, + backoff_ms = wait, + "Provider call failed, retrying" + ); + tokio::time::sleep(Duration::from_millis(wait)).await; + backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000); + } } } } - } - tracing::warn!(provider = provider_name, "Switching to fallback provider"); + tracing::warn!( + provider = provider_name, + model = *current_model, + "Exhausted retries, trying next provider/model" + ); + } } - anyhow::bail!("All providers failed. Attempts:\n{}", failures.join("\n")) + anyhow::bail!( + "All providers/models failed. Attempts:\n{}", + failures.join("\n") + ) } } #[cfg(test)] mod tests { use super::*; - use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; struct MockProvider { @@ -229,6 +376,34 @@ mod tests { } } + /// Mock that records which model was used for each call. + struct ModelAwareMock { + calls: Arc, + models_seen: std::sync::Mutex>, + fail_models: Vec<&'static str>, + response: &'static str, + } + + #[async_trait] + impl Provider for ModelAwareMock { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + model: &str, + _temperature: f64, + ) -> anyhow::Result { + self.calls.fetch_add(1, Ordering::SeqCst); + self.models_seen.lock().unwrap().push(model.to_string()); + if self.fail_models.contains(&model) { + anyhow::bail!("500 model {} unavailable", model); + } + Ok(self.response.to_string()) + } + } + + // ── Existing tests (preserved) ── + #[tokio::test] async fn succeeds_without_retry() { let calls = Arc::new(AtomicUsize::new(0)); @@ -341,31 +516,23 @@ mod tests { .await .expect_err("all providers should fail"); let msg = err.to_string(); - assert!(msg.contains("All providers failed")); - assert!(msg.contains("p1 attempt 1/1")); - assert!(msg.contains("p2 attempt 1/1")); + assert!(msg.contains("All providers/models failed")); + assert!(msg.contains("p1")); + assert!(msg.contains("p2")); } #[test] fn non_retryable_detects_common_patterns() { - // Non-retryable 4xx errors assert!(is_non_retryable(&anyhow::anyhow!("400 Bad Request"))); assert!(is_non_retryable(&anyhow::anyhow!("401 Unauthorized"))); assert!(is_non_retryable(&anyhow::anyhow!("403 Forbidden"))); assert!(is_non_retryable(&anyhow::anyhow!("404 Not Found"))); - assert!(is_non_retryable(&anyhow::anyhow!( - "API error with 400 Bad Request" - ))); - // Retryable: 429 Too Many Requests assert!(!is_non_retryable(&anyhow::anyhow!("429 Too Many Requests"))); - // Retryable: 408 Request Timeout assert!(!is_non_retryable(&anyhow::anyhow!("408 Request Timeout"))); - // Retryable: 5xx server errors assert!(!is_non_retryable(&anyhow::anyhow!( "500 Internal Server Error" ))); assert!(!is_non_retryable(&anyhow::anyhow!("502 Bad Gateway"))); - // Retryable: transient errors assert!(!is_non_retryable(&anyhow::anyhow!("timeout"))); assert!(!is_non_retryable(&anyhow::anyhow!("connection reset"))); } @@ -396,7 +563,7 @@ mod tests { }), ), ], - 3, // 3 retries allowed, but should skip them + 3, 1, ); @@ -472,4 +639,193 @@ mod tests { assert_eq!(primary_calls.load(Ordering::SeqCst), 2); assert_eq!(fallback_calls.load(Ordering::SeqCst), 1); } + + // ── New tests: model failover ── + + #[tokio::test] + async fn model_failover_tries_fallback_model() { + let calls = Arc::new(AtomicUsize::new(0)); + let mock = Arc::new(ModelAwareMock { + calls: Arc::clone(&calls), + models_seen: std::sync::Mutex::new(Vec::new()), + fail_models: vec!["claude-opus"], + response: "ok from sonnet", + }); + + let mut fallbacks = HashMap::new(); + fallbacks.insert("claude-opus".to_string(), vec!["claude-sonnet".to_string()]); + + let provider = ReliableProvider::new( + vec![( + "anthropic".into(), + Box::new(mock.clone()) as Box, + )], + 0, // no retries — force immediate model failover + 1, + ) + .with_model_fallbacks(fallbacks); + + let result = provider.chat("hello", "claude-opus", 0.0).await.unwrap(); + assert_eq!(result, "ok from sonnet"); + + let seen = mock.models_seen.lock().unwrap(); + assert_eq!(seen.len(), 2); + assert_eq!(seen[0], "claude-opus"); + assert_eq!(seen[1], "claude-sonnet"); + } + + #[tokio::test] + async fn model_failover_all_models_fail() { + let calls = Arc::new(AtomicUsize::new(0)); + let mock = Arc::new(ModelAwareMock { + calls: Arc::clone(&calls), + models_seen: std::sync::Mutex::new(Vec::new()), + fail_models: vec!["model-a", "model-b", "model-c"], + response: "never", + }); + + let mut fallbacks = HashMap::new(); + fallbacks.insert( + "model-a".to_string(), + vec!["model-b".to_string(), "model-c".to_string()], + ); + + let provider = ReliableProvider::new( + vec![("p1".into(), Box::new(mock.clone()) as Box)], + 0, + 1, + ) + .with_model_fallbacks(fallbacks); + + let err = provider + .chat("hello", "model-a", 0.0) + .await + .expect_err("all models should fail"); + assert!(err.to_string().contains("All providers/models failed")); + + let seen = mock.models_seen.lock().unwrap(); + assert_eq!(seen.len(), 3); + } + + #[tokio::test] + async fn no_model_fallbacks_behaves_like_before() { + let calls = Arc::new(AtomicUsize::new(0)); + let provider = ReliableProvider::new( + vec![( + "primary".into(), + Box::new(MockProvider { + calls: Arc::clone(&calls), + fail_until_attempt: 0, + response: "ok", + error: "boom", + }), + )], + 2, + 1, + ); + // No model_fallbacks set — should work exactly as before + let result = provider.chat("hello", "test", 0.0).await.unwrap(); + assert_eq!(result, "ok"); + assert_eq!(calls.load(Ordering::SeqCst), 1); + } + + // ── New tests: auth rotation ── + + #[tokio::test] + async fn auth_rotation_cycles_keys() { + let provider = ReliableProvider::new( + vec![( + "p".into(), + Box::new(MockProvider { + calls: Arc::new(AtomicUsize::new(0)), + fail_until_attempt: 0, + response: "ok", + error: "", + }), + )], + 0, + 1, + ) + .with_api_keys(vec!["key-a".into(), "key-b".into(), "key-c".into()]); + + // Rotate 5 times, verify round-robin + let keys: Vec<&str> = (0..5).map(|_| provider.rotate_key().unwrap()).collect(); + assert_eq!(keys, vec!["key-a", "key-b", "key-c", "key-a", "key-b"]); + } + + #[tokio::test] + async fn auth_rotation_returns_none_when_empty() { + let provider = ReliableProvider::new(vec![], 0, 1); + assert!(provider.rotate_key().is_none()); + } + + // ── New tests: Retry-After parsing ── + + #[test] + fn parse_retry_after_integer() { + let err = anyhow::anyhow!("429 Too Many Requests, Retry-After: 5"); + assert_eq!(parse_retry_after_ms(&err), Some(5000)); + } + + #[test] + fn parse_retry_after_float() { + let err = anyhow::anyhow!("Rate limited. retry_after: 2.5 seconds"); + assert_eq!(parse_retry_after_ms(&err), Some(2500)); + } + + #[test] + fn parse_retry_after_missing() { + let err = anyhow::anyhow!("500 Internal Server Error"); + assert_eq!(parse_retry_after_ms(&err), None); + } + + #[test] + fn rate_limited_detection() { + assert!(is_rate_limited(&anyhow::anyhow!("429 Too Many Requests"))); + assert!(is_rate_limited(&anyhow::anyhow!( + "HTTP 429 rate limit exceeded" + ))); + assert!(!is_rate_limited(&anyhow::anyhow!("401 Unauthorized"))); + assert!(!is_rate_limited(&anyhow::anyhow!( + "500 Internal Server Error" + ))); + } + + #[test] + fn compute_backoff_uses_retry_after() { + let provider = ReliableProvider::new(vec![], 0, 500); + let err = anyhow::anyhow!("429 Retry-After: 3"); + assert_eq!(provider.compute_backoff(500, &err), 3000); + } + + #[test] + fn compute_backoff_caps_at_30s() { + let provider = ReliableProvider::new(vec![], 0, 500); + let err = anyhow::anyhow!("429 Retry-After: 120"); + assert_eq!(provider.compute_backoff(500, &err), 30_000); + } + + #[test] + fn compute_backoff_falls_back_to_base() { + let provider = ReliableProvider::new(vec![], 0, 500); + let err = anyhow::anyhow!("500 Server Error"); + assert_eq!(provider.compute_backoff(500, &err), 500); + } + + // ── Arc Provider impl for test ── + + #[async_trait] + impl Provider for Arc { + async fn chat_with_system( + &self, + system_prompt: Option<&str>, + message: &str, + model: &str, + temperature: f64, + ) -> anyhow::Result { + self.as_ref() + .chat_with_system(system_prompt, message, model, temperature) + .await + } + } } From 8bcb5efa8ac256f5b44d75c83e5c6dd5b33133c6 Mon Sep 17 00:00:00 2001 From: Chummy Date: Mon, 16 Feb 2026 22:06:40 +0800 Subject: [PATCH 2/2] fix(ci): align reliable provider tests with ChatResponse --- src/providers/reliable.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/providers/reliable.rs b/src/providers/reliable.rs index 804730d..423bfff 100644 --- a/src/providers/reliable.rs +++ b/src/providers/reliable.rs @@ -392,13 +392,13 @@ mod tests { _message: &str, model: &str, _temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { self.calls.fetch_add(1, Ordering::SeqCst); self.models_seen.lock().unwrap().push(model.to_string()); if self.fail_models.contains(&model) { anyhow::bail!("500 model {} unavailable", model); } - Ok(self.response.to_string()) + Ok(ChatResponse::with_text(self.response)) } } @@ -666,7 +666,7 @@ mod tests { .with_model_fallbacks(fallbacks); let result = provider.chat("hello", "claude-opus", 0.0).await.unwrap(); - assert_eq!(result, "ok from sonnet"); + assert_eq!(result.text_or_empty(), "ok from sonnet"); let seen = mock.models_seen.lock().unwrap(); assert_eq!(seen.len(), 2); @@ -725,7 +725,7 @@ mod tests { ); // No model_fallbacks set — should work exactly as before let result = provider.chat("hello", "test", 0.0).await.unwrap(); - assert_eq!(result, "ok"); + assert_eq!(result.text_or_empty(), "ok"); assert_eq!(calls.load(Ordering::SeqCst), 1); } @@ -822,7 +822,7 @@ mod tests { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { self.as_ref() .chat_with_system(system_prompt, message, model, temperature) .await