fix(channel): hot-apply runtime config updates for running channel service
This commit is contained in:
parent
95ec5922d1
commit
740eb17d76
7 changed files with 410 additions and 22 deletions
|
|
@ -73,8 +73,8 @@ use std::fmt::Write;
|
|||
use std::path::{Path, PathBuf};
|
||||
use std::process::Command;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, Instant};
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
use std::time::{Duration, Instant, SystemTime};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
/// Per-sender conversation history for channel messages.
|
||||
|
|
@ -139,6 +139,33 @@ struct ModelCacheEntry {
|
|||
models: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct ChannelRuntimeDefaults {
|
||||
default_provider: String,
|
||||
model: String,
|
||||
temperature: f64,
|
||||
api_key: Option<String>,
|
||||
api_url: Option<String>,
|
||||
reliability: crate::config::ReliabilityConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
struct ConfigFileStamp {
|
||||
modified: SystemTime,
|
||||
len: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RuntimeConfigState {
|
||||
defaults: ChannelRuntimeDefaults,
|
||||
last_applied_stamp: Option<ConfigFileStamp>,
|
||||
}
|
||||
|
||||
fn runtime_config_store() -> &'static Mutex<HashMap<PathBuf, RuntimeConfigState>> {
|
||||
static STORE: OnceLock<Mutex<HashMap<PathBuf, RuntimeConfigState>>> = OnceLock::new();
|
||||
STORE.get_or_init(|| Mutex::new(HashMap::new()))
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ChannelRuntimeContext {
|
||||
channels_by_name: Arc<HashMap<String, Arc<dyn Channel>>>,
|
||||
|
|
@ -318,10 +345,176 @@ fn resolve_provider_alias(name: &str) -> Option<String> {
|
|||
None
|
||||
}
|
||||
|
||||
fn default_route_selection(ctx: &ChannelRuntimeContext) -> ChannelRouteSelection {
|
||||
ChannelRouteSelection {
|
||||
provider: ctx.default_provider.as_str().to_string(),
|
||||
fn resolved_default_provider(config: &Config) -> String {
|
||||
config
|
||||
.default_provider
|
||||
.clone()
|
||||
.unwrap_or_else(|| "openrouter".to_string())
|
||||
}
|
||||
|
||||
fn resolved_default_model(config: &Config) -> String {
|
||||
config
|
||||
.default_model
|
||||
.clone()
|
||||
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".to_string())
|
||||
}
|
||||
|
||||
fn runtime_defaults_from_config(config: &Config) -> ChannelRuntimeDefaults {
|
||||
ChannelRuntimeDefaults {
|
||||
default_provider: resolved_default_provider(config),
|
||||
model: resolved_default_model(config),
|
||||
temperature: config.default_temperature,
|
||||
api_key: config.api_key.clone(),
|
||||
api_url: config.api_url.clone(),
|
||||
reliability: config.reliability.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn runtime_config_path(ctx: &ChannelRuntimeContext) -> Option<PathBuf> {
|
||||
ctx.provider_runtime_options
|
||||
.zeroclaw_dir
|
||||
.as_ref()
|
||||
.map(|dir| dir.join("config.toml"))
|
||||
}
|
||||
|
||||
fn runtime_defaults_snapshot(ctx: &ChannelRuntimeContext) -> ChannelRuntimeDefaults {
|
||||
if let Some(config_path) = runtime_config_path(ctx) {
|
||||
let store = runtime_config_store()
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
if let Some(state) = store.get(&config_path) {
|
||||
return state.defaults.clone();
|
||||
}
|
||||
}
|
||||
|
||||
ChannelRuntimeDefaults {
|
||||
default_provider: ctx.default_provider.as_str().to_string(),
|
||||
model: ctx.model.as_str().to_string(),
|
||||
temperature: ctx.temperature,
|
||||
api_key: ctx.api_key.clone(),
|
||||
api_url: ctx.api_url.clone(),
|
||||
reliability: (*ctx.reliability).clone(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn config_file_stamp(path: &Path) -> Option<ConfigFileStamp> {
|
||||
let metadata = tokio::fs::metadata(path).await.ok()?;
|
||||
let modified = metadata.modified().ok()?;
|
||||
Some(ConfigFileStamp {
|
||||
modified,
|
||||
len: metadata.len(),
|
||||
})
|
||||
}
|
||||
|
||||
fn decrypt_optional_secret_for_runtime_reload(
|
||||
store: &crate::security::SecretStore,
|
||||
value: &mut Option<String>,
|
||||
field_name: &str,
|
||||
) -> Result<()> {
|
||||
if let Some(raw) = value.clone() {
|
||||
if crate::security::SecretStore::is_encrypted(&raw) {
|
||||
*value = Some(
|
||||
store
|
||||
.decrypt(&raw)
|
||||
.with_context(|| format!("Failed to decrypt {field_name}"))?,
|
||||
);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn load_runtime_defaults_from_config_file(path: &Path) -> Result<ChannelRuntimeDefaults> {
|
||||
let contents = tokio::fs::read_to_string(path)
|
||||
.await
|
||||
.with_context(|| format!("Failed to read {}", path.display()))?;
|
||||
let mut parsed: Config =
|
||||
toml::from_str(&contents).with_context(|| format!("Failed to parse {}", path.display()))?;
|
||||
parsed.config_path = path.to_path_buf();
|
||||
|
||||
if let Some(zeroclaw_dir) = path.parent() {
|
||||
let store = crate::security::SecretStore::new(zeroclaw_dir, parsed.secrets.encrypt);
|
||||
decrypt_optional_secret_for_runtime_reload(&store, &mut parsed.api_key, "config.api_key")?;
|
||||
}
|
||||
|
||||
parsed.apply_env_overrides();
|
||||
Ok(runtime_defaults_from_config(&parsed))
|
||||
}
|
||||
|
||||
async fn maybe_apply_runtime_config_update(ctx: &ChannelRuntimeContext) -> Result<()> {
|
||||
let Some(config_path) = runtime_config_path(ctx) else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let Some(stamp) = config_file_stamp(&config_path).await else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
{
|
||||
let store = runtime_config_store()
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
if let Some(state) = store.get(&config_path) {
|
||||
if state.last_applied_stamp == Some(stamp) {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let next_defaults = load_runtime_defaults_from_config_file(&config_path).await?;
|
||||
let next_default_provider = providers::create_resilient_provider_with_options(
|
||||
&next_defaults.default_provider,
|
||||
next_defaults.api_key.as_deref(),
|
||||
next_defaults.api_url.as_deref(),
|
||||
&next_defaults.reliability,
|
||||
&ctx.provider_runtime_options,
|
||||
)?;
|
||||
let next_default_provider: Arc<dyn Provider> = Arc::from(next_default_provider);
|
||||
|
||||
if let Err(err) = next_default_provider.warmup().await {
|
||||
tracing::warn!(
|
||||
provider = %next_defaults.default_provider,
|
||||
"Provider warmup failed after config reload: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
let mut cache = ctx.provider_cache.lock().unwrap_or_else(|e| e.into_inner());
|
||||
cache.clear();
|
||||
cache.insert(
|
||||
next_defaults.default_provider.clone(),
|
||||
Arc::clone(&next_default_provider),
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
let mut store = runtime_config_store()
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
store.insert(
|
||||
config_path.clone(),
|
||||
RuntimeConfigState {
|
||||
defaults: next_defaults.clone(),
|
||||
last_applied_stamp: Some(stamp),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
path = %config_path.display(),
|
||||
provider = %next_defaults.default_provider,
|
||||
model = %next_defaults.model,
|
||||
temperature = next_defaults.temperature,
|
||||
"Applied updated channel runtime config from disk"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn default_route_selection(ctx: &ChannelRuntimeContext) -> ChannelRouteSelection {
|
||||
let defaults = runtime_defaults_snapshot(ctx);
|
||||
ChannelRouteSelection {
|
||||
provider: defaults.default_provider,
|
||||
model: defaults.model,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -456,10 +649,6 @@ async fn get_or_create_provider(
|
|||
ctx: &ChannelRuntimeContext,
|
||||
provider_name: &str,
|
||||
) -> anyhow::Result<Arc<dyn Provider>> {
|
||||
if provider_name == ctx.default_provider.as_str() {
|
||||
return Ok(Arc::clone(&ctx.provider));
|
||||
}
|
||||
|
||||
if let Some(existing) = ctx
|
||||
.provider_cache
|
||||
.lock()
|
||||
|
|
@ -470,17 +659,22 @@ async fn get_or_create_provider(
|
|||
return Ok(existing);
|
||||
}
|
||||
|
||||
let api_url = if provider_name == ctx.default_provider.as_str() {
|
||||
ctx.api_url.as_deref()
|
||||
if provider_name == ctx.default_provider.as_str() {
|
||||
return Ok(Arc::clone(&ctx.provider));
|
||||
}
|
||||
|
||||
let defaults = runtime_defaults_snapshot(ctx);
|
||||
let api_url = if provider_name == defaults.default_provider.as_str() {
|
||||
defaults.api_url.as_deref()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let provider = providers::create_resilient_provider_with_options(
|
||||
provider_name,
|
||||
ctx.api_key.as_deref(),
|
||||
defaults.api_key.as_deref(),
|
||||
api_url,
|
||||
&ctx.reliability,
|
||||
&defaults.reliability,
|
||||
&ctx.provider_runtime_options,
|
||||
)?;
|
||||
let provider: Arc<dyn Provider> = Arc::from(provider);
|
||||
|
|
@ -877,12 +1071,16 @@ async fn process_channel_message(
|
|||
);
|
||||
|
||||
let target_channel = ctx.channels_by_name.get(&msg.channel).cloned();
|
||||
if let Err(err) = maybe_apply_runtime_config_update(ctx.as_ref()).await {
|
||||
tracing::warn!("Failed to apply runtime config update: {err}");
|
||||
}
|
||||
if handle_runtime_command_if_needed(ctx.as_ref(), &msg, target_channel.as_ref()).await {
|
||||
return;
|
||||
}
|
||||
|
||||
let history_key = conversation_history_key(&msg);
|
||||
let route = get_route_selection(ctx.as_ref(), &history_key);
|
||||
let runtime_defaults = runtime_defaults_snapshot(ctx.as_ref());
|
||||
let active_provider = match get_or_create_provider(ctx.as_ref(), &route.provider).await {
|
||||
Ok(provider) => provider,
|
||||
Err(err) => {
|
||||
|
|
@ -1036,7 +1234,7 @@ async fn process_channel_message(
|
|||
ctx.observer.as_ref(),
|
||||
route.provider.as_str(),
|
||||
route.model.as_str(),
|
||||
ctx.temperature,
|
||||
runtime_defaults.temperature,
|
||||
true,
|
||||
None,
|
||||
msg.channel.as_str(),
|
||||
|
|
@ -1964,10 +2162,7 @@ pub async fn doctor_channels(config: Config) -> Result<()> {
|
|||
/// Start all configured channels and route messages to the agent
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn start_channels(config: Config) -> Result<()> {
|
||||
let provider_name = config
|
||||
.default_provider
|
||||
.clone()
|
||||
.unwrap_or_else(|| "openrouter".into());
|
||||
let provider_name = resolved_default_provider(&config);
|
||||
let provider_runtime_options = providers::ProviderRuntimeOptions {
|
||||
auth_profile_override: None,
|
||||
zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from),
|
||||
|
|
@ -1988,6 +2183,20 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
tracing::warn!("Provider warmup failed (non-fatal): {e}");
|
||||
}
|
||||
|
||||
let initial_stamp = config_file_stamp(&config.config_path).await;
|
||||
{
|
||||
let mut store = runtime_config_store()
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
store.insert(
|
||||
config.config_path.clone(),
|
||||
RuntimeConfigState {
|
||||
defaults: runtime_defaults_from_config(&config),
|
||||
last_applied_stamp: initial_stamp,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
let observer: Arc<dyn Observer> =
|
||||
Arc::from(observability::create_observer(&config.observability));
|
||||
let runtime: Arc<dyn runtime::RuntimeAdapter> =
|
||||
|
|
@ -1996,10 +2205,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
&config.autonomy,
|
||||
&config.workspace_dir,
|
||||
));
|
||||
let model = config
|
||||
.default_model
|
||||
.clone()
|
||||
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into());
|
||||
let model = resolved_default_model(&config);
|
||||
let temperature = config.default_temperature;
|
||||
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory_with_storage(
|
||||
&config.memory,
|
||||
|
|
@ -3161,6 +3367,161 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_prefers_cached_default_provider_instance() {
|
||||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let startup_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let startup_provider: Arc<dyn Provider> = startup_provider_impl.clone();
|
||||
let reloaded_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let reloaded_provider: Arc<dyn Provider> = reloaded_provider_impl.clone();
|
||||
|
||||
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
|
||||
provider_cache_seed.insert("test-provider".to_string(), reloaded_provider);
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: Arc::clone(&startup_provider),
|
||||
default_provider: Arc::new("test-provider".to_string()),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||
model: Arc::new("default-model".to_string()),
|
||||
temperature: 0.0,
|
||||
auto_save_memory: false,
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
runtime_ctx,
|
||||
traits::ChannelMessage {
|
||||
id: "msg-default-provider-cache".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "chat-1".to_string(),
|
||||
content: "hello cached default provider".to_string(),
|
||||
channel: "telegram".to_string(),
|
||||
timestamp: 3,
|
||||
thread_ts: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(startup_provider_impl.call_count.load(Ordering::SeqCst), 0);
|
||||
assert_eq!(reloaded_provider_impl.call_count.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_uses_runtime_default_model_from_store() {
|
||||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let provider_impl = Arc::new(ModelCaptureProvider::default());
|
||||
let provider: Arc<dyn Provider> = provider_impl.clone();
|
||||
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
|
||||
provider_cache_seed.insert("test-provider".to_string(), Arc::clone(&provider));
|
||||
|
||||
let temp = tempfile::TempDir::new().expect("temp dir");
|
||||
let config_path = temp.path().join("config.toml");
|
||||
|
||||
{
|
||||
let mut store = runtime_config_store()
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
store.insert(
|
||||
config_path.clone(),
|
||||
RuntimeConfigState {
|
||||
defaults: ChannelRuntimeDefaults {
|
||||
default_provider: "test-provider".to_string(),
|
||||
model: "hot-reloaded-model".to_string(),
|
||||
temperature: 0.5,
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: crate::config::ReliabilityConfig::default(),
|
||||
},
|
||||
last_applied_stamp: None,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: Arc::clone(&provider),
|
||||
default_provider: Arc::new("test-provider".to_string()),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||
model: Arc::new("startup-model".to_string()),
|
||||
temperature: 0.0,
|
||||
auto_save_memory: false,
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions {
|
||||
zeroclaw_dir: Some(temp.path().to_path_buf()),
|
||||
..providers::ProviderRuntimeOptions::default()
|
||||
},
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
runtime_ctx,
|
||||
traits::ChannelMessage {
|
||||
id: "msg-runtime-store-model".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "chat-1".to_string(),
|
||||
content: "hello runtime defaults".to_string(),
|
||||
channel: "telegram".to_string(),
|
||||
timestamp: 4,
|
||||
thread_ts: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
{
|
||||
let mut store = runtime_config_store()
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
store.remove(&config_path);
|
||||
}
|
||||
|
||||
assert_eq!(provider_impl.call_count.load(Ordering::SeqCst), 1);
|
||||
assert_eq!(
|
||||
provider_impl
|
||||
.models
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner())
|
||||
.as_slice(),
|
||||
&["hot-reloaded-model".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_respects_configured_max_tool_iterations_above_default() {
|
||||
let channel_impl = Arc::new(RecordingChannel::default());
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue