feat(channels): add runtime provider/model switching for telegram and discord

This commit is contained in:
Chummy 2026-02-18 22:00:15 +08:00
parent 0b66ed026c
commit 8988a069a6
3 changed files with 666 additions and 13 deletions

View file

@ -13,6 +13,21 @@ zeroclaw channel doctor
zeroclaw channel bind-telegram <IDENTITY> zeroclaw channel bind-telegram <IDENTITY>
``` ```
## In-Chat Runtime Model Switching (Telegram / Discord)
When running `zeroclaw channel start` (or daemon mode), Telegram and Discord now support sender-scoped runtime switching:
- `/models` — show available providers and current selection
- `/models <provider>` — switch provider for the current sender session
- `/model` — show current model and cached model IDs (if available)
- `/model <model-id>` — switch model for the current sender session
Notes:
- Switching clears only that sender's in-memory conversation history to avoid cross-model context contamination.
- Model cache previews come from `zeroclaw models refresh --provider <ID>`.
- These are runtime chat commands, not CLI subcommands.
## Channel Matrix ## Channel Matrix
| Channel | Config section | Access control field | Setup path | | Channel | Config section | Access control field | Setup path |

View file

@ -80,6 +80,13 @@ Last verified: **February 18, 2026**.
- `zeroclaw channel add <type> <json>` - `zeroclaw channel add <type> <json>`
- `zeroclaw channel remove <name>` - `zeroclaw channel remove <name>`
Runtime in-chat commands (Telegram/Discord while channel server is running):
- `/models`
- `/models <provider>`
- `/model`
- `/model <model-id>`
`add/remove` currently route you back to managed setup/manual config paths (not full declarative mutators yet). `add/remove` currently route you back to managed setup/manual config paths (not full declarative mutators yet).
### `integrations` ### `integrations`
@ -118,4 +125,3 @@ To verify docs against your current binary quickly:
zeroclaw --help zeroclaw --help
zeroclaw <command> --help zeroclaw <command> --help
``` ```

View file

