feat(channels): add runtime provider/model switching for telegram and discord
This commit is contained in:
parent
0b66ed026c
commit
8988a069a6
3 changed files with 666 additions and 13 deletions
|
|
@ -13,6 +13,21 @@ zeroclaw channel doctor
|
|||
zeroclaw channel bind-telegram <IDENTITY>
|
||||
```
|
||||
|
||||
## In-Chat Runtime Model Switching (Telegram / Discord)
|
||||
|
||||
When running `zeroclaw channel start` (or daemon mode), Telegram and Discord now support sender-scoped runtime switching:
|
||||
|
||||
- `/models` — show available providers and current selection
|
||||
- `/models <provider>` — switch provider for the current sender session
|
||||
- `/model` — show current model and cached model IDs (if available)
|
||||
- `/model <model-id>` — switch model for the current sender session
|
||||
|
||||
Notes:
|
||||
|
||||
- Switching clears only that sender's in-memory conversation history to avoid cross-model context contamination.
|
||||
- Model cache previews come from `zeroclaw models refresh --provider <ID>`.
|
||||
- These are runtime chat commands, not CLI subcommands.
|
||||
|
||||
## Channel Matrix
|
||||
|
||||
| Channel | Config section | Access control field | Setup path |
|
||||
|
|
|
|||
|
|
@ -80,6 +80,13 @@ Last verified: **February 18, 2026**.
|
|||
- `zeroclaw channel add <type> <json>`
|
||||
- `zeroclaw channel remove <name>`
|
||||
|
||||
Runtime in-chat commands (Telegram/Discord while channel server is running):
|
||||
|
||||
- `/models`
|
||||
- `/models <provider>`
|
||||
- `/model`
|
||||
- `/model <model-id>`
|
||||
|
||||
`add/remove` currently route you back to managed setup/manual config paths (not full declarative mutators yet).
|
||||
|
||||
### `integrations`
|
||||
|
|
@ -118,4 +125,3 @@ To verify docs against your current binary quickly:
|
|||
zeroclaw --help
|
||||
zeroclaw <command> --help
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -41,9 +41,10 @@ use crate::security::SecurityPolicy;
|
|||
use crate::tools::{self, Tool};
|
||||
use crate::util::truncate_with_ellipsis;
|
||||
use anyhow::{Context, Result};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Write;
|
||||
use std::path::PathBuf;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::Command;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, Instant};
|
||||
|
|
@ -66,11 +67,42 @@ const CHANNEL_PARALLELISM_PER_CHANNEL: usize = 4;
|
|||
const CHANNEL_MIN_IN_FLIGHT_MESSAGES: usize = 8;
|
||||
const CHANNEL_MAX_IN_FLIGHT_MESSAGES: usize = 64;
|
||||
const CHANNEL_TYPING_REFRESH_INTERVAL_SECS: u64 = 4;
|
||||
const MODEL_CACHE_FILE: &str = "models_cache.json";
|
||||
const MODEL_CACHE_PREVIEW_LIMIT: usize = 10;
|
||||
|
||||
type ProviderCacheMap = Arc<Mutex<HashMap<String, Arc<dyn Provider>>>>;
|
||||
type RouteSelectionMap = Arc<Mutex<HashMap<String, ChannelRouteSelection>>>;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
struct ChannelRouteSelection {
|
||||
provider: String,
|
||||
model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
enum ChannelRuntimeCommand {
|
||||
ShowProviders,
|
||||
SetProvider(String),
|
||||
ShowModel,
|
||||
SetModel(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Deserialize)]
|
||||
struct ModelCacheState {
|
||||
entries: Vec<ModelCacheEntry>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Deserialize)]
|
||||
struct ModelCacheEntry {
|
||||
provider: String,
|
||||
models: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ChannelRuntimeContext {
|
||||
channels_by_name: Arc<HashMap<String, Arc<dyn Channel>>>,
|
||||
provider: Arc<dyn Provider>,
|
||||
default_provider: Arc<String>,
|
||||
memory: Arc<dyn Memory>,
|
||||
tools_registry: Arc<Vec<Box<dyn Tool>>>,
|
||||
observer: Arc<dyn Observer>,
|
||||
|
|
@ -81,12 +113,23 @@ struct ChannelRuntimeContext {
|
|||
max_tool_iterations: usize,
|
||||
min_relevance_score: f64,
|
||||
conversation_histories: ConversationHistoryMap,
|
||||
provider_cache: ProviderCacheMap,
|
||||
route_overrides: RouteSelectionMap,
|
||||
api_key: Option<String>,
|
||||
api_url: Option<String>,
|
||||
reliability: Arc<crate::config::ReliabilityConfig>,
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions,
|
||||
workspace_dir: Arc<PathBuf>,
|
||||
}
|
||||
|
||||
fn conversation_memory_key(msg: &traits::ChannelMessage) -> String {
|
||||
format!("{}_{}_{}", msg.channel, msg.sender, msg.id)
|
||||
}
|
||||
|
||||
fn conversation_history_key(msg: &traits::ChannelMessage) -> String {
|
||||
format!("{}_{}", msg.channel, msg.sender)
|
||||
}
|
||||
|
||||
fn channel_delivery_instructions(channel_name: &str) -> Option<&'static str> {
|
||||
match channel_name {
|
||||
"telegram" => Some(
|
||||
|
|
@ -96,6 +139,307 @@ fn channel_delivery_instructions(channel_name: &str) -> Option<&'static str> {
|
|||
}
|
||||
}
|
||||
|
||||
fn supports_runtime_model_switch(channel_name: &str) -> bool {
|
||||
matches!(channel_name, "telegram" | "discord")
|
||||
}
|
||||
|
||||
fn parse_runtime_command(channel_name: &str, content: &str) -> Option<ChannelRuntimeCommand> {
|
||||
if !supports_runtime_model_switch(channel_name) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let trimmed = content.trim();
|
||||
if !trimmed.starts_with('/') {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut parts = trimmed.split_whitespace();
|
||||
let command_token = parts.next()?;
|
||||
let base_command = command_token
|
||||
.split('@')
|
||||
.next()
|
||||
.unwrap_or(command_token)
|
||||
.to_ascii_lowercase();
|
||||
|
||||
match base_command.as_str() {
|
||||
"/models" => {
|
||||
if let Some(provider) = parts.next() {
|
||||
Some(ChannelRuntimeCommand::SetProvider(
|
||||
provider.trim().to_string(),
|
||||
))
|
||||
} else {
|
||||
Some(ChannelRuntimeCommand::ShowProviders)
|
||||
}
|
||||
}
|
||||
"/model" => {
|
||||
let model = parts.collect::<Vec<_>>().join(" ").trim().to_string();
|
||||
if model.is_empty() {
|
||||
Some(ChannelRuntimeCommand::ShowModel)
|
||||
} else {
|
||||
Some(ChannelRuntimeCommand::SetModel(model))
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_provider_alias(name: &str) -> Option<String> {
|
||||
let candidate = name.trim();
|
||||
if candidate.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let providers_list = providers::list_providers();
|
||||
for provider in providers_list {
|
||||
if provider.name.eq_ignore_ascii_case(candidate)
|
||||
|| provider
|
||||
.aliases
|
||||
.iter()
|
||||
.any(|alias| alias.eq_ignore_ascii_case(candidate))
|
||||
{
|
||||
return Some(provider.name.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn default_route_selection(ctx: &ChannelRuntimeContext) -> ChannelRouteSelection {
|
||||
ChannelRouteSelection {
|
||||
provider: ctx.default_provider.as_str().to_string(),
|
||||
model: ctx.model.as_str().to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_route_selection(ctx: &ChannelRuntimeContext, sender_key: &str) -> ChannelRouteSelection {
|
||||
ctx.route_overrides
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner())
|
||||
.get(sender_key)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| default_route_selection(ctx))
|
||||
}
|
||||
|
||||
fn set_route_selection(ctx: &ChannelRuntimeContext, sender_key: &str, next: ChannelRouteSelection) {
|
||||
let default_route = default_route_selection(ctx);
|
||||
let mut routes = ctx
|
||||
.route_overrides
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
if next == default_route {
|
||||
routes.remove(sender_key);
|
||||
} else {
|
||||
routes.insert(sender_key.to_string(), next);
|
||||
}
|
||||
}
|
||||
|
||||
fn clear_sender_history(ctx: &ChannelRuntimeContext, sender_key: &str) {
|
||||
ctx.conversation_histories
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner())
|
||||
.remove(sender_key);
|
||||
}
|
||||
|
||||
fn load_cached_model_preview(workspace_dir: &Path, provider_name: &str) -> Vec<String> {
|
||||
let cache_path = workspace_dir.join("state").join(MODEL_CACHE_FILE);
|
||||
let Ok(raw) = std::fs::read_to_string(cache_path) else {
|
||||
return Vec::new();
|
||||
};
|
||||
let Ok(state) = serde_json::from_str::<ModelCacheState>(&raw) else {
|
||||
return Vec::new();
|
||||
};
|
||||
|
||||
state
|
||||
.entries
|
||||
.into_iter()
|
||||
.find(|entry| entry.provider == provider_name)
|
||||
.map(|entry| {
|
||||
entry
|
||||
.models
|
||||
.into_iter()
|
||||
.take(MODEL_CACHE_PREVIEW_LIMIT)
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
async fn get_or_create_provider(
|
||||
ctx: &ChannelRuntimeContext,
|
||||
provider_name: &str,
|
||||
) -> anyhow::Result<Arc<dyn Provider>> {
|
||||
if provider_name == ctx.default_provider.as_str() {
|
||||
return Ok(Arc::clone(&ctx.provider));
|
||||
}
|
||||
|
||||
if let Some(existing) = ctx
|
||||
.provider_cache
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner())
|
||||
.get(provider_name)
|
||||
.cloned()
|
||||
{
|
||||
return Ok(existing);
|
||||
}
|
||||
|
||||
let api_url = if provider_name == ctx.default_provider.as_str() {
|
||||
ctx.api_url.as_deref()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let provider = providers::create_resilient_provider_with_options(
|
||||
provider_name,
|
||||
ctx.api_key.as_deref(),
|
||||
api_url,
|
||||
&ctx.reliability,
|
||||
&ctx.provider_runtime_options,
|
||||
)?;
|
||||
let provider: Arc<dyn Provider> = Arc::from(provider);
|
||||
|
||||
if let Err(err) = provider.warmup().await {
|
||||
tracing::warn!(provider = provider_name, "Provider warmup failed: {err}");
|
||||
}
|
||||
|
||||
let mut cache = ctx.provider_cache.lock().unwrap_or_else(|e| e.into_inner());
|
||||
let cached = cache
|
||||
.entry(provider_name.to_string())
|
||||
.or_insert_with(|| Arc::clone(&provider));
|
||||
Ok(Arc::clone(cached))
|
||||
}
|
||||
|
||||
fn build_models_help_response(current: &ChannelRouteSelection, workspace_dir: &Path) -> String {
|
||||
let mut response = String::new();
|
||||
let _ = writeln!(
|
||||
response,
|
||||
"Current provider: `{}`\nCurrent model: `{}`",
|
||||
current.provider, current.model
|
||||
);
|
||||
response.push_str("\nSwitch model with `/model <model-id>`.\n");
|
||||
|
||||
let cached_models = load_cached_model_preview(workspace_dir, ¤t.provider);
|
||||
if cached_models.is_empty() {
|
||||
let _ = writeln!(
|
||||
response,
|
||||
"\nNo cached model list found for `{}`. Ask the operator to run `zeroclaw models refresh --provider {}`.",
|
||||
current.provider, current.provider
|
||||
);
|
||||
} else {
|
||||
let _ = writeln!(
|
||||
response,
|
||||
"\nCached model IDs (top {}):",
|
||||
cached_models.len()
|
||||
);
|
||||
for model in cached_models {
|
||||
let _ = writeln!(response, "- `{model}`");
|
||||
}
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
fn build_providers_help_response(current: &ChannelRouteSelection) -> String {
|
||||
let mut response = String::new();
|
||||
let _ = writeln!(
|
||||
response,
|
||||
"Current provider: `{}`\nCurrent model: `{}`",
|
||||
current.provider, current.model
|
||||
);
|
||||
response.push_str("\nSwitch provider with `/models <provider>`.\n");
|
||||
response.push_str("Switch model with `/model <model-id>`.\n\n");
|
||||
response.push_str("Available providers:\n");
|
||||
for provider in providers::list_providers() {
|
||||
if provider.aliases.is_empty() {
|
||||
let _ = writeln!(response, "- {}", provider.name);
|
||||
} else {
|
||||
let _ = writeln!(
|
||||
response,
|
||||
"- {} (aliases: {})",
|
||||
provider.name,
|
||||
provider.aliases.join(", ")
|
||||
);
|
||||
}
|
||||
}
|
||||
response
|
||||
}
|
||||
|
||||
async fn handle_runtime_command_if_needed(
|
||||
ctx: &ChannelRuntimeContext,
|
||||
msg: &traits::ChannelMessage,
|
||||
target_channel: Option<&Arc<dyn Channel>>,
|
||||
) -> bool {
|
||||
let Some(command) = parse_runtime_command(&msg.channel, &msg.content) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
let Some(channel) = target_channel else {
|
||||
return true;
|
||||
};
|
||||
|
||||
let sender_key = conversation_history_key(msg);
|
||||
let mut current = get_route_selection(ctx, &sender_key);
|
||||
|
||||
let response = match command {
|
||||
ChannelRuntimeCommand::ShowProviders => build_providers_help_response(¤t),
|
||||
ChannelRuntimeCommand::SetProvider(raw_provider) => {
|
||||
match resolve_provider_alias(&raw_provider) {
|
||||
Some(provider_name) => match get_or_create_provider(ctx, &provider_name).await {
|
||||
Ok(_) => {
|
||||
if provider_name != current.provider {
|
||||
current.provider = provider_name.clone();
|
||||
set_route_selection(ctx, &sender_key, current.clone());
|
||||
clear_sender_history(ctx, &sender_key);
|
||||
}
|
||||
|
||||
format!(
|
||||
"Provider switched to `{provider_name}` for this sender session. Current model is `{}`.\nUse `/model <model-id>` to set a provider-compatible model.",
|
||||
current.model
|
||||
)
|
||||
}
|
||||
Err(err) => {
|
||||
let safe_err = providers::sanitize_api_error(&err.to_string());
|
||||
format!(
|
||||
"Failed to initialize provider `{provider_name}`. Route unchanged.\nDetails: {safe_err}"
|
||||
)
|
||||
}
|
||||
},
|
||||
None => format!(
|
||||
"Unknown provider `{raw_provider}`. Use `/models` to list valid providers."
|
||||
),
|
||||
}
|
||||
}
|
||||
ChannelRuntimeCommand::ShowModel => {
|
||||
build_models_help_response(¤t, ctx.workspace_dir.as_path())
|
||||
}
|
||||
ChannelRuntimeCommand::SetModel(raw_model) => {
|
||||
let model = raw_model.trim().trim_matches('`').to_string();
|
||||
if model.is_empty() {
|
||||
"Model ID cannot be empty. Use `/model <model-id>`.".to_string()
|
||||
} else {
|
||||
current.model = model.clone();
|
||||
set_route_selection(ctx, &sender_key, current.clone());
|
||||
clear_sender_history(ctx, &sender_key);
|
||||
|
||||
format!(
|
||||
"Model switched to `{model}` for provider `{}` in this sender session.",
|
||||
current.provider
|
||||
)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(err) = channel
|
||||
.send(&SendMessage::new(response, &msg.reply_target))
|
||||
.await
|
||||
{
|
||||
tracing::warn!(
|
||||
"Failed to send runtime command response on {}: {err}",
|
||||
channel.name()
|
||||
);
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
async fn build_memory_context(
|
||||
mem: &dyn Memory,
|
||||
user_msg: &str,
|
||||
|
|
@ -217,6 +561,30 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
|||
truncate_with_ellipsis(&msg.content, 80)
|
||||
);
|
||||
|
||||
let target_channel = ctx.channels_by_name.get(&msg.channel).cloned();
|
||||
if handle_runtime_command_if_needed(ctx.as_ref(), &msg, target_channel.as_ref()).await {
|
||||
return;
|
||||
}
|
||||
|
||||
let history_key = conversation_history_key(&msg);
|
||||
let route = get_route_selection(ctx.as_ref(), &history_key);
|
||||
let active_provider = match get_or_create_provider(ctx.as_ref(), &route.provider).await {
|
||||
Ok(provider) => provider,
|
||||
Err(err) => {
|
||||
let safe_err = providers::sanitize_api_error(&err.to_string());
|
||||
let message = format!(
|
||||
"⚠️ Failed to initialize provider `{}`. Please run `/models` to choose another provider.\nDetails: {safe_err}",
|
||||
route.provider
|
||||
);
|
||||
if let Some(channel) = target_channel.as_ref() {
|
||||
let _ = channel
|
||||
.send(&SendMessage::new(message, &msg.reply_target))
|
||||
.await;
|
||||
}
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let memory_context =
|
||||
build_memory_context(ctx.memory.as_ref(), &msg.content, ctx.min_relevance_score).await;
|
||||
|
||||
|
|
@ -239,13 +607,10 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
|||
format!("{memory_context}{}", msg.content)
|
||||
};
|
||||
|
||||
let target_channel = ctx.channels_by_name.get(&msg.channel).cloned();
|
||||
|
||||
println!(" ⏳ Processing message...");
|
||||
let started_at = Instant::now();
|
||||
|
||||
// Build history from per-sender conversation cache
|
||||
let history_key = format!("{}_{}", msg.channel, msg.sender);
|
||||
let mut prior_turns = ctx
|
||||
.conversation_histories
|
||||
.lock()
|
||||
|
|
@ -333,12 +698,12 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
|||
let llm_result = tokio::time::timeout(
|
||||
Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS),
|
||||
run_tool_call_loop(
|
||||
ctx.provider.as_ref(),
|
||||
active_provider.as_ref(),
|
||||
&mut history,
|
||||
ctx.tools_registry.as_ref(),
|
||||
ctx.observer.as_ref(),
|
||||
"channel-runtime",
|
||||
ctx.model.as_str(),
|
||||
route.provider.as_str(),
|
||||
route.model.as_str(),
|
||||
ctx.temperature,
|
||||
true,
|
||||
None,
|
||||
|
|
@ -1117,16 +1482,17 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
.default_provider
|
||||
.clone()
|
||||
.unwrap_or_else(|| "openrouter".into());
|
||||
let provider_runtime_options = providers::ProviderRuntimeOptions {
|
||||
auth_profile_override: None,
|
||||
zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from),
|
||||
secrets_encrypt: config.secrets.encrypt,
|
||||
};
|
||||
let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider_with_options(
|
||||
&provider_name,
|
||||
config.api_key.as_deref(),
|
||||
config.api_url.as_deref(),
|
||||
&config.reliability,
|
||||
&providers::ProviderRuntimeOptions {
|
||||
auth_profile_override: None,
|
||||
zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from),
|
||||
secrets_encrypt: config.secrets.encrypt,
|
||||
},
|
||||
&provider_runtime_options,
|
||||
)?);
|
||||
|
||||
// Warm up the provider connection pool (TLS handshake, DNS, HTTP/2 setup)
|
||||
|
|
@ -1441,9 +1807,13 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
|
||||
println!(" 🚦 In-flight message limit: {max_in_flight_messages}");
|
||||
|
||||
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
|
||||
provider_cache_seed.insert(provider_name.clone(), Arc::clone(&provider));
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name,
|
||||
provider: Arc::clone(&provider),
|
||||
default_provider: Arc::new(provider_name),
|
||||
memory: Arc::clone(&mem),
|
||||
tools_registry: Arc::clone(&tools_registry),
|
||||
observer,
|
||||
|
|
@ -1454,6 +1824,13 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
max_tool_iterations: config.agent.max_tool_iterations,
|
||||
min_relevance_score: config.memory.min_relevance_score,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: config.api_key.clone(),
|
||||
api_url: config.api_url.clone(),
|
||||
reliability: Arc::new(config.reliability.clone()),
|
||||
provider_runtime_options,
|
||||
workspace_dir: Arc::new(config.workspace_dir.clone()),
|
||||
});
|
||||
|
||||
run_message_dispatch_loop(rx, runtime_ctx, max_in_flight_messages).await;
|
||||
|
|
@ -1506,6 +1883,41 @@ mod tests {
|
|||
stop_typing_calls: AtomicUsize,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct TelegramRecordingChannel {
|
||||
sent_messages: tokio::sync::Mutex<Vec<String>>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Channel for TelegramRecordingChannel {
|
||||
fn name(&self) -> &str {
|
||||
"telegram"
|
||||
}
|
||||
|
||||
async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
|
||||
self.sent_messages
|
||||
.lock()
|
||||
.await
|
||||
.push(format!("{}:{}", message.recipient, message.content));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn listen(
|
||||
&self,
|
||||
_tx: tokio::sync::mpsc::Sender<traits::ChannelMessage>,
|
||||
) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_typing(&self, _recipient: &str) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop_typing(&self, _recipient: &str) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Channel for RecordingChannel {
|
||||
fn name(&self) -> &str {
|
||||
|
|
@ -1667,6 +2079,39 @@ mod tests {
|
|||
|
||||
struct MockPriceTool;
|
||||
|
||||
#[derive(Default)]
|
||||
struct ModelCaptureProvider {
|
||||
call_count: AtomicUsize,
|
||||
models: std::sync::Mutex<Vec<String>>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Provider for ModelCaptureProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
_system_prompt: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Ok("fallback".to_string())
|
||||
}
|
||||
|
||||
async fn chat_with_history(
|
||||
&self,
|
||||
_messages: &[ChatMessage],
|
||||
model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
self.call_count.fetch_add(1, Ordering::SeqCst);
|
||||
self.models
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner())
|
||||
.push(model.to_string());
|
||||
Ok("ok".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Tool for MockPriceTool {
|
||||
fn name(&self) -> &str {
|
||||
|
|
@ -1716,6 +2161,7 @@ mod tests {
|
|||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: Arc::new(ToolCallingProvider),
|
||||
default_provider: Arc::new("test-provider".to_string()),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![Box::new(MockPriceTool)]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
|
|
@ -1726,6 +2172,13 @@ mod tests {
|
|||
max_tool_iterations: 10,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
|
|
@ -1760,6 +2213,7 @@ mod tests {
|
|||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: Arc::new(ToolCallingAliasProvider),
|
||||
default_provider: Arc::new("test-provider".to_string()),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![Box::new(MockPriceTool)]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
|
|
@ -1770,6 +2224,13 @@ mod tests {
|
|||
max_tool_iterations: 10,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
|
|
@ -1793,6 +2254,153 @@ mod tests {
|
|||
assert!(!sent_messages[0].contains("mock_price"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_handles_models_command_without_llm_call() {
|
||||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let default_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let default_provider: Arc<dyn Provider> = default_provider_impl.clone();
|
||||
let fallback_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let fallback_provider: Arc<dyn Provider> = fallback_provider_impl.clone();
|
||||
|
||||
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
|
||||
provider_cache_seed.insert("test-provider".to_string(), Arc::clone(&default_provider));
|
||||
provider_cache_seed.insert("openrouter".to_string(), fallback_provider);
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: Arc::clone(&default_provider),
|
||||
default_provider: Arc::new("test-provider".to_string()),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||
model: Arc::new("default-model".to_string()),
|
||||
temperature: 0.0,
|
||||
auto_save_memory: false,
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
runtime_ctx.clone(),
|
||||
traits::ChannelMessage {
|
||||
id: "msg-cmd-1".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "chat-1".to_string(),
|
||||
content: "/models openrouter".to_string(),
|
||||
channel: "telegram".to_string(),
|
||||
timestamp: 1,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
let sent = channel_impl.sent_messages.lock().await;
|
||||
assert_eq!(sent.len(), 1);
|
||||
assert!(sent[0].contains("Provider switched to `openrouter`"));
|
||||
|
||||
let route_key = "telegram_alice";
|
||||
let route = runtime_ctx
|
||||
.route_overrides
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner())
|
||||
.get(route_key)
|
||||
.cloned()
|
||||
.expect("route should be stored for sender");
|
||||
assert_eq!(route.provider, "openrouter");
|
||||
assert_eq!(route.model, "default-model");
|
||||
|
||||
assert_eq!(default_provider_impl.call_count.load(Ordering::SeqCst), 0);
|
||||
assert_eq!(fallback_provider_impl.call_count.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_uses_route_override_provider_and_model() {
|
||||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let default_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let default_provider: Arc<dyn Provider> = default_provider_impl.clone();
|
||||
let routed_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let routed_provider: Arc<dyn Provider> = routed_provider_impl.clone();
|
||||
|
||||
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
|
||||
provider_cache_seed.insert("test-provider".to_string(), Arc::clone(&default_provider));
|
||||
provider_cache_seed.insert("openrouter".to_string(), routed_provider);
|
||||
|
||||
let route_key = "telegram_alice".to_string();
|
||||
let mut route_overrides = HashMap::new();
|
||||
route_overrides.insert(
|
||||
route_key,
|
||||
ChannelRouteSelection {
|
||||
provider: "openrouter".to_string(),
|
||||
model: "route-model".to_string(),
|
||||
},
|
||||
);
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: Arc::clone(&default_provider),
|
||||
default_provider: Arc::new("test-provider".to_string()),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||
model: Arc::new("default-model".to_string()),
|
||||
temperature: 0.0,
|
||||
auto_save_memory: false,
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(route_overrides)),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
runtime_ctx,
|
||||
traits::ChannelMessage {
|
||||
id: "msg-routed-1".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "chat-1".to_string(),
|
||||
content: "hello routed provider".to_string(),
|
||||
channel: "telegram".to_string(),
|
||||
timestamp: 2,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(default_provider_impl.call_count.load(Ordering::SeqCst), 0);
|
||||
assert_eq!(routed_provider_impl.call_count.load(Ordering::SeqCst), 1);
|
||||
assert_eq!(
|
||||
routed_provider_impl
|
||||
.models
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner())
|
||||
.as_slice(),
|
||||
&["route-model".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
struct NoopMemory;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
|
|
@ -1858,6 +2466,7 @@ mod tests {
|
|||
provider: Arc::new(SlowProvider {
|
||||
delay: Duration::from_millis(250),
|
||||
}),
|
||||
default_provider: Arc::new("test-provider".to_string()),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
|
|
@ -1868,6 +2477,13 @@ mod tests {
|
|||
max_tool_iterations: 10,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(4);
|
||||
|
|
@ -1920,6 +2536,7 @@ mod tests {
|
|||
provider: Arc::new(SlowProvider {
|
||||
delay: Duration::from_millis(20),
|
||||
}),
|
||||
default_provider: Arc::new("test-provider".to_string()),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
|
|
@ -1930,6 +2547,13 @@ mod tests {
|
|||
max_tool_iterations: 10,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
|
|
@ -2302,6 +2926,7 @@ mod tests {
|
|||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: provider_impl.clone(),
|
||||
default_provider: Arc::new("test-provider".to_string()),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
|
|
@ -2312,6 +2937,13 @@ mod tests {
|
|||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue