diff --git a/src/providers/reliable.rs b/src/providers/reliable.rs index 4c7a44b..29f1903 100644 --- a/src/providers/reliable.rs +++ b/src/providers/reliable.rs @@ -1,4 +1,4 @@ -use super::traits::{ChatMessage, StreamChunk, StreamOptions, StreamResult}; +use super::traits::{ChatMessage, ChatResponse, StreamChunk, StreamOptions, StreamResult}; use super::Provider; use async_trait::async_trait; use futures_util::{stream, StreamExt}; @@ -353,6 +353,110 @@ impl Provider for ReliableProvider { ) } + fn supports_native_tools(&self) -> bool { + self.providers + .first() + .map(|(_, p)| p.supports_native_tools()) + .unwrap_or(false) + } + + async fn chat_with_tools( + &self, + messages: &[ChatMessage], + tools: &[serde_json::Value], + model: &str, + temperature: f64, + ) -> anyhow::Result { + let models = self.model_chain(model); + let mut failures = Vec::new(); + + for current_model in &models { + for (provider_name, provider) in &self.providers { + let mut backoff_ms = self.base_backoff_ms; + + for attempt in 0..=self.max_retries { + match provider + .chat_with_tools(messages, tools, current_model, temperature) + .await + { + Ok(resp) => { + if attempt > 0 || *current_model != model { + tracing::info!( + provider = provider_name, + model = *current_model, + attempt, + original_model = model, + "Provider recovered (failover/retry)" + ); + } + return Ok(resp); + } + Err(e) => { + let non_retryable = is_non_retryable(&e); + let rate_limited = is_rate_limited(&e); + + let failure_reason = if rate_limited { + "rate_limited" + } else if non_retryable { + "non_retryable" + } else { + "retryable" + }; + failures.push(format!( + "{provider_name}/{current_model} attempt {}/{}: {failure_reason}", + attempt + 1, + self.max_retries + 1 + )); + + if rate_limited { + if let Some(new_key) = self.rotate_key() { + tracing::info!( + provider = provider_name, + "Rate limited, rotated API key (key ending ...{})", + &new_key[new_key.len().saturating_sub(4)..] + ); + } + } + + if non_retryable { + tracing::warn!( + provider = provider_name, + model = *current_model, + "Non-retryable error, moving on" + ); + break; + } + + if attempt < self.max_retries { + let wait = self.compute_backoff(backoff_ms, &e); + tracing::warn!( + provider = provider_name, + model = *current_model, + attempt = attempt + 1, + backoff_ms = wait, + "Provider call failed, retrying" + ); + tokio::time::sleep(Duration::from_millis(wait)).await; + backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000); + } + } + } + } + + tracing::warn!( + provider = provider_name, + model = *current_model, + "Exhausted retries, trying next provider/model" + ); + } + } + + anyhow::bail!( + "All providers/models failed. Attempts:\n{}", + failures.join("\n") + ) + } + fn supports_streaming(&self) -> bool { self.providers.iter().any(|(_, p)| p.supports_streaming()) }