@ -41,9 +41,10 @@ use crate::security::SecurityPolicy;
use crate::tools::{self, Tool}; use crate::tools::{self, Tool};
use crate::util::truncate_with_ellipsis; use crate::util::truncate_with_ellipsis;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use serde::Deserialize;
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Write; use std::fmt::Write;
use std::path::PathBuf; use std::path::{Path, PathBuf};
use std::process::Command; use std::process::Command;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
@ -66,11 +67,42 @@ const CHANNEL_PARALLELISM_PER_CHANNEL: usize = 4;
const CHANNEL_MIN_IN_FLIGHT_MESSAGES: usize = 8; const CHANNEL_MIN_IN_FLIGHT_MESSAGES: usize = 8;
const CHANNEL_MAX_IN_FLIGHT_MESSAGES: usize = 64; const CHANNEL_MAX_IN_FLIGHT_MESSAGES: usize = 64;
const CHANNEL_TYPING_REFRESH_INTERVAL_SECS: u64 = 4; const CHANNEL_TYPING_REFRESH_INTERVAL_SECS: u64 = 4;
const MODEL_CACHE_FILE: &str = "models_cache.json";
const MODEL_CACHE_PREVIEW_LIMIT: usize = 10;
type ProviderCacheMap = Arc<Mutex<HashMap<String, Arc<dyn Provider>>>>;
type RouteSelectionMap = Arc<Mutex<HashMap<String, ChannelRouteSelection>>>;
#[derive(Debug, Clone, PartialEq, Eq)]
struct ChannelRouteSelection {
provider: String,
model: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum ChannelRuntimeCommand {
ShowProviders,
SetProvider(String),
ShowModel,
SetModel(String),
}
#[derive(Debug, Clone, Default, Deserialize)]
struct ModelCacheState {
entries: Vec<ModelCacheEntry>,
}
#[derive(Debug, Clone, Default, Deserialize)]
struct ModelCacheEntry {
provider: String,
models: Vec<String>,
}
#[derive(Clone)] #[derive(Clone)]
struct ChannelRuntimeContext { struct ChannelRuntimeContext {
channels_by_name: Arc<HashMap<String, Arc<dyn Channel>>>, channels_by_name: Arc<HashMap<String, Arc<dyn Channel>>>,
provider: Arc<dyn Provider>, provider: Arc<dyn Provider>,
default_provider: Arc<String>,
memory: Arc<dyn Memory>, memory: Arc<dyn Memory>,
tools_registry: Arc<Vec<Box<dyn Tool>>>, tools_registry: Arc<Vec<Box<dyn Tool>>>,
observer: Arc<dyn Observer>, observer: Arc<dyn Observer>,
@ -81,12 +113,23 @@ struct ChannelRuntimeContext {
max_tool_iterations: usize, max_tool_iterations: usize,
min_relevance_score: f64, min_relevance_score: f64,
conversation_histories: ConversationHistoryMap, conversation_histories: ConversationHistoryMap,
provider_cache: ProviderCacheMap,
route_overrides: RouteSelectionMap,
api_key: Option<String>,
api_url: Option<String>,
reliability: Arc<crate::config::ReliabilityConfig>,
provider_runtime_options: providers::ProviderRuntimeOptions,
workspace_dir: Arc<PathBuf>,
} }
fn conversation_memory_key(msg: &traits::ChannelMessage) -> String { fn conversation_memory_key(msg: &traits::ChannelMessage) -> String {
format!("{}_{}_{}", msg.channel, msg.sender, msg.id) format!("{}_{}_{}", msg.channel, msg.sender, msg.id)
} }
fn conversation_history_key(msg: &traits::ChannelMessage) -> String {
format!("{}_{}", msg.channel, msg.sender)
}
fn channel_delivery_instructions(channel_name: &str) -> Option<&'static str> { fn channel_delivery_instructions(channel_name: &str) -> Option<&'static str> {
match channel_name { match channel_name {
"telegram" => Some( "telegram" => Some(
@ -96,6 +139,307 @@ fn channel_delivery_instructions(channel_name: &str) -> Option<&'static str> {
} }
} }
fn supports_runtime_model_switch(channel_name: &str) -> bool {
matches!(channel_name, "telegram" | "discord")
}
fn parse_runtime_command(channel_name: &str, content: &str) -> Option<ChannelRuntimeCommand> {
if !supports_runtime_model_switch(channel_name) {
return None;
}
let trimmed = content.trim();
if !trimmed.starts_with('/') {
return None;
}
let mut parts = trimmed.split_whitespace();
let command_token = parts.next()?;
let base_command = command_token
.split('@')
.next()
.unwrap_or(command_token)
.to_ascii_lowercase();
match base_command.as_str() {
"/models" => {
if let Some(provider) = parts.next() {
Some(ChannelRuntimeCommand::SetProvider(
provider.trim().to_string(),
))
} else {
Some(ChannelRuntimeCommand::ShowProviders)
}
}
"/model" => {
let model = parts.collect::<Vec<_>>().join(" ").trim().to_string();
if model.is_empty() {
Some(ChannelRuntimeCommand::ShowModel)
} else {
Some(ChannelRuntimeCommand::SetModel(model))
}
}
_ => None,
}
}
fn resolve_provider_alias(name: &str) -> Option<String> {
let candidate = name.trim();
if candidate.is_empty() {
return None;
}
let providers_list = providers::list_providers();
for provider in providers_list {
if provider.name.eq_ignore_ascii_case(candidate)
|| provider
.aliases
.iter()
.any(|alias| alias.eq_ignore_ascii_case(candidate))
{
return Some(provider.name.to_string());
}
}
None
}
fn default_route_selection(ctx: &ChannelRuntimeContext) -> ChannelRouteSelection {
ChannelRouteSelection {
provider: ctx.default_provider.as_str().to_string(),
model: ctx.model.as_str().to_string(),
}
}
fn get_route_selection(ctx: &ChannelRuntimeContext, sender_key: &str) -> ChannelRouteSelection {
ctx.route_overrides
.lock()
.unwrap_or_else(|e| e.into_inner())
.get(sender_key)
.cloned()
.unwrap_or_else(|| default_route_selection(ctx))
}
fn set_route_selection(ctx: &ChannelRuntimeContext, sender_key: &str, next: ChannelRouteSelection) {
let default_route = default_route_selection(ctx);
let mut routes = ctx
.route_overrides
.lock()
.unwrap_or_else(|e| e.into_inner());
if next == default_route {
routes.remove(sender_key);
} else {
routes.insert(sender_key.to_string(), next);
}
}
fn clear_sender_history(ctx: &ChannelRuntimeContext, sender_key: &str) {
ctx.conversation_histories
.lock()
.unwrap_or_else(|e| e.into_inner())
.remove(sender_key);
}
fn load_cached_model_preview(workspace_dir: &Path, provider_name: &str) -> Vec<String> {
let cache_path = workspace_dir.join("state").join(MODEL_CACHE_FILE);
let Ok(raw) = std::fs::read_to_string(cache_path) else {
return Vec::new();
};
let Ok(state) = serde_json::from_str::<ModelCacheState>(&raw) else {
return Vec::new();
};
state
.entries
.into_iter()
.find(|entry| entry.provider == provider_name)
.map(|entry| {
entry
.models
.into_iter()
.take(MODEL_CACHE_PREVIEW_LIMIT)
.collect::<Vec<_>>()
})
.unwrap_or_default()
}
async fn get_or_create_provider(
ctx: &ChannelRuntimeContext,
provider_name: &str,
) -> anyhow::Result<Arc<dyn Provider>> {
if provider_name == ctx.default_provider.as_str() {
return Ok(Arc::clone(&ctx.provider));
}
if let Some(existing) = ctx
.provider_cache
.lock()
.unwrap_or_else(|e| e.into_inner())
.get(provider_name)
.cloned()
{
return Ok(existing);
}
let api_url = if provider_name == ctx.default_provider.as_str() {
ctx.api_url.as_deref()
} else {
None
};
let provider = providers::create_resilient_provider_with_options(
provider_name,
ctx.api_key.as_deref(),
api_url,
&ctx.reliability,
&ctx.provider_runtime_options,
)?;
let provider: Arc<dyn Provider> = Arc::from(provider);
if let Err(err) = provider.warmup().await {
tracing::warn!(provider = provider_name, "Provider warmup failed: {err}");
}
let mut cache = ctx.provider_cache.lock().unwrap_or_else(|e| e.into_inner());
let cached = cache
.entry(provider_name.to_string())
.or_insert_with(|| Arc::clone(&provider));
Ok(Arc::clone(cached))
}
fn build_models_help_response(current: &ChannelRouteSelection, workspace_dir: &Path) -> String {
let mut response = String::new();
let _ = writeln!(
response,
"Current provider: `{}`\nCurrent model: `{}`",
current.provider, current.model
);
response.push_str("\nSwitch model with `/model <model-id>`.\n");
let cached_models = load_cached_model_preview(workspace_dir, &current.provider);
if cached_models.is_empty() {
let _ = writeln!(
response,
"\nNo cached model list found for `{}`. Ask the operator to run `zeroclaw models refresh --provider {}`.",
current.provider, current.provider
);
} else {
let _ = writeln!(
response,
"\nCached model IDs (top {}):",
cached_models.len()
);
for model in cached_models {
let _ = writeln!(response, "- `{model}`");
}
}
response
}
fn build_providers_help_response(current: &ChannelRouteSelection) -> String {
let mut response = String::new();
let _ = writeln!(
response,
"Current provider: `{}`\nCurrent model: `{}`",
current.provider, current.model
);
response.push_str("\nSwitch provider with `/models <provider>`.\n");
response.push_str("Switch model with `/model <model-id>`.\n\n");
response.push_str("Available providers:\n");
for provider in providers::list_providers() {
if provider.aliases.is_empty() {
let _ = writeln!(response, "- {}", provider.name);
} else {
let _ = writeln!(
response,
"- {} (aliases: {})",
provider.name,
provider.aliases.join(", ")
);
}
}
response
}
async fn handle_runtime_command_if_needed(
ctx: &ChannelRuntimeContext,
msg: &traits::ChannelMessage,
target_channel: Option<&Arc<dyn Channel>>,
) -> bool {
let Some(command) = parse_runtime_command(&msg.channel, &msg.content) else {
return false;
};
let Some(channel) = target_channel else {
return true;
};
let sender_key = conversation_history_key(msg);
let mut current = get_route_selection(ctx, &sender_key);
let response = match command {
ChannelRuntimeCommand::ShowProviders => build_providers_help_response(&current),
ChannelRuntimeCommand::SetProvider(raw_provider) => {
match resolve_provider_alias(&raw_provider) {
Some(provider_name) => match get_or_create_provider(ctx, &provider_name).await {
Ok(_) => {
if provider_name != current.provider {
current.provider = provider_name.clone();
set_route_selection(ctx, &sender_key, current.clone());
clear_sender_history(ctx, &sender_key);
}
format!(
"Provider switched to `{provider_name}` for this sender session. Current model is `{}`.\nUse `/model <model-id>` to set a provider-compatible model.",
current.model
)
}
Err(err) => {
let safe_err = providers::sanitize_api_error(&err.to_string());
format!(
"Failed to initialize provider `{provider_name}`. Route unchanged.\nDetails: {safe_err}"
)
}
},
None => format!(
"Unknown provider `{raw_provider}`. Use `/models` to list valid providers."
),
}
}
ChannelRuntimeCommand::ShowModel => {
build_models_help_response(&current, ctx.workspace_dir.as_path())
}
ChannelRuntimeCommand::SetModel(raw_model) => {
let model = raw_model.trim().trim_matches('`').to_string();
if model.is_empty() {
"Model ID cannot be empty. Use `/model <model-id>`.".to_string()
} else {
current.model = model.clone();
set_route_selection(ctx, &sender_key, current.clone());
clear_sender_history(ctx, &sender_key);
format!(
"Model switched to `{model}` for provider `{}` in this sender session.",
current.provider
)
}
}
};
if let Err(err) = channel
.send(&SendMessage::new(response, &msg.reply_target))
.await
{
tracing::warn!(
"Failed to send runtime command response on {}: {err}",
channel.name()
);
}
true
}
async fn build_memory_context( async fn build_memory_context(
mem: &dyn Memory, mem: &dyn Memory,
user_msg: &str, user_msg: &str,
@ -217,6 +561,30 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
truncate_with_ellipsis(&msg.content, 80) truncate_with_ellipsis(&msg.content, 80)
); );
let target_channel = ctx.channels_by_name.get(&msg.channel).cloned();
if handle_runtime_command_if_needed(ctx.as_ref(), &msg, target_channel.as_ref()).await {
return;
}
let history_key = conversation_history_key(&msg);
let route = get_route_selection(ctx.as_ref(), &history_key);
let active_provider = match get_or_create_provider(ctx.as_ref(), &route.provider).await {
Ok(provider) => provider,
Err(err) => {
let safe_err = providers::sanitize_api_error(&err.to_string());
let message = format!(
"⚠️ Failed to initialize provider `{}`. Please run `/models` to choose another provider.\nDetails: {safe_err}",
route.provider
);
if let Some(channel) = target_channel.as_ref() {
let _ = channel
.send(&SendMessage::new(message, &msg.reply_target))
.await;
}
return;
}
};
let memory_context = let memory_context =
build_memory_context(ctx.memory.as_ref(), &msg.content, ctx.min_relevance_score).await; build_memory_context(ctx.memory.as_ref(), &msg.content, ctx.min_relevance_score).await;
@ -239,13 +607,10 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
format!("{memory_context}{}", msg.content) format!("{memory_context}{}", msg.content)
}; };
let target_channel = ctx.channels_by_name.get(&msg.channel).cloned();
println!(" ⏳ Processing message..."); println!(" ⏳ Processing message...");
let started_at = Instant::now(); let started_at = Instant::now();
// Build history from per-sender conversation cache // Build history from per-sender conversation cache
let history_key = format!("{}_{}", msg.channel, msg.sender);
let mut prior_turns = ctx let mut prior_turns = ctx
.conversation_histories .conversation_histories
.lock() .lock()
@ -333,12 +698,12 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
let llm_result = tokio::time::timeout( let llm_result = tokio::time::timeout(
Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS), Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS),
run_tool_call_loop( run_tool_call_loop(
ctx.provider.as_ref(), active_provider.as_ref(),
&mut history, &mut history,
ctx.tools_registry.as_ref(), ctx.tools_registry.as_ref(),
ctx.observer.as_ref(), ctx.observer.as_ref(),
"channel-runtime", route.provider.as_str(),
ctx.model.as_str(), route.model.as_str(),
ctx.temperature, ctx.temperature,
true, true,
None, None,
@ -1117,16 +1482,17 @@ pub async fn start_channels(config: Config) -> Result<()> {
.default_provider .default_provider
.clone() .clone()
.unwrap_or_else(|| "openrouter".into()); .unwrap_or_else(|| "openrouter".into());
let provider_runtime_options = providers::ProviderRuntimeOptions {
auth_profile_override: None,
zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from),
secrets_encrypt: config.secrets.encrypt,
};
let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider_with_options( let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider_with_options(
&provider_name, &provider_name,
config.api_key.as_deref(), config.api_key.as_deref(),
config.api_url.as_deref(), config.api_url.as_deref(),
&config.reliability, &config.reliability,
&providers::ProviderRuntimeOptions { &provider_runtime_options,
auth_profile_override: None,
zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from),
secrets_encrypt: config.secrets.encrypt,
},
)?); )?);
// Warm up the provider connection pool (TLS handshake, DNS, HTTP/2 setup) // Warm up the provider connection pool (TLS handshake, DNS, HTTP/2 setup)
@ -1441,9 +1807,13 @@ pub async fn start_channels(config: Config) -> Result<()> {
println!(" 🚦 In-flight message limit: {max_in_flight_messages}"); println!(" 🚦 In-flight message limit: {max_in_flight_messages}");
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
provider_cache_seed.insert(provider_name.clone(), Arc::clone(&provider));
let runtime_ctx = Arc::new(ChannelRuntimeContext { let runtime_ctx = Arc::new(ChannelRuntimeContext {
channels_by_name, channels_by_name,
provider: Arc::clone(&provider), provider: Arc::clone(&provider),
default_provider: Arc::new(provider_name),
memory: Arc::clone(&mem), memory: Arc::clone(&mem),
tools_registry: Arc::clone(&tools_registry), tools_registry: Arc::clone(&tools_registry),
observer, observer,
@ -1454,6 +1824,13 @@ pub async fn start_channels(config: Config) -> Result<()> {
max_tool_iterations: config.agent.max_tool_iterations, max_tool_iterations: config.agent.max_tool_iterations,
min_relevance_score: config.memory.min_relevance_score, min_relevance_score: config.memory.min_relevance_score,
conversation_histories: Arc::new(Mutex::new(HashMap::new())), conversation_histories: Arc::new(Mutex::new(HashMap::new())),
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: config.api_key.clone(),
api_url: config.api_url.clone(),
reliability: Arc::new(config.reliability.clone()),
provider_runtime_options,
workspace_dir: Arc::new(config.workspace_dir.clone()),
}); });
run_message_dispatch_loop(rx, runtime_ctx, max_in_flight_messages).await; run_message_dispatch_loop(rx, runtime_ctx, max_in_flight_messages).await;
@ -1506,6 +1883,41 @@ mod tests {
stop_typing_calls: AtomicUsize, stop_typing_calls: AtomicUsize,
} }
#[derive(Default)]
struct TelegramRecordingChannel {
sent_messages: tokio::sync::Mutex<Vec<String>>,
}
#[async_trait::async_trait]
impl Channel for TelegramRecordingChannel {
fn name(&self) -> &str {
"telegram"
}
async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
self.sent_messages
.lock()
.await
.push(format!("{}:{}", message.recipient, message.content));
Ok(())
}
async fn listen(
&self,
_tx: tokio::sync::mpsc::Sender<traits::ChannelMessage>,
) -> anyhow::Result<()> {
Ok(())
}
async fn start_typing(&self, _recipient: &str) -> anyhow::Result<()> {
Ok(())
}
async fn stop_typing(&self, _recipient: &str) -> anyhow::Result<()> {
Ok(())
}
}
#[async_trait::async_trait] #[async_trait::async_trait]
impl Channel for RecordingChannel { impl Channel for RecordingChannel {
fn name(&self) -> &str { fn name(&self) -> &str {
@ -1667,6 +2079,39 @@ mod tests {
struct MockPriceTool; struct MockPriceTool;
#[derive(Default)]
struct ModelCaptureProvider {
call_count: AtomicUsize,
models: std::sync::Mutex<Vec<String>>,
}
#[async_trait::async_trait]
impl Provider for ModelCaptureProvider {
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
Ok("fallback".to_string())
}
async fn chat_with_history(
&self,
_messages: &[ChatMessage],
model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
self.call_count.fetch_add(1, Ordering::SeqCst);
self.models
.lock()
.unwrap_or_else(|e| e.into_inner())
.push(model.to_string());
Ok("ok".to_string())
}
}
#[async_trait::async_trait] #[async_trait::async_trait]
impl Tool for MockPriceTool { impl Tool for MockPriceTool {
fn name(&self) -> &str { fn name(&self) -> &str {
@ -1716,6 +2161,7 @@ mod tests {
let runtime_ctx = Arc::new(ChannelRuntimeContext { let runtime_ctx = Arc::new(ChannelRuntimeContext {
channels_by_name: Arc::new(channels_by_name), channels_by_name: Arc::new(channels_by_name),
provider: Arc::new(ToolCallingProvider), provider: Arc::new(ToolCallingProvider),
default_provider: Arc::new("test-provider".to_string()),
memory: Arc::new(NoopMemory), memory: Arc::new(NoopMemory),
tools_registry: Arc::new(vec![Box::new(MockPriceTool)]), tools_registry: Arc::new(vec![Box::new(MockPriceTool)]),
observer: Arc::new(NoopObserver), observer: Arc::new(NoopObserver),
@ -1726,6 +2172,13 @@ mod tests {
max_tool_iterations: 10, max_tool_iterations: 10,
min_relevance_score: 0.0, min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())), conversation_histories: Arc::new(Mutex::new(HashMap::new())),
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
api_url: None,
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
workspace_dir: Arc::new(std::env::temp_dir()),
}); });
process_channel_message( process_channel_message(
@ -1760,6 +2213,7 @@ mod tests {
let runtime_ctx = Arc::new(ChannelRuntimeContext { let runtime_ctx = Arc::new(ChannelRuntimeContext {
channels_by_name: Arc::new(channels_by_name), channels_by_name: Arc::new(channels_by_name),
provider: Arc::new(ToolCallingAliasProvider), provider: Arc::new(ToolCallingAliasProvider),
default_provider: Arc::new("test-provider".to_string()),
memory: Arc::new(NoopMemory), memory: Arc::new(NoopMemory),
tools_registry: Arc::new(vec![Box::new(MockPriceTool)]), tools_registry: Arc::new(vec![Box::new(MockPriceTool)]),
observer: Arc::new(NoopObserver), observer: Arc::new(NoopObserver),
@ -1770,6 +2224,13 @@ mod tests {
max_tool_iterations: 10, max_tool_iterations: 10,
min_relevance_score: 0.0, min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())), conversation_histories: Arc::new(Mutex::new(HashMap::new())),
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
api_url: None,
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
workspace_dir: Arc::new(std::env::temp_dir()),
}); });
process_channel_message( process_channel_message(
@ -1793,6 +2254,153 @@ mod tests {
assert!(!sent_messages[0].contains("mock_price")); assert!(!sent_messages[0].contains("mock_price"));
} }
#[tokio::test]
async fn process_channel_message_handles_models_command_without_llm_call() {
let channel_impl = Arc::new(TelegramRecordingChannel::default());
let channel: Arc<dyn Channel> = channel_impl.clone();
let mut channels_by_name = HashMap::new();
channels_by_name.insert(channel.name().to_string(), channel);
let default_provider_impl = Arc::new(ModelCaptureProvider::default());
let default_provider: Arc<dyn Provider> = default_provider_impl.clone();
let fallback_provider_impl = Arc::new(ModelCaptureProvider::default());
let fallback_provider: Arc<dyn Provider> = fallback_provider_impl.clone();
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
provider_cache_seed.insert("test-provider".to_string(), Arc::clone(&default_provider));
provider_cache_seed.insert("openrouter".to_string(), fallback_provider);
let runtime_ctx = Arc::new(ChannelRuntimeContext {
channels_by_name: Arc::new(channels_by_name),
provider: Arc::clone(&default_provider),
default_provider: Arc::new("test-provider".to_string()),
memory: Arc::new(NoopMemory),
tools_registry: Arc::new(vec![]),
observer: Arc::new(NoopObserver),
system_prompt: Arc::new("test-system-prompt".to_string()),
model: Arc::new("default-model".to_string()),
temperature: 0.0,
auto_save_memory: false,
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
api_url: None,
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
workspace_dir: Arc::new(std::env::temp_dir()),
});
process_channel_message(
runtime_ctx.clone(),
traits::ChannelMessage {
id: "msg-cmd-1".to_string(),
sender: "alice".to_string(),
reply_target: "chat-1".to_string(),
content: "/models openrouter".to_string(),
channel: "telegram".to_string(),
timestamp: 1,
},
)
.await;
let sent = channel_impl.sent_messages.lock().await;
assert_eq!(sent.len(), 1);
assert!(sent[0].contains("Provider switched to `openrouter`"));
let route_key = "telegram_alice";
let route = runtime_ctx
.route_overrides
.lock()
.unwrap_or_else(|e| e.into_inner())
.get(route_key)
.cloned()
.expect("route should be stored for sender");
assert_eq!(route.provider, "openrouter");
assert_eq!(route.model, "default-model");
assert_eq!(default_provider_impl.call_count.load(Ordering::SeqCst), 0);
assert_eq!(fallback_provider_impl.call_count.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn process_channel_message_uses_route_override_provider_and_model() {
let channel_impl = Arc::new(TelegramRecordingChannel::default());
let channel: Arc<dyn Channel> = channel_impl.clone();
let mut channels_by_name = HashMap::new();
channels_by_name.insert(channel.name().to_string(), channel);
let default_provider_impl = Arc::new(ModelCaptureProvider::default());
let default_provider: Arc<dyn Provider> = default_provider_impl.clone();
let routed_provider_impl = Arc::new(ModelCaptureProvider::default());
let routed_provider: Arc<dyn Provider> = routed_provider_impl.clone();
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
provider_cache_seed.insert("test-provider".to_string(), Arc::clone(&default_provider));
provider_cache_seed.insert("openrouter".to_string(), routed_provider);
let route_key = "telegram_alice".to_string();
let mut route_overrides = HashMap::new();
route_overrides.insert(
route_key,
ChannelRouteSelection {
provider: "openrouter".to_string(),
model: "route-model".to_string(),
},
);
let runtime_ctx = Arc::new(ChannelRuntimeContext {
channels_by_name: Arc::new(channels_by_name),
provider: Arc::clone(&default_provider),
default_provider: Arc::new("test-provider".to_string()),
memory: Arc::new(NoopMemory),
tools_registry: Arc::new(vec![]),
observer: Arc::new(NoopObserver),
system_prompt: Arc::new("test-system-prompt".to_string()),
model: Arc::new("default-model".to_string()),
temperature: 0.0,
auto_save_memory: false,
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
route_overrides: Arc::new(Mutex::new(route_overrides)),
api_key: None,
api_url: None,
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
workspace_dir: Arc::new(std::env::temp_dir()),
});
process_channel_message(
runtime_ctx,
traits::ChannelMessage {
id: "msg-routed-1".to_string(),
sender: "alice".to_string(),
reply_target: "chat-1".to_string(),
content: "hello routed provider".to_string(),
channel: "telegram".to_string(),
timestamp: 2,
},
)
.await;
assert_eq!(default_provider_impl.call_count.load(Ordering::SeqCst), 0);
assert_eq!(routed_provider_impl.call_count.load(Ordering::SeqCst), 1);
assert_eq!(
routed_provider_impl
.models
.lock()
.unwrap_or_else(|e| e.into_inner())
.as_slice(),
&["route-model".to_string()]
);
}
struct NoopMemory; struct NoopMemory;
#[async_trait::async_trait] #[async_trait::async_trait]
@ -1858,6 +2466,7 @@ mod tests {
provider: Arc::new(SlowProvider { provider: Arc::new(SlowProvider {
delay: Duration::from_millis(250), delay: Duration::from_millis(250),
}), }),
default_provider: Arc::new("test-provider".to_string()),
memory: Arc::new(NoopMemory), memory: Arc::new(NoopMemory),
tools_registry: Arc::new(vec![]), tools_registry: Arc::new(vec![]),
observer: Arc::new(NoopObserver), observer: Arc::new(NoopObserver),
@ -1868,6 +2477,13 @@ mod tests {
max_tool_iterations: 10, max_tool_iterations: 10,
min_relevance_score: 0.0, min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())), conversation_histories: Arc::new(Mutex::new(HashMap::new())),
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
api_url: None,
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
workspace_dir: Arc::new(std::env::temp_dir()),
}); });
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(4); let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(4);
@ -1920,6 +2536,7 @@ mod tests {
provider: Arc::new(SlowProvider { provider: Arc::new(SlowProvider {
delay: Duration::from_millis(20), delay: Duration::from_millis(20),
}), }),
default_provider: Arc::new("test-provider".to_string()),
memory: Arc::new(NoopMemory), memory: Arc::new(NoopMemory),
tools_registry: Arc::new(vec![]), tools_registry: Arc::new(vec![]),
observer: Arc::new(NoopObserver), observer: Arc::new(NoopObserver),
@ -1930,6 +2547,13 @@ mod tests {
max_tool_iterations: 10, max_tool_iterations: 10,
min_relevance_score: 0.0, min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())), conversation_histories: Arc::new(Mutex::new(HashMap::new())),
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
api_url: None,
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
workspace_dir: Arc::new(std::env::temp_dir()),
}); });
process_channel_message( process_channel_message(
@ -2302,6 +2926,7 @@ mod tests {
let runtime_ctx = Arc::new(ChannelRuntimeContext { let runtime_ctx = Arc::new(ChannelRuntimeContext {
channels_by_name: Arc::new(channels_by_name), channels_by_name: Arc::new(channels_by_name),
provider: provider_impl.clone(), provider: provider_impl.clone(),
default_provider: Arc::new("test-provider".to_string()),
memory: Arc::new(NoopMemory), memory: Arc::new(NoopMemory),
tools_registry: Arc::new(vec![]), tools_registry: Arc::new(vec![]),
observer: Arc::new(NoopObserver), observer: Arc::new(NoopObserver),
@ -2312,6 +2937,13 @@ mod tests {
max_tool_iterations: 5, max_tool_iterations: 5,
min_relevance_score: 0.0, min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())), conversation_histories: Arc::new(Mutex::new(HashMap::new())),
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
api_url: None,
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
workspace_dir: Arc::new(std::env::temp_dir()),
}); });
process_channel_message( process_channel_message(