diff --git a/docs/channels-reference.md b/docs/channels-reference.md index 066840f..31ed8ef 100644 --- a/docs/channels-reference.md +++ b/docs/channels-reference.md @@ -13,6 +13,21 @@ zeroclaw channel doctor zeroclaw channel bind-telegram ``` +## In-Chat Runtime Model Switching (Telegram / Discord) + +When running `zeroclaw channel start` (or daemon mode), Telegram and Discord now support sender-scoped runtime switching: + +- `/models` — show available providers and current selection +- `/models ` — switch provider for the current sender session +- `/model` — show current model and cached model IDs (if available) +- `/model ` — switch model for the current sender session + +Notes: + +- Switching clears only that sender's in-memory conversation history to avoid cross-model context contamination. +- Model cache previews come from `zeroclaw models refresh --provider `. +- These are runtime chat commands, not CLI subcommands. + ## Channel Matrix | Channel | Config section | Access control field | Setup path | diff --git a/docs/commands-reference.md b/docs/commands-reference.md index 7b685ec..7c53cd7 100644 --- a/docs/commands-reference.md +++ b/docs/commands-reference.md @@ -80,6 +80,13 @@ Last verified: **February 18, 2026**. - `zeroclaw channel add ` - `zeroclaw channel remove ` +Runtime in-chat commands (Telegram/Discord while channel server is running): + +- `/models` +- `/models ` +- `/model` +- `/model ` + `add/remove` currently route you back to managed setup/manual config paths (not full declarative mutators yet). ### `integrations` @@ -118,4 +125,3 @@ To verify docs against your current binary quickly: zeroclaw --help zeroclaw --help ``` - diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 749d624..b189139 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -41,9 +41,10 @@ use crate::security::SecurityPolicy; use crate::tools::{self, Tool}; use crate::util::truncate_with_ellipsis; use anyhow::{Context, Result}; +use serde::Deserialize; use std::collections::HashMap; use std::fmt::Write; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use std::process::Command; use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; @@ -66,11 +67,42 @@ const CHANNEL_PARALLELISM_PER_CHANNEL: usize = 4; const CHANNEL_MIN_IN_FLIGHT_MESSAGES: usize = 8; const CHANNEL_MAX_IN_FLIGHT_MESSAGES: usize = 64; const CHANNEL_TYPING_REFRESH_INTERVAL_SECS: u64 = 4; +const MODEL_CACHE_FILE: &str = "models_cache.json"; +const MODEL_CACHE_PREVIEW_LIMIT: usize = 10; + +type ProviderCacheMap = Arc>>>; +type RouteSelectionMap = Arc>>; + +#[derive(Debug, Clone, PartialEq, Eq)] +struct ChannelRouteSelection { + provider: String, + model: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum ChannelRuntimeCommand { + ShowProviders, + SetProvider(String), + ShowModel, + SetModel(String), +} + +#[derive(Debug, Clone, Default, Deserialize)] +struct ModelCacheState { + entries: Vec, +} + +#[derive(Debug, Clone, Default, Deserialize)] +struct ModelCacheEntry { + provider: String, + models: Vec, +} #[derive(Clone)] struct ChannelRuntimeContext { channels_by_name: Arc>>, provider: Arc, + default_provider: Arc, memory: Arc, tools_registry: Arc>>, observer: Arc, @@ -81,12 +113,23 @@ struct ChannelRuntimeContext { max_tool_iterations: usize, min_relevance_score: f64, conversation_histories: ConversationHistoryMap, + provider_cache: ProviderCacheMap, + route_overrides: RouteSelectionMap, + api_key: Option, + api_url: Option, + reliability: Arc, + provider_runtime_options: providers::ProviderRuntimeOptions, + workspace_dir: Arc, } fn conversation_memory_key(msg: &traits::ChannelMessage) -> String { format!("{}_{}_{}", msg.channel, msg.sender, msg.id) } +fn conversation_history_key(msg: &traits::ChannelMessage) -> String { + format!("{}_{}", msg.channel, msg.sender) +} + fn channel_delivery_instructions(channel_name: &str) -> Option<&'static str> { match channel_name { "telegram" => Some( @@ -96,6 +139,307 @@ fn channel_delivery_instructions(channel_name: &str) -> Option<&'static str> { } } +fn supports_runtime_model_switch(channel_name: &str) -> bool { + matches!(channel_name, "telegram" | "discord") +} + +fn parse_runtime_command(channel_name: &str, content: &str) -> Option { + if !supports_runtime_model_switch(channel_name) { + return None; + } + + let trimmed = content.trim(); + if !trimmed.starts_with('/') { + return None; + } + + let mut parts = trimmed.split_whitespace(); + let command_token = parts.next()?; + let base_command = command_token + .split('@') + .next() + .unwrap_or(command_token) + .to_ascii_lowercase(); + + match base_command.as_str() { + "/models" => { + if let Some(provider) = parts.next() { + Some(ChannelRuntimeCommand::SetProvider( + provider.trim().to_string(), + )) + } else { + Some(ChannelRuntimeCommand::ShowProviders) + } + } + "/model" => { + let model = parts.collect::>().join(" ").trim().to_string(); + if model.is_empty() { + Some(ChannelRuntimeCommand::ShowModel) + } else { + Some(ChannelRuntimeCommand::SetModel(model)) + } + } + _ => None, + } +} + +fn resolve_provider_alias(name: &str) -> Option { + let candidate = name.trim(); + if candidate.is_empty() { + return None; + } + + let providers_list = providers::list_providers(); + for provider in providers_list { + if provider.name.eq_ignore_ascii_case(candidate) + || provider + .aliases + .iter() + .any(|alias| alias.eq_ignore_ascii_case(candidate)) + { + return Some(provider.name.to_string()); + } + } + + None +} + +fn default_route_selection(ctx: &ChannelRuntimeContext) -> ChannelRouteSelection { + ChannelRouteSelection { + provider: ctx.default_provider.as_str().to_string(), + model: ctx.model.as_str().to_string(), + } +} + +fn get_route_selection(ctx: &ChannelRuntimeContext, sender_key: &str) -> ChannelRouteSelection { + ctx.route_overrides + .lock() + .unwrap_or_else(|e| e.into_inner()) + .get(sender_key) + .cloned() + .unwrap_or_else(|| default_route_selection(ctx)) +} + +fn set_route_selection(ctx: &ChannelRuntimeContext, sender_key: &str, next: ChannelRouteSelection) { + let default_route = default_route_selection(ctx); + let mut routes = ctx + .route_overrides + .lock() + .unwrap_or_else(|e| e.into_inner()); + if next == default_route { + routes.remove(sender_key); + } else { + routes.insert(sender_key.to_string(), next); + } +} + +fn clear_sender_history(ctx: &ChannelRuntimeContext, sender_key: &str) { + ctx.conversation_histories + .lock() + .unwrap_or_else(|e| e.into_inner()) + .remove(sender_key); +} + +fn load_cached_model_preview(workspace_dir: &Path, provider_name: &str) -> Vec { + let cache_path = workspace_dir.join("state").join(MODEL_CACHE_FILE); + let Ok(raw) = std::fs::read_to_string(cache_path) else { + return Vec::new(); + }; + let Ok(state) = serde_json::from_str::(&raw) else { + return Vec::new(); + }; + + state + .entries + .into_iter() + .find(|entry| entry.provider == provider_name) + .map(|entry| { + entry + .models + .into_iter() + .take(MODEL_CACHE_PREVIEW_LIMIT) + .collect::>() + }) + .unwrap_or_default() +} + +async fn get_or_create_provider( + ctx: &ChannelRuntimeContext, + provider_name: &str, +) -> anyhow::Result> { + if provider_name == ctx.default_provider.as_str() { + return Ok(Arc::clone(&ctx.provider)); + } + + if let Some(existing) = ctx + .provider_cache + .lock() + .unwrap_or_else(|e| e.into_inner()) + .get(provider_name) + .cloned() + { + return Ok(existing); + } + + let api_url = if provider_name == ctx.default_provider.as_str() { + ctx.api_url.as_deref() + } else { + None + }; + + let provider = providers::create_resilient_provider_with_options( + provider_name, + ctx.api_key.as_deref(), + api_url, + &ctx.reliability, + &ctx.provider_runtime_options, + )?; + let provider: Arc = Arc::from(provider); + + if let Err(err) = provider.warmup().await { + tracing::warn!(provider = provider_name, "Provider warmup failed: {err}"); + } + + let mut cache = ctx.provider_cache.lock().unwrap_or_else(|e| e.into_inner()); + let cached = cache + .entry(provider_name.to_string()) + .or_insert_with(|| Arc::clone(&provider)); + Ok(Arc::clone(cached)) +} + +fn build_models_help_response(current: &ChannelRouteSelection, workspace_dir: &Path) -> String { + let mut response = String::new(); + let _ = writeln!( + response, + "Current provider: `{}`\nCurrent model: `{}`", + current.provider, current.model + ); + response.push_str("\nSwitch model with `/model `.\n"); + + let cached_models = load_cached_model_preview(workspace_dir, ¤t.provider); + if cached_models.is_empty() { + let _ = writeln!( + response, + "\nNo cached model list found for `{}`. Ask the operator to run `zeroclaw models refresh --provider {}`.", + current.provider, current.provider + ); + } else { + let _ = writeln!( + response, + "\nCached model IDs (top {}):", + cached_models.len() + ); + for model in cached_models { + let _ = writeln!(response, "- `{model}`"); + } + } + + response +} + +fn build_providers_help_response(current: &ChannelRouteSelection) -> String { + let mut response = String::new(); + let _ = writeln!( + response, + "Current provider: `{}`\nCurrent model: `{}`", + current.provider, current.model + ); + response.push_str("\nSwitch provider with `/models `.\n"); + response.push_str("Switch model with `/model `.\n\n"); + response.push_str("Available providers:\n"); + for provider in providers::list_providers() { + if provider.aliases.is_empty() { + let _ = writeln!(response, "- {}", provider.name); + } else { + let _ = writeln!( + response, + "- {} (aliases: {})", + provider.name, + provider.aliases.join(", ") + ); + } + } + response +} + +async fn handle_runtime_command_if_needed( + ctx: &ChannelRuntimeContext, + msg: &traits::ChannelMessage, + target_channel: Option<&Arc>, +) -> bool { + let Some(command) = parse_runtime_command(&msg.channel, &msg.content) else { + return false; + }; + + let Some(channel) = target_channel else { + return true; + }; + + let sender_key = conversation_history_key(msg); + let mut current = get_route_selection(ctx, &sender_key); + + let response = match command { + ChannelRuntimeCommand::ShowProviders => build_providers_help_response(¤t), + ChannelRuntimeCommand::SetProvider(raw_provider) => { + match resolve_provider_alias(&raw_provider) { + Some(provider_name) => match get_or_create_provider(ctx, &provider_name).await { + Ok(_) => { + if provider_name != current.provider { + current.provider = provider_name.clone(); + set_route_selection(ctx, &sender_key, current.clone()); + clear_sender_history(ctx, &sender_key); + } + + format!( + "Provider switched to `{provider_name}` for this sender session. Current model is `{}`.\nUse `/model ` to set a provider-compatible model.", + current.model + ) + } + Err(err) => { + let safe_err = providers::sanitize_api_error(&err.to_string()); + format!( + "Failed to initialize provider `{provider_name}`. Route unchanged.\nDetails: {safe_err}" + ) + } + }, + None => format!( + "Unknown provider `{raw_provider}`. Use `/models` to list valid providers." + ), + } + } + ChannelRuntimeCommand::ShowModel => { + build_models_help_response(¤t, ctx.workspace_dir.as_path()) + } + ChannelRuntimeCommand::SetModel(raw_model) => { + let model = raw_model.trim().trim_matches('`').to_string(); + if model.is_empty() { + "Model ID cannot be empty. Use `/model `.".to_string() + } else { + current.model = model.clone(); + set_route_selection(ctx, &sender_key, current.clone()); + clear_sender_history(ctx, &sender_key); + + format!( + "Model switched to `{model}` for provider `{}` in this sender session.", + current.provider + ) + } + } + }; + + if let Err(err) = channel + .send(&SendMessage::new(response, &msg.reply_target)) + .await + { + tracing::warn!( + "Failed to send runtime command response on {}: {err}", + channel.name() + ); + } + + true +} + async fn build_memory_context( mem: &dyn Memory, user_msg: &str, @@ -217,6 +561,30 @@ async fn process_channel_message(ctx: Arc, msg: traits::C truncate_with_ellipsis(&msg.content, 80) ); + let target_channel = ctx.channels_by_name.get(&msg.channel).cloned(); + if handle_runtime_command_if_needed(ctx.as_ref(), &msg, target_channel.as_ref()).await { + return; + } + + let history_key = conversation_history_key(&msg); + let route = get_route_selection(ctx.as_ref(), &history_key); + let active_provider = match get_or_create_provider(ctx.as_ref(), &route.provider).await { + Ok(provider) => provider, + Err(err) => { + let safe_err = providers::sanitize_api_error(&err.to_string()); + let message = format!( + "⚠️ Failed to initialize provider `{}`. Please run `/models` to choose another provider.\nDetails: {safe_err}", + route.provider + ); + if let Some(channel) = target_channel.as_ref() { + let _ = channel + .send(&SendMessage::new(message, &msg.reply_target)) + .await; + } + return; + } + }; + let memory_context = build_memory_context(ctx.memory.as_ref(), &msg.content, ctx.min_relevance_score).await; @@ -239,13 +607,10 @@ async fn process_channel_message(ctx: Arc, msg: traits::C format!("{memory_context}{}", msg.content) }; - let target_channel = ctx.channels_by_name.get(&msg.channel).cloned(); - println!(" ⏳ Processing message..."); let started_at = Instant::now(); // Build history from per-sender conversation cache - let history_key = format!("{}_{}", msg.channel, msg.sender); let mut prior_turns = ctx .conversation_histories .lock() @@ -333,12 +698,12 @@ async fn process_channel_message(ctx: Arc, msg: traits::C let llm_result = tokio::time::timeout( Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS), run_tool_call_loop( - ctx.provider.as_ref(), + active_provider.as_ref(), &mut history, ctx.tools_registry.as_ref(), ctx.observer.as_ref(), - "channel-runtime", - ctx.model.as_str(), + route.provider.as_str(), + route.model.as_str(), ctx.temperature, true, None, @@ -1117,16 +1482,17 @@ pub async fn start_channels(config: Config) -> Result<()> { .default_provider .clone() .unwrap_or_else(|| "openrouter".into()); + let provider_runtime_options = providers::ProviderRuntimeOptions { + auth_profile_override: None, + zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from), + secrets_encrypt: config.secrets.encrypt, + }; let provider: Arc = Arc::from(providers::create_resilient_provider_with_options( &provider_name, config.api_key.as_deref(), config.api_url.as_deref(), &config.reliability, - &providers::ProviderRuntimeOptions { - auth_profile_override: None, - zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from), - secrets_encrypt: config.secrets.encrypt, - }, + &provider_runtime_options, )?); // Warm up the provider connection pool (TLS handshake, DNS, HTTP/2 setup) @@ -1441,9 +1807,13 @@ pub async fn start_channels(config: Config) -> Result<()> { println!(" 🚦 In-flight message limit: {max_in_flight_messages}"); + let mut provider_cache_seed: HashMap> = HashMap::new(); + provider_cache_seed.insert(provider_name.clone(), Arc::clone(&provider)); + let runtime_ctx = Arc::new(ChannelRuntimeContext { channels_by_name, provider: Arc::clone(&provider), + default_provider: Arc::new(provider_name), memory: Arc::clone(&mem), tools_registry: Arc::clone(&tools_registry), observer, @@ -1454,6 +1824,13 @@ pub async fn start_channels(config: Config) -> Result<()> { max_tool_iterations: config.agent.max_tool_iterations, min_relevance_score: config.memory.min_relevance_score, conversation_histories: Arc::new(Mutex::new(HashMap::new())), + provider_cache: Arc::new(Mutex::new(provider_cache_seed)), + route_overrides: Arc::new(Mutex::new(HashMap::new())), + api_key: config.api_key.clone(), + api_url: config.api_url.clone(), + reliability: Arc::new(config.reliability.clone()), + provider_runtime_options, + workspace_dir: Arc::new(config.workspace_dir.clone()), }); run_message_dispatch_loop(rx, runtime_ctx, max_in_flight_messages).await; @@ -1506,6 +1883,41 @@ mod tests { stop_typing_calls: AtomicUsize, } + #[derive(Default)] + struct TelegramRecordingChannel { + sent_messages: tokio::sync::Mutex>, + } + + #[async_trait::async_trait] + impl Channel for TelegramRecordingChannel { + fn name(&self) -> &str { + "telegram" + } + + async fn send(&self, message: &SendMessage) -> anyhow::Result<()> { + self.sent_messages + .lock() + .await + .push(format!("{}:{}", message.recipient, message.content)); + Ok(()) + } + + async fn listen( + &self, + _tx: tokio::sync::mpsc::Sender, + ) -> anyhow::Result<()> { + Ok(()) + } + + async fn start_typing(&self, _recipient: &str) -> anyhow::Result<()> { + Ok(()) + } + + async fn stop_typing(&self, _recipient: &str) -> anyhow::Result<()> { + Ok(()) + } + } + #[async_trait::async_trait] impl Channel for RecordingChannel { fn name(&self) -> &str { @@ -1667,6 +2079,39 @@ mod tests { struct MockPriceTool; + #[derive(Default)] + struct ModelCaptureProvider { + call_count: AtomicUsize, + models: std::sync::Mutex>, + } + + #[async_trait::async_trait] + impl Provider for ModelCaptureProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + Ok("fallback".to_string()) + } + + async fn chat_with_history( + &self, + _messages: &[ChatMessage], + model: &str, + _temperature: f64, + ) -> anyhow::Result { + self.call_count.fetch_add(1, Ordering::SeqCst); + self.models + .lock() + .unwrap_or_else(|e| e.into_inner()) + .push(model.to_string()); + Ok("ok".to_string()) + } + } + #[async_trait::async_trait] impl Tool for MockPriceTool { fn name(&self) -> &str { @@ -1716,6 +2161,7 @@ mod tests { let runtime_ctx = Arc::new(ChannelRuntimeContext { channels_by_name: Arc::new(channels_by_name), provider: Arc::new(ToolCallingProvider), + default_provider: Arc::new("test-provider".to_string()), memory: Arc::new(NoopMemory), tools_registry: Arc::new(vec![Box::new(MockPriceTool)]), observer: Arc::new(NoopObserver), @@ -1726,6 +2172,13 @@ mod tests { max_tool_iterations: 10, min_relevance_score: 0.0, conversation_histories: Arc::new(Mutex::new(HashMap::new())), + provider_cache: Arc::new(Mutex::new(HashMap::new())), + route_overrides: Arc::new(Mutex::new(HashMap::new())), + api_key: None, + api_url: None, + reliability: Arc::new(crate::config::ReliabilityConfig::default()), + provider_runtime_options: providers::ProviderRuntimeOptions::default(), + workspace_dir: Arc::new(std::env::temp_dir()), }); process_channel_message( @@ -1760,6 +2213,7 @@ mod tests { let runtime_ctx = Arc::new(ChannelRuntimeContext { channels_by_name: Arc::new(channels_by_name), provider: Arc::new(ToolCallingAliasProvider), + default_provider: Arc::new("test-provider".to_string()), memory: Arc::new(NoopMemory), tools_registry: Arc::new(vec![Box::new(MockPriceTool)]), observer: Arc::new(NoopObserver), @@ -1770,6 +2224,13 @@ mod tests { max_tool_iterations: 10, min_relevance_score: 0.0, conversation_histories: Arc::new(Mutex::new(HashMap::new())), + provider_cache: Arc::new(Mutex::new(HashMap::new())), + route_overrides: Arc::new(Mutex::new(HashMap::new())), + api_key: None, + api_url: None, + reliability: Arc::new(crate::config::ReliabilityConfig::default()), + provider_runtime_options: providers::ProviderRuntimeOptions::default(), + workspace_dir: Arc::new(std::env::temp_dir()), }); process_channel_message( @@ -1793,6 +2254,153 @@ mod tests { assert!(!sent_messages[0].contains("mock_price")); } + #[tokio::test] + async fn process_channel_message_handles_models_command_without_llm_call() { + let channel_impl = Arc::new(TelegramRecordingChannel::default()); + let channel: Arc = channel_impl.clone(); + + let mut channels_by_name = HashMap::new(); + channels_by_name.insert(channel.name().to_string(), channel); + + let default_provider_impl = Arc::new(ModelCaptureProvider::default()); + let default_provider: Arc = default_provider_impl.clone(); + let fallback_provider_impl = Arc::new(ModelCaptureProvider::default()); + let fallback_provider: Arc = fallback_provider_impl.clone(); + + let mut provider_cache_seed: HashMap> = HashMap::new(); + provider_cache_seed.insert("test-provider".to_string(), Arc::clone(&default_provider)); + provider_cache_seed.insert("openrouter".to_string(), fallback_provider); + + let runtime_ctx = Arc::new(ChannelRuntimeContext { + channels_by_name: Arc::new(channels_by_name), + provider: Arc::clone(&default_provider), + default_provider: Arc::new("test-provider".to_string()), + memory: Arc::new(NoopMemory), + tools_registry: Arc::new(vec![]), + observer: Arc::new(NoopObserver), + system_prompt: Arc::new("test-system-prompt".to_string()), + model: Arc::new("default-model".to_string()), + temperature: 0.0, + auto_save_memory: false, + max_tool_iterations: 5, + min_relevance_score: 0.0, + conversation_histories: Arc::new(Mutex::new(HashMap::new())), + provider_cache: Arc::new(Mutex::new(provider_cache_seed)), + route_overrides: Arc::new(Mutex::new(HashMap::new())), + api_key: None, + api_url: None, + reliability: Arc::new(crate::config::ReliabilityConfig::default()), + provider_runtime_options: providers::ProviderRuntimeOptions::default(), + workspace_dir: Arc::new(std::env::temp_dir()), + }); + + process_channel_message( + runtime_ctx.clone(), + traits::ChannelMessage { + id: "msg-cmd-1".to_string(), + sender: "alice".to_string(), + reply_target: "chat-1".to_string(), + content: "/models openrouter".to_string(), + channel: "telegram".to_string(), + timestamp: 1, + }, + ) + .await; + + let sent = channel_impl.sent_messages.lock().await; + assert_eq!(sent.len(), 1); + assert!(sent[0].contains("Provider switched to `openrouter`")); + + let route_key = "telegram_alice"; + let route = runtime_ctx + .route_overrides + .lock() + .unwrap_or_else(|e| e.into_inner()) + .get(route_key) + .cloned() + .expect("route should be stored for sender"); + assert_eq!(route.provider, "openrouter"); + assert_eq!(route.model, "default-model"); + + assert_eq!(default_provider_impl.call_count.load(Ordering::SeqCst), 0); + assert_eq!(fallback_provider_impl.call_count.load(Ordering::SeqCst), 0); + } + + #[tokio::test] + async fn process_channel_message_uses_route_override_provider_and_model() { + let channel_impl = Arc::new(TelegramRecordingChannel::default()); + let channel: Arc = channel_impl.clone(); + + let mut channels_by_name = HashMap::new(); + channels_by_name.insert(channel.name().to_string(), channel); + + let default_provider_impl = Arc::new(ModelCaptureProvider::default()); + let default_provider: Arc = default_provider_impl.clone(); + let routed_provider_impl = Arc::new(ModelCaptureProvider::default()); + let routed_provider: Arc = routed_provider_impl.clone(); + + let mut provider_cache_seed: HashMap> = HashMap::new(); + provider_cache_seed.insert("test-provider".to_string(), Arc::clone(&default_provider)); + provider_cache_seed.insert("openrouter".to_string(), routed_provider); + + let route_key = "telegram_alice".to_string(); + let mut route_overrides = HashMap::new(); + route_overrides.insert( + route_key, + ChannelRouteSelection { + provider: "openrouter".to_string(), + model: "route-model".to_string(), + }, + ); + + let runtime_ctx = Arc::new(ChannelRuntimeContext { + channels_by_name: Arc::new(channels_by_name), + provider: Arc::clone(&default_provider), + default_provider: Arc::new("test-provider".to_string()), + memory: Arc::new(NoopMemory), + tools_registry: Arc::new(vec![]), + observer: Arc::new(NoopObserver), + system_prompt: Arc::new("test-system-prompt".to_string()), + model: Arc::new("default-model".to_string()), + temperature: 0.0, + auto_save_memory: false, + max_tool_iterations: 5, + min_relevance_score: 0.0, + conversation_histories: Arc::new(Mutex::new(HashMap::new())), + provider_cache: Arc::new(Mutex::new(provider_cache_seed)), + route_overrides: Arc::new(Mutex::new(route_overrides)), + api_key: None, + api_url: None, + reliability: Arc::new(crate::config::ReliabilityConfig::default()), + provider_runtime_options: providers::ProviderRuntimeOptions::default(), + workspace_dir: Arc::new(std::env::temp_dir()), + }); + + process_channel_message( + runtime_ctx, + traits::ChannelMessage { + id: "msg-routed-1".to_string(), + sender: "alice".to_string(), + reply_target: "chat-1".to_string(), + content: "hello routed provider".to_string(), + channel: "telegram".to_string(), + timestamp: 2, + }, + ) + .await; + + assert_eq!(default_provider_impl.call_count.load(Ordering::SeqCst), 0); + assert_eq!(routed_provider_impl.call_count.load(Ordering::SeqCst), 1); + assert_eq!( + routed_provider_impl + .models + .lock() + .unwrap_or_else(|e| e.into_inner()) + .as_slice(), + &["route-model".to_string()] + ); + } + struct NoopMemory; #[async_trait::async_trait] @@ -1858,6 +2466,7 @@ mod tests { provider: Arc::new(SlowProvider { delay: Duration::from_millis(250), }), + default_provider: Arc::new("test-provider".to_string()), memory: Arc::new(NoopMemory), tools_registry: Arc::new(vec![]), observer: Arc::new(NoopObserver), @@ -1868,6 +2477,13 @@ mod tests { max_tool_iterations: 10, min_relevance_score: 0.0, conversation_histories: Arc::new(Mutex::new(HashMap::new())), + provider_cache: Arc::new(Mutex::new(HashMap::new())), + route_overrides: Arc::new(Mutex::new(HashMap::new())), + api_key: None, + api_url: None, + reliability: Arc::new(crate::config::ReliabilityConfig::default()), + provider_runtime_options: providers::ProviderRuntimeOptions::default(), + workspace_dir: Arc::new(std::env::temp_dir()), }); let (tx, rx) = tokio::sync::mpsc::channel::(4); @@ -1920,6 +2536,7 @@ mod tests { provider: Arc::new(SlowProvider { delay: Duration::from_millis(20), }), + default_provider: Arc::new("test-provider".to_string()), memory: Arc::new(NoopMemory), tools_registry: Arc::new(vec![]), observer: Arc::new(NoopObserver), @@ -1930,6 +2547,13 @@ mod tests { max_tool_iterations: 10, min_relevance_score: 0.0, conversation_histories: Arc::new(Mutex::new(HashMap::new())), + provider_cache: Arc::new(Mutex::new(HashMap::new())), + route_overrides: Arc::new(Mutex::new(HashMap::new())), + api_key: None, + api_url: None, + reliability: Arc::new(crate::config::ReliabilityConfig::default()), + provider_runtime_options: providers::ProviderRuntimeOptions::default(), + workspace_dir: Arc::new(std::env::temp_dir()), }); process_channel_message( @@ -2302,6 +2926,7 @@ mod tests { let runtime_ctx = Arc::new(ChannelRuntimeContext { channels_by_name: Arc::new(channels_by_name), provider: provider_impl.clone(), + default_provider: Arc::new("test-provider".to_string()), memory: Arc::new(NoopMemory), tools_registry: Arc::new(vec![]), observer: Arc::new(NoopObserver), @@ -2312,6 +2937,13 @@ mod tests { max_tool_iterations: 5, min_relevance_score: 0.0, conversation_histories: Arc::new(Mutex::new(HashMap::new())), + provider_cache: Arc::new(Mutex::new(HashMap::new())), + route_overrides: Arc::new(Mutex::new(HashMap::new())), + api_key: None, + api_url: None, + reliability: Arc::new(crate::config::ReliabilityConfig::default()), + provider_runtime_options: providers::ProviderRuntimeOptions::default(), + workspace_dir: Arc::new(std::env::temp_dir()), }); process_channel_message(