From 1cfc63831cf56a6dd92f5b05d17d5d34b0ea7cd4 Mon Sep 17 00:00:00 2001 From: Argenis Date: Sun, 15 Feb 2026 11:40:58 -0500 Subject: [PATCH] feat(providers): add multi-model router for task-based provider routing Co-Authored-By: Claude Opus 4.6 --- Cargo.lock | 68 ++++++++ Cargo.toml | 3 + src/agent/loop_.rs | 4 +- src/config/mod.rs | 4 +- src/config/schema.rs | 37 +++++ src/gateway/mod.rs | 12 +- src/onboard/wizard.rs | 2 + src/providers/mod.rs | 68 +++++++- src/providers/router.rs | 348 ++++++++++++++++++++++++++++++++++++++++ 9 files changed, 537 insertions(+), 9 deletions(-) create mode 100644 src/providers/router.rs diff --git a/Cargo.lock b/Cargo.lock index fdbe1e0..ced7e82 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -579,6 +579,12 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + [[package]] name = "foldhash" version = "0.1.5" @@ -1180,6 +1186,15 @@ version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + [[package]] name = "log" version = "0.4.29" @@ -1316,6 +1331,29 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + [[package]] name = "percent-encoding" version = "2.3.2" @@ -1388,6 +1426,20 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "prometheus" +version = "0.13.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d33c28a30771f7f96db69893f78b857f7450d7e0237e9c8fc6427a81bae7ed1" +dependencies = [ + "cfg-if", + "fnv", + "lazy_static", + "memchr", + "parking_lot", + "thiserror 1.0.69", +] + [[package]] name = "psm" version = "0.1.30" @@ -1533,6 +1585,15 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + [[package]] name = "redox_users" version = "0.4.6" @@ -1695,6 +1756,12 @@ version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + [[package]] name = "semver" version = "1.0.27" @@ -2955,6 +3022,7 @@ dependencies = [ "http-body-util", "lettre", "mail-parser", + "prometheus", "reqwest", "rusqlite", "rustls", diff --git a/Cargo.toml b/Cargo.toml index ff7c96d..a9a1924 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,9 @@ shellexpand = "3.1" tracing = { version = "0.1", default-features = false } tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "ansi"] } +# Observability - Prometheus metrics +prometheus = { version = "0.13", default-features = false } + # Error handling anyhow = "1.0" thiserror = "2.0" diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index 9ca3fd4..19ed860 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -73,10 +73,12 @@ pub async fn run( .or(config.default_model.as_deref()) .unwrap_or("anthropic/claude-sonnet-4-20250514"); - let provider: Box = providers::create_resilient_provider( + let provider: Box = providers::create_routed_provider( provider_name, config.api_key.as_deref(), &config.reliability, + &config.model_routes, + model_name, )?; observer.record_event(&ObserverEvent::AgentStart { diff --git a/src/config/mod.rs b/src/config/mod.rs index f5849c1..e5a6521 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -3,6 +3,6 @@ pub mod schema; pub use schema::{ AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, DiscordConfig, GatewayConfig, HeartbeatConfig, IMessageConfig, IdentityConfig, MatrixConfig, MemoryConfig, - ObservabilityConfig, ReliabilityConfig, RuntimeConfig, SecretsConfig, SlackConfig, - TelegramConfig, TunnelConfig, WebhookConfig, + ModelRouteConfig, ObservabilityConfig, ReliabilityConfig, RuntimeConfig, SecretsConfig, + SlackConfig, TelegramConfig, TunnelConfig, WebhookConfig, }; diff --git a/src/config/schema.rs b/src/config/schema.rs index e93eda4..764ba69 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -32,6 +32,10 @@ pub struct Config { #[serde(default)] pub reliability: ReliabilityConfig, + /// Model routing rules — route `hint:` to specific provider+model combos. + #[serde(default)] + pub model_routes: Vec, + #[serde(default)] pub heartbeat: HeartbeatConfig, @@ -446,6 +450,36 @@ impl Default for ReliabilityConfig { } } +// ── Model routing ──────────────────────────────────────────────── + +/// Route a task hint to a specific provider + model. +/// +/// ```toml +/// [[model_routes]] +/// hint = "reasoning" +/// provider = "openrouter" +/// model = "anthropic/claude-opus-4-20250514" +/// +/// [[model_routes]] +/// hint = "fast" +/// provider = "groq" +/// model = "llama-3.3-70b-versatile" +/// ``` +/// +/// Usage: pass `hint:reasoning` as the model parameter to route the request. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelRouteConfig { + /// Task hint name (e.g. "reasoning", "fast", "code", "summarize") + pub hint: String, + /// Provider to route to (must match a known provider name) + pub provider: String, + /// Model to use with that provider + pub model: String, + /// Optional API key override for this route's provider + #[serde(default)] + pub api_key: Option, +} + // ── Heartbeat ──────────────────────────────────────────────────── #[derive(Debug, Clone, Serialize, Deserialize)] @@ -670,6 +704,7 @@ impl Default for Config { autonomy: AutonomyConfig::default(), runtime: RuntimeConfig::default(), reliability: ReliabilityConfig::default(), + model_routes: Vec::new(), heartbeat: HeartbeatConfig::default(), channels_config: ChannelsConfig::default(), memory: MemoryConfig::default(), @@ -875,6 +910,7 @@ mod tests { kind: "docker".into(), }, reliability: ReliabilityConfig::default(), + model_routes: Vec::new(), heartbeat: HeartbeatConfig { enabled: true, interval_minutes: 15, @@ -962,6 +998,7 @@ default_temperature = 0.7 autonomy: AutonomyConfig::default(), runtime: RuntimeConfig::default(), reliability: ReliabilityConfig::default(), + model_routes: Vec::new(), heartbeat: HeartbeatConfig::default(), channels_config: ChannelsConfig::default(), memory: MemoryConfig::default(), diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 918dd43..bede685 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -66,15 +66,17 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { let actual_port = listener.local_addr()?.port(); let display_addr = format!("{host}:{actual_port}"); - let provider: Arc = Arc::from(providers::create_resilient_provider( - config.default_provider.as_deref().unwrap_or("openrouter"), - config.api_key.as_deref(), - &config.reliability, - )?); let model = config .default_model .clone() .unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into()); + let provider: Arc = Arc::from(providers::create_routed_provider( + config.default_provider.as_deref().unwrap_or("openrouter"), + config.api_key.as_deref(), + &config.reliability, + &config.model_routes, + &model, + )?); let temperature = config.default_temperature; let mem: Arc = Arc::from(memory::create_memory( &config.memory, diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index 41831c2..ec95aa3 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -96,6 +96,7 @@ pub fn run_wizard() -> Result { autonomy: AutonomyConfig::default(), runtime: RuntimeConfig::default(), reliability: crate::config::ReliabilityConfig::default(), + model_routes: Vec::new(), heartbeat: HeartbeatConfig::default(), channels_config, memory: memory_config, // User-selected memory backend @@ -286,6 +287,7 @@ pub fn run_quick_setup( autonomy: AutonomyConfig::default(), runtime: RuntimeConfig::default(), reliability: crate::config::ReliabilityConfig::default(), + model_routes: Vec::new(), heartbeat: HeartbeatConfig::default(), channels_config: ChannelsConfig::default(), memory: memory_config, diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 2cc8dc0..1ff85b7 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -5,6 +5,7 @@ pub mod ollama; pub mod openai; pub mod openrouter; pub mod reliable; +pub mod router; pub mod traits; pub use traits::Provider; @@ -153,7 +154,7 @@ fn resolve_api_key(name: &str, api_key: Option<&str>) -> Option { /// Factory: create the right provider from config #[allow(clippy::too_many_lines)] pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result> { - let resolved_key = resolve_api_key(name, api_key); + let _resolved_key = resolve_api_key(name, api_key); match name { // ── Primary providers (custom implementations) ─────── "openrouter" => Ok(Box::new(openrouter::OpenRouterProvider::new(api_key))), @@ -316,6 +317,71 @@ pub fn create_resilient_provider( ))) } +/// Create a RouterProvider if model routes are configured, otherwise return a +/// standard resilient provider. The router wraps individual providers per route, +/// each with its own retry/fallback chain. +pub fn create_routed_provider( + primary_name: &str, + api_key: Option<&str>, + reliability: &crate::config::ReliabilityConfig, + model_routes: &[crate::config::ModelRouteConfig], + default_model: &str, +) -> anyhow::Result> { + if model_routes.is_empty() { + return create_resilient_provider(primary_name, api_key, reliability); + } + + // Collect unique provider names needed + let mut needed: Vec = vec![primary_name.to_string()]; + for route in model_routes { + if !needed.iter().any(|n| n == &route.provider) { + needed.push(route.provider.clone()); + } + } + + // Create each provider (with its own resilience wrapper) + let mut providers: Vec<(String, Box)> = Vec::new(); + for name in &needed { + let key = model_routes + .iter() + .find(|r| &r.provider == name) + .and_then(|r| r.api_key.as_deref()) + .or(api_key); + match create_resilient_provider(name, key, reliability) { + Ok(provider) => providers.push((name.clone(), provider)), + Err(e) => { + if name == primary_name { + return Err(e); + } + tracing::warn!( + provider = name.as_str(), + "Ignoring routed provider that failed to create: {e}" + ); + } + } + } + + // Build route table + let routes: Vec<(String, router::Route)> = model_routes + .iter() + .map(|r| { + ( + r.hint.clone(), + router::Route { + provider_name: r.provider.clone(), + model: r.model.clone(), + }, + ) + }) + .collect(); + + Ok(Box::new(router::RouterProvider::new( + providers, + routes, + default_model.to_string(), + ))) +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/providers/router.rs b/src/providers/router.rs new file mode 100644 index 0000000..52dab47 --- /dev/null +++ b/src/providers/router.rs @@ -0,0 +1,348 @@ +use super::Provider; +use async_trait::async_trait; +use std::collections::HashMap; + +/// A single route: maps a task hint to a provider + model combo. +#[derive(Debug, Clone)] +pub struct Route { + pub provider_name: String, + pub model: String, +} + +/// Multi-model router — routes requests to different provider+model combos +/// based on a task hint encoded in the model parameter. +/// +/// The model parameter can be: +/// - A regular model name (e.g. "anthropic/claude-sonnet-4-20250514") → uses default provider +/// - A hint-prefixed string (e.g. "hint:reasoning") → resolves via route table +/// +/// This wraps multiple pre-created providers and selects the right one per request. +pub struct RouterProvider { + routes: HashMap, // hint → (provider_index, model) + providers: Vec<(String, Box)>, + default_index: usize, + default_model: String, +} + +impl RouterProvider { + /// Create a new router with a default provider and optional routes. + /// + /// `providers` is a list of (name, provider) pairs. The first one is the default. + /// `routes` maps hint names to Route structs containing provider_name and model. + pub fn new( + providers: Vec<(String, Box)>, + routes: Vec<(String, Route)>, + default_model: String, + ) -> Self { + // Build provider name → index lookup + let name_to_index: HashMap<&str, usize> = providers + .iter() + .enumerate() + .map(|(i, (name, _))| (name.as_str(), i)) + .collect(); + + // Resolve routes to provider indices + let resolved_routes: HashMap = routes + .into_iter() + .filter_map(|(hint, route)| { + let index = name_to_index.get(route.provider_name.as_str()).copied(); + match index { + Some(i) => Some((hint, (i, route.model))), + None => { + tracing::warn!( + hint = hint, + provider = route.provider_name, + "Route references unknown provider, skipping" + ); + None + } + } + }) + .collect(); + + Self { + routes: resolved_routes, + providers, + default_index: 0, + default_model, + } + } + + /// Resolve a model parameter to a (provider, actual_model) pair. + /// + /// If the model starts with "hint:", look up the hint in the route table. + /// Otherwise, use the default provider with the given model name. + /// Resolve a model parameter to a (provider_index, actual_model) pair. + fn resolve(&self, model: &str) -> (usize, String) { + if let Some(hint) = model.strip_prefix("hint:") { + if let Some((idx, resolved_model)) = self.routes.get(hint) { + return (*idx, resolved_model.clone()); + } + tracing::warn!( + hint = hint, + "Unknown route hint, falling back to default provider" + ); + } + + // Not a hint or hint not found — use default provider with the model as-is + (self.default_index, model.to_string()) + } +} + +#[async_trait] +impl Provider for RouterProvider { + async fn chat_with_system( + &self, + system_prompt: Option<&str>, + message: &str, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let (provider_idx, resolved_model) = self.resolve(model); + + let (provider_name, provider) = &self.providers[provider_idx]; + tracing::info!( + provider = provider_name.as_str(), + model = resolved_model.as_str(), + "Router dispatching request" + ); + + provider + .chat_with_system(system_prompt, message, &resolved_model, temperature) + .await + } + + async fn warmup(&self) -> anyhow::Result<()> { + for (name, provider) in &self.providers { + tracing::info!(provider = name, "Warming up routed provider"); + if let Err(e) = provider.warmup().await { + tracing::warn!(provider = name, "Warmup failed (non-fatal): {e}"); + } + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + + struct MockProvider { + calls: Arc, + response: &'static str, + last_model: std::sync::Mutex, + } + + impl MockProvider { + fn new(response: &'static str) -> Self { + Self { + calls: Arc::new(AtomicUsize::new(0)), + response, + last_model: std::sync::Mutex::new(String::new()), + } + } + + fn call_count(&self) -> usize { + self.calls.load(Ordering::SeqCst) + } + + fn last_model(&self) -> String { + self.last_model.lock().unwrap().clone() + } + } + + #[async_trait] + impl Provider for MockProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + model: &str, + _temperature: f64, + ) -> anyhow::Result { + self.calls.fetch_add(1, Ordering::SeqCst); + *self.last_model.lock().unwrap() = model.to_string(); + Ok(self.response.to_string()) + } + } + + fn make_router( + providers: Vec<(&'static str, &'static str)>, + routes: Vec<(&str, &str, &str)>, + ) -> (RouterProvider, Vec>) { + let mocks: Vec> = providers + .iter() + .map(|(_, response)| Arc::new(MockProvider::new(response))) + .collect(); + + let provider_list: Vec<(String, Box)> = providers + .iter() + .zip(mocks.iter()) + .map(|((name, _), mock)| { + (name.to_string(), Box::new(Arc::clone(mock)) as Box) + }) + .collect(); + + let route_list: Vec<(String, Route)> = routes + .iter() + .map(|(hint, provider_name, model)| { + ( + hint.to_string(), + Route { + provider_name: provider_name.to_string(), + model: model.to_string(), + }, + ) + }) + .collect(); + + let router = RouterProvider::new( + provider_list, + route_list, + "default-model".to_string(), + ); + + (router, mocks) + } + + // Arc should also be a Provider + #[async_trait] + impl Provider for Arc { + async fn chat_with_system( + &self, + system_prompt: Option<&str>, + message: &str, + model: &str, + temperature: f64, + ) -> anyhow::Result { + self.as_ref() + .chat_with_system(system_prompt, message, model, temperature) + .await + } + } + + #[tokio::test] + async fn routes_hint_to_correct_provider() { + let (router, mocks) = make_router( + vec![("fast", "fast-response"), ("smart", "smart-response")], + vec![ + ("fast", "fast", "llama-3-70b"), + ("reasoning", "smart", "claude-opus"), + ], + ); + + let result = router.chat("hello", "hint:reasoning", 0.5).await.unwrap(); + assert_eq!(result, "smart-response"); + assert_eq!(mocks[1].call_count(), 1); + assert_eq!(mocks[1].last_model(), "claude-opus"); + assert_eq!(mocks[0].call_count(), 0); + } + + #[tokio::test] + async fn routes_fast_hint() { + let (router, mocks) = make_router( + vec![("fast", "fast-response"), ("smart", "smart-response")], + vec![("fast", "fast", "llama-3-70b")], + ); + + let result = router.chat("hello", "hint:fast", 0.5).await.unwrap(); + assert_eq!(result, "fast-response"); + assert_eq!(mocks[0].call_count(), 1); + assert_eq!(mocks[0].last_model(), "llama-3-70b"); + } + + #[tokio::test] + async fn unknown_hint_falls_back_to_default() { + let (router, mocks) = make_router( + vec![("default", "default-response"), ("other", "other-response")], + vec![], + ); + + let result = router.chat("hello", "hint:nonexistent", 0.5).await.unwrap(); + assert_eq!(result, "default-response"); + assert_eq!(mocks[0].call_count(), 1); + // Falls back to default with the hint as model name + assert_eq!(mocks[0].last_model(), "hint:nonexistent"); + } + + #[tokio::test] + async fn non_hint_model_uses_default_provider() { + let (router, mocks) = make_router( + vec![("primary", "primary-response"), ("secondary", "secondary-response")], + vec![("code", "secondary", "codellama")], + ); + + let result = router + .chat("hello", "anthropic/claude-sonnet-4-20250514", 0.5) + .await + .unwrap(); + assert_eq!(result, "primary-response"); + assert_eq!(mocks[0].call_count(), 1); + assert_eq!(mocks[0].last_model(), "anthropic/claude-sonnet-4-20250514"); + } + + #[test] + fn resolve_preserves_model_for_non_hints() { + let (router, _) = make_router( + vec![("default", "ok")], + vec![], + ); + + let (idx, model) = router.resolve("gpt-4o"); + assert_eq!(idx, 0); + assert_eq!(model, "gpt-4o"); + } + + #[test] + fn resolve_strips_hint_prefix() { + let (router, _) = make_router( + vec![("fast", "ok"), ("smart", "ok")], + vec![("reasoning", "smart", "claude-opus")], + ); + + let (idx, model) = router.resolve("hint:reasoning"); + assert_eq!(idx, 1); + assert_eq!(model, "claude-opus"); + } + + #[test] + fn skips_routes_with_unknown_provider() { + let (router, _) = make_router( + vec![("default", "ok")], + vec![("broken", "nonexistent", "model")], + ); + + // Route should not exist + assert!(!router.routes.contains_key("broken")); + } + + #[tokio::test] + async fn warmup_calls_all_providers() { + let (router, _) = make_router( + vec![("a", "ok"), ("b", "ok")], + vec![], + ); + + // Warmup should not error + assert!(router.warmup().await.is_ok()); + } + + #[tokio::test] + async fn chat_with_system_passes_system_prompt() { + let mock = Arc::new(MockProvider::new("response")); + let router = RouterProvider::new( + vec![("default".into(), Box::new(Arc::clone(&mock)) as Box)], + vec![], + "model".into(), + ); + + let result = router + .chat_with_system(Some("system"), "hello", "model", 0.5) + .await + .unwrap(); + assert_eq!(result, "response"); + assert_eq!(mock.call_count(), 1); + } +}