feat(providers): add multi-model router for task-based provider routing
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
eadeffef26
commit
1cfc63831c
9 changed files with 537 additions and 9 deletions
68
Cargo.lock
generated
68
Cargo.lock
generated
|
|
@ -579,6 +579,12 @@ version = "0.1.9"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582"
|
checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fnv"
|
||||||
|
version = "1.0.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "foldhash"
|
name = "foldhash"
|
||||||
version = "0.1.5"
|
version = "0.1.5"
|
||||||
|
|
@ -1180,6 +1186,15 @@ version = "0.8.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77"
|
checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "lock_api"
|
||||||
|
version = "0.4.14"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965"
|
||||||
|
dependencies = [
|
||||||
|
"scopeguard",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "log"
|
name = "log"
|
||||||
version = "0.4.29"
|
version = "0.4.29"
|
||||||
|
|
@ -1316,6 +1331,29 @@ version = "0.2.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d"
|
checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "parking_lot"
|
||||||
|
version = "0.12.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a"
|
||||||
|
dependencies = [
|
||||||
|
"lock_api",
|
||||||
|
"parking_lot_core",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "parking_lot_core"
|
||||||
|
version = "0.9.12"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"libc",
|
||||||
|
"redox_syscall",
|
||||||
|
"smallvec",
|
||||||
|
"windows-link",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "percent-encoding"
|
name = "percent-encoding"
|
||||||
version = "2.3.2"
|
version = "2.3.2"
|
||||||
|
|
@ -1388,6 +1426,20 @@ dependencies = [
|
||||||
"unicode-ident",
|
"unicode-ident",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "prometheus"
|
||||||
|
version = "0.13.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3d33c28a30771f7f96db69893f78b857f7450d7e0237e9c8fc6427a81bae7ed1"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"fnv",
|
||||||
|
"lazy_static",
|
||||||
|
"memchr",
|
||||||
|
"parking_lot",
|
||||||
|
"thiserror 1.0.69",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "psm"
|
name = "psm"
|
||||||
version = "0.1.30"
|
version = "0.1.30"
|
||||||
|
|
@ -1533,6 +1585,15 @@ dependencies = [
|
||||||
"getrandom 0.3.4",
|
"getrandom 0.3.4",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "redox_syscall"
|
||||||
|
version = "0.5.18"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "redox_users"
|
name = "redox_users"
|
||||||
version = "0.4.6"
|
version = "0.4.6"
|
||||||
|
|
@ -1695,6 +1756,12 @@ version = "1.0.23"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f"
|
checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "scopeguard"
|
||||||
|
version = "1.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "semver"
|
name = "semver"
|
||||||
version = "1.0.27"
|
version = "1.0.27"
|
||||||
|
|
@ -2955,6 +3022,7 @@ dependencies = [
|
||||||
"http-body-util",
|
"http-body-util",
|
||||||
"lettre",
|
"lettre",
|
||||||
"mail-parser",
|
"mail-parser",
|
||||||
|
"prometheus",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"rusqlite",
|
"rusqlite",
|
||||||
"rustls",
|
"rustls",
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,9 @@ shellexpand = "3.1"
|
||||||
tracing = { version = "0.1", default-features = false }
|
tracing = { version = "0.1", default-features = false }
|
||||||
tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "ansi"] }
|
tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "ansi"] }
|
||||||
|
|
||||||
|
# Observability - Prometheus metrics
|
||||||
|
prometheus = { version = "0.13", default-features = false }
|
||||||
|
|
||||||
# Error handling
|
# Error handling
|
||||||
anyhow = "1.0"
|
anyhow = "1.0"
|
||||||
thiserror = "2.0"
|
thiserror = "2.0"
|
||||||
|
|
|
||||||
|
|
@ -73,10 +73,12 @@ pub async fn run(
|
||||||
.or(config.default_model.as_deref())
|
.or(config.default_model.as_deref())
|
||||||
.unwrap_or("anthropic/claude-sonnet-4-20250514");
|
.unwrap_or("anthropic/claude-sonnet-4-20250514");
|
||||||
|
|
||||||
let provider: Box<dyn Provider> = providers::create_resilient_provider(
|
let provider: Box<dyn Provider> = providers::create_routed_provider(
|
||||||
provider_name,
|
provider_name,
|
||||||
config.api_key.as_deref(),
|
config.api_key.as_deref(),
|
||||||
&config.reliability,
|
&config.reliability,
|
||||||
|
&config.model_routes,
|
||||||
|
model_name,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
observer.record_event(&ObserverEvent::AgentStart {
|
observer.record_event(&ObserverEvent::AgentStart {
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,6 @@ pub mod schema;
|
||||||
pub use schema::{
|
pub use schema::{
|
||||||
AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, DiscordConfig,
|
AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, DiscordConfig,
|
||||||
GatewayConfig, HeartbeatConfig, IMessageConfig, IdentityConfig, MatrixConfig, MemoryConfig,
|
GatewayConfig, HeartbeatConfig, IMessageConfig, IdentityConfig, MatrixConfig, MemoryConfig,
|
||||||
ObservabilityConfig, ReliabilityConfig, RuntimeConfig, SecretsConfig, SlackConfig,
|
ModelRouteConfig, ObservabilityConfig, ReliabilityConfig, RuntimeConfig, SecretsConfig,
|
||||||
TelegramConfig, TunnelConfig, WebhookConfig,
|
SlackConfig, TelegramConfig, TunnelConfig, WebhookConfig,
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,10 @@ pub struct Config {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub reliability: ReliabilityConfig,
|
pub reliability: ReliabilityConfig,
|
||||||
|
|
||||||
|
/// Model routing rules — route `hint:<name>` to specific provider+model combos.
|
||||||
|
#[serde(default)]
|
||||||
|
pub model_routes: Vec<ModelRouteConfig>,
|
||||||
|
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub heartbeat: HeartbeatConfig,
|
pub heartbeat: HeartbeatConfig,
|
||||||
|
|
||||||
|
|
@ -446,6 +450,36 @@ impl Default for ReliabilityConfig {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Model routing ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Route a task hint to a specific provider + model.
|
||||||
|
///
|
||||||
|
/// ```toml
|
||||||
|
/// [[model_routes]]
|
||||||
|
/// hint = "reasoning"
|
||||||
|
/// provider = "openrouter"
|
||||||
|
/// model = "anthropic/claude-opus-4-20250514"
|
||||||
|
///
|
||||||
|
/// [[model_routes]]
|
||||||
|
/// hint = "fast"
|
||||||
|
/// provider = "groq"
|
||||||
|
/// model = "llama-3.3-70b-versatile"
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// Usage: pass `hint:reasoning` as the model parameter to route the request.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ModelRouteConfig {
|
||||||
|
/// Task hint name (e.g. "reasoning", "fast", "code", "summarize")
|
||||||
|
pub hint: String,
|
||||||
|
/// Provider to route to (must match a known provider name)
|
||||||
|
pub provider: String,
|
||||||
|
/// Model to use with that provider
|
||||||
|
pub model: String,
|
||||||
|
/// Optional API key override for this route's provider
|
||||||
|
#[serde(default)]
|
||||||
|
pub api_key: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
// ── Heartbeat ────────────────────────────────────────────────────
|
// ── Heartbeat ────────────────────────────────────────────────────
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
|
@ -670,6 +704,7 @@ impl Default for Config {
|
||||||
autonomy: AutonomyConfig::default(),
|
autonomy: AutonomyConfig::default(),
|
||||||
runtime: RuntimeConfig::default(),
|
runtime: RuntimeConfig::default(),
|
||||||
reliability: ReliabilityConfig::default(),
|
reliability: ReliabilityConfig::default(),
|
||||||
|
model_routes: Vec::new(),
|
||||||
heartbeat: HeartbeatConfig::default(),
|
heartbeat: HeartbeatConfig::default(),
|
||||||
channels_config: ChannelsConfig::default(),
|
channels_config: ChannelsConfig::default(),
|
||||||
memory: MemoryConfig::default(),
|
memory: MemoryConfig::default(),
|
||||||
|
|
@ -875,6 +910,7 @@ mod tests {
|
||||||
kind: "docker".into(),
|
kind: "docker".into(),
|
||||||
},
|
},
|
||||||
reliability: ReliabilityConfig::default(),
|
reliability: ReliabilityConfig::default(),
|
||||||
|
model_routes: Vec::new(),
|
||||||
heartbeat: HeartbeatConfig {
|
heartbeat: HeartbeatConfig {
|
||||||
enabled: true,
|
enabled: true,
|
||||||
interval_minutes: 15,
|
interval_minutes: 15,
|
||||||
|
|
@ -962,6 +998,7 @@ default_temperature = 0.7
|
||||||
autonomy: AutonomyConfig::default(),
|
autonomy: AutonomyConfig::default(),
|
||||||
runtime: RuntimeConfig::default(),
|
runtime: RuntimeConfig::default(),
|
||||||
reliability: ReliabilityConfig::default(),
|
reliability: ReliabilityConfig::default(),
|
||||||
|
model_routes: Vec::new(),
|
||||||
heartbeat: HeartbeatConfig::default(),
|
heartbeat: HeartbeatConfig::default(),
|
||||||
channels_config: ChannelsConfig::default(),
|
channels_config: ChannelsConfig::default(),
|
||||||
memory: MemoryConfig::default(),
|
memory: MemoryConfig::default(),
|
||||||
|
|
|
||||||
|
|
@ -66,15 +66,17 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||||
let actual_port = listener.local_addr()?.port();
|
let actual_port = listener.local_addr()?.port();
|
||||||
let display_addr = format!("{host}:{actual_port}");
|
let display_addr = format!("{host}:{actual_port}");
|
||||||
|
|
||||||
let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider(
|
|
||||||
config.default_provider.as_deref().unwrap_or("openrouter"),
|
|
||||||
config.api_key.as_deref(),
|
|
||||||
&config.reliability,
|
|
||||||
)?);
|
|
||||||
let model = config
|
let model = config
|
||||||
.default_model
|
.default_model
|
||||||
.clone()
|
.clone()
|
||||||
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into());
|
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into());
|
||||||
|
let provider: Arc<dyn Provider> = Arc::from(providers::create_routed_provider(
|
||||||
|
config.default_provider.as_deref().unwrap_or("openrouter"),
|
||||||
|
config.api_key.as_deref(),
|
||||||
|
&config.reliability,
|
||||||
|
&config.model_routes,
|
||||||
|
&model,
|
||||||
|
)?);
|
||||||
let temperature = config.default_temperature;
|
let temperature = config.default_temperature;
|
||||||
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory(
|
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory(
|
||||||
&config.memory,
|
&config.memory,
|
||||||
|
|
|
||||||
|
|
@ -96,6 +96,7 @@ pub fn run_wizard() -> Result<Config> {
|
||||||
autonomy: AutonomyConfig::default(),
|
autonomy: AutonomyConfig::default(),
|
||||||
runtime: RuntimeConfig::default(),
|
runtime: RuntimeConfig::default(),
|
||||||
reliability: crate::config::ReliabilityConfig::default(),
|
reliability: crate::config::ReliabilityConfig::default(),
|
||||||
|
model_routes: Vec::new(),
|
||||||
heartbeat: HeartbeatConfig::default(),
|
heartbeat: HeartbeatConfig::default(),
|
||||||
channels_config,
|
channels_config,
|
||||||
memory: memory_config, // User-selected memory backend
|
memory: memory_config, // User-selected memory backend
|
||||||
|
|
@ -286,6 +287,7 @@ pub fn run_quick_setup(
|
||||||
autonomy: AutonomyConfig::default(),
|
autonomy: AutonomyConfig::default(),
|
||||||
runtime: RuntimeConfig::default(),
|
runtime: RuntimeConfig::default(),
|
||||||
reliability: crate::config::ReliabilityConfig::default(),
|
reliability: crate::config::ReliabilityConfig::default(),
|
||||||
|
model_routes: Vec::new(),
|
||||||
heartbeat: HeartbeatConfig::default(),
|
heartbeat: HeartbeatConfig::default(),
|
||||||
channels_config: ChannelsConfig::default(),
|
channels_config: ChannelsConfig::default(),
|
||||||
memory: memory_config,
|
memory: memory_config,
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ pub mod ollama;
|
||||||
pub mod openai;
|
pub mod openai;
|
||||||
pub mod openrouter;
|
pub mod openrouter;
|
||||||
pub mod reliable;
|
pub mod reliable;
|
||||||
|
pub mod router;
|
||||||
pub mod traits;
|
pub mod traits;
|
||||||
|
|
||||||
pub use traits::Provider;
|
pub use traits::Provider;
|
||||||
|
|
@ -153,7 +154,7 @@ fn resolve_api_key(name: &str, api_key: Option<&str>) -> Option<String> {
|
||||||
/// Factory: create the right provider from config
|
/// Factory: create the right provider from config
|
||||||
#[allow(clippy::too_many_lines)]
|
#[allow(clippy::too_many_lines)]
|
||||||
pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<dyn Provider>> {
|
pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<dyn Provider>> {
|
||||||
let resolved_key = resolve_api_key(name, api_key);
|
let _resolved_key = resolve_api_key(name, api_key);
|
||||||
match name {
|
match name {
|
||||||
// ── Primary providers (custom implementations) ───────
|
// ── Primary providers (custom implementations) ───────
|
||||||
"openrouter" => Ok(Box::new(openrouter::OpenRouterProvider::new(api_key))),
|
"openrouter" => Ok(Box::new(openrouter::OpenRouterProvider::new(api_key))),
|
||||||
|
|
@ -316,6 +317,71 @@ pub fn create_resilient_provider(
|
||||||
)))
|
)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a RouterProvider if model routes are configured, otherwise return a
|
||||||
|
/// standard resilient provider. The router wraps individual providers per route,
|
||||||
|
/// each with its own retry/fallback chain.
|
||||||
|
pub fn create_routed_provider(
|
||||||
|
primary_name: &str,
|
||||||
|
api_key: Option<&str>,
|
||||||
|
reliability: &crate::config::ReliabilityConfig,
|
||||||
|
model_routes: &[crate::config::ModelRouteConfig],
|
||||||
|
default_model: &str,
|
||||||
|
) -> anyhow::Result<Box<dyn Provider>> {
|
||||||
|
if model_routes.is_empty() {
|
||||||
|
return create_resilient_provider(primary_name, api_key, reliability);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect unique provider names needed
|
||||||
|
let mut needed: Vec<String> = vec![primary_name.to_string()];
|
||||||
|
for route in model_routes {
|
||||||
|
if !needed.iter().any(|n| n == &route.provider) {
|
||||||
|
needed.push(route.provider.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create each provider (with its own resilience wrapper)
|
||||||
|
let mut providers: Vec<(String, Box<dyn Provider>)> = Vec::new();
|
||||||
|
for name in &needed {
|
||||||
|
let key = model_routes
|
||||||
|
.iter()
|
||||||
|
.find(|r| &r.provider == name)
|
||||||
|
.and_then(|r| r.api_key.as_deref())
|
||||||
|
.or(api_key);
|
||||||
|
match create_resilient_provider(name, key, reliability) {
|
||||||
|
Ok(provider) => providers.push((name.clone(), provider)),
|
||||||
|
Err(e) => {
|
||||||
|
if name == primary_name {
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
tracing::warn!(
|
||||||
|
provider = name.as_str(),
|
||||||
|
"Ignoring routed provider that failed to create: {e}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build route table
|
||||||
|
let routes: Vec<(String, router::Route)> = model_routes
|
||||||
|
.iter()
|
||||||
|
.map(|r| {
|
||||||
|
(
|
||||||
|
r.hint.clone(),
|
||||||
|
router::Route {
|
||||||
|
provider_name: r.provider.clone(),
|
||||||
|
model: r.model.clone(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(Box::new(router::RouterProvider::new(
|
||||||
|
providers,
|
||||||
|
routes,
|
||||||
|
default_model.to_string(),
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
|
||||||
348
src/providers/router.rs
Normal file
348
src/providers/router.rs
Normal file
|
|
@ -0,0 +1,348 @@
|
||||||
|
use super::Provider;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
/// A single route: maps a task hint to a provider + model combo.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Route {
|
||||||
|
pub provider_name: String,
|
||||||
|
pub model: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Multi-model router — routes requests to different provider+model combos
|
||||||
|
/// based on a task hint encoded in the model parameter.
|
||||||
|
///
|
||||||
|
/// The model parameter can be:
|
||||||
|
/// - A regular model name (e.g. "anthropic/claude-sonnet-4-20250514") → uses default provider
|
||||||
|
/// - A hint-prefixed string (e.g. "hint:reasoning") → resolves via route table
|
||||||
|
///
|
||||||
|
/// This wraps multiple pre-created providers and selects the right one per request.
|
||||||
|
pub struct RouterProvider {
|
||||||
|
routes: HashMap<String, (usize, String)>, // hint → (provider_index, model)
|
||||||
|
providers: Vec<(String, Box<dyn Provider>)>,
|
||||||
|
default_index: usize,
|
||||||
|
default_model: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RouterProvider {
|
||||||
|
/// Create a new router with a default provider and optional routes.
|
||||||
|
///
|
||||||
|
/// `providers` is a list of (name, provider) pairs. The first one is the default.
|
||||||
|
/// `routes` maps hint names to Route structs containing provider_name and model.
|
||||||
|
pub fn new(
|
||||||
|
providers: Vec<(String, Box<dyn Provider>)>,
|
||||||
|
routes: Vec<(String, Route)>,
|
||||||
|
default_model: String,
|
||||||
|
) -> Self {
|
||||||
|
// Build provider name → index lookup
|
||||||
|
let name_to_index: HashMap<&str, usize> = providers
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, (name, _))| (name.as_str(), i))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Resolve routes to provider indices
|
||||||
|
let resolved_routes: HashMap<String, (usize, String)> = routes
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|(hint, route)| {
|
||||||
|
let index = name_to_index.get(route.provider_name.as_str()).copied();
|
||||||
|
match index {
|
||||||
|
Some(i) => Some((hint, (i, route.model))),
|
||||||
|
None => {
|
||||||
|
tracing::warn!(
|
||||||
|
hint = hint,
|
||||||
|
provider = route.provider_name,
|
||||||
|
"Route references unknown provider, skipping"
|
||||||
|
);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
routes: resolved_routes,
|
||||||
|
providers,
|
||||||
|
default_index: 0,
|
||||||
|
default_model,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resolve a model parameter to a (provider, actual_model) pair.
|
||||||
|
///
|
||||||
|
/// If the model starts with "hint:", look up the hint in the route table.
|
||||||
|
/// Otherwise, use the default provider with the given model name.
|
||||||
|
/// Resolve a model parameter to a (provider_index, actual_model) pair.
|
||||||
|
fn resolve(&self, model: &str) -> (usize, String) {
|
||||||
|
if let Some(hint) = model.strip_prefix("hint:") {
|
||||||
|
if let Some((idx, resolved_model)) = self.routes.get(hint) {
|
||||||
|
return (*idx, resolved_model.clone());
|
||||||
|
}
|
||||||
|
tracing::warn!(
|
||||||
|
hint = hint,
|
||||||
|
"Unknown route hint, falling back to default provider"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not a hint or hint not found — use default provider with the model as-is
|
||||||
|
(self.default_index, model.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Provider for RouterProvider {
|
||||||
|
async fn chat_with_system(
|
||||||
|
&self,
|
||||||
|
system_prompt: Option<&str>,
|
||||||
|
message: &str,
|
||||||
|
model: &str,
|
||||||
|
temperature: f64,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
let (provider_idx, resolved_model) = self.resolve(model);
|
||||||
|
|
||||||
|
let (provider_name, provider) = &self.providers[provider_idx];
|
||||||
|
tracing::info!(
|
||||||
|
provider = provider_name.as_str(),
|
||||||
|
model = resolved_model.as_str(),
|
||||||
|
"Router dispatching request"
|
||||||
|
);
|
||||||
|
|
||||||
|
provider
|
||||||
|
.chat_with_system(system_prompt, message, &resolved_model, temperature)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn warmup(&self) -> anyhow::Result<()> {
|
||||||
|
for (name, provider) in &self.providers {
|
||||||
|
tracing::info!(provider = name, "Warming up routed provider");
|
||||||
|
if let Err(e) = provider.warmup().await {
|
||||||
|
tracing::warn!(provider = name, "Warmup failed (non-fatal): {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
struct MockProvider {
|
||||||
|
calls: Arc<AtomicUsize>,
|
||||||
|
response: &'static str,
|
||||||
|
last_model: std::sync::Mutex<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MockProvider {
|
||||||
|
fn new(response: &'static str) -> Self {
|
||||||
|
Self {
|
||||||
|
calls: Arc::new(AtomicUsize::new(0)),
|
||||||
|
response,
|
||||||
|
last_model: std::sync::Mutex::new(String::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call_count(&self) -> usize {
|
||||||
|
self.calls.load(Ordering::SeqCst)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn last_model(&self) -> String {
|
||||||
|
self.last_model.lock().unwrap().clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Provider for MockProvider {
|
||||||
|
async fn chat_with_system(
|
||||||
|
&self,
|
||||||
|
_system_prompt: Option<&str>,
|
||||||
|
_message: &str,
|
||||||
|
model: &str,
|
||||||
|
_temperature: f64,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
self.calls.fetch_add(1, Ordering::SeqCst);
|
||||||
|
*self.last_model.lock().unwrap() = model.to_string();
|
||||||
|
Ok(self.response.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_router(
|
||||||
|
providers: Vec<(&'static str, &'static str)>,
|
||||||
|
routes: Vec<(&str, &str, &str)>,
|
||||||
|
) -> (RouterProvider, Vec<Arc<MockProvider>>) {
|
||||||
|
let mocks: Vec<Arc<MockProvider>> = providers
|
||||||
|
.iter()
|
||||||
|
.map(|(_, response)| Arc::new(MockProvider::new(response)))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let provider_list: Vec<(String, Box<dyn Provider>)> = providers
|
||||||
|
.iter()
|
||||||
|
.zip(mocks.iter())
|
||||||
|
.map(|((name, _), mock)| {
|
||||||
|
(name.to_string(), Box::new(Arc::clone(mock)) as Box<dyn Provider>)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let route_list: Vec<(String, Route)> = routes
|
||||||
|
.iter()
|
||||||
|
.map(|(hint, provider_name, model)| {
|
||||||
|
(
|
||||||
|
hint.to_string(),
|
||||||
|
Route {
|
||||||
|
provider_name: provider_name.to_string(),
|
||||||
|
model: model.to_string(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let router = RouterProvider::new(
|
||||||
|
provider_list,
|
||||||
|
route_list,
|
||||||
|
"default-model".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
(router, mocks)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Arc<MockProvider> should also be a Provider
|
||||||
|
#[async_trait]
|
||||||
|
impl Provider for Arc<MockProvider> {
|
||||||
|
async fn chat_with_system(
|
||||||
|
&self,
|
||||||
|
system_prompt: Option<&str>,
|
||||||
|
message: &str,
|
||||||
|
model: &str,
|
||||||
|
temperature: f64,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
self.as_ref()
|
||||||
|
.chat_with_system(system_prompt, message, model, temperature)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn routes_hint_to_correct_provider() {
|
||||||
|
let (router, mocks) = make_router(
|
||||||
|
vec![("fast", "fast-response"), ("smart", "smart-response")],
|
||||||
|
vec![
|
||||||
|
("fast", "fast", "llama-3-70b"),
|
||||||
|
("reasoning", "smart", "claude-opus"),
|
||||||
|
],
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = router.chat("hello", "hint:reasoning", 0.5).await.unwrap();
|
||||||
|
assert_eq!(result, "smart-response");
|
||||||
|
assert_eq!(mocks[1].call_count(), 1);
|
||||||
|
assert_eq!(mocks[1].last_model(), "claude-opus");
|
||||||
|
assert_eq!(mocks[0].call_count(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn routes_fast_hint() {
|
||||||
|
let (router, mocks) = make_router(
|
||||||
|
vec![("fast", "fast-response"), ("smart", "smart-response")],
|
||||||
|
vec![("fast", "fast", "llama-3-70b")],
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = router.chat("hello", "hint:fast", 0.5).await.unwrap();
|
||||||
|
assert_eq!(result, "fast-response");
|
||||||
|
assert_eq!(mocks[0].call_count(), 1);
|
||||||
|
assert_eq!(mocks[0].last_model(), "llama-3-70b");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn unknown_hint_falls_back_to_default() {
|
||||||
|
let (router, mocks) = make_router(
|
||||||
|
vec![("default", "default-response"), ("other", "other-response")],
|
||||||
|
vec![],
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = router.chat("hello", "hint:nonexistent", 0.5).await.unwrap();
|
||||||
|
assert_eq!(result, "default-response");
|
||||||
|
assert_eq!(mocks[0].call_count(), 1);
|
||||||
|
// Falls back to default with the hint as model name
|
||||||
|
assert_eq!(mocks[0].last_model(), "hint:nonexistent");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn non_hint_model_uses_default_provider() {
|
||||||
|
let (router, mocks) = make_router(
|
||||||
|
vec![("primary", "primary-response"), ("secondary", "secondary-response")],
|
||||||
|
vec![("code", "secondary", "codellama")],
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = router
|
||||||
|
.chat("hello", "anthropic/claude-sonnet-4-20250514", 0.5)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(result, "primary-response");
|
||||||
|
assert_eq!(mocks[0].call_count(), 1);
|
||||||
|
assert_eq!(mocks[0].last_model(), "anthropic/claude-sonnet-4-20250514");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resolve_preserves_model_for_non_hints() {
|
||||||
|
let (router, _) = make_router(
|
||||||
|
vec![("default", "ok")],
|
||||||
|
vec![],
|
||||||
|
);
|
||||||
|
|
||||||
|
let (idx, model) = router.resolve("gpt-4o");
|
||||||
|
assert_eq!(idx, 0);
|
||||||
|
assert_eq!(model, "gpt-4o");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resolve_strips_hint_prefix() {
|
||||||
|
let (router, _) = make_router(
|
||||||
|
vec![("fast", "ok"), ("smart", "ok")],
|
||||||
|
vec![("reasoning", "smart", "claude-opus")],
|
||||||
|
);
|
||||||
|
|
||||||
|
let (idx, model) = router.resolve("hint:reasoning");
|
||||||
|
assert_eq!(idx, 1);
|
||||||
|
assert_eq!(model, "claude-opus");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn skips_routes_with_unknown_provider() {
|
||||||
|
let (router, _) = make_router(
|
||||||
|
vec![("default", "ok")],
|
||||||
|
vec![("broken", "nonexistent", "model")],
|
||||||
|
);
|
||||||
|
|
||||||
|
// Route should not exist
|
||||||
|
assert!(!router.routes.contains_key("broken"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn warmup_calls_all_providers() {
|
||||||
|
let (router, _) = make_router(
|
||||||
|
vec![("a", "ok"), ("b", "ok")],
|
||||||
|
vec![],
|
||||||
|
);
|
||||||
|
|
||||||
|
// Warmup should not error
|
||||||
|
assert!(router.warmup().await.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn chat_with_system_passes_system_prompt() {
|
||||||
|
let mock = Arc::new(MockProvider::new("response"));
|
||||||
|
let router = RouterProvider::new(
|
||||||
|
vec![("default".into(), Box::new(Arc::clone(&mock)) as Box<dyn Provider>)],
|
||||||
|
vec![],
|
||||||
|
"model".into(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = router
|
||||||
|
.chat_with_system(Some("system"), "hello", "model", 0.5)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(result, "response");
|
||||||
|
assert_eq!(mock.call_count(), 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue