fix(channel): hot-apply runtime config updates for running channel service

This commit is contained in:
Chummy 2026-02-20 01:51:32 +08:00
parent 95ec5922d1
commit 740eb17d76
7 changed files with 410 additions and 22 deletions

View file

@ -73,8 +73,8 @@ use std::fmt::Write;
use std::path::{Path, PathBuf};
use std::process::Command;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use std::sync::{Arc, Mutex, OnceLock};
use std::time::{Duration, Instant, SystemTime};
use tokio_util::sync::CancellationToken;
/// Per-sender conversation history for channel messages.
@ -139,6 +139,33 @@ struct ModelCacheEntry {
models: Vec<String>,
}
#[derive(Debug, Clone)]
struct ChannelRuntimeDefaults {
default_provider: String,
model: String,
temperature: f64,
api_key: Option<String>,
api_url: Option<String>,
reliability: crate::config::ReliabilityConfig,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct ConfigFileStamp {
modified: SystemTime,
len: u64,
}
#[derive(Debug, Clone)]
struct RuntimeConfigState {
defaults: ChannelRuntimeDefaults,
last_applied_stamp: Option<ConfigFileStamp>,
}
fn runtime_config_store() -> &'static Mutex<HashMap<PathBuf, RuntimeConfigState>> {
static STORE: OnceLock<Mutex<HashMap<PathBuf, RuntimeConfigState>>> = OnceLock::new();
STORE.get_or_init(|| Mutex::new(HashMap::new()))
}
#[derive(Clone)]
struct ChannelRuntimeContext {
channels_by_name: Arc<HashMap<String, Arc<dyn Channel>>>,
@ -318,10 +345,176 @@ fn resolve_provider_alias(name: &str) -> Option<String> {
None
}
fn default_route_selection(ctx: &ChannelRuntimeContext) -> ChannelRouteSelection {
ChannelRouteSelection {
provider: ctx.default_provider.as_str().to_string(),
fn resolved_default_provider(config: &Config) -> String {
config
.default_provider
.clone()
.unwrap_or_else(|| "openrouter".to_string())
}
fn resolved_default_model(config: &Config) -> String {
config
.default_model
.clone()
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".to_string())
}
fn runtime_defaults_from_config(config: &Config) -> ChannelRuntimeDefaults {
ChannelRuntimeDefaults {
default_provider: resolved_default_provider(config),
model: resolved_default_model(config),
temperature: config.default_temperature,
api_key: config.api_key.clone(),
api_url: config.api_url.clone(),
reliability: config.reliability.clone(),
}
}
fn runtime_config_path(ctx: &ChannelRuntimeContext) -> Option<PathBuf> {
ctx.provider_runtime_options
.zeroclaw_dir
.as_ref()
.map(|dir| dir.join("config.toml"))
}
fn runtime_defaults_snapshot(ctx: &ChannelRuntimeContext) -> ChannelRuntimeDefaults {
if let Some(config_path) = runtime_config_path(ctx) {
let store = runtime_config_store()
.lock()
.unwrap_or_else(|e| e.into_inner());
if let Some(state) = store.get(&config_path) {
return state.defaults.clone();
}
}
ChannelRuntimeDefaults {
default_provider: ctx.default_provider.as_str().to_string(),
model: ctx.model.as_str().to_string(),
temperature: ctx.temperature,
api_key: ctx.api_key.clone(),
api_url: ctx.api_url.clone(),
reliability: (*ctx.reliability).clone(),
}
}
async fn config_file_stamp(path: &Path) -> Option<ConfigFileStamp> {
let metadata = tokio::fs::metadata(path).await.ok()?;
let modified = metadata.modified().ok()?;
Some(ConfigFileStamp {
modified,
len: metadata.len(),
})
}
fn decrypt_optional_secret_for_runtime_reload(
store: &crate::security::SecretStore,
value: &mut Option<String>,
field_name: &str,
) -> Result<()> {
if let Some(raw) = value.clone() {
if crate::security::SecretStore::is_encrypted(&raw) {
*value = Some(
store
.decrypt(&raw)
.with_context(|| format!("Failed to decrypt {field_name}"))?,
);
}
}
Ok(())
}
async fn load_runtime_defaults_from_config_file(path: &Path) -> Result<ChannelRuntimeDefaults> {
let contents = tokio::fs::read_to_string(path)
.await
.with_context(|| format!("Failed to read {}", path.display()))?;
let mut parsed: Config =
toml::from_str(&contents).with_context(|| format!("Failed to parse {}", path.display()))?;
parsed.config_path = path.to_path_buf();
if let Some(zeroclaw_dir) = path.parent() {
let store = crate::security::SecretStore::new(zeroclaw_dir, parsed.secrets.encrypt);
decrypt_optional_secret_for_runtime_reload(&store, &mut parsed.api_key, "config.api_key")?;
}
parsed.apply_env_overrides();
Ok(runtime_defaults_from_config(&parsed))
}
async fn maybe_apply_runtime_config_update(ctx: &ChannelRuntimeContext) -> Result<()> {
let Some(config_path) = runtime_config_path(ctx) else {
return Ok(());
};
let Some(stamp) = config_file_stamp(&config_path).await else {
return Ok(());
};
{
let store = runtime_config_store()
.lock()
.unwrap_or_else(|e| e.into_inner());
if let Some(state) = store.get(&config_path) {
if state.last_applied_stamp == Some(stamp) {
return Ok(());
}
}
}
let next_defaults = load_runtime_defaults_from_config_file(&config_path).await?;
let next_default_provider = providers::create_resilient_provider_with_options(
&next_defaults.default_provider,
next_defaults.api_key.as_deref(),
next_defaults.api_url.as_deref(),
&next_defaults.reliability,
&ctx.provider_runtime_options,
)?;
let next_default_provider: Arc<dyn Provider> = Arc::from(next_default_provider);
if let Err(err) = next_default_provider.warmup().await {
tracing::warn!(
provider = %next_defaults.default_provider,
"Provider warmup failed after config reload: {err}"
);
}
{
let mut cache = ctx.provider_cache.lock().unwrap_or_else(|e| e.into_inner());
cache.clear();
cache.insert(
next_defaults.default_provider.clone(),
Arc::clone(&next_default_provider),
);
}
{
let mut store = runtime_config_store()
.lock()
.unwrap_or_else(|e| e.into_inner());
store.insert(
config_path.clone(),
RuntimeConfigState {
defaults: next_defaults.clone(),
last_applied_stamp: Some(stamp),
},
);
}
tracing::info!(
path = %config_path.display(),
provider = %next_defaults.default_provider,
model = %next_defaults.model,
temperature = next_defaults.temperature,
"Applied updated channel runtime config from disk"
);
Ok(())
}
fn default_route_selection(ctx: &ChannelRuntimeContext) -> ChannelRouteSelection {
let defaults = runtime_defaults_snapshot(ctx);
ChannelRouteSelection {
provider: defaults.default_provider,
model: defaults.model,
}
}
@ -456,10 +649,6 @@ async fn get_or_create_provider(
ctx: &ChannelRuntimeContext,
provider_name: &str,
) -> anyhow::Result<Arc<dyn Provider>> {
if provider_name == ctx.default_provider.as_str() {
return Ok(Arc::clone(&ctx.provider));
}
if let Some(existing) = ctx
.provider_cache
.lock()
@ -470,17 +659,22 @@ async fn get_or_create_provider(
return Ok(existing);
}
let api_url = if provider_name == ctx.default_provider.as_str() {
ctx.api_url.as_deref()
if provider_name == ctx.default_provider.as_str() {
return Ok(Arc::clone(&ctx.provider));
}
let defaults = runtime_defaults_snapshot(ctx);
let api_url = if provider_name == defaults.default_provider.as_str() {
defaults.api_url.as_deref()
} else {
None
};
let provider = providers::create_resilient_provider_with_options(
provider_name,
ctx.api_key.as_deref(),
defaults.api_key.as_deref(),
api_url,
&ctx.reliability,
&defaults.reliability,
&ctx.provider_runtime_options,
)?;
let provider: Arc<dyn Provider> = Arc::from(provider);
@ -877,12 +1071,16 @@ async fn process_channel_message(
);
let target_channel = ctx.channels_by_name.get(&msg.channel).cloned();
if let Err(err) = maybe_apply_runtime_config_update(ctx.as_ref()).await {
tracing::warn!("Failed to apply runtime config update: {err}");
}
if handle_runtime_command_if_needed(ctx.as_ref(), &msg, target_channel.as_ref()).await {
return;
}
let history_key = conversation_history_key(&msg);
let route = get_route_selection(ctx.as_ref(), &history_key);
let runtime_defaults = runtime_defaults_snapshot(ctx.as_ref());
let active_provider = match get_or_create_provider(ctx.as_ref(), &route.provider).await {
Ok(provider) => provider,
Err(err) => {
@ -1036,7 +1234,7 @@ async fn process_channel_message(
ctx.observer.as_ref(),
route.provider.as_str(),
route.model.as_str(),
ctx.temperature,
runtime_defaults.temperature,
true,
None,
msg.channel.as_str(),
@ -1964,10 +2162,7 @@ pub async fn doctor_channels(config: Config) -> Result<()> {
/// Start all configured channels and route messages to the agent
#[allow(clippy::too_many_lines)]
pub async fn start_channels(config: Config) -> Result<()> {
let provider_name = config
.default_provider
.clone()
.unwrap_or_else(|| "openrouter".into());
let provider_name = resolved_default_provider(&config);
let provider_runtime_options = providers::ProviderRuntimeOptions {
auth_profile_override: None,
zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from),
@ -1988,6 +2183,20 @@ pub async fn start_channels(config: Config) -> Result<()> {
tracing::warn!("Provider warmup failed (non-fatal): {e}");
}
let initial_stamp = config_file_stamp(&config.config_path).await;
{
let mut store = runtime_config_store()
.lock()
.unwrap_or_else(|e| e.into_inner());
store.insert(
config.config_path.clone(),
RuntimeConfigState {
defaults: runtime_defaults_from_config(&config),
last_applied_stamp: initial_stamp,
},
);
}
let observer: Arc<dyn Observer> =
Arc::from(observability::create_observer(&config.observability));
let runtime: Arc<dyn runtime::RuntimeAdapter> =
@ -1996,10 +2205,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
&config.autonomy,
&config.workspace_dir,
));
let model = config
.default_model
.clone()
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into());
let model = resolved_default_model(&config);
let temperature = config.default_temperature;
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory_with_storage(
&config.memory,
@ -3161,6 +3367,161 @@ mod tests {
);
}
#[tokio::test]
async fn process_channel_message_prefers_cached_default_provider_instance() {
let channel_impl = Arc::new(TelegramRecordingChannel::default());
let channel: Arc<dyn Channel> = channel_impl.clone();
let mut channels_by_name = HashMap::new();
channels_by_name.insert(channel.name().to_string(), channel);
let startup_provider_impl = Arc::new(ModelCaptureProvider::default());
let startup_provider: Arc<dyn Provider> = startup_provider_impl.clone();
let reloaded_provider_impl = Arc::new(ModelCaptureProvider::default());
let reloaded_provider: Arc<dyn Provider> = reloaded_provider_impl.clone();
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
provider_cache_seed.insert("test-provider".to_string(), reloaded_provider);
let runtime_ctx = Arc::new(ChannelRuntimeContext {
channels_by_name: Arc::new(channels_by_name),
provider: Arc::clone(&startup_provider),
default_provider: Arc::new("test-provider".to_string()),
memory: Arc::new(NoopMemory),
tools_registry: Arc::new(vec![]),
observer: Arc::new(NoopObserver),
system_prompt: Arc::new("test-system-prompt".to_string()),
model: Arc::new("default-model".to_string()),
temperature: 0.0,
auto_save_memory: false,
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
api_url: None,
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
workspace_dir: Arc::new(std::env::temp_dir()),
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
multimodal: crate::config::MultimodalConfig::default(),
});
process_channel_message(
runtime_ctx,
traits::ChannelMessage {
id: "msg-default-provider-cache".to_string(),
sender: "alice".to_string(),
reply_target: "chat-1".to_string(),
content: "hello cached default provider".to_string(),
channel: "telegram".to_string(),
timestamp: 3,
thread_ts: None,
},
)
.await;
assert_eq!(startup_provider_impl.call_count.load(Ordering::SeqCst), 0);
assert_eq!(reloaded_provider_impl.call_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn process_channel_message_uses_runtime_default_model_from_store() {
let channel_impl = Arc::new(TelegramRecordingChannel::default());
let channel: Arc<dyn Channel> = channel_impl.clone();
let mut channels_by_name = HashMap::new();
channels_by_name.insert(channel.name().to_string(), channel);
let provider_impl = Arc::new(ModelCaptureProvider::default());
let provider: Arc<dyn Provider> = provider_impl.clone();
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
provider_cache_seed.insert("test-provider".to_string(), Arc::clone(&provider));
let temp = tempfile::TempDir::new().expect("temp dir");
let config_path = temp.path().join("config.toml");
{
let mut store = runtime_config_store()
.lock()
.unwrap_or_else(|e| e.into_inner());
store.insert(
config_path.clone(),
RuntimeConfigState {
defaults: ChannelRuntimeDefaults {
default_provider: "test-provider".to_string(),
model: "hot-reloaded-model".to_string(),
temperature: 0.5,
api_key: None,
api_url: None,
reliability: crate::config::ReliabilityConfig::default(),
},
last_applied_stamp: None,
},
);
}
let runtime_ctx = Arc::new(ChannelRuntimeContext {
channels_by_name: Arc::new(channels_by_name),
provider: Arc::clone(&provider),
default_provider: Arc::new("test-provider".to_string()),
memory: Arc::new(NoopMemory),
tools_registry: Arc::new(vec![]),
observer: Arc::new(NoopObserver),
system_prompt: Arc::new("test-system-prompt".to_string()),
model: Arc::new("startup-model".to_string()),
temperature: 0.0,
auto_save_memory: false,
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
api_url: None,
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
provider_runtime_options: providers::ProviderRuntimeOptions {
zeroclaw_dir: Some(temp.path().to_path_buf()),
..providers::ProviderRuntimeOptions::default()
},
workspace_dir: Arc::new(std::env::temp_dir()),
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
multimodal: crate::config::MultimodalConfig::default(),
});
process_channel_message(
runtime_ctx,
traits::ChannelMessage {
id: "msg-runtime-store-model".to_string(),
sender: "alice".to_string(),
reply_target: "chat-1".to_string(),
content: "hello runtime defaults".to_string(),
channel: "telegram".to_string(),
timestamp: 4,
thread_ts: None,
},
)
.await;
{
let mut store = runtime_config_store()
.lock()
.unwrap_or_else(|e| e.into_inner());
store.remove(&config_path);
}
assert_eq!(provider_impl.call_count.load(Ordering::SeqCst), 1);
assert_eq!(
provider_impl
.models
.lock()
.unwrap_or_else(|e| e.into_inner())
.as_slice(),
&["hot-reloaded-model".to_string()]
);
}
#[tokio::test]
async fn process_channel_message_respects_configured_max_tool_iterations_above_default() {
let channel_impl = Arc::new(RecordingChannel::default());