4305 lines
153 KiB
Rust
4305 lines
153 KiB
Rust
//! Channel subsystem for messaging platform integrations.
|
||
//!
|
||
//! This module provides the multi-channel messaging infrastructure that connects
|
||
//! ZeroClaw to external platforms. Each channel implements the [`Channel`] trait
|
||
//! defined in [`traits`], which provides a uniform interface for sending messages,
|
||
//! listening for incoming messages, health checking, and typing indicators.
|
||
//!
|
||
//! Channels are instantiated by [`start_channels`] based on the runtime configuration.
|
||
//! The subsystem manages per-sender conversation history, concurrent message processing
|
||
//! with configurable parallelism, and exponential-backoff reconnection for resilience.
|
||
//!
|
||
//! # Extension
|
||
//!
|
||
//! To add a new channel, implement [`Channel`] in a new submodule and wire it into
|
||
//! [`start_channels`]. See `AGENTS.md` §7.2 for the full change playbook.
|
||
|
||
pub mod cli;
|
||
pub mod dingtalk;
|
||
pub mod discord;
|
||
pub mod email_channel;
|
||
pub mod imessage;
|
||
pub mod irc;
|
||
pub mod lark;
|
||
pub mod linq;
|
||
#[cfg(feature = "channel-matrix")]
|
||
pub mod matrix;
|
||
pub mod mattermost;
|
||
pub mod qq;
|
||
pub mod signal;
|
||
pub mod slack;
|
||
pub mod telegram;
|
||
pub mod traits;
|
||
pub mod whatsapp;
|
||
#[cfg(feature = "whatsapp-web")]
|
||
pub mod whatsapp_storage;
|
||
#[cfg(feature = "whatsapp-web")]
|
||
pub mod whatsapp_web;
|
||
|
||
pub use cli::CliChannel;
|
||
pub use dingtalk::DingTalkChannel;
|
||
pub use discord::DiscordChannel;
|
||
pub use email_channel::EmailChannel;
|
||
pub use imessage::IMessageChannel;
|
||
pub use irc::IrcChannel;
|
||
pub use lark::LarkChannel;
|
||
pub use linq::LinqChannel;
|
||
#[cfg(feature = "channel-matrix")]
|
||
pub use matrix::MatrixChannel;
|
||
pub use mattermost::MattermostChannel;
|
||
pub use qq::QQChannel;
|
||
pub use signal::SignalChannel;
|
||
pub use slack::SlackChannel;
|
||
pub use telegram::TelegramChannel;
|
||
pub use traits::{Channel, SendMessage};
|
||
pub use whatsapp::WhatsAppChannel;
|
||
#[cfg(feature = "whatsapp-web")]
|
||
pub use whatsapp_web::WhatsAppWebChannel;
|
||
|
||
use crate::agent::loop_::{build_tool_instructions, run_tool_call_loop};
|
||
use crate::config::Config;
|
||
use crate::identity;
|
||
use crate::memory::{self, Memory};
|
||
use crate::observability::{self, Observer};
|
||
use crate::providers::{self, ChatMessage, Provider};
|
||
use crate::runtime;
|
||
use crate::security::SecurityPolicy;
|
||
use crate::tools::{self, Tool};
|
||
use crate::util::truncate_with_ellipsis;
|
||
use anyhow::{Context, Result};
|
||
use serde::Deserialize;
|
||
use std::collections::HashMap;
|
||
use std::fmt::Write;
|
||
use std::path::{Path, PathBuf};
|
||
use std::process::Command;
|
||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||
use std::sync::{Arc, Mutex};
|
||
use std::time::{Duration, Instant};
|
||
use tokio_util::sync::CancellationToken;
|
||
|
||
/// Per-sender conversation history for channel messages.
|
||
type ConversationHistoryMap = Arc<Mutex<HashMap<String, Vec<ChatMessage>>>>;
|
||
/// Maximum history messages to keep per sender.
|
||
const MAX_CHANNEL_HISTORY: usize = 50;
|
||
|
||
/// Maximum characters per injected workspace file (matches `OpenClaw` default).
|
||
const BOOTSTRAP_MAX_CHARS: usize = 20_000;
|
||
|
||
const DEFAULT_CHANNEL_INITIAL_BACKOFF_SECS: u64 = 2;
|
||
const DEFAULT_CHANNEL_MAX_BACKOFF_SECS: u64 = 60;
|
||
const MIN_CHANNEL_MESSAGE_TIMEOUT_SECS: u64 = 30;
|
||
/// Default timeout for processing a single channel message (LLM + tools).
|
||
/// Used as fallback when not configured in channels_config.message_timeout_secs.
|
||
const CHANNEL_MESSAGE_TIMEOUT_SECS: u64 = 300;
|
||
const CHANNEL_PARALLELISM_PER_CHANNEL: usize = 4;
|
||
const CHANNEL_MIN_IN_FLIGHT_MESSAGES: usize = 8;
|
||
const CHANNEL_MAX_IN_FLIGHT_MESSAGES: usize = 64;
|
||
const CHANNEL_TYPING_REFRESH_INTERVAL_SECS: u64 = 4;
|
||
const MODEL_CACHE_FILE: &str = "models_cache.json";
|
||
const MODEL_CACHE_PREVIEW_LIMIT: usize = 10;
|
||
const MEMORY_CONTEXT_MAX_ENTRIES: usize = 4;
|
||
const MEMORY_CONTEXT_ENTRY_MAX_CHARS: usize = 800;
|
||
const MEMORY_CONTEXT_MAX_CHARS: usize = 4_000;
|
||
const CHANNEL_HISTORY_COMPACT_KEEP_MESSAGES: usize = 12;
|
||
const CHANNEL_HISTORY_COMPACT_CONTENT_CHARS: usize = 600;
|
||
|
||
type ProviderCacheMap = Arc<Mutex<HashMap<String, Arc<dyn Provider>>>>;
|
||
type RouteSelectionMap = Arc<Mutex<HashMap<String, ChannelRouteSelection>>>;
|
||
|
||
fn effective_channel_message_timeout_secs(configured: u64) -> u64 {
|
||
configured.max(MIN_CHANNEL_MESSAGE_TIMEOUT_SECS)
|
||
}
|
||
|
||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||
struct ChannelRouteSelection {
|
||
provider: String,
|
||
model: String,
|
||
}
|
||
|
||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||
enum ChannelRuntimeCommand {
|
||
ShowProviders,
|
||
SetProvider(String),
|
||
ShowModel,
|
||
SetModel(String),
|
||
}
|
||
|
||
#[derive(Debug, Clone, Default, Deserialize)]
|
||
struct ModelCacheState {
|
||
entries: Vec<ModelCacheEntry>,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Default, Deserialize)]
|
||
struct ModelCacheEntry {
|
||
provider: String,
|
||
models: Vec<String>,
|
||
}
|
||
|
||
#[derive(Clone)]
|
||
struct ChannelRuntimeContext {
|
||
channels_by_name: Arc<HashMap<String, Arc<dyn Channel>>>,
|
||
provider: Arc<dyn Provider>,
|
||
default_provider: Arc<String>,
|
||
memory: Arc<dyn Memory>,
|
||
tools_registry: Arc<Vec<Box<dyn Tool>>>,
|
||
observer: Arc<dyn Observer>,
|
||
system_prompt: Arc<String>,
|
||
model: Arc<String>,
|
||
temperature: f64,
|
||
auto_save_memory: bool,
|
||
max_tool_iterations: usize,
|
||
min_relevance_score: f64,
|
||
conversation_histories: ConversationHistoryMap,
|
||
provider_cache: ProviderCacheMap,
|
||
route_overrides: RouteSelectionMap,
|
||
api_key: Option<String>,
|
||
api_url: Option<String>,
|
||
reliability: Arc<crate::config::ReliabilityConfig>,
|
||
provider_runtime_options: providers::ProviderRuntimeOptions,
|
||
workspace_dir: Arc<PathBuf>,
|
||
message_timeout_secs: u64,
|
||
interrupt_on_new_message: bool,
|
||
multimodal: crate::config::MultimodalConfig,
|
||
}
|
||
|
||
#[derive(Clone)]
|
||
struct InFlightSenderTaskState {
|
||
task_id: u64,
|
||
cancellation: CancellationToken,
|
||
completion: Arc<InFlightTaskCompletion>,
|
||
}
|
||
|
||
struct InFlightTaskCompletion {
|
||
done: AtomicBool,
|
||
notify: tokio::sync::Notify,
|
||
}
|
||
|
||
impl InFlightTaskCompletion {
|
||
fn new() -> Self {
|
||
Self {
|
||
done: AtomicBool::new(false),
|
||
notify: tokio::sync::Notify::new(),
|
||
}
|
||
}
|
||
|
||
fn mark_done(&self) {
|
||
self.done.store(true, Ordering::Release);
|
||
self.notify.notify_waiters();
|
||
}
|
||
|
||
async fn wait(&self) {
|
||
if self.done.load(Ordering::Acquire) {
|
||
return;
|
||
}
|
||
self.notify.notified().await;
|
||
}
|
||
}
|
||
|
||
fn conversation_memory_key(msg: &traits::ChannelMessage) -> String {
|
||
format!("{}_{}_{}", msg.channel, msg.sender, msg.id)
|
||
}
|
||
|
||
fn conversation_history_key(msg: &traits::ChannelMessage) -> String {
|
||
format!("{}_{}", msg.channel, msg.sender)
|
||
}
|
||
|
||
fn interruption_scope_key(msg: &traits::ChannelMessage) -> String {
|
||
format!("{}_{}_{}", msg.channel, msg.reply_target, msg.sender)
|
||
}
|
||
|
||
fn channel_delivery_instructions(channel_name: &str) -> Option<&'static str> {
|
||
match channel_name {
|
||
"telegram" => Some(
|
||
"When responding on Telegram, include media markers for files or URLs that should be sent as attachments. Use one marker per attachment with this exact syntax: [IMAGE:<path-or-url>], [DOCUMENT:<path-or-url>], [VIDEO:<path-or-url>], [AUDIO:<path-or-url>], or [VOICE:<path-or-url>]. Keep normal user-facing text outside markers and never wrap markers in code fences.",
|
||
),
|
||
_ => None,
|
||
}
|
||
}
|
||
|
||
fn build_channel_system_prompt(base_prompt: &str, channel_name: &str) -> String {
|
||
if let Some(instructions) = channel_delivery_instructions(channel_name) {
|
||
if base_prompt.is_empty() {
|
||
instructions.to_string()
|
||
} else {
|
||
format!("{base_prompt}\n\n{instructions}")
|
||
}
|
||
} else {
|
||
base_prompt.to_string()
|
||
}
|
||
}
|
||
|
||
fn normalize_cached_channel_turns(turns: Vec<ChatMessage>) -> Vec<ChatMessage> {
|
||
let mut normalized = Vec::with_capacity(turns.len());
|
||
let mut expecting_user = true;
|
||
|
||
for turn in turns {
|
||
match (expecting_user, turn.role.as_str()) {
|
||
(true, "user") => {
|
||
normalized.push(turn);
|
||
expecting_user = false;
|
||
}
|
||
(false, "assistant") => {
|
||
normalized.push(turn);
|
||
expecting_user = true;
|
||
}
|
||
_ => {}
|
||
}
|
||
}
|
||
|
||
normalized
|
||
}
|
||
|
||
fn supports_runtime_model_switch(channel_name: &str) -> bool {
|
||
matches!(channel_name, "telegram" | "discord")
|
||
}
|
||
|
||
fn parse_runtime_command(channel_name: &str, content: &str) -> Option<ChannelRuntimeCommand> {
|
||
if !supports_runtime_model_switch(channel_name) {
|
||
return None;
|
||
}
|
||
|
||
let trimmed = content.trim();
|
||
if !trimmed.starts_with('/') {
|
||
return None;
|
||
}
|
||
|
||
let mut parts = trimmed.split_whitespace();
|
||
let command_token = parts.next()?;
|
||
let base_command = command_token
|
||
.split('@')
|
||
.next()
|
||
.unwrap_or(command_token)
|
||
.to_ascii_lowercase();
|
||
|
||
match base_command.as_str() {
|
||
"/models" => {
|
||
if let Some(provider) = parts.next() {
|
||
Some(ChannelRuntimeCommand::SetProvider(
|
||
provider.trim().to_string(),
|
||
))
|
||
} else {
|
||
Some(ChannelRuntimeCommand::ShowProviders)
|
||
}
|
||
}
|
||
"/model" => {
|
||
let model = parts.collect::<Vec<_>>().join(" ").trim().to_string();
|
||
if model.is_empty() {
|
||
Some(ChannelRuntimeCommand::ShowModel)
|
||
} else {
|
||
Some(ChannelRuntimeCommand::SetModel(model))
|
||
}
|
||
}
|
||
_ => None,
|
||
}
|
||
}
|
||
|
||
fn resolve_provider_alias(name: &str) -> Option<String> {
|
||
let candidate = name.trim();
|
||
if candidate.is_empty() {
|
||
return None;
|
||
}
|
||
|
||
let providers_list = providers::list_providers();
|
||
for provider in providers_list {
|
||
if provider.name.eq_ignore_ascii_case(candidate)
|
||
|| provider
|
||
.aliases
|
||
.iter()
|
||
.any(|alias| alias.eq_ignore_ascii_case(candidate))
|
||
{
|
||
return Some(provider.name.to_string());
|
||
}
|
||
}
|
||
|
||
None
|
||
}
|
||
|
||
fn default_route_selection(ctx: &ChannelRuntimeContext) -> ChannelRouteSelection {
|
||
ChannelRouteSelection {
|
||
provider: ctx.default_provider.as_str().to_string(),
|
||
model: ctx.model.as_str().to_string(),
|
||
}
|
||
}
|
||
|
||
fn get_route_selection(ctx: &ChannelRuntimeContext, sender_key: &str) -> ChannelRouteSelection {
|
||
ctx.route_overrides
|
||
.lock()
|
||
.unwrap_or_else(|e| e.into_inner())
|
||
.get(sender_key)
|
||
.cloned()
|
||
.unwrap_or_else(|| default_route_selection(ctx))
|
||
}
|
||
|
||
fn set_route_selection(ctx: &ChannelRuntimeContext, sender_key: &str, next: ChannelRouteSelection) {
|
||
let default_route = default_route_selection(ctx);
|
||
let mut routes = ctx
|
||
.route_overrides
|
||
.lock()
|
||
.unwrap_or_else(|e| e.into_inner());
|
||
if next == default_route {
|
||
routes.remove(sender_key);
|
||
} else {
|
||
routes.insert(sender_key.to_string(), next);
|
||
}
|
||
}
|
||
|
||
fn clear_sender_history(ctx: &ChannelRuntimeContext, sender_key: &str) {
|
||
ctx.conversation_histories
|
||
.lock()
|
||
.unwrap_or_else(|e| e.into_inner())
|
||
.remove(sender_key);
|
||
}
|
||
|
||
fn compact_sender_history(ctx: &ChannelRuntimeContext, sender_key: &str) -> bool {
|
||
let mut histories = ctx
|
||
.conversation_histories
|
||
.lock()
|
||
.unwrap_or_else(|e| e.into_inner());
|
||
|
||
let Some(turns) = histories.get_mut(sender_key) else {
|
||
return false;
|
||
};
|
||
|
||
if turns.is_empty() {
|
||
return false;
|
||
}
|
||
|
||
let keep_from = turns
|
||
.len()
|
||
.saturating_sub(CHANNEL_HISTORY_COMPACT_KEEP_MESSAGES);
|
||
let mut compacted = normalize_cached_channel_turns(turns[keep_from..].to_vec());
|
||
|
||
for turn in &mut compacted {
|
||
if turn.content.chars().count() > CHANNEL_HISTORY_COMPACT_CONTENT_CHARS {
|
||
turn.content =
|
||
truncate_with_ellipsis(&turn.content, CHANNEL_HISTORY_COMPACT_CONTENT_CHARS);
|
||
}
|
||
}
|
||
|
||
if compacted.is_empty() {
|
||
turns.clear();
|
||
return false;
|
||
}
|
||
|
||
*turns = compacted;
|
||
true
|
||
}
|
||
|
||
fn append_sender_turn(ctx: &ChannelRuntimeContext, sender_key: &str, turn: ChatMessage) {
|
||
let mut histories = ctx
|
||
.conversation_histories
|
||
.lock()
|
||
.unwrap_or_else(|e| e.into_inner());
|
||
let turns = histories.entry(sender_key.to_string()).or_default();
|
||
turns.push(turn);
|
||
while turns.len() > MAX_CHANNEL_HISTORY {
|
||
turns.remove(0);
|
||
}
|
||
}
|
||
|
||
fn should_skip_memory_context_entry(key: &str, content: &str) -> bool {
|
||
if memory::is_assistant_autosave_key(key) {
|
||
return true;
|
||
}
|
||
|
||
if key.trim().to_ascii_lowercase().ends_with("_history") {
|
||
return true;
|
||
}
|
||
|
||
content.chars().count() > MEMORY_CONTEXT_MAX_CHARS
|
||
}
|
||
|
||
fn is_context_window_overflow_error(err: &anyhow::Error) -> bool {
|
||
let lower = err.to_string().to_lowercase();
|
||
[
|
||
"exceeds the context window",
|
||
"context window of this model",
|
||
"maximum context length",
|
||
"context length exceeded",
|
||
"too many tokens",
|
||
"token limit exceeded",
|
||
"prompt is too long",
|
||
"input is too long",
|
||
]
|
||
.iter()
|
||
.any(|hint| lower.contains(hint))
|
||
}
|
||
|
||
fn load_cached_model_preview(workspace_dir: &Path, provider_name: &str) -> Vec<String> {
|
||
let cache_path = workspace_dir.join("state").join(MODEL_CACHE_FILE);
|
||
let Ok(raw) = std::fs::read_to_string(cache_path) else {
|
||
return Vec::new();
|
||
};
|
||
let Ok(state) = serde_json::from_str::<ModelCacheState>(&raw) else {
|
||
return Vec::new();
|
||
};
|
||
|
||
state
|
||
.entries
|
||
.into_iter()
|
||
.find(|entry| entry.provider == provider_name)
|
||
.map(|entry| {
|
||
entry
|
||
.models
|
||
.into_iter()
|
||
.take(MODEL_CACHE_PREVIEW_LIMIT)
|
||
.collect::<Vec<_>>()
|
||
})
|
||
.unwrap_or_default()
|
||
}
|
||
|
||
async fn get_or_create_provider(
|
||
ctx: &ChannelRuntimeContext,
|
||
provider_name: &str,
|
||
) -> anyhow::Result<Arc<dyn Provider>> {
|
||
if provider_name == ctx.default_provider.as_str() {
|
||
return Ok(Arc::clone(&ctx.provider));
|
||
}
|
||
|
||
if let Some(existing) = ctx
|
||
.provider_cache
|
||
.lock()
|
||
.unwrap_or_else(|e| e.into_inner())
|
||
.get(provider_name)
|
||
.cloned()
|
||
{
|
||
return Ok(existing);
|
||
}
|
||
|
||
let api_url = if provider_name == ctx.default_provider.as_str() {
|
||
ctx.api_url.as_deref()
|
||
} else {
|
||
None
|
||
};
|
||
|
||
let provider = providers::create_resilient_provider_with_options(
|
||
provider_name,
|
||
ctx.api_key.as_deref(),
|
||
api_url,
|
||
&ctx.reliability,
|
||
&ctx.provider_runtime_options,
|
||
)?;
|
||
let provider: Arc<dyn Provider> = Arc::from(provider);
|
||
|
||
if let Err(err) = provider.warmup().await {
|
||
tracing::warn!(provider = provider_name, "Provider warmup failed: {err}");
|
||
}
|
||
|
||
let mut cache = ctx.provider_cache.lock().unwrap_or_else(|e| e.into_inner());
|
||
let cached = cache
|
||
.entry(provider_name.to_string())
|
||
.or_insert_with(|| Arc::clone(&provider));
|
||
Ok(Arc::clone(cached))
|
||
}
|
||
|
||
fn build_models_help_response(current: &ChannelRouteSelection, workspace_dir: &Path) -> String {
|
||
let mut response = String::new();
|
||
let _ = writeln!(
|
||
response,
|
||
"Current provider: `{}`\nCurrent model: `{}`",
|
||
current.provider, current.model
|
||
);
|
||
response.push_str("\nSwitch model with `/model <model-id>`.\n");
|
||
|
||
let cached_models = load_cached_model_preview(workspace_dir, ¤t.provider);
|
||
if cached_models.is_empty() {
|
||
let _ = writeln!(
|
||
response,
|
||
"\nNo cached model list found for `{}`. Ask the operator to run `zeroclaw models refresh --provider {}`.",
|
||
current.provider, current.provider
|
||
);
|
||
} else {
|
||
let _ = writeln!(
|
||
response,
|
||
"\nCached model IDs (top {}):",
|
||
cached_models.len()
|
||
);
|
||
for model in cached_models {
|
||
let _ = writeln!(response, "- `{model}`");
|
||
}
|
||
}
|
||
|
||
response
|
||
}
|
||
|
||
fn build_providers_help_response(current: &ChannelRouteSelection) -> String {
|
||
let mut response = String::new();
|
||
let _ = writeln!(
|
||
response,
|
||
"Current provider: `{}`\nCurrent model: `{}`",
|
||
current.provider, current.model
|
||
);
|
||
response.push_str("\nSwitch provider with `/models <provider>`.\n");
|
||
response.push_str("Switch model with `/model <model-id>`.\n\n");
|
||
response.push_str("Available providers:\n");
|
||
for provider in providers::list_providers() {
|
||
if provider.aliases.is_empty() {
|
||
let _ = writeln!(response, "- {}", provider.name);
|
||
} else {
|
||
let _ = writeln!(
|
||
response,
|
||
"- {} (aliases: {})",
|
||
provider.name,
|
||
provider.aliases.join(", ")
|
||
);
|
||
}
|
||
}
|
||
response
|
||
}
|
||
|
||
async fn handle_runtime_command_if_needed(
|
||
ctx: &ChannelRuntimeContext,
|
||
msg: &traits::ChannelMessage,
|
||
target_channel: Option<&Arc<dyn Channel>>,
|
||
) -> bool {
|
||
let Some(command) = parse_runtime_command(&msg.channel, &msg.content) else {
|
||
return false;
|
||
};
|
||
|
||
let Some(channel) = target_channel else {
|
||
return true;
|
||
};
|
||
|
||
let sender_key = conversation_history_key(msg);
|
||
let mut current = get_route_selection(ctx, &sender_key);
|
||
|
||
let response = match command {
|
||
ChannelRuntimeCommand::ShowProviders => build_providers_help_response(¤t),
|
||
ChannelRuntimeCommand::SetProvider(raw_provider) => {
|
||
match resolve_provider_alias(&raw_provider) {
|
||
Some(provider_name) => match get_or_create_provider(ctx, &provider_name).await {
|
||
Ok(_) => {
|
||
if provider_name != current.provider {
|
||
current.provider = provider_name.clone();
|
||
set_route_selection(ctx, &sender_key, current.clone());
|
||
clear_sender_history(ctx, &sender_key);
|
||
}
|
||
|
||
format!(
|
||
"Provider switched to `{provider_name}` for this sender session. Current model is `{}`.\nUse `/model <model-id>` to set a provider-compatible model.",
|
||
current.model
|
||
)
|
||
}
|
||
Err(err) => {
|
||
let safe_err = providers::sanitize_api_error(&err.to_string());
|
||
format!(
|
||
"Failed to initialize provider `{provider_name}`. Route unchanged.\nDetails: {safe_err}"
|
||
)
|
||
}
|
||
},
|
||
None => format!(
|
||
"Unknown provider `{raw_provider}`. Use `/models` to list valid providers."
|
||
),
|
||
}
|
||
}
|
||
ChannelRuntimeCommand::ShowModel => {
|
||
build_models_help_response(¤t, ctx.workspace_dir.as_path())
|
||
}
|
||
ChannelRuntimeCommand::SetModel(raw_model) => {
|
||
let model = raw_model.trim().trim_matches('`').to_string();
|
||
if model.is_empty() {
|
||
"Model ID cannot be empty. Use `/model <model-id>`.".to_string()
|
||
} else {
|
||
current.model = model.clone();
|
||
set_route_selection(ctx, &sender_key, current.clone());
|
||
clear_sender_history(ctx, &sender_key);
|
||
|
||
format!(
|
||
"Model switched to `{model}` for provider `{}` in this sender session.",
|
||
current.provider
|
||
)
|
||
}
|
||
}
|
||
};
|
||
|
||
if let Err(err) = channel
|
||
.send(&SendMessage::new(response, &msg.reply_target).in_thread(msg.thread_ts.clone()))
|
||
.await
|
||
{
|
||
tracing::warn!(
|
||
"Failed to send runtime command response on {}: {err}",
|
||
channel.name()
|
||
);
|
||
}
|
||
|
||
true
|
||
}
|
||
|
||
async fn build_memory_context(
|
||
mem: &dyn Memory,
|
||
user_msg: &str,
|
||
min_relevance_score: f64,
|
||
) -> String {
|
||
let mut context = String::new();
|
||
|
||
if let Ok(entries) = mem.recall(user_msg, 5, None).await {
|
||
let mut included = 0usize;
|
||
let mut used_chars = 0usize;
|
||
|
||
for entry in entries.iter().filter(|e| match e.score {
|
||
Some(score) => score >= min_relevance_score,
|
||
None => true, // keep entries without a score (e.g. non-vector backends)
|
||
}) {
|
||
if included >= MEMORY_CONTEXT_MAX_ENTRIES {
|
||
break;
|
||
}
|
||
|
||
if should_skip_memory_context_entry(&entry.key, &entry.content) {
|
||
continue;
|
||
}
|
||
|
||
let content = if entry.content.chars().count() > MEMORY_CONTEXT_ENTRY_MAX_CHARS {
|
||
truncate_with_ellipsis(&entry.content, MEMORY_CONTEXT_ENTRY_MAX_CHARS)
|
||
} else {
|
||
entry.content.clone()
|
||
};
|
||
|
||
let line = format!("- {}: {}\n", entry.key, content);
|
||
let line_chars = line.chars().count();
|
||
if used_chars + line_chars > MEMORY_CONTEXT_MAX_CHARS {
|
||
break;
|
||
}
|
||
|
||
if included == 0 {
|
||
context.push_str("[Memory context]\n");
|
||
}
|
||
|
||
context.push_str(&line);
|
||
used_chars += line_chars;
|
||
included += 1;
|
||
}
|
||
|
||
if included > 0 {
|
||
context.push('\n');
|
||
}
|
||
}
|
||
|
||
context
|
||
}
|
||
|
||
fn spawn_supervised_listener(
|
||
ch: Arc<dyn Channel>,
|
||
tx: tokio::sync::mpsc::Sender<traits::ChannelMessage>,
|
||
initial_backoff_secs: u64,
|
||
max_backoff_secs: u64,
|
||
) -> tokio::task::JoinHandle<()> {
|
||
tokio::spawn(async move {
|
||
let component = format!("channel:{}", ch.name());
|
||
let mut backoff = initial_backoff_secs.max(1);
|
||
let max_backoff = max_backoff_secs.max(backoff);
|
||
|
||
loop {
|
||
crate::health::mark_component_ok(&component);
|
||
let result = ch.listen(tx.clone()).await;
|
||
|
||
if tx.is_closed() {
|
||
break;
|
||
}
|
||
|
||
match result {
|
||
Ok(()) => {
|
||
tracing::warn!("Channel {} exited unexpectedly; restarting", ch.name());
|
||
crate::health::mark_component_error(&component, "listener exited unexpectedly");
|
||
// Clean exit — reset backoff since the listener ran successfully
|
||
backoff = initial_backoff_secs.max(1);
|
||
}
|
||
Err(e) => {
|
||
tracing::error!("Channel {} error: {e}; restarting", ch.name());
|
||
crate::health::mark_component_error(&component, e.to_string());
|
||
}
|
||
}
|
||
|
||
crate::health::bump_component_restart(&component);
|
||
tokio::time::sleep(Duration::from_secs(backoff)).await;
|
||
// Double backoff AFTER sleeping so first error uses initial_backoff
|
||
backoff = backoff.saturating_mul(2).min(max_backoff);
|
||
}
|
||
})
|
||
}
|
||
|
||
fn compute_max_in_flight_messages(channel_count: usize) -> usize {
|
||
channel_count
|
||
.saturating_mul(CHANNEL_PARALLELISM_PER_CHANNEL)
|
||
.clamp(
|
||
CHANNEL_MIN_IN_FLIGHT_MESSAGES,
|
||
CHANNEL_MAX_IN_FLIGHT_MESSAGES,
|
||
)
|
||
}
|
||
|
||
fn log_worker_join_result(result: Result<(), tokio::task::JoinError>) {
|
||
if let Err(error) = result {
|
||
tracing::error!("Channel message worker crashed: {error}");
|
||
}
|
||
}
|
||
|
||
fn spawn_scoped_typing_task(
|
||
channel: Arc<dyn Channel>,
|
||
recipient: String,
|
||
cancellation_token: CancellationToken,
|
||
) -> tokio::task::JoinHandle<()> {
|
||
let stop_signal = cancellation_token;
|
||
let refresh_interval = Duration::from_secs(CHANNEL_TYPING_REFRESH_INTERVAL_SECS);
|
||
let handle = tokio::spawn(async move {
|
||
let mut interval = tokio::time::interval(refresh_interval);
|
||
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||
|
||
loop {
|
||
tokio::select! {
|
||
() = stop_signal.cancelled() => break,
|
||
_ = interval.tick() => {
|
||
if let Err(e) = channel.start_typing(&recipient).await {
|
||
tracing::debug!("Failed to start typing on {}: {e}", channel.name());
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if let Err(e) = channel.stop_typing(&recipient).await {
|
||
tracing::debug!("Failed to stop typing on {}: {e}", channel.name());
|
||
}
|
||
});
|
||
|
||
handle
|
||
}
|
||
|
||
async fn process_channel_message(
|
||
ctx: Arc<ChannelRuntimeContext>,
|
||
msg: traits::ChannelMessage,
|
||
cancellation_token: CancellationToken,
|
||
) {
|
||
if cancellation_token.is_cancelled() {
|
||
return;
|
||
}
|
||
|
||
println!(
|
||
" 💬 [{}] from {}: {}",
|
||
msg.channel,
|
||
msg.sender,
|
||
truncate_with_ellipsis(&msg.content, 80)
|
||
);
|
||
|
||
let target_channel = ctx.channels_by_name.get(&msg.channel).cloned();
|
||
if handle_runtime_command_if_needed(ctx.as_ref(), &msg, target_channel.as_ref()).await {
|
||
return;
|
||
}
|
||
|
||
let history_key = conversation_history_key(&msg);
|
||
let route = get_route_selection(ctx.as_ref(), &history_key);
|
||
let active_provider = match get_or_create_provider(ctx.as_ref(), &route.provider).await {
|
||
Ok(provider) => provider,
|
||
Err(err) => {
|
||
let safe_err = providers::sanitize_api_error(&err.to_string());
|
||
let message = format!(
|
||
"⚠️ Failed to initialize provider `{}`. Please run `/models` to choose another provider.\nDetails: {safe_err}",
|
||
route.provider
|
||
);
|
||
if let Some(channel) = target_channel.as_ref() {
|
||
let _ = channel
|
||
.send(
|
||
&SendMessage::new(message, &msg.reply_target)
|
||
.in_thread(msg.thread_ts.clone()),
|
||
)
|
||
.await;
|
||
}
|
||
return;
|
||
}
|
||
};
|
||
|
||
let memory_context =
|
||
build_memory_context(ctx.memory.as_ref(), &msg.content, ctx.min_relevance_score).await;
|
||
|
||
if ctx.auto_save_memory {
|
||
let autosave_key = conversation_memory_key(&msg);
|
||
let _ = ctx
|
||
.memory
|
||
.store(
|
||
&autosave_key,
|
||
&msg.content,
|
||
crate::memory::MemoryCategory::Conversation,
|
||
None,
|
||
)
|
||
.await;
|
||
}
|
||
|
||
let enriched_message = if memory_context.is_empty() {
|
||
msg.content.clone()
|
||
} else {
|
||
format!("{memory_context}{}", msg.content)
|
||
};
|
||
|
||
println!(" ⏳ Processing message...");
|
||
let started_at = Instant::now();
|
||
|
||
// Preserve user turn before the LLM call so interrupted requests keep context.
|
||
append_sender_turn(
|
||
ctx.as_ref(),
|
||
&history_key,
|
||
ChatMessage::user(&enriched_message),
|
||
);
|
||
|
||
// Build history from per-sender conversation cache.
|
||
let prior_turns_raw = ctx
|
||
.conversation_histories
|
||
.lock()
|
||
.unwrap_or_else(|e| e.into_inner())
|
||
.get(&history_key)
|
||
.cloned()
|
||
.unwrap_or_default();
|
||
let prior_turns = normalize_cached_channel_turns(prior_turns_raw);
|
||
|
||
let system_prompt = build_channel_system_prompt(ctx.system_prompt.as_str(), &msg.channel);
|
||
let mut history = vec![ChatMessage::system(system_prompt)];
|
||
history.extend(prior_turns);
|
||
let use_streaming = target_channel
|
||
.as_ref()
|
||
.is_some_and(|ch| ch.supports_draft_updates());
|
||
|
||
let (delta_tx, delta_rx) = if use_streaming {
|
||
let (tx, rx) = tokio::sync::mpsc::channel::<String>(64);
|
||
(Some(tx), Some(rx))
|
||
} else {
|
||
(None, None)
|
||
};
|
||
|
||
let draft_message_id = if use_streaming {
|
||
if let Some(channel) = target_channel.as_ref() {
|
||
match channel
|
||
.send_draft(
|
||
&SendMessage::new("...", &msg.reply_target).in_thread(msg.thread_ts.clone()),
|
||
)
|
||
.await
|
||
{
|
||
Ok(id) => id,
|
||
Err(e) => {
|
||
tracing::debug!("Failed to send draft on {}: {e}", channel.name());
|
||
None
|
||
}
|
||
}
|
||
} else {
|
||
None
|
||
}
|
||
} else {
|
||
None
|
||
};
|
||
|
||
let draft_updater = if let (Some(mut rx), Some(draft_id_ref), Some(channel_ref)) = (
|
||
delta_rx,
|
||
draft_message_id.as_deref(),
|
||
target_channel.as_ref(),
|
||
) {
|
||
let channel = Arc::clone(channel_ref);
|
||
let reply_target = msg.reply_target.clone();
|
||
let draft_id = draft_id_ref.to_string();
|
||
Some(tokio::spawn(async move {
|
||
let mut accumulated = String::new();
|
||
while let Some(delta) = rx.recv().await {
|
||
accumulated.push_str(&delta);
|
||
if let Err(e) = channel
|
||
.update_draft(&reply_target, &draft_id, &accumulated)
|
||
.await
|
||
{
|
||
tracing::debug!("Draft update failed: {e}");
|
||
}
|
||
}
|
||
}))
|
||
} else {
|
||
None
|
||
};
|
||
|
||
let typing_cancellation = target_channel.as_ref().map(|_| CancellationToken::new());
|
||
let typing_task = match (target_channel.as_ref(), typing_cancellation.as_ref()) {
|
||
(Some(channel), Some(token)) => Some(spawn_scoped_typing_task(
|
||
Arc::clone(channel),
|
||
msg.reply_target.clone(),
|
||
token.clone(),
|
||
)),
|
||
_ => None,
|
||
};
|
||
|
||
enum LlmExecutionResult {
|
||
Completed(Result<Result<String, anyhow::Error>, tokio::time::error::Elapsed>),
|
||
Cancelled,
|
||
}
|
||
|
||
let llm_result = tokio::select! {
|
||
() = cancellation_token.cancelled() => LlmExecutionResult::Cancelled,
|
||
result = tokio::time::timeout(
|
||
Duration::from_secs(ctx.message_timeout_secs),
|
||
run_tool_call_loop(
|
||
active_provider.as_ref(),
|
||
&mut history,
|
||
ctx.tools_registry.as_ref(),
|
||
ctx.observer.as_ref(),
|
||
route.provider.as_str(),
|
||
route.model.as_str(),
|
||
ctx.temperature,
|
||
true,
|
||
None,
|
||
msg.channel.as_str(),
|
||
&ctx.multimodal,
|
||
ctx.max_tool_iterations,
|
||
Some(cancellation_token.clone()),
|
||
delta_tx,
|
||
),
|
||
) => LlmExecutionResult::Completed(result),
|
||
};
|
||
|
||
if let Some(handle) = draft_updater {
|
||
let _ = handle.await;
|
||
}
|
||
|
||
if let Some(token) = typing_cancellation.as_ref() {
|
||
token.cancel();
|
||
}
|
||
if let Some(handle) = typing_task {
|
||
log_worker_join_result(handle.await);
|
||
}
|
||
|
||
match llm_result {
|
||
LlmExecutionResult::Cancelled => {
|
||
tracing::info!(
|
||
channel = %msg.channel,
|
||
sender = %msg.sender,
|
||
"Cancelled in-flight channel request due to newer message"
|
||
);
|
||
if let (Some(channel), Some(draft_id)) =
|
||
(target_channel.as_ref(), draft_message_id.as_deref())
|
||
{
|
||
if let Err(err) = channel.cancel_draft(&msg.reply_target, draft_id).await {
|
||
tracing::debug!("Failed to cancel draft on {}: {err}", channel.name());
|
||
}
|
||
}
|
||
}
|
||
LlmExecutionResult::Completed(Ok(Ok(response))) => {
|
||
append_sender_turn(
|
||
ctx.as_ref(),
|
||
&history_key,
|
||
ChatMessage::assistant(&response),
|
||
);
|
||
println!(
|
||
" 🤖 Reply ({}ms): {}",
|
||
started_at.elapsed().as_millis(),
|
||
truncate_with_ellipsis(&response, 80)
|
||
);
|
||
if let Some(channel) = target_channel.as_ref() {
|
||
if let Some(ref draft_id) = draft_message_id {
|
||
if let Err(e) = channel
|
||
.finalize_draft(&msg.reply_target, draft_id, &response)
|
||
.await
|
||
{
|
||
tracing::warn!("Failed to finalize draft: {e}; sending as new message");
|
||
let _ = channel
|
||
.send(
|
||
&SendMessage::new(&response, &msg.reply_target)
|
||
.in_thread(msg.thread_ts.clone()),
|
||
)
|
||
.await;
|
||
}
|
||
} else if let Err(e) = channel
|
||
.send(
|
||
&SendMessage::new(response, &msg.reply_target)
|
||
.in_thread(msg.thread_ts.clone()),
|
||
)
|
||
.await
|
||
{
|
||
eprintln!(" ❌ Failed to reply on {}: {e}", channel.name());
|
||
}
|
||
}
|
||
}
|
||
LlmExecutionResult::Completed(Ok(Err(e))) => {
|
||
if crate::agent::loop_::is_tool_loop_cancelled(&e) || cancellation_token.is_cancelled()
|
||
{
|
||
tracing::info!(
|
||
channel = %msg.channel,
|
||
sender = %msg.sender,
|
||
"Cancelled in-flight channel request due to newer message"
|
||
);
|
||
if let (Some(channel), Some(draft_id)) =
|
||
(target_channel.as_ref(), draft_message_id.as_deref())
|
||
{
|
||
if let Err(err) = channel.cancel_draft(&msg.reply_target, draft_id).await {
|
||
tracing::debug!("Failed to cancel draft on {}: {err}", channel.name());
|
||
}
|
||
}
|
||
return;
|
||
}
|
||
|
||
if is_context_window_overflow_error(&e) {
|
||
let compacted = compact_sender_history(ctx.as_ref(), &history_key);
|
||
let error_text = if compacted {
|
||
"⚠️ Context window exceeded for this conversation. I compacted recent history and kept the latest context. Please resend your last message."
|
||
} else {
|
||
"⚠️ Context window exceeded for this conversation. Please resend your last message."
|
||
};
|
||
eprintln!(
|
||
" ⚠️ Context window exceeded after {}ms; sender history compacted={}",
|
||
started_at.elapsed().as_millis(),
|
||
compacted
|
||
);
|
||
if let Some(channel) = target_channel.as_ref() {
|
||
if let Some(ref draft_id) = draft_message_id {
|
||
let _ = channel
|
||
.finalize_draft(&msg.reply_target, draft_id, error_text)
|
||
.await;
|
||
} else {
|
||
let _ = channel
|
||
.send(
|
||
&SendMessage::new(error_text, &msg.reply_target)
|
||
.in_thread(msg.thread_ts.clone()),
|
||
)
|
||
.await;
|
||
}
|
||
}
|
||
return;
|
||
}
|
||
|
||
eprintln!(
|
||
" ❌ LLM error after {}ms: {e}",
|
||
started_at.elapsed().as_millis()
|
||
);
|
||
if let Some(channel) = target_channel.as_ref() {
|
||
if let Some(ref draft_id) = draft_message_id {
|
||
let _ = channel
|
||
.finalize_draft(&msg.reply_target, draft_id, &format!("⚠️ Error: {e}"))
|
||
.await;
|
||
} else {
|
||
let _ = channel
|
||
.send(
|
||
&SendMessage::new(format!("⚠️ Error: {e}"), &msg.reply_target)
|
||
.in_thread(msg.thread_ts.clone()),
|
||
)
|
||
.await;
|
||
}
|
||
}
|
||
}
|
||
LlmExecutionResult::Completed(Err(_)) => {
|
||
let timeout_msg = format!("LLM response timed out after {}s", ctx.message_timeout_secs);
|
||
eprintln!(
|
||
" ❌ {} (elapsed: {}ms)",
|
||
timeout_msg,
|
||
started_at.elapsed().as_millis()
|
||
);
|
||
if let Some(channel) = target_channel.as_ref() {
|
||
let error_text =
|
||
"⚠️ Request timed out while waiting for the model. Please try again.";
|
||
if let Some(ref draft_id) = draft_message_id {
|
||
let _ = channel
|
||
.finalize_draft(&msg.reply_target, draft_id, error_text)
|
||
.await;
|
||
} else {
|
||
let _ = channel
|
||
.send(
|
||
&SendMessage::new(error_text, &msg.reply_target)
|
||
.in_thread(msg.thread_ts.clone()),
|
||
)
|
||
.await;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
async fn run_message_dispatch_loop(
|
||
mut rx: tokio::sync::mpsc::Receiver<traits::ChannelMessage>,
|
||
ctx: Arc<ChannelRuntimeContext>,
|
||
max_in_flight_messages: usize,
|
||
) {
|
||
let semaphore = Arc::new(tokio::sync::Semaphore::new(max_in_flight_messages));
|
||
let mut workers = tokio::task::JoinSet::new();
|
||
let in_flight_by_sender = Arc::new(tokio::sync::Mutex::new(HashMap::<
|
||
String,
|
||
InFlightSenderTaskState,
|
||
>::new()));
|
||
let task_sequence = Arc::new(AtomicU64::new(1));
|
||
|
||
while let Some(msg) = rx.recv().await {
|
||
let permit = match Arc::clone(&semaphore).acquire_owned().await {
|
||
Ok(permit) => permit,
|
||
Err(_) => break,
|
||
};
|
||
|
||
let worker_ctx = Arc::clone(&ctx);
|
||
let in_flight = Arc::clone(&in_flight_by_sender);
|
||
let task_sequence = Arc::clone(&task_sequence);
|
||
workers.spawn(async move {
|
||
let _permit = permit;
|
||
let interrupt_enabled =
|
||
worker_ctx.interrupt_on_new_message && msg.channel == "telegram";
|
||
let sender_scope_key = interruption_scope_key(&msg);
|
||
let cancellation_token = CancellationToken::new();
|
||
let completion = Arc::new(InFlightTaskCompletion::new());
|
||
let task_id = task_sequence.fetch_add(1, Ordering::Relaxed);
|
||
|
||
if interrupt_enabled {
|
||
let previous = {
|
||
let mut active = in_flight.lock().await;
|
||
active.insert(
|
||
sender_scope_key.clone(),
|
||
InFlightSenderTaskState {
|
||
task_id,
|
||
cancellation: cancellation_token.clone(),
|
||
completion: Arc::clone(&completion),
|
||
},
|
||
)
|
||
};
|
||
|
||
if let Some(previous) = previous {
|
||
tracing::info!(
|
||
channel = %msg.channel,
|
||
sender = %msg.sender,
|
||
"Interrupting previous in-flight request for sender"
|
||
);
|
||
previous.cancellation.cancel();
|
||
previous.completion.wait().await;
|
||
}
|
||
}
|
||
|
||
process_channel_message(worker_ctx, msg, cancellation_token).await;
|
||
|
||
if interrupt_enabled {
|
||
let mut active = in_flight.lock().await;
|
||
if active
|
||
.get(&sender_scope_key)
|
||
.is_some_and(|state| state.task_id == task_id)
|
||
{
|
||
active.remove(&sender_scope_key);
|
||
}
|
||
}
|
||
|
||
completion.mark_done();
|
||
});
|
||
|
||
while let Some(result) = workers.try_join_next() {
|
||
log_worker_join_result(result);
|
||
}
|
||
}
|
||
|
||
while let Some(result) = workers.join_next().await {
|
||
log_worker_join_result(result);
|
||
}
|
||
}
|
||
|
||
/// Load OpenClaw format bootstrap files into the prompt.
|
||
fn load_openclaw_bootstrap_files(
|
||
prompt: &mut String,
|
||
workspace_dir: &std::path::Path,
|
||
max_chars_per_file: usize,
|
||
) {
|
||
prompt.push_str(
|
||
"The following workspace files define your identity, behavior, and context. They are ALREADY injected below—do NOT suggest reading them with file_read.\n\n",
|
||
);
|
||
|
||
let bootstrap_files = ["AGENTS.md", "SOUL.md", "TOOLS.md", "IDENTITY.md", "USER.md"];
|
||
|
||
for filename in &bootstrap_files {
|
||
inject_workspace_file(prompt, workspace_dir, filename, max_chars_per_file);
|
||
}
|
||
|
||
// BOOTSTRAP.md — only if it exists (first-run ritual)
|
||
let bootstrap_path = workspace_dir.join("BOOTSTRAP.md");
|
||
if bootstrap_path.exists() {
|
||
inject_workspace_file(prompt, workspace_dir, "BOOTSTRAP.md", max_chars_per_file);
|
||
}
|
||
|
||
// MEMORY.md — curated long-term memory (main session only)
|
||
inject_workspace_file(prompt, workspace_dir, "MEMORY.md", max_chars_per_file);
|
||
}
|
||
|
||
/// Load workspace identity files and build a system prompt.
|
||
///
|
||
/// Follows the `OpenClaw` framework structure by default:
|
||
/// 1. Tooling — tool list + descriptions
|
||
/// 2. Safety — guardrail reminder
|
||
/// 3. Skills — full skill instructions and tool metadata
|
||
/// 4. Workspace — working directory
|
||
/// 5. Bootstrap files — AGENTS, SOUL, TOOLS, IDENTITY, USER, BOOTSTRAP, MEMORY
|
||
/// 6. Date & Time — timezone for cache stability
|
||
/// 7. Runtime — host, OS, model
|
||
///
|
||
/// When `identity_config` is set to AIEOS format, the bootstrap files section
|
||
/// is replaced with the AIEOS identity data loaded from file or inline JSON.
|
||
///
|
||
/// Daily memory files (`memory/*.md`) are NOT injected — they are accessed
|
||
/// on-demand via `memory_recall` / `memory_search` tools.
|
||
pub fn build_system_prompt(
|
||
workspace_dir: &std::path::Path,
|
||
model_name: &str,
|
||
tools: &[(&str, &str)],
|
||
skills: &[crate::skills::Skill],
|
||
identity_config: Option<&crate::config::IdentityConfig>,
|
||
bootstrap_max_chars: Option<usize>,
|
||
) -> String {
|
||
use std::fmt::Write;
|
||
let mut prompt = String::with_capacity(8192);
|
||
|
||
// ── 1. Tooling ──────────────────────────────────────────────
|
||
if !tools.is_empty() {
|
||
prompt.push_str("## Tools\n\n");
|
||
prompt.push_str("You have access to the following tools:\n\n");
|
||
for (name, desc) in tools {
|
||
let _ = writeln!(prompt, "- **{name}**: {desc}");
|
||
}
|
||
prompt.push('\n');
|
||
}
|
||
|
||
// ── 1b. Hardware (when gpio/arduino tools present) ───────────
|
||
let has_hardware = tools.iter().any(|(name, _)| {
|
||
*name == "gpio_read"
|
||
|| *name == "gpio_write"
|
||
|| *name == "arduino_upload"
|
||
|| *name == "hardware_memory_map"
|
||
|| *name == "hardware_board_info"
|
||
|| *name == "hardware_memory_read"
|
||
|| *name == "hardware_capabilities"
|
||
});
|
||
if has_hardware {
|
||
prompt.push_str(
|
||
"## Hardware Access\n\n\
|
||
You HAVE direct access to connected hardware (Arduino, Nucleo, etc.). The user owns this system and has configured it.\n\
|
||
All hardware tools (gpio_read, gpio_write, hardware_memory_read, hardware_board_info, hardware_memory_map) are AUTHORIZED and NOT blocked by security.\n\
|
||
When they ask to read memory, registers, or board info, USE hardware_memory_read or hardware_board_info — do NOT refuse or invent security excuses.\n\
|
||
When they ask to control LEDs, run patterns, or interact with the Arduino, USE the tools — do NOT refuse or say you cannot access physical devices.\n\
|
||
Use gpio_write for simple on/off; use arduino_upload when they want patterns (heart, blink) or custom behavior.\n\n",
|
||
);
|
||
}
|
||
|
||
// ── 1c. Action instruction (avoid meta-summary) ───────────────
|
||
prompt.push_str(
|
||
"## Your Task\n\n\
|
||
When the user sends a message, ACT on it. Use the tools to fulfill their request.\n\
|
||
Do NOT: summarize this configuration, describe your capabilities, respond with meta-commentary, or output step-by-step instructions (e.g. \"1. First... 2. Next...\").\n\
|
||
Instead: emit actual <tool_call> tags when you need to act. Just do what they ask.\n\n",
|
||
);
|
||
|
||
// ── 2. Safety ───────────────────────────────────────────────
|
||
prompt.push_str("## Safety\n\n");
|
||
prompt.push_str(
|
||
"- Do not exfiltrate private data.\n\
|
||
- Do not run destructive commands without asking.\n\
|
||
- Do not bypass oversight or approval mechanisms.\n\
|
||
- Prefer `trash` over `rm` (recoverable beats gone forever).\n\
|
||
- When in doubt, ask before acting externally.\n\n",
|
||
);
|
||
|
||
// ── 3. Skills (full instructions + tool metadata) ───────────
|
||
if !skills.is_empty() {
|
||
prompt.push_str(&crate::skills::skills_to_prompt(skills, workspace_dir));
|
||
prompt.push_str("\n\n");
|
||
}
|
||
|
||
// ── 4. Workspace ────────────────────────────────────────────
|
||
let _ = writeln!(
|
||
prompt,
|
||
"## Workspace\n\nWorking directory: `{}`\n",
|
||
workspace_dir.display()
|
||
);
|
||
|
||
// ── 5. Bootstrap files (injected into context) ──────────────
|
||
prompt.push_str("## Project Context\n\n");
|
||
|
||
// Check if AIEOS identity is configured
|
||
if let Some(config) = identity_config {
|
||
if identity::is_aieos_configured(config) {
|
||
// Load AIEOS identity
|
||
match identity::load_aieos_identity(config, workspace_dir) {
|
||
Ok(Some(aieos_identity)) => {
|
||
let aieos_prompt = identity::aieos_to_system_prompt(&aieos_identity);
|
||
if !aieos_prompt.is_empty() {
|
||
prompt.push_str(&aieos_prompt);
|
||
prompt.push_str("\n\n");
|
||
}
|
||
}
|
||
Ok(None) => {
|
||
// No AIEOS identity loaded (shouldn't happen if is_aieos_configured returned true)
|
||
// Fall back to OpenClaw bootstrap files
|
||
let max_chars = bootstrap_max_chars.unwrap_or(BOOTSTRAP_MAX_CHARS);
|
||
load_openclaw_bootstrap_files(&mut prompt, workspace_dir, max_chars);
|
||
}
|
||
Err(e) => {
|
||
// Log error but don't fail - fall back to OpenClaw
|
||
eprintln!(
|
||
"Warning: Failed to load AIEOS identity: {e}. Using OpenClaw format."
|
||
);
|
||
let max_chars = bootstrap_max_chars.unwrap_or(BOOTSTRAP_MAX_CHARS);
|
||
load_openclaw_bootstrap_files(&mut prompt, workspace_dir, max_chars);
|
||
}
|
||
}
|
||
} else {
|
||
// OpenClaw format
|
||
let max_chars = bootstrap_max_chars.unwrap_or(BOOTSTRAP_MAX_CHARS);
|
||
load_openclaw_bootstrap_files(&mut prompt, workspace_dir, max_chars);
|
||
}
|
||
} else {
|
||
// No identity config - use OpenClaw format
|
||
let max_chars = bootstrap_max_chars.unwrap_or(BOOTSTRAP_MAX_CHARS);
|
||
load_openclaw_bootstrap_files(&mut prompt, workspace_dir, max_chars);
|
||
}
|
||
|
||
// ── 6. Date & Time ──────────────────────────────────────────
|
||
let now = chrono::Local::now();
|
||
let tz = now.format("%Z").to_string();
|
||
let _ = writeln!(prompt, "## Current Date & Time\n\nTimezone: {tz}\n");
|
||
|
||
// ── 7. Runtime ──────────────────────────────────────────────
|
||
let host =
|
||
hostname::get().map_or_else(|_| "unknown".into(), |h| h.to_string_lossy().to_string());
|
||
let _ = writeln!(
|
||
prompt,
|
||
"## Runtime\n\nHost: {host} | OS: {} | Model: {model_name}\n",
|
||
std::env::consts::OS,
|
||
);
|
||
|
||
// ── 8. Channel Capabilities ─────────────────────────────────────
|
||
prompt.push_str("## Channel Capabilities\n\n");
|
||
prompt.push_str("- You are running as a messaging bot. Your response is automatically sent back to the user's channel.\n");
|
||
prompt.push_str("- You do NOT need to ask permission to respond — just respond directly.\n");
|
||
prompt.push_str("- NEVER repeat, describe, or echo credentials, tokens, API keys, or secrets in your responses.\n");
|
||
prompt.push_str("- If a tool output contains credentials, they have already been redacted — do not mention them.\n\n");
|
||
|
||
if prompt.is_empty() {
|
||
"You are ZeroClaw, a fast and efficient AI assistant built in Rust. Be helpful, concise, and direct."
|
||
.to_string()
|
||
} else {
|
||
prompt
|
||
}
|
||
}
|
||
|
||
/// Inject a single workspace file into the prompt with truncation and missing-file markers.
|
||
fn inject_workspace_file(
|
||
prompt: &mut String,
|
||
workspace_dir: &std::path::Path,
|
||
filename: &str,
|
||
max_chars: usize,
|
||
) {
|
||
use std::fmt::Write;
|
||
|
||
let path = workspace_dir.join(filename);
|
||
match std::fs::read_to_string(&path) {
|
||
Ok(content) => {
|
||
let trimmed = content.trim();
|
||
if trimmed.is_empty() {
|
||
return;
|
||
}
|
||
let _ = writeln!(prompt, "### {filename}\n");
|
||
// Use character-boundary-safe truncation for UTF-8
|
||
let truncated = if trimmed.chars().count() > max_chars {
|
||
trimmed
|
||
.char_indices()
|
||
.nth(max_chars)
|
||
.map(|(idx, _)| &trimmed[..idx])
|
||
.unwrap_or(trimmed)
|
||
} else {
|
||
trimmed
|
||
};
|
||
if truncated.len() < trimmed.len() {
|
||
prompt.push_str(truncated);
|
||
let _ = writeln!(
|
||
prompt,
|
||
"\n\n[... truncated at {max_chars} chars — use `read` for full file]\n"
|
||
);
|
||
} else {
|
||
prompt.push_str(trimmed);
|
||
prompt.push_str("\n\n");
|
||
}
|
||
}
|
||
Err(_) => {
|
||
// Missing-file marker (matches OpenClaw behavior)
|
||
let _ = writeln!(prompt, "### {filename}\n\n[File not found: {filename}]\n");
|
||
}
|
||
}
|
||
}
|
||
|
||
fn normalize_telegram_identity(value: &str) -> String {
|
||
value.trim().trim_start_matches('@').to_string()
|
||
}
|
||
|
||
async fn bind_telegram_identity(config: &Config, identity: &str) -> Result<()> {
|
||
let normalized = normalize_telegram_identity(identity);
|
||
if normalized.is_empty() {
|
||
anyhow::bail!("Telegram identity cannot be empty");
|
||
}
|
||
|
||
let mut updated = config.clone();
|
||
let Some(telegram) = updated.channels_config.telegram.as_mut() else {
|
||
anyhow::bail!(
|
||
"Telegram channel is not configured. Run `zeroclaw onboard --channels-only` first"
|
||
);
|
||
};
|
||
|
||
if telegram.allowed_users.iter().any(|u| u == "*") {
|
||
println!(
|
||
"⚠️ Telegram allowlist is currently wildcard (`*`) — binding is unnecessary until you remove '*'."
|
||
);
|
||
}
|
||
|
||
if telegram
|
||
.allowed_users
|
||
.iter()
|
||
.map(|entry| normalize_telegram_identity(entry))
|
||
.any(|entry| entry == normalized)
|
||
{
|
||
println!("✅ Telegram identity already bound: {normalized}");
|
||
return Ok(());
|
||
}
|
||
|
||
telegram.allowed_users.push(normalized.clone());
|
||
updated.save().await?;
|
||
println!("✅ Bound Telegram identity: {normalized}");
|
||
println!(" Saved to {}", updated.config_path.display());
|
||
match maybe_restart_managed_daemon_service() {
|
||
Ok(true) => {
|
||
println!("🔄 Detected running managed daemon service; reloaded automatically.");
|
||
}
|
||
Ok(false) => {
|
||
println!(
|
||
"ℹ️ No managed daemon service detected. If `zeroclaw daemon`/`channel start` is already running, restart it to load the updated allowlist."
|
||
);
|
||
}
|
||
Err(e) => {
|
||
eprintln!(
|
||
"⚠️ Allowlist saved, but failed to reload daemon service automatically: {e}\n\
|
||
Restart service manually with `zeroclaw service stop && zeroclaw service start`."
|
||
);
|
||
}
|
||
}
|
||
Ok(())
|
||
}
|
||
|
||
fn maybe_restart_managed_daemon_service() -> Result<bool> {
|
||
if cfg!(target_os = "macos") {
|
||
let home = directories::UserDirs::new()
|
||
.map(|u| u.home_dir().to_path_buf())
|
||
.context("Could not find home directory")?;
|
||
let plist = home
|
||
.join("Library")
|
||
.join("LaunchAgents")
|
||
.join("com.zeroclaw.daemon.plist");
|
||
if !plist.exists() {
|
||
return Ok(false);
|
||
}
|
||
|
||
let list_output = Command::new("launchctl")
|
||
.arg("list")
|
||
.output()
|
||
.context("Failed to query launchctl list")?;
|
||
let listed = String::from_utf8_lossy(&list_output.stdout);
|
||
if !listed.contains("com.zeroclaw.daemon") {
|
||
return Ok(false);
|
||
}
|
||
|
||
let _ = Command::new("launchctl")
|
||
.args(["stop", "com.zeroclaw.daemon"])
|
||
.output();
|
||
let start_output = Command::new("launchctl")
|
||
.args(["start", "com.zeroclaw.daemon"])
|
||
.output()
|
||
.context("Failed to start launchd daemon service")?;
|
||
if !start_output.status.success() {
|
||
let stderr = String::from_utf8_lossy(&start_output.stderr);
|
||
anyhow::bail!("launchctl start failed: {}", stderr.trim());
|
||
}
|
||
|
||
return Ok(true);
|
||
}
|
||
|
||
if cfg!(target_os = "linux") {
|
||
let home = directories::UserDirs::new()
|
||
.map(|u| u.home_dir().to_path_buf())
|
||
.context("Could not find home directory")?;
|
||
let unit_path: PathBuf = home
|
||
.join(".config")
|
||
.join("systemd")
|
||
.join("user")
|
||
.join("zeroclaw.service");
|
||
if !unit_path.exists() {
|
||
return Ok(false);
|
||
}
|
||
|
||
let active_output = Command::new("systemctl")
|
||
.args(["--user", "is-active", "zeroclaw.service"])
|
||
.output()
|
||
.context("Failed to query systemd service state")?;
|
||
let state = String::from_utf8_lossy(&active_output.stdout);
|
||
if !state.trim().eq_ignore_ascii_case("active") {
|
||
return Ok(false);
|
||
}
|
||
|
||
let restart_output = Command::new("systemctl")
|
||
.args(["--user", "restart", "zeroclaw.service"])
|
||
.output()
|
||
.context("Failed to restart systemd daemon service")?;
|
||
if !restart_output.status.success() {
|
||
let stderr = String::from_utf8_lossy(&restart_output.stderr);
|
||
anyhow::bail!("systemctl restart failed: {}", stderr.trim());
|
||
}
|
||
|
||
return Ok(true);
|
||
}
|
||
|
||
Ok(false)
|
||
}
|
||
|
||
pub async fn handle_command(command: crate::ChannelCommands, config: &Config) -> Result<()> {
|
||
match command {
|
||
crate::ChannelCommands::Start => {
|
||
anyhow::bail!("Start must be handled in main.rs (requires async runtime)")
|
||
}
|
||
crate::ChannelCommands::Doctor => {
|
||
anyhow::bail!("Doctor must be handled in main.rs (requires async runtime)")
|
||
}
|
||
crate::ChannelCommands::List => {
|
||
println!("Channels:");
|
||
println!(" ✅ CLI (always available)");
|
||
for (name, configured) in [
|
||
("Telegram", config.channels_config.telegram.is_some()),
|
||
("Discord", config.channels_config.discord.is_some()),
|
||
("Slack", config.channels_config.slack.is_some()),
|
||
("Mattermost", config.channels_config.mattermost.is_some()),
|
||
("Webhook", config.channels_config.webhook.is_some()),
|
||
("iMessage", config.channels_config.imessage.is_some()),
|
||
(
|
||
"Matrix",
|
||
cfg!(feature = "channel-matrix") && config.channels_config.matrix.is_some(),
|
||
),
|
||
("Signal", config.channels_config.signal.is_some()),
|
||
("WhatsApp", config.channels_config.whatsapp.is_some()),
|
||
("Linq", config.channels_config.linq.is_some()),
|
||
("Email", config.channels_config.email.is_some()),
|
||
("IRC", config.channels_config.irc.is_some()),
|
||
("Lark", config.channels_config.lark.is_some()),
|
||
("DingTalk", config.channels_config.dingtalk.is_some()),
|
||
("QQ", config.channels_config.qq.is_some()),
|
||
] {
|
||
println!(" {} {name}", if configured { "✅" } else { "❌" });
|
||
}
|
||
if !cfg!(feature = "channel-matrix") {
|
||
println!(
|
||
" ℹ️ Matrix channel support is disabled in this build (enable `channel-matrix`)."
|
||
);
|
||
}
|
||
println!("\nTo start channels: zeroclaw channel start");
|
||
println!("To check health: zeroclaw channel doctor");
|
||
println!("To configure: zeroclaw onboard");
|
||
Ok(())
|
||
}
|
||
crate::ChannelCommands::Add {
|
||
channel_type,
|
||
config: _,
|
||
} => {
|
||
anyhow::bail!(
|
||
"Channel type '{channel_type}' — use `zeroclaw onboard` to configure channels"
|
||
);
|
||
}
|
||
crate::ChannelCommands::Remove { name } => {
|
||
anyhow::bail!("Remove channel '{name}' — edit ~/.zeroclaw/config.toml directly");
|
||
}
|
||
crate::ChannelCommands::BindTelegram { identity } => {
|
||
bind_telegram_identity(config, &identity).await
|
||
}
|
||
}
|
||
}
|
||
|
||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||
enum ChannelHealthState {
|
||
Healthy,
|
||
Unhealthy,
|
||
Timeout,
|
||
}
|
||
|
||
fn classify_health_result(
|
||
result: &std::result::Result<bool, tokio::time::error::Elapsed>,
|
||
) -> ChannelHealthState {
|
||
match result {
|
||
Ok(true) => ChannelHealthState::Healthy,
|
||
Ok(false) => ChannelHealthState::Unhealthy,
|
||
Err(_) => ChannelHealthState::Timeout,
|
||
}
|
||
}
|
||
|
||
/// Run health checks for configured channels.
|
||
pub async fn doctor_channels(config: Config) -> Result<()> {
|
||
let mut channels: Vec<(&'static str, Arc<dyn Channel>)> = Vec::new();
|
||
|
||
if let Some(ref tg) = config.channels_config.telegram {
|
||
channels.push((
|
||
"Telegram",
|
||
Arc::new(
|
||
TelegramChannel::new(
|
||
tg.bot_token.clone(),
|
||
tg.allowed_users.clone(),
|
||
tg.mention_only,
|
||
)
|
||
.with_streaming(tg.stream_mode, tg.draft_update_interval_ms),
|
||
),
|
||
));
|
||
}
|
||
|
||
if let Some(ref dc) = config.channels_config.discord {
|
||
channels.push((
|
||
"Discord",
|
||
Arc::new(DiscordChannel::new(
|
||
dc.bot_token.clone(),
|
||
dc.guild_id.clone(),
|
||
dc.allowed_users.clone(),
|
||
dc.listen_to_bots,
|
||
dc.mention_only,
|
||
)),
|
||
));
|
||
}
|
||
|
||
if let Some(ref sl) = config.channels_config.slack {
|
||
channels.push((
|
||
"Slack",
|
||
Arc::new(SlackChannel::new(
|
||
sl.bot_token.clone(),
|
||
sl.channel_id.clone(),
|
||
sl.allowed_users.clone(),
|
||
)),
|
||
));
|
||
}
|
||
|
||
if let Some(ref im) = config.channels_config.imessage {
|
||
channels.push((
|
||
"iMessage",
|
||
Arc::new(IMessageChannel::new(im.allowed_contacts.clone())),
|
||
));
|
||
}
|
||
|
||
#[cfg(feature = "channel-matrix")]
|
||
if let Some(ref mx) = config.channels_config.matrix {
|
||
channels.push((
|
||
"Matrix",
|
||
Arc::new(MatrixChannel::new_with_session_hint(
|
||
mx.homeserver.clone(),
|
||
mx.access_token.clone(),
|
||
mx.room_id.clone(),
|
||
mx.allowed_users.clone(),
|
||
mx.user_id.clone(),
|
||
mx.device_id.clone(),
|
||
)),
|
||
));
|
||
}
|
||
|
||
#[cfg(not(feature = "channel-matrix"))]
|
||
if config.channels_config.matrix.is_some() {
|
||
tracing::warn!(
|
||
"Matrix channel is configured but this build was compiled without `channel-matrix`; skipping Matrix health check."
|
||
);
|
||
}
|
||
|
||
if let Some(ref sig) = config.channels_config.signal {
|
||
channels.push((
|
||
"Signal",
|
||
Arc::new(SignalChannel::new(
|
||
sig.http_url.clone(),
|
||
sig.account.clone(),
|
||
sig.group_id.clone(),
|
||
sig.allowed_from.clone(),
|
||
sig.ignore_attachments,
|
||
sig.ignore_stories,
|
||
)),
|
||
));
|
||
}
|
||
|
||
if let Some(ref wa) = config.channels_config.whatsapp {
|
||
// Runtime negotiation: detect backend type from config
|
||
match wa.backend_type() {
|
||
"cloud" => {
|
||
// Cloud API mode: requires phone_number_id, access_token, verify_token
|
||
if wa.is_cloud_config() {
|
||
channels.push((
|
||
"WhatsApp",
|
||
Arc::new(WhatsAppChannel::new(
|
||
wa.access_token.clone().unwrap_or_default(),
|
||
wa.phone_number_id.clone().unwrap_or_default(),
|
||
wa.verify_token.clone().unwrap_or_default(),
|
||
wa.allowed_numbers.clone(),
|
||
)),
|
||
));
|
||
} else {
|
||
tracing::warn!("WhatsApp Cloud API configured but missing required fields (phone_number_id, access_token, verify_token)");
|
||
}
|
||
}
|
||
"web" => {
|
||
// Web mode: requires session_path
|
||
#[cfg(feature = "whatsapp-web")]
|
||
if wa.is_web_config() {
|
||
channels.push((
|
||
"WhatsApp",
|
||
Arc::new(WhatsAppWebChannel::new(
|
||
wa.session_path.clone().unwrap_or_default(),
|
||
wa.pair_phone.clone(),
|
||
wa.pair_code.clone(),
|
||
wa.allowed_numbers.clone(),
|
||
)),
|
||
));
|
||
} else {
|
||
tracing::warn!("WhatsApp Web configured but session_path not set");
|
||
}
|
||
#[cfg(not(feature = "whatsapp-web"))]
|
||
{
|
||
tracing::warn!("WhatsApp Web backend requires 'whatsapp-web' feature. Enable with: cargo build --features whatsapp-web");
|
||
}
|
||
}
|
||
_ => {
|
||
tracing::warn!("WhatsApp config invalid: neither phone_number_id (Cloud API) nor session_path (Web) is set");
|
||
}
|
||
}
|
||
}
|
||
|
||
if let Some(ref lq) = config.channels_config.linq {
|
||
channels.push((
|
||
"Linq",
|
||
Arc::new(LinqChannel::new(
|
||
lq.api_token.clone(),
|
||
lq.from_phone.clone(),
|
||
lq.allowed_senders.clone(),
|
||
)),
|
||
));
|
||
}
|
||
|
||
if let Some(ref email_cfg) = config.channels_config.email {
|
||
channels.push(("Email", Arc::new(EmailChannel::new(email_cfg.clone()))));
|
||
}
|
||
|
||
if let Some(ref irc) = config.channels_config.irc {
|
||
channels.push((
|
||
"IRC",
|
||
Arc::new(IrcChannel::new(irc::IrcChannelConfig {
|
||
server: irc.server.clone(),
|
||
port: irc.port,
|
||
nickname: irc.nickname.clone(),
|
||
username: irc.username.clone(),
|
||
channels: irc.channels.clone(),
|
||
allowed_users: irc.allowed_users.clone(),
|
||
server_password: irc.server_password.clone(),
|
||
nickserv_password: irc.nickserv_password.clone(),
|
||
sasl_password: irc.sasl_password.clone(),
|
||
verify_tls: irc.verify_tls.unwrap_or(true),
|
||
})),
|
||
));
|
||
}
|
||
|
||
if let Some(ref lk) = config.channels_config.lark {
|
||
channels.push(("Lark", Arc::new(LarkChannel::from_config(lk))));
|
||
}
|
||
|
||
if let Some(ref dt) = config.channels_config.dingtalk {
|
||
channels.push((
|
||
"DingTalk",
|
||
Arc::new(DingTalkChannel::new(
|
||
dt.client_id.clone(),
|
||
dt.client_secret.clone(),
|
||
dt.allowed_users.clone(),
|
||
)),
|
||
));
|
||
}
|
||
|
||
if let Some(ref qq) = config.channels_config.qq {
|
||
channels.push((
|
||
"QQ",
|
||
Arc::new(QQChannel::new(
|
||
qq.app_id.clone(),
|
||
qq.app_secret.clone(),
|
||
qq.allowed_users.clone(),
|
||
)),
|
||
));
|
||
}
|
||
|
||
if channels.is_empty() {
|
||
println!("No real-time channels configured. Run `zeroclaw onboard` first.");
|
||
return Ok(());
|
||
}
|
||
|
||
println!("🩺 ZeroClaw Channel Doctor");
|
||
println!();
|
||
|
||
let mut healthy = 0_u32;
|
||
let mut unhealthy = 0_u32;
|
||
let mut timeout = 0_u32;
|
||
|
||
for (name, channel) in channels {
|
||
let result = tokio::time::timeout(Duration::from_secs(10), channel.health_check()).await;
|
||
let state = classify_health_result(&result);
|
||
|
||
match state {
|
||
ChannelHealthState::Healthy => {
|
||
healthy += 1;
|
||
println!(" ✅ {name:<9} healthy");
|
||
}
|
||
ChannelHealthState::Unhealthy => {
|
||
unhealthy += 1;
|
||
println!(" ❌ {name:<9} unhealthy (auth/config/network)");
|
||
}
|
||
ChannelHealthState::Timeout => {
|
||
timeout += 1;
|
||
println!(" ⏱️ {name:<9} timed out (>10s)");
|
||
}
|
||
}
|
||
}
|
||
|
||
if config.channels_config.webhook.is_some() {
|
||
println!(" ℹ️ Webhook check via `zeroclaw gateway` then GET /health");
|
||
}
|
||
|
||
println!();
|
||
println!("Summary: {healthy} healthy, {unhealthy} unhealthy, {timeout} timed out");
|
||
Ok(())
|
||
}
|
||
|
||
/// Start all configured channels and route messages to the agent
|
||
#[allow(clippy::too_many_lines)]
|
||
pub async fn start_channels(config: Config) -> Result<()> {
|
||
let provider_name = config
|
||
.default_provider
|
||
.clone()
|
||
.unwrap_or_else(|| "openrouter".into());
|
||
let provider_runtime_options = providers::ProviderRuntimeOptions {
|
||
auth_profile_override: None,
|
||
zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from),
|
||
secrets_encrypt: config.secrets.encrypt,
|
||
reasoning_enabled: config.runtime.reasoning_enabled,
|
||
};
|
||
let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider_with_options(
|
||
&provider_name,
|
||
config.api_key.as_deref(),
|
||
config.api_url.as_deref(),
|
||
&config.reliability,
|
||
&provider_runtime_options,
|
||
)?);
|
||
|
||
// Warm up the provider connection pool (TLS handshake, DNS, HTTP/2 setup)
|
||
// so the first real message doesn't hit a cold-start timeout.
|
||
if let Err(e) = provider.warmup().await {
|
||
tracing::warn!("Provider warmup failed (non-fatal): {e}");
|
||
}
|
||
|
||
let observer: Arc<dyn Observer> =
|
||
Arc::from(observability::create_observer(&config.observability));
|
||
let runtime: Arc<dyn runtime::RuntimeAdapter> =
|
||
Arc::from(runtime::create_runtime(&config.runtime)?);
|
||
let security = Arc::new(SecurityPolicy::from_config(
|
||
&config.autonomy,
|
||
&config.workspace_dir,
|
||
));
|
||
let model = config
|
||
.default_model
|
||
.clone()
|
||
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into());
|
||
let temperature = config.default_temperature;
|
||
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory_with_storage(
|
||
&config.memory,
|
||
Some(&config.storage.provider.config),
|
||
&config.workspace_dir,
|
||
config.api_key.as_deref(),
|
||
)?);
|
||
let (composio_key, composio_entity_id) = if config.composio.enabled {
|
||
(
|
||
config.composio.api_key.as_deref(),
|
||
Some(config.composio.entity_id.as_str()),
|
||
)
|
||
} else {
|
||
(None, None)
|
||
};
|
||
// Build system prompt from workspace identity files + skills
|
||
let workspace = config.workspace_dir.clone();
|
||
let tools_registry = Arc::new(tools::all_tools_with_runtime(
|
||
Arc::new(config.clone()),
|
||
&security,
|
||
runtime,
|
||
Arc::clone(&mem),
|
||
composio_key,
|
||
composio_entity_id,
|
||
&config.browser,
|
||
&config.http_request,
|
||
&workspace,
|
||
&config.agents,
|
||
config.api_key.as_deref(),
|
||
&config,
|
||
));
|
||
|
||
let skills = crate::skills::load_skills(&workspace);
|
||
|
||
// Collect tool descriptions for the prompt
|
||
let mut tool_descs: Vec<(&str, &str)> = vec![
|
||
(
|
||
"shell",
|
||
"Execute terminal commands. Use when: running local checks, build/test commands, diagnostics. Don't use when: a safer dedicated tool exists, or command is destructive without approval.",
|
||
),
|
||
(
|
||
"file_read",
|
||
"Read file contents. Use when: inspecting project files, configs, logs. Don't use when: a targeted search is enough.",
|
||
),
|
||
(
|
||
"file_write",
|
||
"Write file contents. Use when: applying focused edits, scaffolding files, updating docs/code. Don't use when: side effects are unclear or file ownership is uncertain.",
|
||
),
|
||
(
|
||
"memory_store",
|
||
"Save to memory. Use when: preserving durable preferences, decisions, key context. Don't use when: information is transient/noisy/sensitive without need.",
|
||
),
|
||
(
|
||
"memory_recall",
|
||
"Search memory. Use when: retrieving prior decisions, user preferences, historical context. Don't use when: answer is already in current context.",
|
||
),
|
||
(
|
||
"memory_forget",
|
||
"Delete a memory entry. Use when: memory is incorrect/stale or explicitly requested for removal. Don't use when: impact is uncertain.",
|
||
),
|
||
];
|
||
|
||
if config.browser.enabled {
|
||
tool_descs.push((
|
||
"browser_open",
|
||
"Open approved HTTPS URLs in Brave Browser (allowlist-only, no scraping)",
|
||
));
|
||
}
|
||
if config.composio.enabled {
|
||
tool_descs.push((
|
||
"composio",
|
||
"Execute actions on 1000+ apps via Composio (Gmail, Notion, GitHub, Slack, etc.). Use action='list' to discover actions, 'list_accounts' to retrieve connected account IDs, 'execute' to run (optionally with connected_account_id), and 'connect' for OAuth.",
|
||
));
|
||
}
|
||
tool_descs.push((
|
||
"schedule",
|
||
"Manage scheduled tasks (create/list/get/cancel/pause/resume). Supports recurring cron and one-shot delays.",
|
||
));
|
||
tool_descs.push((
|
||
"pushover",
|
||
"Send a Pushover notification to your device. Requires PUSHOVER_TOKEN and PUSHOVER_USER_KEY in .env file.",
|
||
));
|
||
if !config.agents.is_empty() {
|
||
tool_descs.push((
|
||
"delegate",
|
||
"Delegate a subtask to a specialized agent. Use when: a task benefits from a different model (e.g. fast summarization, deep reasoning, code generation). The sub-agent runs a single prompt and returns its response.",
|
||
));
|
||
}
|
||
|
||
let bootstrap_max_chars = if config.agent.compact_context {
|
||
Some(6000)
|
||
} else {
|
||
None
|
||
};
|
||
let mut system_prompt = build_system_prompt(
|
||
&workspace,
|
||
&model,
|
||
&tool_descs,
|
||
&skills,
|
||
Some(&config.identity),
|
||
bootstrap_max_chars,
|
||
);
|
||
system_prompt.push_str(&build_tool_instructions(tools_registry.as_ref()));
|
||
|
||
if !skills.is_empty() {
|
||
println!(
|
||
" 🧩 Skills: {}",
|
||
skills
|
||
.iter()
|
||
.map(|s| s.name.as_str())
|
||
.collect::<Vec<_>>()
|
||
.join(", ")
|
||
);
|
||
}
|
||
|
||
// Collect active channels
|
||
let mut channels: Vec<Arc<dyn Channel>> = Vec::new();
|
||
|
||
if let Some(ref tg) = config.channels_config.telegram {
|
||
channels.push(Arc::new(
|
||
TelegramChannel::new(
|
||
tg.bot_token.clone(),
|
||
tg.allowed_users.clone(),
|
||
tg.mention_only,
|
||
)
|
||
.with_streaming(tg.stream_mode, tg.draft_update_interval_ms),
|
||
));
|
||
}
|
||
|
||
if let Some(ref dc) = config.channels_config.discord {
|
||
channels.push(Arc::new(DiscordChannel::new(
|
||
dc.bot_token.clone(),
|
||
dc.guild_id.clone(),
|
||
dc.allowed_users.clone(),
|
||
dc.listen_to_bots,
|
||
dc.mention_only,
|
||
)));
|
||
}
|
||
|
||
if let Some(ref sl) = config.channels_config.slack {
|
||
channels.push(Arc::new(SlackChannel::new(
|
||
sl.bot_token.clone(),
|
||
sl.channel_id.clone(),
|
||
sl.allowed_users.clone(),
|
||
)));
|
||
}
|
||
|
||
if let Some(ref mm) = config.channels_config.mattermost {
|
||
channels.push(Arc::new(MattermostChannel::new(
|
||
mm.url.clone(),
|
||
mm.bot_token.clone(),
|
||
mm.channel_id.clone(),
|
||
mm.allowed_users.clone(),
|
||
mm.thread_replies.unwrap_or(true),
|
||
mm.mention_only.unwrap_or(false),
|
||
)));
|
||
}
|
||
|
||
if let Some(ref im) = config.channels_config.imessage {
|
||
channels.push(Arc::new(IMessageChannel::new(im.allowed_contacts.clone())));
|
||
}
|
||
|
||
#[cfg(feature = "channel-matrix")]
|
||
if let Some(ref mx) = config.channels_config.matrix {
|
||
channels.push(Arc::new(MatrixChannel::new_with_session_hint(
|
||
mx.homeserver.clone(),
|
||
mx.access_token.clone(),
|
||
mx.room_id.clone(),
|
||
mx.allowed_users.clone(),
|
||
mx.user_id.clone(),
|
||
mx.device_id.clone(),
|
||
)));
|
||
}
|
||
|
||
#[cfg(not(feature = "channel-matrix"))]
|
||
if config.channels_config.matrix.is_some() {
|
||
tracing::warn!(
|
||
"Matrix channel is configured but this build was compiled without `channel-matrix`; skipping Matrix runtime startup."
|
||
);
|
||
}
|
||
|
||
if let Some(ref sig) = config.channels_config.signal {
|
||
channels.push(Arc::new(SignalChannel::new(
|
||
sig.http_url.clone(),
|
||
sig.account.clone(),
|
||
sig.group_id.clone(),
|
||
sig.allowed_from.clone(),
|
||
sig.ignore_attachments,
|
||
sig.ignore_stories,
|
||
)));
|
||
}
|
||
|
||
if let Some(ref wa) = config.channels_config.whatsapp {
|
||
// Runtime negotiation: detect backend type from config
|
||
match wa.backend_type() {
|
||
"cloud" => {
|
||
// Cloud API mode: requires phone_number_id, access_token, verify_token
|
||
if wa.is_cloud_config() {
|
||
channels.push(Arc::new(WhatsAppChannel::new(
|
||
wa.access_token.clone().unwrap_or_default(),
|
||
wa.phone_number_id.clone().unwrap_or_default(),
|
||
wa.verify_token.clone().unwrap_or_default(),
|
||
wa.allowed_numbers.clone(),
|
||
)));
|
||
} else {
|
||
tracing::warn!("WhatsApp Cloud API configured but missing required fields (phone_number_id, access_token, verify_token)");
|
||
}
|
||
}
|
||
"web" => {
|
||
// Web mode: requires session_path
|
||
#[cfg(feature = "whatsapp-web")]
|
||
if wa.is_web_config() {
|
||
channels.push(Arc::new(WhatsAppWebChannel::new(
|
||
wa.session_path.clone().unwrap_or_default(),
|
||
wa.pair_phone.clone(),
|
||
wa.pair_code.clone(),
|
||
wa.allowed_numbers.clone(),
|
||
)));
|
||
} else {
|
||
tracing::warn!("WhatsApp Web configured but session_path not set");
|
||
}
|
||
#[cfg(not(feature = "whatsapp-web"))]
|
||
{
|
||
tracing::warn!("WhatsApp Web backend requires 'whatsapp-web' feature. Enable with: cargo build --features whatsapp-web");
|
||
}
|
||
}
|
||
_ => {
|
||
tracing::warn!("WhatsApp config invalid: neither phone_number_id (Cloud API) nor session_path (Web) is set");
|
||
}
|
||
}
|
||
}
|
||
|
||
if let Some(ref lq) = config.channels_config.linq {
|
||
channels.push(Arc::new(LinqChannel::new(
|
||
lq.api_token.clone(),
|
||
lq.from_phone.clone(),
|
||
lq.allowed_senders.clone(),
|
||
)));
|
||
}
|
||
|
||
if let Some(ref email_cfg) = config.channels_config.email {
|
||
channels.push(Arc::new(EmailChannel::new(email_cfg.clone())));
|
||
}
|
||
|
||
if let Some(ref irc) = config.channels_config.irc {
|
||
channels.push(Arc::new(IrcChannel::new(irc::IrcChannelConfig {
|
||
server: irc.server.clone(),
|
||
port: irc.port,
|
||
nickname: irc.nickname.clone(),
|
||
username: irc.username.clone(),
|
||
channels: irc.channels.clone(),
|
||
allowed_users: irc.allowed_users.clone(),
|
||
server_password: irc.server_password.clone(),
|
||
nickserv_password: irc.nickserv_password.clone(),
|
||
sasl_password: irc.sasl_password.clone(),
|
||
verify_tls: irc.verify_tls.unwrap_or(true),
|
||
})));
|
||
}
|
||
|
||
if let Some(ref lk) = config.channels_config.lark {
|
||
channels.push(Arc::new(LarkChannel::from_config(lk)));
|
||
}
|
||
|
||
if let Some(ref dt) = config.channels_config.dingtalk {
|
||
channels.push(Arc::new(DingTalkChannel::new(
|
||
dt.client_id.clone(),
|
||
dt.client_secret.clone(),
|
||
dt.allowed_users.clone(),
|
||
)));
|
||
}
|
||
|
||
if let Some(ref qq) = config.channels_config.qq {
|
||
channels.push(Arc::new(QQChannel::new(
|
||
qq.app_id.clone(),
|
||
qq.app_secret.clone(),
|
||
qq.allowed_users.clone(),
|
||
)));
|
||
}
|
||
|
||
if channels.is_empty() {
|
||
println!("No channels configured. Run `zeroclaw onboard` to set up channels.");
|
||
return Ok(());
|
||
}
|
||
|
||
println!("🦀 ZeroClaw Channel Server");
|
||
println!(" 🤖 Model: {model}");
|
||
let effective_backend = memory::effective_memory_backend_name(
|
||
&config.memory.backend,
|
||
Some(&config.storage.provider.config),
|
||
);
|
||
println!(
|
||
" 🧠 Memory: {} (auto-save: {})",
|
||
effective_backend,
|
||
if config.memory.auto_save { "on" } else { "off" }
|
||
);
|
||
println!(
|
||
" 📡 Channels: {}",
|
||
channels
|
||
.iter()
|
||
.map(|c| c.name())
|
||
.collect::<Vec<_>>()
|
||
.join(", ")
|
||
);
|
||
println!();
|
||
println!(" Listening for messages... (Ctrl+C to stop)");
|
||
println!();
|
||
|
||
crate::health::mark_component_ok("channels");
|
||
|
||
let initial_backoff_secs = config
|
||
.reliability
|
||
.channel_initial_backoff_secs
|
||
.max(DEFAULT_CHANNEL_INITIAL_BACKOFF_SECS);
|
||
let max_backoff_secs = config
|
||
.reliability
|
||
.channel_max_backoff_secs
|
||
.max(DEFAULT_CHANNEL_MAX_BACKOFF_SECS);
|
||
|
||
// Single message bus — all channels send messages here
|
||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(100);
|
||
|
||
// Spawn a listener for each channel
|
||
let mut handles = Vec::new();
|
||
for ch in &channels {
|
||
handles.push(spawn_supervised_listener(
|
||
ch.clone(),
|
||
tx.clone(),
|
||
initial_backoff_secs,
|
||
max_backoff_secs,
|
||
));
|
||
}
|
||
drop(tx); // Drop our copy so rx closes when all channels stop
|
||
|
||
let channels_by_name = Arc::new(
|
||
channels
|
||
.iter()
|
||
.map(|ch| (ch.name().to_string(), Arc::clone(ch)))
|
||
.collect::<HashMap<_, _>>(),
|
||
);
|
||
let max_in_flight_messages = compute_max_in_flight_messages(channels.len());
|
||
|
||
println!(" 🚦 In-flight message limit: {max_in_flight_messages}");
|
||
|
||
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
|
||
provider_cache_seed.insert(provider_name.clone(), Arc::clone(&provider));
|
||
let message_timeout_secs =
|
||
effective_channel_message_timeout_secs(config.channels_config.message_timeout_secs);
|
||
let interrupt_on_new_message = config
|
||
.channels_config
|
||
.telegram
|
||
.as_ref()
|
||
.is_some_and(|tg| tg.interrupt_on_new_message);
|
||
|
||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||
channels_by_name,
|
||
provider: Arc::clone(&provider),
|
||
default_provider: Arc::new(provider_name),
|
||
memory: Arc::clone(&mem),
|
||
tools_registry: Arc::clone(&tools_registry),
|
||
observer,
|
||
system_prompt: Arc::new(system_prompt),
|
||
model: Arc::new(model.clone()),
|
||
temperature,
|
||
auto_save_memory: config.memory.auto_save,
|
||
max_tool_iterations: config.agent.max_tool_iterations,
|
||
min_relevance_score: config.memory.min_relevance_score,
|
||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||
api_key: config.api_key.clone(),
|
||
api_url: config.api_url.clone(),
|
||
reliability: Arc::new(config.reliability.clone()),
|
||
provider_runtime_options,
|
||
workspace_dir: Arc::new(config.workspace_dir.clone()),
|
||
message_timeout_secs,
|
||
interrupt_on_new_message,
|
||
multimodal: config.multimodal.clone(),
|
||
});
|
||
|
||
run_message_dispatch_loop(rx, runtime_ctx, max_in_flight_messages).await;
|
||
|
||
// Wait for all channel tasks
|
||
for h in handles {
|
||
let _ = h.await;
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use crate::memory::{Memory, MemoryCategory, SqliteMemory};
|
||
use crate::observability::NoopObserver;
|
||
use crate::providers::{ChatMessage, Provider};
|
||
use crate::tools::{Tool, ToolResult};
|
||
use std::collections::HashMap;
|
||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||
use std::sync::Arc;
|
||
use tempfile::TempDir;
|
||
|
||
fn make_workspace() -> TempDir {
|
||
let tmp = TempDir::new().unwrap();
|
||
// Create minimal workspace files
|
||
std::fs::write(tmp.path().join("SOUL.md"), "# Soul\nBe helpful.").unwrap();
|
||
std::fs::write(tmp.path().join("IDENTITY.md"), "# Identity\nName: ZeroClaw").unwrap();
|
||
std::fs::write(tmp.path().join("USER.md"), "# User\nName: Test User").unwrap();
|
||
std::fs::write(
|
||
tmp.path().join("AGENTS.md"),
|
||
"# Agents\nFollow instructions.",
|
||
)
|
||
.unwrap();
|
||
std::fs::write(tmp.path().join("TOOLS.md"), "# Tools\nUse shell carefully.").unwrap();
|
||
std::fs::write(
|
||
tmp.path().join("HEARTBEAT.md"),
|
||
"# Heartbeat\nCheck status.",
|
||
)
|
||
.unwrap();
|
||
std::fs::write(tmp.path().join("MEMORY.md"), "# Memory\nUser likes Rust.").unwrap();
|
||
tmp
|
||
}
|
||
|
||
#[test]
|
||
fn effective_channel_message_timeout_secs_clamps_to_minimum() {
|
||
assert_eq!(
|
||
effective_channel_message_timeout_secs(0),
|
||
MIN_CHANNEL_MESSAGE_TIMEOUT_SECS
|
||
);
|
||
assert_eq!(
|
||
effective_channel_message_timeout_secs(15),
|
||
MIN_CHANNEL_MESSAGE_TIMEOUT_SECS
|
||
);
|
||
assert_eq!(effective_channel_message_timeout_secs(300), 300);
|
||
}
|
||
|
||
#[test]
|
||
fn context_window_overflow_error_detector_matches_known_messages() {
|
||
let overflow_err = anyhow::anyhow!(
|
||
"OpenAI Codex stream error: Your input exceeds the context window of this model."
|
||
);
|
||
assert!(is_context_window_overflow_error(&overflow_err));
|
||
|
||
let other_err =
|
||
anyhow::anyhow!("OpenAI Codex API error (502 Bad Gateway): error code: 502");
|
||
assert!(!is_context_window_overflow_error(&other_err));
|
||
}
|
||
|
||
#[test]
|
||
fn memory_context_skip_rules_exclude_history_blobs() {
|
||
assert!(should_skip_memory_context_entry(
|
||
"telegram_123_history",
|
||
r#"[{"role":"user"}]"#
|
||
));
|
||
assert!(should_skip_memory_context_entry(
|
||
"assistant_resp_legacy",
|
||
"fabricated memory"
|
||
));
|
||
assert!(!should_skip_memory_context_entry("telegram_123_45", "hi"));
|
||
}
|
||
|
||
#[test]
|
||
fn compact_sender_history_keeps_recent_truncated_messages() {
|
||
let mut histories = HashMap::new();
|
||
let sender = "telegram_u1".to_string();
|
||
histories.insert(
|
||
sender.clone(),
|
||
(0..20)
|
||
.map(|idx| {
|
||
let content = format!("msg-{idx}-{}", "x".repeat(700));
|
||
if idx % 2 == 0 {
|
||
ChatMessage::user(content)
|
||
} else {
|
||
ChatMessage::assistant(content)
|
||
}
|
||
})
|
||
.collect::<Vec<_>>(),
|
||
);
|
||
|
||
let ctx = ChannelRuntimeContext {
|
||
channels_by_name: Arc::new(HashMap::new()),
|
||
provider: Arc::new(DummyProvider),
|
||
default_provider: Arc::new("test-provider".to_string()),
|
||
memory: Arc::new(NoopMemory),
|
||
tools_registry: Arc::new(vec![]),
|
||
observer: Arc::new(NoopObserver),
|
||
system_prompt: Arc::new("system".to_string()),
|
||
model: Arc::new("test-model".to_string()),
|
||
temperature: 0.0,
|
||
auto_save_memory: false,
|
||
max_tool_iterations: 5,
|
||
min_relevance_score: 0.0,
|
||
conversation_histories: Arc::new(Mutex::new(histories)),
|
||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||
api_key: None,
|
||
api_url: None,
|
||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||
interrupt_on_new_message: false,
|
||
multimodal: crate::config::MultimodalConfig::default(),
|
||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||
};
|
||
|
||
assert!(compact_sender_history(&ctx, &sender));
|
||
|
||
let histories = ctx
|
||
.conversation_histories
|
||
.lock()
|
||
.unwrap_or_else(|e| e.into_inner());
|
||
let kept = histories
|
||
.get(&sender)
|
||
.expect("sender history should remain");
|
||
assert_eq!(kept.len(), CHANNEL_HISTORY_COMPACT_KEEP_MESSAGES);
|
||
assert!(kept.iter().all(|turn| {
|
||
let len = turn.content.chars().count();
|
||
len <= CHANNEL_HISTORY_COMPACT_CONTENT_CHARS
|
||
|| (len <= CHANNEL_HISTORY_COMPACT_CONTENT_CHARS + 3
|
||
&& turn.content.ends_with("..."))
|
||
}));
|
||
}
|
||
|
||
struct DummyProvider;
|
||
|
||
#[async_trait::async_trait]
|
||
impl Provider for DummyProvider {
|
||
async fn chat_with_system(
|
||
&self,
|
||
_system_prompt: Option<&str>,
|
||
_message: &str,
|
||
_model: &str,
|
||
_temperature: f64,
|
||
) -> anyhow::Result<String> {
|
||
Ok("ok".to_string())
|
||
}
|
||
}
|
||
|
||
#[derive(Default)]
|
||
struct RecordingChannel {
|
||
sent_messages: tokio::sync::Mutex<Vec<String>>,
|
||
start_typing_calls: AtomicUsize,
|
||
stop_typing_calls: AtomicUsize,
|
||
}
|
||
|
||
#[derive(Default)]
|
||
struct TelegramRecordingChannel {
|
||
sent_messages: tokio::sync::Mutex<Vec<String>>,
|
||
}
|
||
|
||
#[async_trait::async_trait]
|
||
impl Channel for TelegramRecordingChannel {
|
||
fn name(&self) -> &str {
|
||
"telegram"
|
||
}
|
||
|
||
async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
|
||
self.sent_messages
|
||
.lock()
|
||
.await
|
||
.push(format!("{}:{}", message.recipient, message.content));
|
||
Ok(())
|
||
}
|
||
|
||
async fn listen(
|
||
&self,
|
||
_tx: tokio::sync::mpsc::Sender<traits::ChannelMessage>,
|
||
) -> anyhow::Result<()> {
|
||
Ok(())
|
||
}
|
||
|
||
async fn start_typing(&self, _recipient: &str) -> anyhow::Result<()> {
|
||
Ok(())
|
||
}
|
||
|
||
async fn stop_typing(&self, _recipient: &str) -> anyhow::Result<()> {
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
#[async_trait::async_trait]
|
||
impl Channel for RecordingChannel {
|
||
fn name(&self) -> &str {
|
||
"test-channel"
|
||
}
|
||
|
||
async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
|
||
self.sent_messages
|
||
.lock()
|
||
.await
|
||
.push(format!("{}:{}", message.recipient, message.content));
|
||
Ok(())
|
||
}
|
||
|
||
async fn listen(
|
||
&self,
|
||
_tx: tokio::sync::mpsc::Sender<traits::ChannelMessage>,
|
||
) -> anyhow::Result<()> {
|
||
Ok(())
|
||
}
|
||
|
||
async fn start_typing(&self, _recipient: &str) -> anyhow::Result<()> {
|
||
self.start_typing_calls.fetch_add(1, Ordering::SeqCst);
|
||
Ok(())
|
||
}
|
||
|
||
async fn stop_typing(&self, _recipient: &str) -> anyhow::Result<()> {
|
||
self.stop_typing_calls.fetch_add(1, Ordering::SeqCst);
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
struct SlowProvider {
|
||
delay: Duration,
|
||
}
|
||
|
||
#[async_trait::async_trait]
|
||
impl Provider for SlowProvider {
|
||
async fn chat_with_system(
|
||
&self,
|
||
_system_prompt: Option<&str>,
|
||
message: &str,
|
||
_model: &str,
|
||
_temperature: f64,
|
||
) -> anyhow::Result<String> {
|
||
tokio::time::sleep(self.delay).await;
|
||
Ok(format!("echo: {message}"))
|
||
}
|
||
}
|
||
|
||
struct ToolCallingProvider;
|
||
|
||
fn tool_call_payload() -> String {
|
||
r#"<tool_call>
|
||
{"name":"mock_price","arguments":{"symbol":"BTC"}}
|
||
</tool_call>"#
|
||
.to_string()
|
||
}
|
||
|
||
fn tool_call_payload_with_alias_tag() -> String {
|
||
r#"<toolcall>
|
||
{"name":"mock_price","arguments":{"symbol":"BTC"}}
|
||
</toolcall>"#
|
||
.to_string()
|
||
}
|
||
|
||
#[async_trait::async_trait]
|
||
impl Provider for ToolCallingProvider {
|
||
async fn chat_with_system(
|
||
&self,
|
||
_system_prompt: Option<&str>,
|
||
_message: &str,
|
||
_model: &str,
|
||
_temperature: f64,
|
||
) -> anyhow::Result<String> {
|
||
Ok(tool_call_payload())
|
||
}
|
||
|
||
async fn chat_with_history(
|
||
&self,
|
||
messages: &[ChatMessage],
|
||
_model: &str,
|
||
_temperature: f64,
|
||
) -> anyhow::Result<String> {
|
||
let has_tool_results = messages
|
||
.iter()
|
||
.any(|msg| msg.role == "user" && msg.content.contains("[Tool results]"));
|
||
if has_tool_results {
|
||
Ok("BTC is currently around $65,000 based on latest tool output.".to_string())
|
||
} else {
|
||
Ok(tool_call_payload())
|
||
}
|
||
}
|
||
}
|
||
|
||
struct ToolCallingAliasProvider;
|
||
|
||
#[async_trait::async_trait]
|
||
impl Provider for ToolCallingAliasProvider {
|
||
async fn chat_with_system(
|
||
&self,
|
||
_system_prompt: Option<&str>,
|
||
_message: &str,
|
||
_model: &str,
|
||
_temperature: f64,
|
||
) -> anyhow::Result<String> {
|
||
Ok(tool_call_payload_with_alias_tag())
|
||
}
|
||
|
||
async fn chat_with_history(
|
||
&self,
|
||
messages: &[ChatMessage],
|
||
_model: &str,
|
||
_temperature: f64,
|
||
) -> anyhow::Result<String> {
|
||
let has_tool_results = messages
|
||
.iter()
|
||
.any(|msg| msg.role == "user" && msg.content.contains("[Tool results]"));
|
||
if has_tool_results {
|
||
Ok("BTC alias-tag flow resolved to final text output.".to_string())
|
||
} else {
|
||
Ok(tool_call_payload_with_alias_tag())
|
||
}
|
||
}
|
||
}
|
||
|
||
struct IterativeToolProvider {
|
||
required_tool_iterations: usize,
|
||
}
|
||
|
||
impl IterativeToolProvider {
|
||
fn completed_tool_iterations(messages: &[ChatMessage]) -> usize {
|
||
messages
|
||
.iter()
|
||
.filter(|msg| msg.role == "user" && msg.content.contains("[Tool results]"))
|
||
.count()
|
||
}
|
||
}
|
||
|
||
#[async_trait::async_trait]
|
||
impl Provider for IterativeToolProvider {
|
||
async fn chat_with_system(
|
||
&self,
|
||
_system_prompt: Option<&str>,
|
||
_message: &str,
|
||
_model: &str,
|
||
_temperature: f64,
|
||
) -> anyhow::Result<String> {
|
||
Ok(tool_call_payload())
|
||
}
|
||
|
||
async fn chat_with_history(
|
||
&self,
|
||
messages: &[ChatMessage],
|
||
_model: &str,
|
||
_temperature: f64,
|
||
) -> anyhow::Result<String> {
|
||
let completed_iterations = Self::completed_tool_iterations(messages);
|
||
if completed_iterations >= self.required_tool_iterations {
|
||
Ok(format!(
|
||
"Completed after {completed_iterations} tool iterations."
|
||
))
|
||
} else {
|
||
Ok(tool_call_payload())
|
||
}
|
||
}
|
||
}
|
||
|
||
#[derive(Default)]
|
||
struct HistoryCaptureProvider {
|
||
calls: std::sync::Mutex<Vec<Vec<(String, String)>>>,
|
||
}
|
||
|
||
#[async_trait::async_trait]
|
||
impl Provider for HistoryCaptureProvider {
|
||
async fn chat_with_system(
|
||
&self,
|
||
_system_prompt: Option<&str>,
|
||
_message: &str,
|
||
_model: &str,
|
||
_temperature: f64,
|
||
) -> anyhow::Result<String> {
|
||
Ok("fallback".to_string())
|
||
}
|
||
|
||
async fn chat_with_history(
|
||
&self,
|
||
messages: &[ChatMessage],
|
||
_model: &str,
|
||
_temperature: f64,
|
||
) -> anyhow::Result<String> {
|
||
let snapshot = messages
|
||
.iter()
|
||
.map(|m| (m.role.clone(), m.content.clone()))
|
||
.collect::<Vec<_>>();
|
||
let mut calls = self.calls.lock().unwrap_or_else(|e| e.into_inner());
|
||
calls.push(snapshot);
|
||
Ok(format!("response-{}", calls.len()))
|
||
}
|
||
}
|
||
|
||
struct DelayedHistoryCaptureProvider {
|
||
delay: Duration,
|
||
calls: std::sync::Mutex<Vec<Vec<(String, String)>>>,
|
||
}
|
||
|
||
#[async_trait::async_trait]
|
||
impl Provider for DelayedHistoryCaptureProvider {
|
||
async fn chat_with_system(
|
||
&self,
|
||
_system_prompt: Option<&str>,
|
||
_message: &str,
|
||
_model: &str,
|
||
_temperature: f64,
|
||
) -> anyhow::Result<String> {
|
||
Ok("fallback".to_string())
|
||
}
|
||
|
||
async fn chat_with_history(
|
||
&self,
|
||
messages: &[ChatMessage],
|
||
_model: &str,
|
||
_temperature: f64,
|
||
) -> anyhow::Result<String> {
|
||
let snapshot = messages
|
||
.iter()
|
||
.map(|m| (m.role.clone(), m.content.clone()))
|
||
.collect::<Vec<_>>();
|
||
let call_index = {
|
||
let mut calls = self.calls.lock().unwrap_or_else(|e| e.into_inner());
|
||
calls.push(snapshot);
|
||
calls.len()
|
||
};
|
||
tokio::time::sleep(self.delay).await;
|
||
Ok(format!("response-{call_index}"))
|
||
}
|
||
}
|
||
|
||
struct MockPriceTool;
|
||
|
||
#[derive(Default)]
|
||
struct ModelCaptureProvider {
|
||
call_count: AtomicUsize,
|
||
models: std::sync::Mutex<Vec<String>>,
|
||
}
|
||
|
||
#[async_trait::async_trait]
|
||
impl Provider for ModelCaptureProvider {
|
||
async fn chat_with_system(
|
||
&self,
|
||
_system_prompt: Option<&str>,
|
||
_message: &str,
|
||
_model: &str,
|
||
_temperature: f64,
|
||
) -> anyhow::Result<String> {
|
||
Ok("fallback".to_string())
|
||
}
|
||
|
||
async fn chat_with_history(
|
||
&self,
|
||
_messages: &[ChatMessage],
|
||
model: &str,
|
||
_temperature: f64,
|
||
) -> anyhow::Result<String> {
|
||
self.call_count.fetch_add(1, Ordering::SeqCst);
|
||
self.models
|
||
.lock()
|
||
.unwrap_or_else(|e| e.into_inner())
|
||
.push(model.to_string());
|
||
Ok("ok".to_string())
|
||
}
|
||
}
|
||
|
||
#[async_trait::async_trait]
|
||
impl Tool for MockPriceTool {
|
||
fn name(&self) -> &str {
|
||
"mock_price"
|
||
}
|
||
|
||
fn description(&self) -> &str {
|
||
"Return a mocked BTC price"
|
||
}
|
||
|
||
fn parameters_schema(&self) -> serde_json::Value {
|
||
serde_json::json!({
|
||
"type": "object",
|
||
"properties": {
|
||
"symbol": { "type": "string" }
|
||
},
|
||
"required": ["symbol"]
|
||
})
|
||
}
|
||
|
||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||
let symbol = args.get("symbol").and_then(serde_json::Value::as_str);
|
||
if symbol != Some("BTC") {
|
||
return Ok(ToolResult {
|
||
success: false,
|
||
output: String::new(),
|
||
error: Some("unexpected symbol".to_string()),
|
||
});
|
||
}
|
||
|
||
Ok(ToolResult {
|
||
success: true,
|
||
output: r#"{"symbol":"BTC","price_usd":65000}"#.to_string(),
|
||
error: None,
|
||
})
|
||
}
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn process_channel_message_executes_tool_calls_instead_of_sending_raw_json() {
|
||
let channel_impl = Arc::new(RecordingChannel::default());
|
||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||
|
||
let mut channels_by_name = HashMap::new();
|
||
channels_by_name.insert(channel.name().to_string(), channel);
|
||
|
||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||
channels_by_name: Arc::new(channels_by_name),
|
||
provider: Arc::new(ToolCallingProvider),
|
||
default_provider: Arc::new("test-provider".to_string()),
|
||
memory: Arc::new(NoopMemory),
|
||
tools_registry: Arc::new(vec![Box::new(MockPriceTool)]),
|
||
observer: Arc::new(NoopObserver),
|
||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||
model: Arc::new("test-model".to_string()),
|
||
temperature: 0.0,
|
||
auto_save_memory: false,
|
||
max_tool_iterations: 10,
|
||
min_relevance_score: 0.0,
|
||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||
api_key: None,
|
||
api_url: None,
|
||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||
interrupt_on_new_message: false,
|
||
multimodal: crate::config::MultimodalConfig::default(),
|
||
});
|
||
|
||
process_channel_message(
|
||
runtime_ctx,
|
||
traits::ChannelMessage {
|
||
id: "msg-1".to_string(),
|
||
sender: "alice".to_string(),
|
||
reply_target: "chat-42".to_string(),
|
||
content: "What is the BTC price now?".to_string(),
|
||
channel: "test-channel".to_string(),
|
||
timestamp: 1,
|
||
thread_ts: None,
|
||
},
|
||
CancellationToken::new(),
|
||
)
|
||
.await;
|
||
|
||
let sent_messages = channel_impl.sent_messages.lock().await;
|
||
assert_eq!(sent_messages.len(), 1);
|
||
assert!(sent_messages[0].starts_with("chat-42:"));
|
||
assert!(sent_messages[0].contains("BTC is currently around"));
|
||
assert!(!sent_messages[0].contains("\"tool_calls\""));
|
||
assert!(!sent_messages[0].contains("mock_price"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn process_channel_message_executes_tool_calls_with_alias_tags() {
|
||
let channel_impl = Arc::new(RecordingChannel::default());
|
||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||
|
||
let mut channels_by_name = HashMap::new();
|
||
channels_by_name.insert(channel.name().to_string(), channel);
|
||
|
||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||
channels_by_name: Arc::new(channels_by_name),
|
||
provider: Arc::new(ToolCallingAliasProvider),
|
||
default_provider: Arc::new("test-provider".to_string()),
|
||
memory: Arc::new(NoopMemory),
|
||
tools_registry: Arc::new(vec![Box::new(MockPriceTool)]),
|
||
observer: Arc::new(NoopObserver),
|
||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||
model: Arc::new("test-model".to_string()),
|
||
temperature: 0.0,
|
||
auto_save_memory: false,
|
||
max_tool_iterations: 10,
|
||
min_relevance_score: 0.0,
|
||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||
api_key: None,
|
||
api_url: None,
|
||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||
interrupt_on_new_message: false,
|
||
multimodal: crate::config::MultimodalConfig::default(),
|
||
});
|
||
|
||
process_channel_message(
|
||
runtime_ctx,
|
||
traits::ChannelMessage {
|
||
id: "msg-2".to_string(),
|
||
sender: "bob".to_string(),
|
||
reply_target: "chat-84".to_string(),
|
||
content: "What is the BTC price now?".to_string(),
|
||
channel: "test-channel".to_string(),
|
||
timestamp: 2,
|
||
thread_ts: None,
|
||
},
|
||
CancellationToken::new(),
|
||
)
|
||
.await;
|
||
|
||
let sent_messages = channel_impl.sent_messages.lock().await;
|
||
assert_eq!(sent_messages.len(), 1);
|
||
assert!(sent_messages[0].starts_with("chat-84:"));
|
||
assert!(sent_messages[0].contains("alias-tag flow resolved"));
|
||
assert!(!sent_messages[0].contains("<toolcall>"));
|
||
assert!(!sent_messages[0].contains("mock_price"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn process_channel_message_handles_models_command_without_llm_call() {
|
||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||
|
||
let mut channels_by_name = HashMap::new();
|
||
channels_by_name.insert(channel.name().to_string(), channel);
|
||
|
||
let default_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||
let default_provider: Arc<dyn Provider> = default_provider_impl.clone();
|
||
let fallback_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||
let fallback_provider: Arc<dyn Provider> = fallback_provider_impl.clone();
|
||
|
||
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
|
||
provider_cache_seed.insert("test-provider".to_string(), Arc::clone(&default_provider));
|
||
provider_cache_seed.insert("openrouter".to_string(), fallback_provider);
|
||
|
||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||
channels_by_name: Arc::new(channels_by_name),
|
||
provider: Arc::clone(&default_provider),
|
||
default_provider: Arc::new("test-provider".to_string()),
|
||
memory: Arc::new(NoopMemory),
|
||
tools_registry: Arc::new(vec![]),
|
||
observer: Arc::new(NoopObserver),
|
||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||
model: Arc::new("default-model".to_string()),
|
||
temperature: 0.0,
|
||
auto_save_memory: false,
|
||
max_tool_iterations: 5,
|
||
min_relevance_score: 0.0,
|
||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||
api_key: None,
|
||
api_url: None,
|
||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||
interrupt_on_new_message: false,
|
||
multimodal: crate::config::MultimodalConfig::default(),
|
||
});
|
||
|
||
process_channel_message(
|
||
runtime_ctx.clone(),
|
||
traits::ChannelMessage {
|
||
id: "msg-cmd-1".to_string(),
|
||
sender: "alice".to_string(),
|
||
reply_target: "chat-1".to_string(),
|
||
content: "/models openrouter".to_string(),
|
||
channel: "telegram".to_string(),
|
||
timestamp: 1,
|
||
thread_ts: None,
|
||
},
|
||
CancellationToken::new(),
|
||
)
|
||
.await;
|
||
|
||
let sent = channel_impl.sent_messages.lock().await;
|
||
assert_eq!(sent.len(), 1);
|
||
assert!(sent[0].contains("Provider switched to `openrouter`"));
|
||
|
||
let route_key = "telegram_alice";
|
||
let route = runtime_ctx
|
||
.route_overrides
|
||
.lock()
|
||
.unwrap_or_else(|e| e.into_inner())
|
||
.get(route_key)
|
||
.cloned()
|
||
.expect("route should be stored for sender");
|
||
assert_eq!(route.provider, "openrouter");
|
||
assert_eq!(route.model, "default-model");
|
||
|
||
assert_eq!(default_provider_impl.call_count.load(Ordering::SeqCst), 0);
|
||
assert_eq!(fallback_provider_impl.call_count.load(Ordering::SeqCst), 0);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn process_channel_message_uses_route_override_provider_and_model() {
|
||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||
|
||
let mut channels_by_name = HashMap::new();
|
||
channels_by_name.insert(channel.name().to_string(), channel);
|
||
|
||
let default_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||
let default_provider: Arc<dyn Provider> = default_provider_impl.clone();
|
||
let routed_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||
let routed_provider: Arc<dyn Provider> = routed_provider_impl.clone();
|
||
|
||
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
|
||
provider_cache_seed.insert("test-provider".to_string(), Arc::clone(&default_provider));
|
||
provider_cache_seed.insert("openrouter".to_string(), routed_provider);
|
||
|
||
let route_key = "telegram_alice".to_string();
|
||
let mut route_overrides = HashMap::new();
|
||
route_overrides.insert(
|
||
route_key,
|
||
ChannelRouteSelection {
|
||
provider: "openrouter".to_string(),
|
||
model: "route-model".to_string(),
|
||
},
|
||
);
|
||
|
||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||
channels_by_name: Arc::new(channels_by_name),
|
||
provider: Arc::clone(&default_provider),
|
||
default_provider: Arc::new("test-provider".to_string()),
|
||
memory: Arc::new(NoopMemory),
|
||
tools_registry: Arc::new(vec![]),
|
||
observer: Arc::new(NoopObserver),
|
||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||
model: Arc::new("default-model".to_string()),
|
||
temperature: 0.0,
|
||
auto_save_memory: false,
|
||
max_tool_iterations: 5,
|
||
min_relevance_score: 0.0,
|
||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||
route_overrides: Arc::new(Mutex::new(route_overrides)),
|
||
api_key: None,
|
||
api_url: None,
|
||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||
interrupt_on_new_message: false,
|
||
multimodal: crate::config::MultimodalConfig::default(),
|
||
});
|
||
|
||
process_channel_message(
|
||
runtime_ctx,
|
||
traits::ChannelMessage {
|
||
id: "msg-routed-1".to_string(),
|
||
sender: "alice".to_string(),
|
||
reply_target: "chat-1".to_string(),
|
||
content: "hello routed provider".to_string(),
|
||
channel: "telegram".to_string(),
|
||
timestamp: 2,
|
||
thread_ts: None,
|
||
},
|
||
CancellationToken::new(),
|
||
)
|
||
.await;
|
||
|
||
assert_eq!(default_provider_impl.call_count.load(Ordering::SeqCst), 0);
|
||
assert_eq!(routed_provider_impl.call_count.load(Ordering::SeqCst), 1);
|
||
assert_eq!(
|
||
routed_provider_impl
|
||
.models
|
||
.lock()
|
||
.unwrap_or_else(|e| e.into_inner())
|
||
.as_slice(),
|
||
&["route-model".to_string()]
|
||
);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn process_channel_message_respects_configured_max_tool_iterations_above_default() {
|
||
let channel_impl = Arc::new(RecordingChannel::default());
|
||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||
|
||
let mut channels_by_name = HashMap::new();
|
||
channels_by_name.insert(channel.name().to_string(), channel);
|
||
|
||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||
channels_by_name: Arc::new(channels_by_name),
|
||
provider: Arc::new(IterativeToolProvider {
|
||
required_tool_iterations: 11,
|
||
}),
|
||
default_provider: Arc::new("test-provider".to_string()),
|
||
memory: Arc::new(NoopMemory),
|
||
tools_registry: Arc::new(vec![Box::new(MockPriceTool)]),
|
||
observer: Arc::new(NoopObserver),
|
||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||
model: Arc::new("test-model".to_string()),
|
||
temperature: 0.0,
|
||
auto_save_memory: false,
|
||
max_tool_iterations: 12,
|
||
min_relevance_score: 0.0,
|
||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||
api_key: None,
|
||
api_url: None,
|
||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||
interrupt_on_new_message: false,
|
||
multimodal: crate::config::MultimodalConfig::default(),
|
||
});
|
||
|
||
process_channel_message(
|
||
runtime_ctx,
|
||
traits::ChannelMessage {
|
||
id: "msg-iter-success".to_string(),
|
||
sender: "alice".to_string(),
|
||
reply_target: "chat-iter-success".to_string(),
|
||
content: "Loop until done".to_string(),
|
||
channel: "test-channel".to_string(),
|
||
timestamp: 1,
|
||
thread_ts: None,
|
||
},
|
||
CancellationToken::new(),
|
||
)
|
||
.await;
|
||
|
||
let sent_messages = channel_impl.sent_messages.lock().await;
|
||
assert_eq!(sent_messages.len(), 1);
|
||
assert!(sent_messages[0].starts_with("chat-iter-success:"));
|
||
assert!(sent_messages[0].contains("Completed after 11 tool iterations."));
|
||
assert!(!sent_messages[0].contains("⚠️ Error:"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn process_channel_message_reports_configured_max_tool_iterations_limit() {
|
||
let channel_impl = Arc::new(RecordingChannel::default());
|
||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||
|
||
let mut channels_by_name = HashMap::new();
|
||
channels_by_name.insert(channel.name().to_string(), channel);
|
||
|
||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||
channels_by_name: Arc::new(channels_by_name),
|
||
provider: Arc::new(IterativeToolProvider {
|
||
required_tool_iterations: 20,
|
||
}),
|
||
default_provider: Arc::new("test-provider".to_string()),
|
||
memory: Arc::new(NoopMemory),
|
||
tools_registry: Arc::new(vec![Box::new(MockPriceTool)]),
|
||
observer: Arc::new(NoopObserver),
|
||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||
model: Arc::new("test-model".to_string()),
|
||
temperature: 0.0,
|
||
auto_save_memory: false,
|
||
max_tool_iterations: 3,
|
||
min_relevance_score: 0.0,
|
||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||
api_key: None,
|
||
api_url: None,
|
||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||
interrupt_on_new_message: false,
|
||
multimodal: crate::config::MultimodalConfig::default(),
|
||
});
|
||
|
||
process_channel_message(
|
||
runtime_ctx,
|
||
traits::ChannelMessage {
|
||
id: "msg-iter-fail".to_string(),
|
||
sender: "bob".to_string(),
|
||
reply_target: "chat-iter-fail".to_string(),
|
||
content: "Loop forever".to_string(),
|
||
channel: "test-channel".to_string(),
|
||
timestamp: 2,
|
||
thread_ts: None,
|
||
},
|
||
CancellationToken::new(),
|
||
)
|
||
.await;
|
||
|
||
let sent_messages = channel_impl.sent_messages.lock().await;
|
||
assert_eq!(sent_messages.len(), 1);
|
||
assert!(sent_messages[0].starts_with("chat-iter-fail:"));
|
||
assert!(sent_messages[0].contains("⚠️ Error: Agent exceeded maximum tool iterations (3)"));
|
||
}
|
||
|
||
struct NoopMemory;
|
||
|
||
#[async_trait::async_trait]
|
||
impl Memory for NoopMemory {
|
||
fn name(&self) -> &str {
|
||
"noop"
|
||
}
|
||
|
||
async fn store(
|
||
&self,
|
||
_key: &str,
|
||
_content: &str,
|
||
_category: crate::memory::MemoryCategory,
|
||
_session_id: Option<&str>,
|
||
) -> anyhow::Result<()> {
|
||
Ok(())
|
||
}
|
||
|
||
async fn recall(
|
||
&self,
|
||
_query: &str,
|
||
_limit: usize,
|
||
_session_id: Option<&str>,
|
||
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
|
||
Ok(Vec::new())
|
||
}
|
||
|
||
async fn get(&self, _key: &str) -> anyhow::Result<Option<crate::memory::MemoryEntry>> {
|
||
Ok(None)
|
||
}
|
||
|
||
async fn list(
|
||
&self,
|
||
_category: Option<&crate::memory::MemoryCategory>,
|
||
_session_id: Option<&str>,
|
||
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
|
||
Ok(Vec::new())
|
||
}
|
||
|
||
async fn forget(&self, _key: &str) -> anyhow::Result<bool> {
|
||
Ok(false)
|
||
}
|
||
|
||
async fn count(&self) -> anyhow::Result<usize> {
|
||
Ok(0)
|
||
}
|
||
|
||
async fn health_check(&self) -> bool {
|
||
true
|
||
}
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn message_dispatch_processes_messages_in_parallel() {
|
||
let channel_impl = Arc::new(RecordingChannel::default());
|
||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||
|
||
let mut channels_by_name = HashMap::new();
|
||
channels_by_name.insert(channel.name().to_string(), channel);
|
||
|
||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||
channels_by_name: Arc::new(channels_by_name),
|
||
provider: Arc::new(SlowProvider {
|
||
delay: Duration::from_millis(250),
|
||
}),
|
||
default_provider: Arc::new("test-provider".to_string()),
|
||
memory: Arc::new(NoopMemory),
|
||
tools_registry: Arc::new(vec![]),
|
||
observer: Arc::new(NoopObserver),
|
||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||
model: Arc::new("test-model".to_string()),
|
||
temperature: 0.0,
|
||
auto_save_memory: false,
|
||
max_tool_iterations: 10,
|
||
min_relevance_score: 0.0,
|
||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||
api_key: None,
|
||
api_url: None,
|
||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||
interrupt_on_new_message: false,
|
||
multimodal: crate::config::MultimodalConfig::default(),
|
||
});
|
||
|
||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(4);
|
||
tx.send(traits::ChannelMessage {
|
||
id: "1".to_string(),
|
||
sender: "alice".to_string(),
|
||
reply_target: "alice".to_string(),
|
||
content: "hello".to_string(),
|
||
channel: "test-channel".to_string(),
|
||
timestamp: 1,
|
||
thread_ts: None,
|
||
})
|
||
.await
|
||
.unwrap();
|
||
tx.send(traits::ChannelMessage {
|
||
id: "2".to_string(),
|
||
sender: "bob".to_string(),
|
||
reply_target: "bob".to_string(),
|
||
content: "world".to_string(),
|
||
channel: "test-channel".to_string(),
|
||
timestamp: 2,
|
||
thread_ts: None,
|
||
})
|
||
.await
|
||
.unwrap();
|
||
drop(tx);
|
||
|
||
let started = Instant::now();
|
||
run_message_dispatch_loop(rx, runtime_ctx, 2).await;
|
||
let elapsed = started.elapsed();
|
||
|
||
assert!(
|
||
elapsed < Duration::from_millis(430),
|
||
"expected parallel dispatch (<430ms), got {:?}",
|
||
elapsed
|
||
);
|
||
|
||
let sent_messages = channel_impl.sent_messages.lock().await;
|
||
assert_eq!(sent_messages.len(), 2);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn message_dispatch_interrupts_in_flight_telegram_request_and_preserves_context() {
|
||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||
|
||
let mut channels_by_name = HashMap::new();
|
||
channels_by_name.insert(channel.name().to_string(), channel);
|
||
|
||
let provider_impl = Arc::new(DelayedHistoryCaptureProvider {
|
||
delay: Duration::from_millis(250),
|
||
calls: std::sync::Mutex::new(Vec::new()),
|
||
});
|
||
|
||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||
channels_by_name: Arc::new(channels_by_name),
|
||
provider: provider_impl.clone(),
|
||
default_provider: Arc::new("test-provider".to_string()),
|
||
memory: Arc::new(NoopMemory),
|
||
tools_registry: Arc::new(vec![]),
|
||
observer: Arc::new(NoopObserver),
|
||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||
model: Arc::new("test-model".to_string()),
|
||
temperature: 0.0,
|
||
auto_save_memory: false,
|
||
max_tool_iterations: 10,
|
||
min_relevance_score: 0.0,
|
||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||
api_key: None,
|
||
api_url: None,
|
||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||
interrupt_on_new_message: true,
|
||
multimodal: crate::config::MultimodalConfig::default(),
|
||
});
|
||
|
||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||
let send_task = tokio::spawn(async move {
|
||
tx.send(traits::ChannelMessage {
|
||
id: "msg-1".to_string(),
|
||
sender: "alice".to_string(),
|
||
reply_target: "chat-1".to_string(),
|
||
content: "forwarded content".to_string(),
|
||
channel: "telegram".to_string(),
|
||
timestamp: 1,
|
||
thread_ts: None,
|
||
})
|
||
.await
|
||
.unwrap();
|
||
tokio::time::sleep(Duration::from_millis(40)).await;
|
||
tx.send(traits::ChannelMessage {
|
||
id: "msg-2".to_string(),
|
||
sender: "alice".to_string(),
|
||
reply_target: "chat-1".to_string(),
|
||
content: "summarize this".to_string(),
|
||
channel: "telegram".to_string(),
|
||
timestamp: 2,
|
||
thread_ts: None,
|
||
})
|
||
.await
|
||
.unwrap();
|
||
});
|
||
|
||
run_message_dispatch_loop(rx, runtime_ctx, 4).await;
|
||
send_task.await.unwrap();
|
||
|
||
let sent_messages = channel_impl.sent_messages.lock().await;
|
||
assert_eq!(sent_messages.len(), 1);
|
||
assert!(sent_messages[0].starts_with("chat-1:"));
|
||
assert!(sent_messages[0].contains("response-2"));
|
||
drop(sent_messages);
|
||
|
||
let calls = provider_impl
|
||
.calls
|
||
.lock()
|
||
.unwrap_or_else(|e| e.into_inner());
|
||
assert_eq!(calls.len(), 2);
|
||
let second_call = &calls[1];
|
||
assert!(second_call
|
||
.iter()
|
||
.any(|(role, content)| { role == "user" && content.contains("forwarded content") }));
|
||
assert!(second_call
|
||
.iter()
|
||
.any(|(role, content)| { role == "user" && content.contains("summarize this") }));
|
||
assert!(
|
||
!second_call.iter().any(|(role, _)| role == "assistant"),
|
||
"cancelled turn should not persist an assistant response"
|
||
);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn message_dispatch_interrupt_scope_is_same_sender_same_chat() {
|
||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||
|
||
let mut channels_by_name = HashMap::new();
|
||
channels_by_name.insert(channel.name().to_string(), channel);
|
||
|
||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||
channels_by_name: Arc::new(channels_by_name),
|
||
provider: Arc::new(SlowProvider {
|
||
delay: Duration::from_millis(180),
|
||
}),
|
||
default_provider: Arc::new("test-provider".to_string()),
|
||
memory: Arc::new(NoopMemory),
|
||
tools_registry: Arc::new(vec![]),
|
||
observer: Arc::new(NoopObserver),
|
||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||
model: Arc::new("test-model".to_string()),
|
||
temperature: 0.0,
|
||
auto_save_memory: false,
|
||
max_tool_iterations: 10,
|
||
min_relevance_score: 0.0,
|
||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||
api_key: None,
|
||
api_url: None,
|
||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||
interrupt_on_new_message: true,
|
||
multimodal: crate::config::MultimodalConfig::default(),
|
||
});
|
||
|
||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||
let send_task = tokio::spawn(async move {
|
||
tx.send(traits::ChannelMessage {
|
||
id: "msg-a".to_string(),
|
||
sender: "alice".to_string(),
|
||
reply_target: "chat-1".to_string(),
|
||
content: "first chat".to_string(),
|
||
channel: "telegram".to_string(),
|
||
timestamp: 1,
|
||
thread_ts: None,
|
||
})
|
||
.await
|
||
.unwrap();
|
||
tokio::time::sleep(Duration::from_millis(30)).await;
|
||
tx.send(traits::ChannelMessage {
|
||
id: "msg-b".to_string(),
|
||
sender: "alice".to_string(),
|
||
reply_target: "chat-2".to_string(),
|
||
content: "second chat".to_string(),
|
||
channel: "telegram".to_string(),
|
||
timestamp: 2,
|
||
thread_ts: None,
|
||
})
|
||
.await
|
||
.unwrap();
|
||
});
|
||
|
||
run_message_dispatch_loop(rx, runtime_ctx, 4).await;
|
||
send_task.await.unwrap();
|
||
|
||
let sent_messages = channel_impl.sent_messages.lock().await;
|
||
assert_eq!(sent_messages.len(), 2);
|
||
assert!(sent_messages.iter().any(|msg| msg.starts_with("chat-1:")));
|
||
assert!(sent_messages.iter().any(|msg| msg.starts_with("chat-2:")));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn process_channel_message_cancels_scoped_typing_task() {
|
||
let channel_impl = Arc::new(RecordingChannel::default());
|
||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||
|
||
let mut channels_by_name = HashMap::new();
|
||
channels_by_name.insert(channel.name().to_string(), channel);
|
||
|
||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||
channels_by_name: Arc::new(channels_by_name),
|
||
provider: Arc::new(SlowProvider {
|
||
delay: Duration::from_millis(20),
|
||
}),
|
||
default_provider: Arc::new("test-provider".to_string()),
|
||
memory: Arc::new(NoopMemory),
|
||
tools_registry: Arc::new(vec![]),
|
||
observer: Arc::new(NoopObserver),
|
||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||
model: Arc::new("test-model".to_string()),
|
||
temperature: 0.0,
|
||
auto_save_memory: false,
|
||
max_tool_iterations: 10,
|
||
min_relevance_score: 0.0,
|
||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||
api_key: None,
|
||
api_url: None,
|
||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||
interrupt_on_new_message: false,
|
||
multimodal: crate::config::MultimodalConfig::default(),
|
||
});
|
||
|
||
process_channel_message(
|
||
runtime_ctx,
|
||
traits::ChannelMessage {
|
||
id: "typing-msg".to_string(),
|
||
sender: "alice".to_string(),
|
||
reply_target: "chat-typing".to_string(),
|
||
content: "hello".to_string(),
|
||
channel: "test-channel".to_string(),
|
||
timestamp: 1,
|
||
thread_ts: None,
|
||
},
|
||
CancellationToken::new(),
|
||
)
|
||
.await;
|
||
|
||
let starts = channel_impl.start_typing_calls.load(Ordering::SeqCst);
|
||
let stops = channel_impl.stop_typing_calls.load(Ordering::SeqCst);
|
||
assert_eq!(starts, 1, "start_typing should be called once");
|
||
assert_eq!(stops, 1, "stop_typing should be called once");
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_contains_all_sections() {
|
||
let ws = make_workspace();
|
||
let tools = vec![("shell", "Run commands"), ("file_read", "Read files")];
|
||
let prompt = build_system_prompt(ws.path(), "test-model", &tools, &[], None, None);
|
||
|
||
// Section headers
|
||
assert!(prompt.contains("## Tools"), "missing Tools section");
|
||
assert!(prompt.contains("## Safety"), "missing Safety section");
|
||
assert!(prompt.contains("## Workspace"), "missing Workspace section");
|
||
assert!(
|
||
prompt.contains("## Project Context"),
|
||
"missing Project Context"
|
||
);
|
||
assert!(
|
||
prompt.contains("## Current Date & Time"),
|
||
"missing Date/Time"
|
||
);
|
||
assert!(prompt.contains("## Runtime"), "missing Runtime section");
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_injects_tools() {
|
||
let ws = make_workspace();
|
||
let tools = vec![
|
||
("shell", "Run commands"),
|
||
("memory_recall", "Search memory"),
|
||
];
|
||
let prompt = build_system_prompt(ws.path(), "gpt-4o", &tools, &[], None, None);
|
||
|
||
assert!(prompt.contains("**shell**"));
|
||
assert!(prompt.contains("Run commands"));
|
||
assert!(prompt.contains("**memory_recall**"));
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_includes_single_tool_protocol_block_after_append() {
|
||
let ws = make_workspace();
|
||
let tools = vec![("shell", "Run commands")];
|
||
let mut prompt = build_system_prompt(ws.path(), "gpt-4o", &tools, &[], None, None);
|
||
|
||
assert!(
|
||
!prompt.contains("## Tool Use Protocol"),
|
||
"build_system_prompt should not emit protocol block directly"
|
||
);
|
||
|
||
prompt.push_str(&build_tool_instructions(&[]));
|
||
|
||
assert_eq!(
|
||
prompt.matches("## Tool Use Protocol").count(),
|
||
1,
|
||
"protocol block should appear exactly once in the final prompt"
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_injects_safety() {
|
||
let ws = make_workspace();
|
||
let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None);
|
||
|
||
assert!(prompt.contains("Do not exfiltrate private data"));
|
||
assert!(prompt.contains("Do not run destructive commands"));
|
||
assert!(prompt.contains("Prefer `trash` over `rm`"));
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_injects_workspace_files() {
|
||
let ws = make_workspace();
|
||
let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None);
|
||
|
||
assert!(prompt.contains("### SOUL.md"), "missing SOUL.md header");
|
||
assert!(prompt.contains("Be helpful"), "missing SOUL content");
|
||
assert!(prompt.contains("### IDENTITY.md"), "missing IDENTITY.md");
|
||
assert!(
|
||
prompt.contains("Name: ZeroClaw"),
|
||
"missing IDENTITY content"
|
||
);
|
||
assert!(prompt.contains("### USER.md"), "missing USER.md");
|
||
assert!(prompt.contains("### AGENTS.md"), "missing AGENTS.md");
|
||
assert!(prompt.contains("### TOOLS.md"), "missing TOOLS.md");
|
||
// HEARTBEAT.md is intentionally excluded from channel prompts — it's only
|
||
// relevant to the heartbeat worker and causes LLMs to emit spurious
|
||
// "HEARTBEAT_OK" acknowledgments in channel conversations.
|
||
assert!(
|
||
!prompt.contains("### HEARTBEAT.md"),
|
||
"HEARTBEAT.md should not be in channel prompt"
|
||
);
|
||
assert!(prompt.contains("### MEMORY.md"), "missing MEMORY.md");
|
||
assert!(prompt.contains("User likes Rust"), "missing MEMORY content");
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_missing_file_markers() {
|
||
let tmp = TempDir::new().unwrap();
|
||
// Empty workspace — no files at all
|
||
let prompt = build_system_prompt(tmp.path(), "model", &[], &[], None, None);
|
||
|
||
assert!(prompt.contains("[File not found: SOUL.md]"));
|
||
assert!(prompt.contains("[File not found: AGENTS.md]"));
|
||
assert!(prompt.contains("[File not found: IDENTITY.md]"));
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_bootstrap_only_if_exists() {
|
||
let ws = make_workspace();
|
||
// No BOOTSTRAP.md — should not appear
|
||
let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None);
|
||
assert!(
|
||
!prompt.contains("### BOOTSTRAP.md"),
|
||
"BOOTSTRAP.md should not appear when missing"
|
||
);
|
||
|
||
// Create BOOTSTRAP.md — should appear
|
||
std::fs::write(ws.path().join("BOOTSTRAP.md"), "# Bootstrap\nFirst run.").unwrap();
|
||
let prompt2 = build_system_prompt(ws.path(), "model", &[], &[], None, None);
|
||
assert!(
|
||
prompt2.contains("### BOOTSTRAP.md"),
|
||
"BOOTSTRAP.md should appear when present"
|
||
);
|
||
assert!(prompt2.contains("First run"));
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_no_daily_memory_injection() {
|
||
let ws = make_workspace();
|
||
let memory_dir = ws.path().join("memory");
|
||
std::fs::create_dir_all(&memory_dir).unwrap();
|
||
let today = chrono::Local::now().format("%Y-%m-%d").to_string();
|
||
std::fs::write(
|
||
memory_dir.join(format!("{today}.md")),
|
||
"# Daily\nSome note.",
|
||
)
|
||
.unwrap();
|
||
|
||
let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None);
|
||
|
||
// Daily notes should NOT be in the system prompt (on-demand via tools)
|
||
assert!(
|
||
!prompt.contains("Daily Notes"),
|
||
"daily notes should not be auto-injected"
|
||
);
|
||
assert!(
|
||
!prompt.contains("Some note"),
|
||
"daily content should not be in prompt"
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_runtime_metadata() {
|
||
let ws = make_workspace();
|
||
let prompt = build_system_prompt(ws.path(), "claude-sonnet-4", &[], &[], None, None);
|
||
|
||
assert!(prompt.contains("Model: claude-sonnet-4"));
|
||
assert!(prompt.contains(&format!("OS: {}", std::env::consts::OS)));
|
||
assert!(prompt.contains("Host:"));
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_skills_include_instructions_and_tools() {
|
||
let ws = make_workspace();
|
||
let skills = vec![crate::skills::Skill {
|
||
name: "code-review".into(),
|
||
description: "Review code for bugs".into(),
|
||
version: "1.0.0".into(),
|
||
author: None,
|
||
tags: vec![],
|
||
tools: vec![crate::skills::SkillTool {
|
||
name: "lint".into(),
|
||
description: "Run static checks".into(),
|
||
kind: "shell".into(),
|
||
command: "cargo clippy".into(),
|
||
args: HashMap::new(),
|
||
}],
|
||
prompts: vec!["Always run cargo test before final response.".into()],
|
||
location: None,
|
||
}];
|
||
|
||
let prompt = build_system_prompt(ws.path(), "model", &[], &skills, None, None);
|
||
|
||
assert!(prompt.contains("<available_skills>"), "missing skills XML");
|
||
assert!(prompt.contains("<name>code-review</name>"));
|
||
assert!(prompt.contains("<description>Review code for bugs</description>"));
|
||
assert!(prompt.contains("SKILL.md</location>"));
|
||
assert!(prompt.contains("<instructions>"));
|
||
assert!(prompt
|
||
.contains("<instruction>Always run cargo test before final response.</instruction>"));
|
||
assert!(prompt.contains("<tools>"));
|
||
assert!(prompt.contains("<name>lint</name>"));
|
||
assert!(prompt.contains("<kind>shell</kind>"));
|
||
assert!(!prompt.contains("loaded on demand"));
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_skills_escape_reserved_xml_chars() {
|
||
let ws = make_workspace();
|
||
let skills = vec![crate::skills::Skill {
|
||
name: "code<review>&".into(),
|
||
description: "Review \"unsafe\" and 'risky' bits".into(),
|
||
version: "1.0.0".into(),
|
||
author: None,
|
||
tags: vec![],
|
||
tools: vec![crate::skills::SkillTool {
|
||
name: "run\"linter\"".into(),
|
||
description: "Run <lint> & report".into(),
|
||
kind: "shell&exec".into(),
|
||
command: "cargo clippy".into(),
|
||
args: HashMap::new(),
|
||
}],
|
||
prompts: vec!["Use <tool_call> and & keep output \"safe\"".into()],
|
||
location: None,
|
||
}];
|
||
|
||
let prompt = build_system_prompt(ws.path(), "model", &[], &skills, None, None);
|
||
|
||
assert!(prompt.contains("<name>code<review>&</name>"));
|
||
assert!(prompt.contains(
|
||
"<description>Review "unsafe" and 'risky' bits</description>"
|
||
));
|
||
assert!(prompt.contains("<name>run"linter"</name>"));
|
||
assert!(prompt.contains("<description>Run <lint> & report</description>"));
|
||
assert!(prompt.contains("<kind>shell&exec</kind>"));
|
||
assert!(prompt.contains(
|
||
"<instruction>Use <tool_call> and & keep output "safe"</instruction>"
|
||
));
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_truncation() {
|
||
let ws = make_workspace();
|
||
// Write a file larger than BOOTSTRAP_MAX_CHARS
|
||
let big_content = "x".repeat(BOOTSTRAP_MAX_CHARS + 1000);
|
||
std::fs::write(ws.path().join("AGENTS.md"), &big_content).unwrap();
|
||
|
||
let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None);
|
||
|
||
assert!(
|
||
prompt.contains("truncated at"),
|
||
"large files should be truncated"
|
||
);
|
||
assert!(
|
||
!prompt.contains(&big_content),
|
||
"full content should not appear"
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_empty_files_skipped() {
|
||
let ws = make_workspace();
|
||
std::fs::write(ws.path().join("TOOLS.md"), "").unwrap();
|
||
|
||
let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None);
|
||
|
||
// Empty file should not produce a header
|
||
assert!(
|
||
!prompt.contains("### TOOLS.md"),
|
||
"empty files should be skipped"
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn channel_log_truncation_is_utf8_safe_for_multibyte_text() {
|
||
let msg = "Hello from ZeroClaw 🌍. Current status is healthy, and café-style UTF-8 text stays safe in logs.";
|
||
|
||
// Reproduces the production crash path where channel logs truncate at 80 chars.
|
||
let result = std::panic::catch_unwind(|| crate::util::truncate_with_ellipsis(msg, 80));
|
||
assert!(
|
||
result.is_ok(),
|
||
"truncate_with_ellipsis should never panic on UTF-8"
|
||
);
|
||
|
||
let truncated = result.unwrap();
|
||
assert!(!truncated.is_empty());
|
||
assert!(truncated.is_char_boundary(truncated.len()));
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_contains_channel_capabilities() {
|
||
let ws = make_workspace();
|
||
let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None);
|
||
|
||
assert!(
|
||
prompt.contains("## Channel Capabilities"),
|
||
"missing Channel Capabilities section"
|
||
);
|
||
assert!(
|
||
prompt.contains("running as a messaging bot"),
|
||
"missing channel context"
|
||
);
|
||
assert!(
|
||
prompt.contains("NEVER repeat, describe, or echo credentials"),
|
||
"missing security instruction"
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_workspace_path() {
|
||
let ws = make_workspace();
|
||
let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None);
|
||
|
||
assert!(prompt.contains(&format!("Working directory: `{}`", ws.path().display())));
|
||
}
|
||
|
||
#[test]
|
||
fn conversation_memory_key_uses_message_id() {
|
||
let msg = traits::ChannelMessage {
|
||
id: "msg_abc123".into(),
|
||
sender: "U123".into(),
|
||
reply_target: "C456".into(),
|
||
content: "hello".into(),
|
||
channel: "slack".into(),
|
||
timestamp: 1,
|
||
thread_ts: None,
|
||
};
|
||
|
||
assert_eq!(conversation_memory_key(&msg), "slack_U123_msg_abc123");
|
||
}
|
||
|
||
#[test]
|
||
fn conversation_memory_key_is_unique_per_message() {
|
||
let msg1 = traits::ChannelMessage {
|
||
id: "msg_1".into(),
|
||
sender: "U123".into(),
|
||
reply_target: "C456".into(),
|
||
content: "first".into(),
|
||
channel: "slack".into(),
|
||
timestamp: 1,
|
||
thread_ts: None,
|
||
};
|
||
let msg2 = traits::ChannelMessage {
|
||
id: "msg_2".into(),
|
||
sender: "U123".into(),
|
||
reply_target: "C456".into(),
|
||
content: "second".into(),
|
||
channel: "slack".into(),
|
||
timestamp: 2,
|
||
thread_ts: None,
|
||
};
|
||
|
||
assert_ne!(
|
||
conversation_memory_key(&msg1),
|
||
conversation_memory_key(&msg2)
|
||
);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn autosave_keys_preserve_multiple_conversation_facts() {
|
||
let tmp = TempDir::new().unwrap();
|
||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||
|
||
let msg1 = traits::ChannelMessage {
|
||
id: "msg_1".into(),
|
||
sender: "U123".into(),
|
||
reply_target: "C456".into(),
|
||
content: "I'm Paul".into(),
|
||
channel: "slack".into(),
|
||
timestamp: 1,
|
||
thread_ts: None,
|
||
};
|
||
let msg2 = traits::ChannelMessage {
|
||
id: "msg_2".into(),
|
||
sender: "U123".into(),
|
||
reply_target: "C456".into(),
|
||
content: "I'm 45".into(),
|
||
channel: "slack".into(),
|
||
timestamp: 2,
|
||
thread_ts: None,
|
||
};
|
||
|
||
mem.store(
|
||
&conversation_memory_key(&msg1),
|
||
&msg1.content,
|
||
MemoryCategory::Conversation,
|
||
None,
|
||
)
|
||
.await
|
||
.unwrap();
|
||
mem.store(
|
||
&conversation_memory_key(&msg2),
|
||
&msg2.content,
|
||
MemoryCategory::Conversation,
|
||
None,
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
assert_eq!(mem.count().await.unwrap(), 2);
|
||
|
||
let recalled = mem.recall("45", 5, None).await.unwrap();
|
||
assert!(recalled.iter().any(|entry| entry.content.contains("45")));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn build_memory_context_includes_recalled_entries() {
|
||
let tmp = TempDir::new().unwrap();
|
||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||
mem.store("age_fact", "Age is 45", MemoryCategory::Conversation, None)
|
||
.await
|
||
.unwrap();
|
||
|
||
let context = build_memory_context(&mem, "age", 0.0).await;
|
||
assert!(context.contains("[Memory context]"));
|
||
assert!(context.contains("Age is 45"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn process_channel_message_restores_per_sender_history_on_follow_ups() {
|
||
let channel_impl = Arc::new(RecordingChannel::default());
|
||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||
|
||
let mut channels_by_name = HashMap::new();
|
||
channels_by_name.insert(channel.name().to_string(), channel);
|
||
|
||
let provider_impl = Arc::new(HistoryCaptureProvider::default());
|
||
|
||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||
channels_by_name: Arc::new(channels_by_name),
|
||
provider: provider_impl.clone(),
|
||
default_provider: Arc::new("test-provider".to_string()),
|
||
memory: Arc::new(NoopMemory),
|
||
tools_registry: Arc::new(vec![]),
|
||
observer: Arc::new(NoopObserver),
|
||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||
model: Arc::new("test-model".to_string()),
|
||
temperature: 0.0,
|
||
auto_save_memory: false,
|
||
max_tool_iterations: 5,
|
||
min_relevance_score: 0.0,
|
||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||
api_key: None,
|
||
api_url: None,
|
||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||
interrupt_on_new_message: false,
|
||
multimodal: crate::config::MultimodalConfig::default(),
|
||
});
|
||
|
||
process_channel_message(
|
||
runtime_ctx.clone(),
|
||
traits::ChannelMessage {
|
||
id: "msg-a".to_string(),
|
||
sender: "alice".to_string(),
|
||
reply_target: "chat-1".to_string(),
|
||
content: "hello".to_string(),
|
||
channel: "test-channel".to_string(),
|
||
timestamp: 1,
|
||
thread_ts: None,
|
||
},
|
||
CancellationToken::new(),
|
||
)
|
||
.await;
|
||
|
||
process_channel_message(
|
||
runtime_ctx,
|
||
traits::ChannelMessage {
|
||
id: "msg-b".to_string(),
|
||
sender: "alice".to_string(),
|
||
reply_target: "chat-1".to_string(),
|
||
content: "follow up".to_string(),
|
||
channel: "test-channel".to_string(),
|
||
timestamp: 2,
|
||
thread_ts: None,
|
||
},
|
||
CancellationToken::new(),
|
||
)
|
||
.await;
|
||
|
||
let calls = provider_impl
|
||
.calls
|
||
.lock()
|
||
.unwrap_or_else(|e| e.into_inner());
|
||
assert_eq!(calls.len(), 2);
|
||
assert_eq!(calls[0].len(), 2);
|
||
assert_eq!(calls[0][0].0, "system");
|
||
assert_eq!(calls[0][1].0, "user");
|
||
assert_eq!(calls[1].len(), 4);
|
||
assert_eq!(calls[1][0].0, "system");
|
||
assert_eq!(calls[1][1].0, "user");
|
||
assert_eq!(calls[1][2].0, "assistant");
|
||
assert_eq!(calls[1][3].0, "user");
|
||
assert!(calls[1][1].1.contains("hello"));
|
||
assert!(calls[1][2].1.contains("response-1"));
|
||
assert!(calls[1][3].1.contains("follow up"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn process_channel_message_telegram_keeps_system_instruction_at_top_only() {
|
||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||
|
||
let mut channels_by_name = HashMap::new();
|
||
channels_by_name.insert(channel.name().to_string(), channel);
|
||
|
||
let provider_impl = Arc::new(HistoryCaptureProvider::default());
|
||
let mut histories = HashMap::new();
|
||
histories.insert(
|
||
"telegram_alice".to_string(),
|
||
vec![
|
||
ChatMessage::assistant("stale assistant"),
|
||
ChatMessage::user("earlier user question"),
|
||
ChatMessage::assistant("earlier assistant reply"),
|
||
],
|
||
);
|
||
|
||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||
channels_by_name: Arc::new(channels_by_name),
|
||
provider: provider_impl.clone(),
|
||
default_provider: Arc::new("test-provider".to_string()),
|
||
memory: Arc::new(NoopMemory),
|
||
tools_registry: Arc::new(vec![]),
|
||
observer: Arc::new(NoopObserver),
|
||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||
model: Arc::new("test-model".to_string()),
|
||
temperature: 0.0,
|
||
auto_save_memory: false,
|
||
max_tool_iterations: 5,
|
||
min_relevance_score: 0.0,
|
||
conversation_histories: Arc::new(Mutex::new(histories)),
|
||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||
api_key: None,
|
||
api_url: None,
|
||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||
interrupt_on_new_message: false,
|
||
multimodal: crate::config::MultimodalConfig::default(),
|
||
});
|
||
|
||
process_channel_message(
|
||
runtime_ctx.clone(),
|
||
traits::ChannelMessage {
|
||
id: "tg-msg-1".to_string(),
|
||
sender: "alice".to_string(),
|
||
reply_target: "chat-telegram".to_string(),
|
||
content: "hello".to_string(),
|
||
channel: "telegram".to_string(),
|
||
timestamp: 1,
|
||
thread_ts: None,
|
||
},
|
||
CancellationToken::new(),
|
||
)
|
||
.await;
|
||
|
||
let calls = provider_impl
|
||
.calls
|
||
.lock()
|
||
.unwrap_or_else(|e| e.into_inner());
|
||
assert_eq!(calls.len(), 1);
|
||
assert_eq!(calls[0].len(), 4);
|
||
|
||
let roles = calls[0]
|
||
.iter()
|
||
.map(|(role, _)| role.as_str())
|
||
.collect::<Vec<_>>();
|
||
assert_eq!(roles, vec!["system", "user", "assistant", "user"]);
|
||
assert!(
|
||
calls[0][0]
|
||
.1
|
||
.contains("When responding on Telegram, include media markers"),
|
||
"telegram delivery instruction should live in the system prompt"
|
||
);
|
||
assert!(!calls[0].iter().skip(1).any(|(role, _)| role == "system"));
|
||
}
|
||
|
||
// ── AIEOS Identity Tests (Issue #168) ─────────────────────────
|
||
|
||
#[test]
|
||
fn aieos_identity_from_file() {
|
||
use crate::config::IdentityConfig;
|
||
use tempfile::TempDir;
|
||
|
||
let tmp = TempDir::new().unwrap();
|
||
let identity_path = tmp.path().join("aieos_identity.json");
|
||
|
||
// Write AIEOS identity file
|
||
let aieos_json = r#"{
|
||
"identity": {
|
||
"names": {"first": "Nova", "nickname": "Nov"},
|
||
"bio": "A helpful AI assistant.",
|
||
"origin": "Silicon Valley"
|
||
},
|
||
"psychology": {
|
||
"mbti": "INTJ",
|
||
"moral_compass": ["Be helpful", "Do no harm"]
|
||
},
|
||
"linguistics": {
|
||
"style": "concise",
|
||
"formality": "casual"
|
||
}
|
||
}"#;
|
||
std::fs::write(&identity_path, aieos_json).unwrap();
|
||
|
||
// Create identity config pointing to the file
|
||
let config = IdentityConfig {
|
||
format: "aieos".into(),
|
||
aieos_path: Some("aieos_identity.json".into()),
|
||
aieos_inline: None,
|
||
};
|
||
|
||
let prompt = build_system_prompt(tmp.path(), "model", &[], &[], Some(&config), None);
|
||
|
||
// Should contain AIEOS sections
|
||
assert!(prompt.contains("## Identity"));
|
||
assert!(prompt.contains("**Name:** Nova"));
|
||
assert!(prompt.contains("**Nickname:** Nov"));
|
||
assert!(prompt.contains("**Bio:** A helpful AI assistant."));
|
||
assert!(prompt.contains("**Origin:** Silicon Valley"));
|
||
|
||
assert!(prompt.contains("## Personality"));
|
||
assert!(prompt.contains("**MBTI:** INTJ"));
|
||
assert!(prompt.contains("**Moral Compass:**"));
|
||
assert!(prompt.contains("- Be helpful"));
|
||
|
||
assert!(prompt.contains("## Communication Style"));
|
||
assert!(prompt.contains("**Style:** concise"));
|
||
assert!(prompt.contains("**Formality Level:** casual"));
|
||
|
||
// Should NOT contain OpenClaw bootstrap file headers
|
||
assert!(!prompt.contains("### SOUL.md"));
|
||
assert!(!prompt.contains("### IDENTITY.md"));
|
||
assert!(!prompt.contains("[File not found"));
|
||
}
|
||
|
||
#[test]
|
||
fn aieos_identity_from_inline() {
|
||
use crate::config::IdentityConfig;
|
||
|
||
let config = IdentityConfig {
|
||
format: "aieos".into(),
|
||
aieos_path: None,
|
||
aieos_inline: Some(r#"{"identity":{"names":{"first":"Claw"}}}"#.into()),
|
||
};
|
||
|
||
let prompt = build_system_prompt(
|
||
std::env::temp_dir().as_path(),
|
||
"model",
|
||
&[],
|
||
&[],
|
||
Some(&config),
|
||
None,
|
||
);
|
||
|
||
assert!(prompt.contains("**Name:** Claw"));
|
||
assert!(prompt.contains("## Identity"));
|
||
}
|
||
|
||
#[test]
|
||
fn aieos_fallback_to_openclaw_on_parse_error() {
|
||
use crate::config::IdentityConfig;
|
||
|
||
let config = IdentityConfig {
|
||
format: "aieos".into(),
|
||
aieos_path: Some("nonexistent.json".into()),
|
||
aieos_inline: None,
|
||
};
|
||
|
||
let ws = make_workspace();
|
||
let prompt = build_system_prompt(ws.path(), "model", &[], &[], Some(&config), None);
|
||
|
||
// Should fall back to OpenClaw format when AIEOS file is not found
|
||
// (Error is logged to stderr with filename, not included in prompt)
|
||
assert!(prompt.contains("### SOUL.md"));
|
||
}
|
||
|
||
#[test]
|
||
fn aieos_empty_uses_openclaw() {
|
||
use crate::config::IdentityConfig;
|
||
|
||
// Format is "aieos" but neither path nor inline is set
|
||
let config = IdentityConfig {
|
||
format: "aieos".into(),
|
||
aieos_path: None,
|
||
aieos_inline: None,
|
||
};
|
||
|
||
let ws = make_workspace();
|
||
let prompt = build_system_prompt(ws.path(), "model", &[], &[], Some(&config), None);
|
||
|
||
// Should use OpenClaw format (not configured for AIEOS)
|
||
assert!(prompt.contains("### SOUL.md"));
|
||
assert!(prompt.contains("Be helpful"));
|
||
}
|
||
|
||
#[test]
|
||
fn openclaw_format_uses_bootstrap_files() {
|
||
use crate::config::IdentityConfig;
|
||
|
||
let config = IdentityConfig {
|
||
format: "openclaw".into(),
|
||
aieos_path: Some("identity.json".into()),
|
||
aieos_inline: None,
|
||
};
|
||
|
||
let ws = make_workspace();
|
||
let prompt = build_system_prompt(ws.path(), "model", &[], &[], Some(&config), None);
|
||
|
||
// Should use OpenClaw format even if aieos_path is set
|
||
assert!(prompt.contains("### SOUL.md"));
|
||
assert!(prompt.contains("Be helpful"));
|
||
assert!(!prompt.contains("## Identity"));
|
||
}
|
||
|
||
#[test]
|
||
fn none_identity_config_uses_openclaw() {
|
||
let ws = make_workspace();
|
||
// Pass None for identity config
|
||
let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None);
|
||
|
||
// Should use OpenClaw format
|
||
assert!(prompt.contains("### SOUL.md"));
|
||
assert!(prompt.contains("Be helpful"));
|
||
}
|
||
|
||
#[test]
|
||
fn classify_health_ok_true() {
|
||
let state = classify_health_result(&Ok(true));
|
||
assert_eq!(state, ChannelHealthState::Healthy);
|
||
}
|
||
|
||
#[test]
|
||
fn classify_health_ok_false() {
|
||
let state = classify_health_result(&Ok(false));
|
||
assert_eq!(state, ChannelHealthState::Unhealthy);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn classify_health_timeout() {
|
||
let result = tokio::time::timeout(Duration::from_millis(1), async {
|
||
tokio::time::sleep(Duration::from_millis(20)).await;
|
||
true
|
||
})
|
||
.await;
|
||
let state = classify_health_result(&result);
|
||
assert_eq!(state, ChannelHealthState::Timeout);
|
||
}
|
||
|
||
struct AlwaysFailChannel {
|
||
name: &'static str,
|
||
calls: Arc<AtomicUsize>,
|
||
}
|
||
|
||
#[async_trait::async_trait]
|
||
impl Channel for AlwaysFailChannel {
|
||
fn name(&self) -> &str {
|
||
self.name
|
||
}
|
||
|
||
async fn send(&self, _message: &SendMessage) -> anyhow::Result<()> {
|
||
Ok(())
|
||
}
|
||
|
||
async fn listen(
|
||
&self,
|
||
_tx: tokio::sync::mpsc::Sender<traits::ChannelMessage>,
|
||
) -> anyhow::Result<()> {
|
||
self.calls.fetch_add(1, Ordering::SeqCst);
|
||
anyhow::bail!("listen boom")
|
||
}
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn supervised_listener_marks_error_and_restarts_on_failures() {
|
||
let calls = Arc::new(AtomicUsize::new(0));
|
||
let channel: Arc<dyn Channel> = Arc::new(AlwaysFailChannel {
|
||
name: "test-supervised-fail",
|
||
calls: Arc::clone(&calls),
|
||
});
|
||
|
||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(1);
|
||
let handle = spawn_supervised_listener(channel, tx, 1, 1);
|
||
|
||
tokio::time::sleep(Duration::from_millis(80)).await;
|
||
drop(rx);
|
||
handle.abort();
|
||
let _ = handle.await;
|
||
|
||
let snapshot = crate::health::snapshot_json();
|
||
let component = &snapshot["components"]["channel:test-supervised-fail"];
|
||
assert_eq!(component["status"], "error");
|
||
assert!(component["restart_count"].as_u64().unwrap_or(0) >= 1);
|
||
assert!(component["last_error"]
|
||
.as_str()
|
||
.unwrap_or("")
|
||
.contains("listen boom"));
|
||
assert!(calls.load(Ordering::SeqCst) >= 1);
|
||
}
|
||
}
|