Previously, the memory-enriched message (with [Memory context] block prepended) was saved to per-sender conversation history. On subsequent turns the LLM saw stale memory fragments with raw keys baked into prior "user" messages, creating compounding noise. Save the original msg.content instead. Memory context is still injected for the current LLM call but no longer persists across turns. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
4309 lines
153 KiB
Rust
4309 lines
153 KiB
Rust
//! Channel subsystem for messaging platform integrations.
|
||
//!
|
||
//! This module provides the multi-channel messaging infrastructure that connects
|
||
//! ZeroClaw to external platforms. Each channel implements the [`Channel`] trait
|
||
//! defined in [`traits`], which provides a uniform interface for sending messages,
|
||
//! listening for incoming messages, health checking, and typing indicators.
|
||
//!
|
||
//! Channels are instantiated by [`start_channels`] based on the runtime configuration.
|
||
//! The subsystem manages per-sender conversation history, concurrent message processing
|
||
//! with configurable parallelism, and exponential-backoff reconnection for resilience.
|
||
//!
|
||
//! # Extension
|
||
//!
|
||
//! To add a new channel, implement [`Channel`] in a new submodule and wire it into
|
||
//! [`start_channels`]. See `AGENTS.md` §7.2 for the full change playbook.
|
||
|
||
pub mod cli;
|
||
pub mod dingtalk;
|
||
pub mod discord;
|
||
pub mod email_channel;
|
||
pub mod imessage;
|
||
pub mod irc;
|
||
pub mod lark;
|
||
pub mod linq;
|
||
#[cfg(feature = "channel-matrix")]
|
||
pub mod matrix;
|
||
pub mod mattermost;
|
||
pub mod qq;
|
||
pub mod signal;
|
||
pub mod slack;
|
||
pub mod telegram;
|
||
pub mod traits;
|
||
pub mod whatsapp;
|
||
#[cfg(feature = "whatsapp-web")]
|
||
pub mod whatsapp_storage;
|
||
#[cfg(feature = "whatsapp-web")]
|
||
pub mod whatsapp_web;
|
||
|
||
pub use cli::CliChannel;
|
||
pub use dingtalk::DingTalkChannel;
|
||
pub use discord::DiscordChannel;
|
||
pub use email_channel::EmailChannel;
|
||
pub use imessage::IMessageChannel;
|
||
pub use irc::IrcChannel;
|
||
pub use lark::LarkChannel;
|
||
pub use linq::LinqChannel;
|
||
#[cfg(feature = "channel-matrix")]
|
||
pub use matrix::MatrixChannel;
|
||
pub use mattermost::MattermostChannel;
|
||
pub use qq::QQChannel;
|
||
pub use signal::SignalChannel;
|
||
pub use slack::SlackChannel;
|
||
pub use telegram::TelegramChannel;
|
||
pub use traits::{Channel, SendMessage};
|
||
pub use whatsapp::WhatsAppChannel;
|
||
#[cfg(feature = "whatsapp-web")]
|
||
pub use whatsapp_web::WhatsAppWebChannel;
|
||
|
||
use crate::agent::loop_::{build_tool_instructions, run_tool_call_loop};
|
||
use crate::config::Config;
|
||
use crate::identity;
|
||
use crate::memory::{self, Memory};
|
||
use crate::observability::{self, Observer};
|
||
use crate::providers::{self, ChatMessage, Provider};
|
||
use crate::runtime;
|
||
use crate::security::SecurityPolicy;
|
||
use crate::tools::{self, Tool};
|
||
use crate::util::truncate_with_ellipsis;
|
||
use anyhow::{Context, Result};
|
||
use serde::Deserialize;
|
||
use std::collections::HashMap;
|
||
use std::fmt::Write;
|
||
use std::path::{Path, PathBuf};
|
||
use std::process::Command;
|
||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||
use std::sync::{Arc, Mutex};
|
||
use std::time::{Duration, Instant};
|
||
use tokio_util::sync::CancellationToken;
|
||
|
||
/// Per-sender conversation history for channel messages.
|
||
type ConversationHistoryMap = Arc<Mutex<HashMap<String, Vec<ChatMessage>>>>;
|
||
/// Maximum history messages to keep per sender.
|
||
const MAX_CHANNEL_HISTORY: usize = 50;
|
||
/// Minimum user-message length (in chars) for auto-save to memory.
|
||
/// Messages shorter than this (e.g. "ok", "thanks") are not stored,
|
||
/// reducing noise in memory recall.
|
||
const AUTOSAVE_MIN_MESSAGE_CHARS: usize = 20;
|
||
|
||
/// Maximum characters per injected workspace file (matches `OpenClaw` default).
|
||
const BOOTSTRAP_MAX_CHARS: usize = 20_000;
|
||
|
||
const DEFAULT_CHANNEL_INITIAL_BACKOFF_SECS: u64 = 2;
|
||
const DEFAULT_CHANNEL_MAX_BACKOFF_SECS: u64 = 60;
|
||
const MIN_CHANNEL_MESSAGE_TIMEOUT_SECS: u64 = 30;
|
||
/// Default timeout for processing a single channel message (LLM + tools).
|
||
/// Used as fallback when not configured in channels_config.message_timeout_secs.
|
||
const CHANNEL_MESSAGE_TIMEOUT_SECS: u64 = 300;
|
||
const CHANNEL_PARALLELISM_PER_CHANNEL: usize = 4;
|
||
const CHANNEL_MIN_IN_FLIGHT_MESSAGES: usize = 8;
|
||
const CHANNEL_MAX_IN_FLIGHT_MESSAGES: usize = 64;
|
||
const CHANNEL_TYPING_REFRESH_INTERVAL_SECS: u64 = 4;
|
||
const MODEL_CACHE_FILE: &str = "models_cache.json";
|
||
const MODEL_CACHE_PREVIEW_LIMIT: usize = 10;
|
||
const MEMORY_CONTEXT_MAX_ENTRIES: usize = 4;
|
||
const MEMORY_CONTEXT_ENTRY_MAX_CHARS: usize = 800;
|
||
const MEMORY_CONTEXT_MAX_CHARS: usize = 4_000;
|
||
const CHANNEL_HISTORY_COMPACT_KEEP_MESSAGES: usize = 12;
|
||
const CHANNEL_HISTORY_COMPACT_CONTENT_CHARS: usize = 600;
|
||
|
||
type ProviderCacheMap = Arc<Mutex<HashMap<String, Arc<dyn Provider>>>>;
|
||
type RouteSelectionMap = Arc<Mutex<HashMap<String, ChannelRouteSelection>>>;
|
||
|
||
fn effective_channel_message_timeout_secs(configured: u64) -> u64 {
|
||
configured.max(MIN_CHANNEL_MESSAGE_TIMEOUT_SECS)
|
||
}
|
||
|
||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||
struct ChannelRouteSelection {
|
||
provider: String,
|
||
model: String,
|
||
}
|
||
|
||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||
enum ChannelRuntimeCommand {
|
||
ShowProviders,
|
||
SetProvider(String),
|
||
ShowModel,
|
||
SetModel(String),
|
||
}
|
||
|
||
#[derive(Debug, Clone, Default, Deserialize)]
|
||
struct ModelCacheState {
|
||
entries: Vec<ModelCacheEntry>,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Default, Deserialize)]
|
||
struct ModelCacheEntry {
|
||
provider: String,
|
||
models: Vec<String>,
|
||
}
|
||
|
||
#[derive(Clone)]
|
||
struct ChannelRuntimeContext {
|
||
channels_by_name: Arc<HashMap<String, Arc<dyn Channel>>>,
|
||
provider: Arc<dyn Provider>,
|
||
default_provider: Arc<String>,
|
||
memory: Arc<dyn Memory>,
|
||
tools_registry: Arc<Vec<Box<dyn Tool>>>,
|
||
observer: Arc<dyn Observer>,
|
||
system_prompt: Arc<String>,
|
||
model: Arc<String>,
|
||
temperature: f64,
|
||
auto_save_memory: bool,
|
||
max_tool_iterations: usize,
|
||
min_relevance_score: f64,
|
||
conversation_histories: ConversationHistoryMap,
|
||
provider_cache: ProviderCacheMap,
|
||
route_overrides: RouteSelectionMap,
|
||
api_key: Option<String>,
|
||
api_url: Option<String>,
|
||
reliability: Arc<crate::config::ReliabilityConfig>,
|
||
provider_runtime_options: providers::ProviderRuntimeOptions,
|
||
workspace_dir: Arc<PathBuf>,
|
||
message_timeout_secs: u64,
|
||
interrupt_on_new_message: bool,
|
||
multimodal: crate::config::MultimodalConfig,
|
||
}
|
||
|
||
#[derive(Clone)]
|
||
struct InFlightSenderTaskState {
|
||
task_id: u64,
|
||
cancellation: CancellationToken,
|
||
completion: Arc<InFlightTaskCompletion>,
|
||
}
|
||
|
||
struct InFlightTaskCompletion {
|
||
done: AtomicBool,
|
||
notify: tokio::sync::Notify,
|
||
}
|
||
|
||
impl InFlightTaskCompletion {
|
||
fn new() -> Self {
|
||
Self {
|
||
done: AtomicBool::new(false),
|
||
notify: tokio::sync::Notify::new(),
|
||
}
|
||
}
|
||
|
||
fn mark_done(&self) {
|
||
self.done.store(true, Ordering::Release);
|
||
self.notify.notify_waiters();
|
||
}
|
||
|
||
async fn wait(&self) {
|
||
if self.done.load(Ordering::Acquire) {
|
||
return;
|
||
}
|
||
self.notify.notified().await;
|
||
}
|
||
}
|
||
|
||
fn conversation_memory_key(msg: &traits::ChannelMessage) -> String {
|
||
format!("{}_{}_{}", msg.channel, msg.sender, msg.id)
|
||
}
|
||
|
||
fn conversation_history_key(msg: &traits::ChannelMessage) -> String {
|
||
format!("{}_{}", msg.channel, msg.sender)
|
||
}
|
||
|
||
fn interruption_scope_key(msg: &traits::ChannelMessage) -> String {
|
||
format!("{}_{}_{}", msg.channel, msg.reply_target, msg.sender)
|
||
}
|
||
|
||
fn channel_delivery_instructions(channel_name: &str) -> Option<&'static str> {
|
||
match channel_name {
|
||
"telegram" => Some(
|
||
"When responding on Telegram, include media markers for files or URLs that should be sent as attachments. Use one marker per attachment with this exact syntax: [IMAGE:<path-or-url>], [DOCUMENT:<path-or-url>], [VIDEO:<path-or-url>], [AUDIO:<path-or-url>], or [VOICE:<path-or-url>]. Keep normal user-facing text outside markers and never wrap markers in code fences.",
|
||
),
|
||
_ => None,
|
||
}
|
||
}
|
||
|
||
fn build_channel_system_prompt(base_prompt: &str, channel_name: &str) -> String {
|
||
if let Some(instructions) = channel_delivery_instructions(channel_name) {
|
||
if base_prompt.is_empty() {
|
||
instructions.to_string()
|
||
} else {
|
||
format!("{base_prompt}\n\n{instructions}")
|
||
}
|
||
} else {
|
||
base_prompt.to_string()
|
||
}
|
||
}
|
||
|
||
fn normalize_cached_channel_turns(turns: Vec<ChatMessage>) -> Vec<ChatMessage> {
|
||
let mut normalized = Vec::with_capacity(turns.len());
|
||
let mut expecting_user = true;
|
||
|
||
for turn in turns {
|
||
match (expecting_user, turn.role.as_str()) {
|
||
(true, "user") => {
|
||
normalized.push(turn);
|
||
expecting_user = false;
|
||
}
|
||
(false, "assistant") => {
|
||
normalized.push(turn);
|
||
expecting_user = true;
|
||
}
|
||
_ => {}
|
||
}
|
||
}
|
||
|
||
normalized
|
||
}
|
||
|
||
fn supports_runtime_model_switch(channel_name: &str) -> bool {
|
||
matches!(channel_name, "telegram" | "discord")
|
||
}
|
||
|
||
fn parse_runtime_command(channel_name: &str, content: &str) -> Option<ChannelRuntimeCommand> {
|
||
if !supports_runtime_model_switch(channel_name) {
|
||
return None;
|
||
}
|
||
|
||
let trimmed = content.trim();
|
||
if !trimmed.starts_with('/') {
|
||
return None;
|
||
}
|
||
|
||
let mut parts = trimmed.split_whitespace();
|
||
let command_token = parts.next()?;
|
||
let base_command = command_token
|
||
.split('@')
|
||
.next()
|
||
.unwrap_or(command_token)
|
||
.to_ascii_lowercase();
|
||
|
||
match base_command.as_str() {
|
||
"/models" => {
|
||
if let Some(provider) = parts.next() {
|
||
Some(ChannelRuntimeCommand::SetProvider(
|
||
provider.trim().to_string(),
|
||
))
|
||
} else {
|
||
Some(ChannelRuntimeCommand::ShowProviders)
|
||
}
|
||
}
|
||
"/model" => {
|
||
let model = parts.collect::<Vec<_>>().join(" ").trim().to_string();
|
||
if model.is_empty() {
|
||
Some(ChannelRuntimeCommand::ShowModel)
|
||
} else {
|
||
Some(ChannelRuntimeCommand::SetModel(model))
|
||
}
|
||
}
|
||
_ => None,
|
||
}
|
||
}
|
||
|
||
fn resolve_provider_alias(name: &str) -> Option<String> {
|
||
let candidate = name.trim();
|
||
if candidate.is_empty() {
|
||
return None;
|
||
}
|
||
|
||
let providers_list = providers::list_providers();
|
||
for provider in providers_list {
|
||
if provider.name.eq_ignore_ascii_case(candidate)
|
||
|| provider
|
||
.aliases
|
||
.iter()
|
||
.any(|alias| alias.eq_ignore_ascii_case(candidate))
|
||
{
|
||
return Some(provider.name.to_string());
|
||
}
|
||
}
|
||
|
||
None
|
||
}
|
||
|
||
fn default_route_selection(ctx: &ChannelRuntimeContext) -> ChannelRouteSelection {
|
||
ChannelRouteSelection {
|
||
provider: ctx.default_provider.as_str().to_string(),
|
||
model: ctx.model.as_str().to_string(),
|
||
}
|
||
}
|
||
|
||
fn get_route_selection(ctx: &ChannelRuntimeContext, sender_key: &str) -> ChannelRouteSelection {
|
||
ctx.route_overrides
|
||
.lock()
|
||
.unwrap_or_else(|e| e.into_inner())
|
||
.get(sender_key)
|
||
.cloned()
|
||
.unwrap_or_else(|| default_route_selection(ctx))
|
||
}
|
||
|
||
fn set_route_selection(ctx: &ChannelRuntimeContext, sender_key: &str, next: ChannelRouteSelection) {
|
||
let default_route = default_route_selection(ctx);
|
||
let mut routes = ctx
|
||
.route_overrides
|
||
.lock()
|
||
.unwrap_or_else(|e| e.into_inner());
|
||
if next == default_route {
|
||
routes.remove(sender_key);
|
||
} else {
|
||
routes.insert(sender_key.to_string(), next);
|
||
}
|
||
}
|
||
|
||
fn clear_sender_history(ctx: &ChannelRuntimeContext, sender_key: &str) {
|
||
ctx.conversation_histories
|
||
.lock()
|
||
.unwrap_or_else(|e| e.into_inner())
|
||
.remove(sender_key);
|
||
}
|
||
|
||
fn compact_sender_history(ctx: &ChannelRuntimeContext, sender_key: &str) -> bool {
|
||
let mut histories = ctx
|
||
.conversation_histories
|
||
.lock()
|
||
.unwrap_or_else(|e| e.into_inner());
|
||
|
||
let Some(turns) = histories.get_mut(sender_key) else {
|
||
return false;
|
||
};
|
||
|
||
if turns.is_empty() {
|
||
return false;
|
||
}
|
||
|
||
let keep_from = turns
|
||
.len()
|
||
.saturating_sub(CHANNEL_HISTORY_COMPACT_KEEP_MESSAGES);
|
||
let mut compacted = normalize_cached_channel_turns(turns[keep_from..].to_vec());
|
||
|
||
for turn in &mut compacted {
|
||
if turn.content.chars().count() > CHANNEL_HISTORY_COMPACT_CONTENT_CHARS {
|
||
turn.content =
|
||
truncate_with_ellipsis(&turn.content, CHANNEL_HISTORY_COMPACT_CONTENT_CHARS);
|
||
}
|
||
}
|
||
|
||
if compacted.is_empty() {
|
||
turns.clear();
|
||
return false;
|
||
}
|
||
|
||
*turns = compacted;
|
||
true
|
||
}
|
||
|
||
fn append_sender_turn(ctx: &ChannelRuntimeContext, sender_key: &str, turn: ChatMessage) {
|
||
let mut histories = ctx
|
||
.conversation_histories
|
||
.lock()
|
||
.unwrap_or_else(|e| e.into_inner());
|
||
let turns = histories.entry(sender_key.to_string()).or_default();
|
||
turns.push(turn);
|
||
while turns.len() > MAX_CHANNEL_HISTORY {
|
||
turns.remove(0);
|
||
}
|
||
}
|
||
|
||
fn should_skip_memory_context_entry(key: &str, content: &str) -> bool {
|
||
if memory::is_assistant_autosave_key(key) {
|
||
return true;
|
||
}
|
||
|
||
if key.trim().to_ascii_lowercase().ends_with("_history") {
|
||
return true;
|
||
}
|
||
|
||
content.chars().count() > MEMORY_CONTEXT_MAX_CHARS
|
||
}
|
||
|
||
fn is_context_window_overflow_error(err: &anyhow::Error) -> bool {
|
||
let lower = err.to_string().to_lowercase();
|
||
[
|
||
"exceeds the context window",
|
||
"context window of this model",
|
||
"maximum context length",
|
||
"context length exceeded",
|
||
"too many tokens",
|
||
"token limit exceeded",
|
||
"prompt is too long",
|
||
"input is too long",
|
||
]
|
||
.iter()
|
||
.any(|hint| lower.contains(hint))
|
||
}
|
||
|
||
fn load_cached_model_preview(workspace_dir: &Path, provider_name: &str) -> Vec<String> {
|
||
let cache_path = workspace_dir.join("state").join(MODEL_CACHE_FILE);
|
||
let Ok(raw) = std::fs::read_to_string(cache_path) else {
|
||
return Vec::new();
|
||
};
|
||
let Ok(state) = serde_json::from_str::<ModelCacheState>(&raw) else {
|
||
return Vec::new();
|
||
};
|
||
|
||
state
|
||
.entries
|
||
.into_iter()
|
||
.find(|entry| entry.provider == provider_name)
|
||
.map(|entry| {
|
||
entry
|
||
.models
|
||
.into_iter()
|
||
.take(MODEL_CACHE_PREVIEW_LIMIT)
|
||
.collect::<Vec<_>>()
|
||
})
|
||
.unwrap_or_default()
|
||
}
|
||
|
||
async fn get_or_create_provider(
|
||
ctx: &ChannelRuntimeContext,
|
||
provider_name: &str,
|
||
) -> anyhow::Result<Arc<dyn Provider>> {
|
||
if provider_name == ctx.default_provider.as_str() {
|
||
return Ok(Arc::clone(&ctx.provider));
|
||
}
|
||
|
||
if let Some(existing) = ctx
|
||
.provider_cache
|
||
.lock()
|
||
.unwrap_or_else(|e| e.into_inner())
|
||
.get(provider_name)
|
||
.cloned()
|
||
{
|
||
return Ok(existing);
|
||
}
|
||
|
||
let api_url = if provider_name == ctx.default_provider.as_str() {
|
||
ctx.api_url.as_deref()
|
||
} else {
|
||
None
|
||
};
|
||
|
||
let provider = providers::create_resilient_provider_with_options(
|
||
provider_name,
|
||
ctx.api_key.as_deref(),
|
||
api_url,
|
||
&ctx.reliability,
|
||
&ctx.provider_runtime_options,
|
||
)?;
|
||
let provider: Arc<dyn Provider> = Arc::from(provider);
|
||
|
||
if let Err(err) = provider.warmup().await {
|
||
tracing::warn!(provider = provider_name, "Provider warmup failed: {err}");
|
||
}
|
||
|
||
let mut cache = ctx.provider_cache.lock().unwrap_or_else(|e| e.into_inner());
|
||
let cached = cache
|
||
.entry(provider_name.to_string())
|
||
.or_insert_with(|| Arc::clone(&provider));
|
||
Ok(Arc::clone(cached))
|
||
}
|
||
|
||
fn build_models_help_response(current: &ChannelRouteSelection, workspace_dir: &Path) -> String {
|
||
let mut response = String::new();
|
||
let _ = writeln!(
|
||
response,
|
||
"Current provider: `{}`\nCurrent model: `{}`",
|
||
current.provider, current.model
|
||
);
|
||
response.push_str("\nSwitch model with `/model <model-id>`.\n");
|
||
|
||
let cached_models = load_cached_model_preview(workspace_dir, ¤t.provider);
|
||
if cached_models.is_empty() {
|
||
let _ = writeln!(
|
||
response,
|
||
"\nNo cached model list found for `{}`. Ask the operator to run `zeroclaw models refresh --provider {}`.",
|
||
current.provider, current.provider
|
||
);
|
||
} else {
|
||
let _ = writeln!(
|
||
response,
|
||
"\nCached model IDs (top {}):",
|
||
cached_models.len()
|
||
);
|
||
for model in cached_models {
|
||
let _ = writeln!(response, "- `{model}`");
|
||
}
|
||
}
|
||
|
||
response
|
||
}
|
||
|
||
fn build_providers_help_response(current: &ChannelRouteSelection) -> String {
|
||
let mut response = String::new();
|
||
let _ = writeln!(
|
||
response,
|
||
"Current provider: `{}`\nCurrent model: `{}`",
|
||
current.provider, current.model
|
||
);
|
||
response.push_str("\nSwitch provider with `/models <provider>`.\n");
|
||
response.push_str("Switch model with `/model <model-id>`.\n\n");
|
||
response.push_str("Available providers:\n");
|
||
for provider in providers::list_providers() {
|
||
if provider.aliases.is_empty() {
|
||
let _ = writeln!(response, "- {}", provider.name);
|
||
} else {
|
||
let _ = writeln!(
|
||
response,
|
||
"- {} (aliases: {})",
|
||
provider.name,
|
||
provider.aliases.join(", ")
|
||
);
|
||
}
|
||
}
|
||
response
|
||
}
|
||
|
||
async fn handle_runtime_command_if_needed(
|
||
ctx: &ChannelRuntimeContext,
|
||
msg: &traits::ChannelMessage,
|
||
target_channel: Option<&Arc<dyn Channel>>,
|
||
) -> bool {
|
||
let Some(command) = parse_runtime_command(&msg.channel, &msg.content) else {
|
||
return false;
|
||
};
|
||
|
||
let Some(channel) = target_channel else {
|
||
return true;
|
||
};
|
||
|
||
let sender_key = conversation_history_key(msg);
|
||
let mut current = get_route_selection(ctx, &sender_key);
|
||
|
||
let response = match command {
|
||
ChannelRuntimeCommand::ShowProviders => build_providers_help_response(¤t),
|
||
ChannelRuntimeCommand::SetProvider(raw_provider) => {
|
||
match resolve_provider_alias(&raw_provider) {
|
||
Some(provider_name) => match get_or_create_provider(ctx, &provider_name).await {
|
||
Ok(_) => {
|
||
if provider_name != current.provider {
|
||
current.provider = provider_name.clone();
|
||
set_route_selection(ctx, &sender_key, current.clone());
|
||
clear_sender_history(ctx, &sender_key);
|
||
}
|
||
|
||
format!(
|
||
"Provider switched to `{provider_name}` for this sender session. Current model is `{}`.\nUse `/model <model-id>` to set a provider-compatible model.",
|
||
current.model
|
||
)
|
||
}
|
||
Err(err) => {
|
||
let safe_err = providers::sanitize_api_error(&err.to_string());
|
||
format!(
|
||
"Failed to initialize provider `{provider_name}`. Route unchanged.\nDetails: {safe_err}"
|
||
)
|
||
}
|
||
},
|
||
None => format!(
|
||
"Unknown provider `{raw_provider}`. Use `/models` to list valid providers."
|
||
),
|
||
}
|
||
}
|
||
ChannelRuntimeCommand::ShowModel => {
|
||
build_models_help_response(¤t, ctx.workspace_dir.as_path())
|
||
}
|
||
ChannelRuntimeCommand::SetModel(raw_model) => {
|
||
let model = raw_model.trim().trim_matches('`').to_string();
|
||
if model.is_empty() {
|
||
"Model ID cannot be empty. Use `/model <model-id>`.".to_string()
|
||
} else {
|
||
current.model = model.clone();
|
||
set_route_selection(ctx, &sender_key, current.clone());
|
||
clear_sender_history(ctx, &sender_key);
|
||
|
||
format!(
|
||
"Model switched to `{model}` for provider `{}` in this sender session.",
|
||
current.provider
|
||
)
|
||
}
|
||
}
|
||
};
|
||
|
||
if let Err(err) = channel
|
||
.send(&SendMessage::new(response, &msg.reply_target).in_thread(msg.thread_ts.clone()))
|
||
.await
|
||
{
|
||
tracing::warn!(
|
||
"Failed to send runtime command response on {}: {err}",
|
||
channel.name()
|
||
);
|
||
}
|
||
|
||
true
|
||
}
|
||
|
||
async fn build_memory_context(
|
||
mem: &dyn Memory,
|
||
user_msg: &str,
|
||
min_relevance_score: f64,
|
||
) -> String {
|
||
let mut context = String::new();
|
||
|
||
if let Ok(entries) = mem.recall(user_msg, 5, None).await {
|
||
let mut included = 0usize;
|
||
let mut used_chars = 0usize;
|
||
|
||
for entry in entries.iter().filter(|e| match e.score {
|
||
Some(score) => score >= min_relevance_score,
|
||
None => true, // keep entries without a score (e.g. non-vector backends)
|
||
}) {
|
||
if included >= MEMORY_CONTEXT_MAX_ENTRIES {
|
||
break;
|
||
}
|
||
|
||
if should_skip_memory_context_entry(&entry.key, &entry.content) {
|
||
continue;
|
||
}
|
||
|
||
let content = if entry.content.chars().count() > MEMORY_CONTEXT_ENTRY_MAX_CHARS {
|
||
truncate_with_ellipsis(&entry.content, MEMORY_CONTEXT_ENTRY_MAX_CHARS)
|
||
} else {
|
||
entry.content.clone()
|
||
};
|
||
|
||
let line = format!("- {}: {}\n", entry.key, content);
|
||
let line_chars = line.chars().count();
|
||
if used_chars + line_chars > MEMORY_CONTEXT_MAX_CHARS {
|
||
break;
|
||
}
|
||
|
||
if included == 0 {
|
||
context.push_str("[Memory context]\n");
|
||
}
|
||
|
||
context.push_str(&line);
|
||
used_chars += line_chars;
|
||
included += 1;
|
||
}
|
||
|
||
if included > 0 {
|
||
context.push('\n');
|
||
}
|
||
}
|
||
|
||
context
|
||
}
|
||
|
||
fn spawn_supervised_listener(
|
||
ch: Arc<dyn Channel>,
|
||
tx: tokio::sync::mpsc::Sender<traits::ChannelMessage>,
|
||
initial_backoff_secs: u64,
|
||
max_backoff_secs: u64,
|
||
) -> tokio::task::JoinHandle<()> {
|
||
tokio::spawn(async move {
|
||
let component = format!("channel:{}", ch.name());
|
||
let mut backoff = initial_backoff_secs.max(1);
|
||
let max_backoff = max_backoff_secs.max(backoff);
|
||
|
||
loop {
|
||
crate::health::mark_component_ok(&component);
|
||
let result = ch.listen(tx.clone()).await;
|
||
|
||
if tx.is_closed() {
|
||
break;
|
||
}
|
||
|
||
match result {
|
||
Ok(()) => {
|
||
tracing::warn!("Channel {} exited unexpectedly; restarting", ch.name());
|
||
crate::health::mark_component_error(&component, "listener exited unexpectedly");
|
||
// Clean exit — reset backoff since the listener ran successfully
|
||
backoff = initial_backoff_secs.max(1);
|
||
}
|
||
Err(e) => {
|
||
tracing::error!("Channel {} error: {e}; restarting", ch.name());
|
||
crate::health::mark_component_error(&component, e.to_string());
|
||
}
|
||
}
|
||
|
||
crate::health::bump_component_restart(&component);
|
||
tokio::time::sleep(Duration::from_secs(backoff)).await;
|
||
// Double backoff AFTER sleeping so first error uses initial_backoff
|
||
backoff = backoff.saturating_mul(2).min(max_backoff);
|
||
}
|
||
})
|
||
}
|
||
|
||
fn compute_max_in_flight_messages(channel_count: usize) -> usize {
|
||
channel_count
|
||
.saturating_mul(CHANNEL_PARALLELISM_PER_CHANNEL)
|
||
.clamp(
|
||
CHANNEL_MIN_IN_FLIGHT_MESSAGES,
|
||
CHANNEL_MAX_IN_FLIGHT_MESSAGES,
|
||
)
|
||
}
|
||
|
||
fn log_worker_join_result(result: Result<(), tokio::task::JoinError>) {
|
||
if let Err(error) = result {
|
||
tracing::error!("Channel message worker crashed: {error}");
|
||
}
|
||
}
|
||
|
||
fn spawn_scoped_typing_task(
|
||
channel: Arc<dyn Channel>,
|
||
recipient: String,
|
||
cancellation_token: CancellationToken,
|
||
) -> tokio::task::JoinHandle<()> {
|
||
let stop_signal = cancellation_token;
|
||
let refresh_interval = Duration::from_secs(CHANNEL_TYPING_REFRESH_INTERVAL_SECS);
|
||
let handle = tokio::spawn(async move {
|
||
let mut interval = tokio::time::interval(refresh_interval);
|
||
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||
|
||
loop {
|
||
tokio::select! {
|
||
() = stop_signal.cancelled() => break,
|
||
_ = interval.tick() => {
|
||
if let Err(e) = channel.start_typing(&recipient).await {
|
||
tracing::debug!("Failed to start typing on {}: {e}", channel.name());
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if let Err(e) = channel.stop_typing(&recipient).await {
|
||
tracing::debug!("Failed to stop typing on {}: {e}", channel.name());
|
||
}
|
||
});
|
||
|
||
handle
|
||
}
|
||
|
||
async fn process_channel_message(
|
||
ctx: Arc<ChannelRuntimeContext>,
|
||
msg: traits::ChannelMessage,
|
||
cancellation_token: CancellationToken,
|
||
) {
|
||
if cancellation_token.is_cancelled() {
|
||
return;
|
||
}
|
||
|
||
println!(
|
||
" 💬 [{}] from {}: {}",
|
||
msg.channel,
|
||
msg.sender,
|
||
truncate_with_ellipsis(&msg.content, 80)
|
||
);
|
||
|
||
let target_channel = ctx.channels_by_name.get(&msg.channel).cloned();
|
||
if handle_runtime_command_if_needed(ctx.as_ref(), &msg, target_channel.as_ref()).await {
|
||
return;
|
||
}
|
||
|
||
let history_key = conversation_history_key(&msg);
|
||
let route = get_route_selection(ctx.as_ref(), &history_key);
|
||
let active_provider = match get_or_create_provider(ctx.as_ref(), &route.provider).await {
|
||
Ok(provider) => provider,
|
||
Err(err) => {
|
||
let safe_err = providers::sanitize_api_error(&err.to_string());
|
||
let message = format!(
|
||
"⚠️ Failed to initialize provider `{}`. Please run `/models` to choose another provider.\nDetails: {safe_err}",
|
||
route.provider
|
||
);
|
||
if let Some(channel) = target_channel.as_ref() {
|
||
let _ = channel
|
||
.send(
|
||
&SendMessage::new(message, &msg.reply_target)
|
||
.in_thread(msg.thread_ts.clone()),
|
||
)
|
||
.await;
|
||
}
|
||
return;
|
||
}
|
||
};
|
||
|
||
let memory_context =
|
||
build_memory_context(ctx.memory.as_ref(), &msg.content, ctx.min_relevance_score).await;
|
||
|
||
if ctx.auto_save_memory && msg.content.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS {
|
||
let autosave_key = conversation_memory_key(&msg);
|
||
let _ = ctx
|
||
.memory
|
||
.store(
|
||
&autosave_key,
|
||
&msg.content,
|
||
crate::memory::MemoryCategory::Conversation,
|
||
None,
|
||
)
|
||
.await;
|
||
}
|
||
|
||
let enriched_message = if memory_context.is_empty() {
|
||
msg.content.clone()
|
||
} else {
|
||
format!("{memory_context}{}", msg.content)
|
||
};
|
||
|
||
println!(" ⏳ Processing message...");
|
||
let started_at = Instant::now();
|
||
|
||
// Preserve user turn before the LLM call so interrupted requests keep context.
|
||
append_sender_turn(
|
||
ctx.as_ref(),
|
||
&history_key,
|
||
ChatMessage::user(&msg.content),
|
||
);
|
||
|
||
// Build history from per-sender conversation cache.
|
||
let prior_turns_raw = ctx
|
||
.conversation_histories
|
||
.lock()
|
||
.unwrap_or_else(|e| e.into_inner())
|
||
.get(&history_key)
|
||
.cloned()
|
||
.unwrap_or_default();
|
||
let prior_turns = normalize_cached_channel_turns(prior_turns_raw);
|
||
|
||
let system_prompt = build_channel_system_prompt(ctx.system_prompt.as_str(), &msg.channel);
|
||
let mut history = vec![ChatMessage::system(system_prompt)];
|
||
history.extend(prior_turns);
|
||
let use_streaming = target_channel
|
||
.as_ref()
|
||
.is_some_and(|ch| ch.supports_draft_updates());
|
||
|
||
let (delta_tx, delta_rx) = if use_streaming {
|
||
let (tx, rx) = tokio::sync::mpsc::channel::<String>(64);
|
||
(Some(tx), Some(rx))
|
||
} else {
|
||
(None, None)
|
||
};
|
||
|
||
let draft_message_id = if use_streaming {
|
||
if let Some(channel) = target_channel.as_ref() {
|
||
match channel
|
||
.send_draft(
|
||
&SendMessage::new("...", &msg.reply_target).in_thread(msg.thread_ts.clone()),
|
||
)
|
||
.await
|
||
{
|
||
Ok(id) => id,
|
||
Err(e) => {
|
||
tracing::debug!("Failed to send draft on {}: {e}", channel.name());
|
||
None
|
||
}
|
||
}
|
||
} else {
|
||
None
|
||
}
|
||
} else {
|
||
None
|
||
};
|
||
|
||
let draft_updater = if let (Some(mut rx), Some(draft_id_ref), Some(channel_ref)) = (
|
||
delta_rx,
|
||
draft_message_id.as_deref(),
|
||
target_channel.as_ref(),
|
||
) {
|
||
let channel = Arc::clone(channel_ref);
|
||
let reply_target = msg.reply_target.clone();
|
||
let draft_id = draft_id_ref.to_string();
|
||
Some(tokio::spawn(async move {
|
||
let mut accumulated = String::new();
|
||
while let Some(delta) = rx.recv().await {
|
||
accumulated.push_str(&delta);
|
||
if let Err(e) = channel
|
||
.update_draft(&reply_target, &draft_id, &accumulated)
|
||
.await
|
||
{
|
||
tracing::debug!("Draft update failed: {e}");
|
||
}
|
||
}
|
||
}))
|
||
} else {
|
||
None
|
||
};
|
||
|
||
let typing_cancellation = target_channel.as_ref().map(|_| CancellationToken::new());
|
||
let typing_task = match (target_channel.as_ref(), typing_cancellation.as_ref()) {
|
||
(Some(channel), Some(token)) => Some(spawn_scoped_typing_task(
|
||
Arc::clone(channel),
|
||
msg.reply_target.clone(),
|
||
token.clone(),
|
||
)),
|
||
_ => None,
|
||
};
|
||
|
||
enum LlmExecutionResult {
|
||
Completed(Result<Result<String, anyhow::Error>, tokio::time::error::Elapsed>),
|
||
Cancelled,
|
||
}
|
||
|
||
let llm_result = tokio::select! {
|
||
() = cancellation_token.cancelled() => LlmExecutionResult::Cancelled,
|
||
result = tokio::time::timeout(
|
||
Duration::from_secs(ctx.message_timeout_secs),
|
||
run_tool_call_loop(
|
||
active_provider.as_ref(),
|
||
&mut history,
|
||
ctx.tools_registry.as_ref(),
|
||
ctx.observer.as_ref(),
|
||
route.provider.as_str(),
|
||
route.model.as_str(),
|
||
ctx.temperature,
|
||
true,
|
||
None,
|
||
msg.channel.as_str(),
|
||
&ctx.multimodal,
|
||
ctx.max_tool_iterations,
|
||
Some(cancellation_token.clone()),
|
||
delta_tx,
|
||
),
|
||
) => LlmExecutionResult::Completed(result),
|
||
};
|
||
|
||
if let Some(handle) = draft_updater {
|
||
let _ = handle.await;
|
||
}
|
||
|
||
if let Some(token) = typing_cancellation.as_ref() {
|
||
token.cancel();
|
||
}
|
||
if let Some(handle) = typing_task {
|
||
log_worker_join_result(handle.await);
|
||
}
|
||
|
||
match llm_result {
|
||
LlmExecutionResult::Cancelled => {
|
||
tracing::info!(
|
||
channel = %msg.channel,
|
||
sender = %msg.sender,
|
||
"Cancelled in-flight channel request due to newer message"
|
||
);
|
||
if let (Some(channel), Some(draft_id)) =
|
||
(target_channel.as_ref(), draft_message_id.as_deref())
|
||
{
|
||
if let Err(err) = channel.cancel_draft(&msg.reply_target, draft_id).await {
|
||
tracing::debug!("Failed to cancel draft on {}: {err}", channel.name());
|
||
}
|
||
}
|
||
}
|
||
LlmExecutionResult::Completed(Ok(Ok(response))) => {
|
||
append_sender_turn(
|
||
ctx.as_ref(),
|
||
&history_key,
|
||
ChatMessage::assistant(&response),
|
||
);
|
||
println!(
|
||
" 🤖 Reply ({}ms): {}",
|
||
started_at.elapsed().as_millis(),
|
||
truncate_with_ellipsis(&response, 80)
|
||
);
|
||
if let Some(channel) = target_channel.as_ref() {
|
||
if let Some(ref draft_id) = draft_message_id {
|
||
if let Err(e) = channel
|
||
.finalize_draft(&msg.reply_target, draft_id, &response)
|
||
.await
|
||
{
|
||
tracing::warn!("Failed to finalize draft: {e}; sending as new message");
|
||
let _ = channel
|
||
.send(
|
||
&SendMessage::new(&response, &msg.reply_target)
|
||
.in_thread(msg.thread_ts.clone()),
|
||
)
|
||
.await;
|
||
}
|
||
} else if let Err(e) = channel
|
||
.send(
|
||
&SendMessage::new(response, &msg.reply_target)
|
||
.in_thread(msg.thread_ts.clone()),
|
||
)
|
||
.await
|
||
{
|
||
eprintln!(" ❌ Failed to reply on {}: {e}", channel.name());
|
||
}
|
||
}
|
||
}
|
||
LlmExecutionResult::Completed(Ok(Err(e))) => {
|
||
if crate::agent::loop_::is_tool_loop_cancelled(&e) || cancellation_token.is_cancelled()
|
||
{
|
||
tracing::info!(
|
||
channel = %msg.channel,
|
||
sender = %msg.sender,
|
||
"Cancelled in-flight channel request due to newer message"
|
||
);
|
||
if let (Some(channel), Some(draft_id)) =
|
||
(target_channel.as_ref(), draft_message_id.as_deref())
|
||
{
|
||
if let Err(err) = channel.cancel_draft(&msg.reply_target, draft_id).await {
|
||
tracing::debug!("Failed to cancel draft on {}: {err}", channel.name());
|
||
}
|
||
}
|
||
return;
|
||
}
|
||
|
||
if is_context_window_overflow_error(&e) {
|
||
let compacted = compact_sender_history(ctx.as_ref(), &history_key);
|
||
let error_text = if compacted {
|
||
"⚠️ Context window exceeded for this conversation. I compacted recent history and kept the latest context. Please resend your last message."
|
||
} else {
|
||
"⚠️ Context window exceeded for this conversation. Please resend your last message."
|
||
};
|
||
eprintln!(
|
||
" ⚠️ Context window exceeded after {}ms; sender history compacted={}",
|
||
started_at.elapsed().as_millis(),
|
||
compacted
|
||
);
|
||
if let Some(channel) = target_channel.as_ref() {
|
||
if let Some(ref draft_id) = draft_message_id {
|
||
let _ = channel
|
||
.finalize_draft(&msg.reply_target, draft_id, error_text)
|
||
.await;
|
||
} else {
|
||
let _ = channel
|
||
.send(
|
||
&SendMessage::new(error_text, &msg.reply_target)
|
||
.in_thread(msg.thread_ts.clone()),
|
||
)
|
||
.await;
|
||
}
|
||
}
|
||
return;
|
||
}
|
||
|
||
eprintln!(
|
||
" ❌ LLM error after {}ms: {e}",
|
||
started_at.elapsed().as_millis()
|
||
);
|
||
if let Some(channel) = target_channel.as_ref() {
|
||
if let Some(ref draft_id) = draft_message_id {
|
||
let _ = channel
|
||
.finalize_draft(&msg.reply_target, draft_id, &format!("⚠️ Error: {e}"))
|
||
.await;
|
||
} else {
|
||
let _ = channel
|
||
.send(
|
||
&SendMessage::new(format!("⚠️ Error: {e}"), &msg.reply_target)
|
||
.in_thread(msg.thread_ts.clone()),
|
||
)
|
||
.await;
|
||
}
|
||
}
|
||
}
|
||
LlmExecutionResult::Completed(Err(_)) => {
|
||
let timeout_msg = format!("LLM response timed out after {}s", ctx.message_timeout_secs);
|
||
eprintln!(
|
||
" ❌ {} (elapsed: {}ms)",
|
||
timeout_msg,
|
||
started_at.elapsed().as_millis()
|
||
);
|
||
if let Some(channel) = target_channel.as_ref() {
|
||
let error_text =
|
||
"⚠️ Request timed out while waiting for the model. Please try again.";
|
||
if let Some(ref draft_id) = draft_message_id {
|
||
let _ = channel
|
||
.finalize_draft(&msg.reply_target, draft_id, error_text)
|
||
.await;
|
||
} else {
|
||
let _ = channel
|
||
.send(
|
||
&SendMessage::new(error_text, &msg.reply_target)
|
||
.in_thread(msg.thread_ts.clone()),
|
||
)
|
||
.await;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
async fn run_message_dispatch_loop(
|
||
mut rx: tokio::sync::mpsc::Receiver<traits::ChannelMessage>,
|
||
ctx: Arc<ChannelRuntimeContext>,
|
||
max_in_flight_messages: usize,
|
||
) {
|
||
let semaphore = Arc::new(tokio::sync::Semaphore::new(max_in_flight_messages));
|
||
let mut workers = tokio::task::JoinSet::new();
|
||
let in_flight_by_sender = Arc::new(tokio::sync::Mutex::new(HashMap::<
|
||
String,
|
||
InFlightSenderTaskState,
|
||
>::new()));
|
||
let task_sequence = Arc::new(AtomicU64::new(1));
|
||
|
||
while let Some(msg) = rx.recv().await {
|
||
let permit = match Arc::clone(&semaphore).acquire_owned().await {
|
||
Ok(permit) => permit,
|
||
Err(_) => break,
|
||
};
|
||
|
||
let worker_ctx = Arc::clone(&ctx);
|
||
let in_flight = Arc::clone(&in_flight_by_sender);
|
||
let task_sequence = Arc::clone(&task_sequence);
|
||
workers.spawn(async move {
|
||
let _permit = permit;
|
||
let interrupt_enabled =
|
||
worker_ctx.interrupt_on_new_message && msg.channel == "telegram";
|
||
let sender_scope_key = interruption_scope_key(&msg);
|
||
let cancellation_token = CancellationToken::new();
|
||
let completion = Arc::new(InFlightTaskCompletion::new());
|
||
let task_id = task_sequence.fetch_add(1, Ordering::Relaxed);
|
||
|
||
if interrupt_enabled {
|
||
let previous = {
|
||
let mut active = in_flight.lock().await;
|
||
active.insert(
|
||
sender_scope_key.clone(),
|
||
InFlightSenderTaskState {
|
||
task_id,
|
||
cancellation: cancellation_token.clone(),
|
||
completion: Arc::clone(&completion),
|
||
},
|
||
)
|
||
};
|
||
|
||
if let Some(previous) = previous {
|
||
tracing::info!(
|
||
channel = %msg.channel,
|
||
sender = %msg.sender,
|
||
"Interrupting previous in-flight request for sender"
|
||
);
|
||
previous.cancellation.cancel();
|
||
previous.completion.wait().await;
|
||
}
|
||
}
|
||
|
||
process_channel_message(worker_ctx, msg, cancellation_token).await;
|
||
|
||
if interrupt_enabled {
|
||
let mut active = in_flight.lock().await;
|
||
if active
|
||
.get(&sender_scope_key)
|
||
.is_some_and(|state| state.task_id == task_id)
|
||
{
|
||
active.remove(&sender_scope_key);
|
||
}
|
||
}
|
||
|
||
completion.mark_done();
|
||
});
|
||
|
||
while let Some(result) = workers.try_join_next() {
|
||
log_worker_join_result(result);
|
||
}
|
||
}
|
||
|
||
while let Some(result) = workers.join_next().await {
|
||
log_worker_join_result(result);
|
||
}
|
||
}
|
||
|
||
/// Load OpenClaw format bootstrap files into the prompt.
|
||
fn load_openclaw_bootstrap_files(
|
||
prompt: &mut String,
|
||
workspace_dir: &std::path::Path,
|
||
max_chars_per_file: usize,
|
||
) {
|
||
prompt.push_str(
|
||
"The following workspace files define your identity, behavior, and context. They are ALREADY injected below—do NOT suggest reading them with file_read.\n\n",
|
||
);
|
||
|
||
let bootstrap_files = ["AGENTS.md", "SOUL.md", "TOOLS.md", "IDENTITY.md", "USER.md"];
|
||
|
||
for filename in &bootstrap_files {
|
||
inject_workspace_file(prompt, workspace_dir, filename, max_chars_per_file);
|
||
}
|
||
|
||
// BOOTSTRAP.md — only if it exists (first-run ritual)
|
||
let bootstrap_path = workspace_dir.join("BOOTSTRAP.md");
|
||
if bootstrap_path.exists() {
|
||
inject_workspace_file(prompt, workspace_dir, "BOOTSTRAP.md", max_chars_per_file);
|
||
}
|
||
|
||
// MEMORY.md — curated long-term memory (main session only)
|
||
inject_workspace_file(prompt, workspace_dir, "MEMORY.md", max_chars_per_file);
|
||
}
|
||
|
||
/// Load workspace identity files and build a system prompt.
|
||
///
|
||
/// Follows the `OpenClaw` framework structure by default:
|
||
/// 1. Tooling — tool list + descriptions
|
||
/// 2. Safety — guardrail reminder
|
||
/// 3. Skills — full skill instructions and tool metadata
|
||
/// 4. Workspace — working directory
|
||
/// 5. Bootstrap files — AGENTS, SOUL, TOOLS, IDENTITY, USER, BOOTSTRAP, MEMORY
|
||
/// 6. Date & Time — timezone for cache stability
|
||
/// 7. Runtime — host, OS, model
|
||
///
|
||
/// When `identity_config` is set to AIEOS format, the bootstrap files section
|
||
/// is replaced with the AIEOS identity data loaded from file or inline JSON.
|
||
///
|
||
/// Daily memory files (`memory/*.md`) are NOT injected — they are accessed
|
||
/// on-demand via `memory_recall` / `memory_search` tools.
|
||
pub fn build_system_prompt(
|
||
workspace_dir: &std::path::Path,
|
||
model_name: &str,
|
||
tools: &[(&str, &str)],
|
||
skills: &[crate::skills::Skill],
|
||
identity_config: Option<&crate::config::IdentityConfig>,
|
||
bootstrap_max_chars: Option<usize>,
|
||
) -> String {
|
||
use std::fmt::Write;
|
||
let mut prompt = String::with_capacity(8192);
|
||
|
||
// ── 1. Tooling ──────────────────────────────────────────────
|
||
if !tools.is_empty() {
|
||
prompt.push_str("## Tools\n\n");
|
||
prompt.push_str("You have access to the following tools:\n\n");
|
||
for (name, desc) in tools {
|
||
let _ = writeln!(prompt, "- **{name}**: {desc}");
|
||
}
|
||
prompt.push('\n');
|
||
}
|
||
|
||
// ── 1b. Hardware (when gpio/arduino tools present) ───────────
|
||
let has_hardware = tools.iter().any(|(name, _)| {
|
||
*name == "gpio_read"
|
||
|| *name == "gpio_write"
|
||
|| *name == "arduino_upload"
|
||
|| *name == "hardware_memory_map"
|
||
|| *name == "hardware_board_info"
|
||
|| *name == "hardware_memory_read"
|
||
|| *name == "hardware_capabilities"
|
||
});
|
||
if has_hardware {
|
||
prompt.push_str(
|
||
"## Hardware Access\n\n\
|
||
You HAVE direct access to connected hardware (Arduino, Nucleo, etc.). The user owns this system and has configured it.\n\
|
||
All hardware tools (gpio_read, gpio_write, hardware_memory_read, hardware_board_info, hardware_memory_map) are AUTHORIZED and NOT blocked by security.\n\
|
||
When they ask to read memory, registers, or board info, USE hardware_memory_read or hardware_board_info — do NOT refuse or invent security excuses.\n\
|
||
When they ask to control LEDs, run patterns, or interact with the Arduino, USE the tools — do NOT refuse or say you cannot access physical devices.\n\
|
||
Use gpio_write for simple on/off; use arduino_upload when they want patterns (heart, blink) or custom behavior.\n\n",
|
||
);
|
||
}
|
||
|
||
// ── 1c. Action instruction (avoid meta-summary) ───────────────
|
||
prompt.push_str(
|
||
"## Your Task\n\n\
|
||
When the user sends a message, ACT on it. Use the tools to fulfill their request.\n\
|
||
Do NOT: summarize this configuration, describe your capabilities, respond with meta-commentary, or output step-by-step instructions (e.g. \"1. First... 2. Next...\").\n\
|
||
Instead: emit actual <tool_call> tags when you need to act. Just do what they ask.\n\n",
|
||
);
|
||
|
||
// ── 2. Safety ───────────────────────────────────────────────
|
||
prompt.push_str("## Safety\n\n");
|
||
prompt.push_str(
|
||
"- Do not exfiltrate private data.\n\
|
||
- Do not run destructive commands without asking.\n\
|
||
- Do not bypass oversight or approval mechanisms.\n\
|
||
- Prefer `trash` over `rm` (recoverable beats gone forever).\n\
|
||
- When in doubt, ask before acting externally.\n\n",
|
||
);
|
||
|
||
// ── 3. Skills (full instructions + tool metadata) ───────────
|
||
if !skills.is_empty() {
|
||
prompt.push_str(&crate::skills::skills_to_prompt(skills, workspace_dir));
|
||
prompt.push_str("\n\n");
|
||
}
|
||
|
||
// ── 4. Workspace ────────────────────────────────────────────
|
||
let _ = writeln!(
|
||
prompt,
|
||
"## Workspace\n\nWorking directory: `{}`\n",
|
||
workspace_dir.display()
|
||
);
|
||
|
||
// ── 5. Bootstrap files (injected into context) ──────────────
|
||
prompt.push_str("## Project Context\n\n");
|
||
|
||
// Check if AIEOS identity is configured
|
||
if let Some(config) = identity_config {
|
||
if identity::is_aieos_configured(config) {
|
||
// Load AIEOS identity
|
||
match identity::load_aieos_identity(config, workspace_dir) {
|
||
Ok(Some(aieos_identity)) => {
|
||
let aieos_prompt = identity::aieos_to_system_prompt(&aieos_identity);
|
||
if !aieos_prompt.is_empty() {
|
||
prompt.push_str(&aieos_prompt);
|
||
prompt.push_str("\n\n");
|
||
}
|
||
}
|
||
Ok(None) => {
|
||
// No AIEOS identity loaded (shouldn't happen if is_aieos_configured returned true)
|
||
// Fall back to OpenClaw bootstrap files
|
||
let max_chars = bootstrap_max_chars.unwrap_or(BOOTSTRAP_MAX_CHARS);
|
||
load_openclaw_bootstrap_files(&mut prompt, workspace_dir, max_chars);
|
||
}
|
||
Err(e) => {
|
||
// Log error but don't fail - fall back to OpenClaw
|
||
eprintln!(
|
||
"Warning: Failed to load AIEOS identity: {e}. Using OpenClaw format."
|
||
);
|
||
let max_chars = bootstrap_max_chars.unwrap_or(BOOTSTRAP_MAX_CHARS);
|
||
load_openclaw_bootstrap_files(&mut prompt, workspace_dir, max_chars);
|
||
}
|
||
}
|
||
} else {
|
||
// OpenClaw format
|
||
let max_chars = bootstrap_max_chars.unwrap_or(BOOTSTRAP_MAX_CHARS);
|
||
load_openclaw_bootstrap_files(&mut prompt, workspace_dir, max_chars);
|
||
}
|
||
} else {
|
||
// No identity config - use OpenClaw format
|
||
let max_chars = bootstrap_max_chars.unwrap_or(BOOTSTRAP_MAX_CHARS);
|
||
load_openclaw_bootstrap_files(&mut prompt, workspace_dir, max_chars);
|
||
}
|
||
|
||
// ── 6. Date & Time ──────────────────────────────────────────
|
||
let now = chrono::Local::now();
|
||
let tz = now.format("%Z").to_string();
|
||
let _ = writeln!(prompt, "## Current Date & Time\n\nTimezone: {tz}\n");
|
||
|
||
// ── 7. Runtime ──────────────────────────────────────────────
|
||
let host =
|
||
hostname::get().map_or_else(|_| "unknown".into(), |h| h.to_string_lossy().to_string());
|
||
let _ = writeln!(
|
||
prompt,
|
||
"## Runtime\n\nHost: {host} | OS: {} | Model: {model_name}\n",
|
||
std::env::consts::OS,
|
||
);
|
||
|
||
// ── 8. Channel Capabilities ─────────────────────────────────────
|
||
prompt.push_str("## Channel Capabilities\n\n");
|
||
prompt.push_str("- You are running as a messaging bot. Your response is automatically sent back to the user's channel.\n");
|
||
prompt.push_str("- You do NOT need to ask permission to respond — just respond directly.\n");
|
||
prompt.push_str("- NEVER repeat, describe, or echo credentials, tokens, API keys, or secrets in your responses.\n");
|
||
prompt.push_str("- If a tool output contains credentials, they have already been redacted — do not mention them.\n\n");
|
||
|
||
if prompt.is_empty() {
|
||
"You are ZeroClaw, a fast and efficient AI assistant built in Rust. Be helpful, concise, and direct."
|
||
.to_string()
|
||
} else {
|
||
prompt
|
||
}
|
||
}
|
||
|
||
/// Inject a single workspace file into the prompt with truncation and missing-file markers.
|
||
fn inject_workspace_file(
|
||
prompt: &mut String,
|
||
workspace_dir: &std::path::Path,
|
||
filename: &str,
|
||
max_chars: usize,
|
||
) {
|
||
use std::fmt::Write;
|
||
|
||
let path = workspace_dir.join(filename);
|
||
match std::fs::read_to_string(&path) {
|
||
Ok(content) => {
|
||
let trimmed = content.trim();
|
||
if trimmed.is_empty() {
|
||
return;
|
||
}
|
||
let _ = writeln!(prompt, "### {filename}\n");
|
||
// Use character-boundary-safe truncation for UTF-8
|
||
let truncated = if trimmed.chars().count() > max_chars {
|
||
trimmed
|
||
.char_indices()
|
||
.nth(max_chars)
|
||
.map(|(idx, _)| &trimmed[..idx])
|
||
.unwrap_or(trimmed)
|
||
} else {
|
||
trimmed
|
||
};
|
||
if truncated.len() < trimmed.len() {
|
||
prompt.push_str(truncated);
|
||
let _ = writeln!(
|
||
prompt,
|
||
"\n\n[... truncated at {max_chars} chars — use `read` for full file]\n"
|
||
);
|
||
} else {
|
||
prompt.push_str(trimmed);
|
||
prompt.push_str("\n\n");
|
||
}
|
||
}
|
||
Err(_) => {
|
||
// Missing-file marker (matches OpenClaw behavior)
|
||
let _ = writeln!(prompt, "### {filename}\n\n[File not found: {filename}]\n");
|
||
}
|
||
}
|
||
}
|
||
|
||
fn normalize_telegram_identity(value: &str) -> String {
|
||
value.trim().trim_start_matches('@').to_string()
|
||
}
|
||
|
||
async fn bind_telegram_identity(config: &Config, identity: &str) -> Result<()> {
|
||
let normalized = normalize_telegram_identity(identity);
|
||
if normalized.is_empty() {
|
||
anyhow::bail!("Telegram identity cannot be empty");
|
||
}
|
||
|
||
let mut updated = config.clone();
|
||
let Some(telegram) = updated.channels_config.telegram.as_mut() else {
|
||
anyhow::bail!(
|
||
"Telegram channel is not configured. Run `zeroclaw onboard --channels-only` first"
|
||
);
|
||
};
|
||
|
||
if telegram.allowed_users.iter().any(|u| u == "*") {
|
||
println!(
|
||
"⚠️ Telegram allowlist is currently wildcard (`*`) — binding is unnecessary until you remove '*'."
|
||
);
|
||
}
|
||
|
||
if telegram
|
||
.allowed_users
|
||
.iter()
|
||
.map(|entry| normalize_telegram_identity(entry))
|
||
.any(|entry| entry == normalized)
|
||
{
|
||
println!("✅ Telegram identity already bound: {normalized}");
|
||
return Ok(());
|
||
}
|
||
|
||
telegram.allowed_users.push(normalized.clone());
|
||
updated.save().await?;
|
||
println!("✅ Bound Telegram identity: {normalized}");
|
||
println!(" Saved to {}", updated.config_path.display());
|
||
match maybe_restart_managed_daemon_service() {
|
||
Ok(true) => {
|
||
println!("🔄 Detected running managed daemon service; reloaded automatically.");
|
||
}
|
||
Ok(false) => {
|
||
println!(
|
||
"ℹ️ No managed daemon service detected. If `zeroclaw daemon`/`channel start` is already running, restart it to load the updated allowlist."
|
||
);
|
||
}
|
||
Err(e) => {
|
||
eprintln!(
|
||
"⚠️ Allowlist saved, but failed to reload daemon service automatically: {e}\n\
|
||
Restart service manually with `zeroclaw service stop && zeroclaw service start`."
|
||
);
|
||
}
|
||
}
|
||
Ok(())
|
||
}
|
||
|
||
fn maybe_restart_managed_daemon_service() -> Result<bool> {
|
||
if cfg!(target_os = "macos") {
|
||
let home = directories::UserDirs::new()
|
||
.map(|u| u.home_dir().to_path_buf())
|
||
.context("Could not find home directory")?;
|
||
let plist = home
|
||
.join("Library")
|
||
.join("LaunchAgents")
|
||
.join("com.zeroclaw.daemon.plist");
|
||
if !plist.exists() {
|
||
return Ok(false);
|
||
}
|
||
|
||
let list_output = Command::new("launchctl")
|
||
.arg("list")
|
||
.output()
|
||
.context("Failed to query launchctl list")?;
|
||
let listed = String::from_utf8_lossy(&list_output.stdout);
|
||
if !listed.contains("com.zeroclaw.daemon") {
|
||
return Ok(false);
|
||
}
|
||
|
||
let _ = Command::new("launchctl")
|
||
.args(["stop", "com.zeroclaw.daemon"])
|
||
.output();
|
||
let start_output = Command::new("launchctl")
|
||
.args(["start", "com.zeroclaw.daemon"])
|
||
.output()
|
||
.context("Failed to start launchd daemon service")?;
|
||
if !start_output.status.success() {
|
||
let stderr = String::from_utf8_lossy(&start_output.stderr);
|
||
anyhow::bail!("launchctl start failed: {}", stderr.trim());
|
||
}
|
||
|
||
return Ok(true);
|
||
}
|
||
|
||
if cfg!(target_os = "linux") {
|
||
let home = directories::UserDirs::new()
|
||
.map(|u| u.home_dir().to_path_buf())
|
||
.context("Could not find home directory")?;
|
||
let unit_path: PathBuf = home
|
||
.join(".config")
|
||
.join("systemd")
|
||
.join("user")
|
||
.join("zeroclaw.service");
|
||
if !unit_path.exists() {
|
||
return Ok(false);
|
||
}
|
||
|
||
let active_output = Command::new("systemctl")
|
||
.args(["--user", "is-active", "zeroclaw.service"])
|
||
.output()
|
||
.context("Failed to query systemd service state")?;
|
||
let state = String::from_utf8_lossy(&active_output.stdout);
|
||
if !state.trim().eq_ignore_ascii_case("active") {
|
||
return Ok(false);
|
||
}
|
||
|
||
let restart_output = Command::new("systemctl")
|
||
.args(["--user", "restart", "zeroclaw.service"])
|
||
.output()
|
||
.context("Failed to restart systemd daemon service")?;
|
||
if !restart_output.status.success() {
|
||
let stderr = String::from_utf8_lossy(&restart_output.stderr);
|
||
anyhow::bail!("systemctl restart failed: {}", stderr.trim());
|
||
}
|
||
|
||
return Ok(true);
|
||
}
|
||
|
||
Ok(false)
|
||
}
|
||
|
||
pub async fn handle_command(command: crate::ChannelCommands, config: &Config) -> Result<()> {
|
||
match command {
|
||
crate::ChannelCommands::Start => {
|
||
anyhow::bail!("Start must be handled in main.rs (requires async runtime)")
|
||
}
|
||
crate::ChannelCommands::Doctor => {
|
||
anyhow::bail!("Doctor must be handled in main.rs (requires async runtime)")
|
||
}
|
||
crate::ChannelCommands::List => {
|
||
println!("Channels:");
|
||
println!(" ✅ CLI (always available)");
|
||
for (name, configured) in [
|
||
("Telegram", config.channels_config.telegram.is_some()),
|
||
("Discord", config.channels_config.discord.is_some()),
|
||
("Slack", config.channels_config.slack.is_some()),
|
||
("Mattermost", config.channels_config.mattermost.is_some()),
|
||
("Webhook", config.channels_config.webhook.is_some()),
|
||
("iMessage", config.channels_config.imessage.is_some()),
|
||
(
|
||
"Matrix",
|
||
cfg!(feature = "channel-matrix") && config.channels_config.matrix.is_some(),
|
||
),
|
||
("Signal", config.channels_config.signal.is_some()),
|
||
("WhatsApp", config.channels_config.whatsapp.is_some()),
|
||
("Linq", config.channels_config.linq.is_some()),
|
||
("Email", config.channels_config.email.is_some()),
|
||
("IRC", config.channels_config.irc.is_some()),
|
||
("Lark", config.channels_config.lark.is_some()),
|
||
("DingTalk", config.channels_config.dingtalk.is_some()),
|
||
("QQ", config.channels_config.qq.is_some()),
|
||
] {
|
||
println!(" {} {name}", if configured { "✅" } else { "❌" });
|
||
}
|
||
if !cfg!(feature = "channel-matrix") {
|
||
println!(
|
||
" ℹ️ Matrix channel support is disabled in this build (enable `channel-matrix`)."
|
||
);
|
||
}
|
||
println!("\nTo start channels: zeroclaw channel start");
|
||
println!("To check health: zeroclaw channel doctor");
|
||
println!("To configure: zeroclaw onboard");
|
||
Ok(())
|
||
}
|
||
crate::ChannelCommands::Add {
|
||
channel_type,
|
||
config: _,
|
||
} => {
|
||
anyhow::bail!(
|
||
"Channel type '{channel_type}' — use `zeroclaw onboard` to configure channels"
|
||
);
|
||
}
|
||
crate::ChannelCommands::Remove { name } => {
|
||
anyhow::bail!("Remove channel '{name}' — edit ~/.zeroclaw/config.toml directly");
|
||
}
|
||
crate::ChannelCommands::BindTelegram { identity } => {
|
||
bind_telegram_identity(config, &identity).await
|
||
}
|
||
}
|
||
}
|
||
|
||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||
enum ChannelHealthState {
|
||
Healthy,
|
||
Unhealthy,
|
||
Timeout,
|
||
}
|
||
|
||
fn classify_health_result(
|
||
result: &std::result::Result<bool, tokio::time::error::Elapsed>,
|
||
) -> ChannelHealthState {
|
||
match result {
|
||
Ok(true) => ChannelHealthState::Healthy,
|
||
Ok(false) => ChannelHealthState::Unhealthy,
|
||
Err(_) => ChannelHealthState::Timeout,
|
||
}
|
||
}
|
||
|
||
/// Run health checks for configured channels.
|
||
pub async fn doctor_channels(config: Config) -> Result<()> {
|
||
let mut channels: Vec<(&'static str, Arc<dyn Channel>)> = Vec::new();
|
||
|
||
if let Some(ref tg) = config.channels_config.telegram {
|
||
channels.push((
|
||
"Telegram",
|
||
Arc::new(
|
||
TelegramChannel::new(
|
||
tg.bot_token.clone(),
|
||
tg.allowed_users.clone(),
|
||
tg.mention_only,
|
||
)
|
||
.with_streaming(tg.stream_mode, tg.draft_update_interval_ms),
|
||
),
|
||
));
|
||
}
|
||
|
||
if let Some(ref dc) = config.channels_config.discord {
|
||
channels.push((
|
||
"Discord",
|
||
Arc::new(DiscordChannel::new(
|
||
dc.bot_token.clone(),
|
||
dc.guild_id.clone(),
|
||
dc.allowed_users.clone(),
|
||
dc.listen_to_bots,
|
||
dc.mention_only,
|
||
)),
|
||
));
|
||
}
|
||
|
||
if let Some(ref sl) = config.channels_config.slack {
|
||
channels.push((
|
||
"Slack",
|
||
Arc::new(SlackChannel::new(
|
||
sl.bot_token.clone(),
|
||
sl.channel_id.clone(),
|
||
sl.allowed_users.clone(),
|
||
)),
|
||
));
|
||
}
|
||
|
||
if let Some(ref im) = config.channels_config.imessage {
|
||
channels.push((
|
||
"iMessage",
|
||
Arc::new(IMessageChannel::new(im.allowed_contacts.clone())),
|
||
));
|
||
}
|
||
|
||
#[cfg(feature = "channel-matrix")]
|
||
if let Some(ref mx) = config.channels_config.matrix {
|
||
channels.push((
|
||
"Matrix",
|
||
Arc::new(MatrixChannel::new_with_session_hint(
|
||
mx.homeserver.clone(),
|
||
mx.access_token.clone(),
|
||
mx.room_id.clone(),
|
||
mx.allowed_users.clone(),
|
||
mx.user_id.clone(),
|
||
mx.device_id.clone(),
|
||
)),
|
||
));
|
||
}
|
||
|
||
#[cfg(not(feature = "channel-matrix"))]
|
||
if config.channels_config.matrix.is_some() {
|
||
tracing::warn!(
|
||
"Matrix channel is configured but this build was compiled without `channel-matrix`; skipping Matrix health check."
|
||
);
|
||
}
|
||
|
||
if let Some(ref sig) = config.channels_config.signal {
|
||
channels.push((
|
||
"Signal",
|
||
Arc::new(SignalChannel::new(
|
||
sig.http_url.clone(),
|
||
sig.account.clone(),
|
||
sig.group_id.clone(),
|
||
sig.allowed_from.clone(),
|
||
sig.ignore_attachments,
|
||
sig.ignore_stories,
|
||
)),
|
||
));
|
||
}
|
||
|
||
if let Some(ref wa) = config.channels_config.whatsapp {
|
||
// Runtime negotiation: detect backend type from config
|
||
match wa.backend_type() {
|
||
"cloud" => {
|
||
// Cloud API mode: requires phone_number_id, access_token, verify_token
|
||
if wa.is_cloud_config() {
|
||
channels.push((
|
||
"WhatsApp",
|
||
Arc::new(WhatsAppChannel::new(
|
||
wa.access_token.clone().unwrap_or_default(),
|
||
wa.phone_number_id.clone().unwrap_or_default(),
|
||
wa.verify_token.clone().unwrap_or_default(),
|
||
wa.allowed_numbers.clone(),
|
||
)),
|
||
));
|
||
} else {
|
||
tracing::warn!("WhatsApp Cloud API configured but missing required fields (phone_number_id, access_token, verify_token)");
|
||
}
|
||
}
|
||
"web" => {
|
||
// Web mode: requires session_path
|
||
#[cfg(feature = "whatsapp-web")]
|
||
if wa.is_web_config() {
|
||
channels.push((
|
||
"WhatsApp",
|
||
Arc::new(WhatsAppWebChannel::new(
|
||
wa.session_path.clone().unwrap_or_default(),
|
||
wa.pair_phone.clone(),
|
||
wa.pair_code.clone(),
|
||
wa.allowed_numbers.clone(),
|
||
)),
|
||
));
|
||
} else {
|
||
tracing::warn!("WhatsApp Web configured but session_path not set");
|
||
}
|
||
#[cfg(not(feature = "whatsapp-web"))]
|
||
{
|
||
tracing::warn!("WhatsApp Web backend requires 'whatsapp-web' feature. Enable with: cargo build --features whatsapp-web");
|
||
}
|
||
}
|
||
_ => {
|
||
tracing::warn!("WhatsApp config invalid: neither phone_number_id (Cloud API) nor session_path (Web) is set");
|
||
}
|
||
}
|
||
}
|
||
|
||
if let Some(ref lq) = config.channels_config.linq {
|
||
channels.push((
|
||
"Linq",
|
||
Arc::new(LinqChannel::new(
|
||
lq.api_token.clone(),
|
||
lq.from_phone.clone(),
|
||
lq.allowed_senders.clone(),
|
||
)),
|
||
));
|
||
}
|
||
|
||
if let Some(ref email_cfg) = config.channels_config.email {
|
||
channels.push(("Email", Arc::new(EmailChannel::new(email_cfg.clone()))));
|
||
}
|
||
|
||
if let Some(ref irc) = config.channels_config.irc {
|
||
channels.push((
|
||
"IRC",
|
||
Arc::new(IrcChannel::new(irc::IrcChannelConfig {
|
||
server: irc.server.clone(),
|
||
port: irc.port,
|
||
nickname: irc.nickname.clone(),
|
||
username: irc.username.clone(),
|
||
channels: irc.channels.clone(),
|
||
allowed_users: irc.allowed_users.clone(),
|
||
server_password: irc.server_password.clone(),
|
||
nickserv_password: irc.nickserv_password.clone(),
|
||
sasl_password: irc.sasl_password.clone(),
|
||
verify_tls: irc.verify_tls.unwrap_or(true),
|
||
})),
|
||
));
|
||
}
|
||
|
||
if let Some(ref lk) = config.channels_config.lark {
|
||
channels.push(("Lark", Arc::new(LarkChannel::from_config(lk))));
|
||
}
|
||
|
||
if let Some(ref dt) = config.channels_config.dingtalk {
|
||
channels.push((
|
||
"DingTalk",
|
||
Arc::new(DingTalkChannel::new(
|
||
dt.client_id.clone(),
|
||
dt.client_secret.clone(),
|
||
dt.allowed_users.clone(),
|
||
)),
|
||
));
|
||
}
|
||
|
||
if let Some(ref qq) = config.channels_config.qq {
|
||
channels.push((
|
||
"QQ",
|
||
Arc::new(QQChannel::new(
|
||
qq.app_id.clone(),
|
||
qq.app_secret.clone(),
|
||
qq.allowed_users.clone(),
|
||
)),
|
||
));
|
||
}
|
||
|
||
if channels.is_empty() {
|
||
println!("No real-time channels configured. Run `zeroclaw onboard` first.");
|
||
return Ok(());
|
||
}
|
||
|
||
println!("🩺 ZeroClaw Channel Doctor");
|
||
println!();
|
||
|
||
let mut healthy = 0_u32;
|
||
let mut unhealthy = 0_u32;
|
||
let mut timeout = 0_u32;
|
||
|
||
for (name, channel) in channels {
|
||
let result = tokio::time::timeout(Duration::from_secs(10), channel.health_check()).await;
|
||
let state = classify_health_result(&result);
|
||
|
||
match state {
|
||
ChannelHealthState::Healthy => {
|
||
healthy += 1;
|
||
println!(" ✅ {name:<9} healthy");
|
||
}
|
||
ChannelHealthState::Unhealthy => {
|
||
unhealthy += 1;
|
||
println!(" ❌ {name:<9} unhealthy (auth/config/network)");
|
||
}
|
||
ChannelHealthState::Timeout => {
|
||
timeout += 1;
|
||
println!(" ⏱️ {name:<9} timed out (>10s)");
|
||
}
|
||
}
|
||
}
|
||
|
||
if config.channels_config.webhook.is_some() {
|
||
println!(" ℹ️ Webhook check via `zeroclaw gateway` then GET /health");
|
||
}
|
||
|
||
println!();
|
||
println!("Summary: {healthy} healthy, {unhealthy} unhealthy, {timeout} timed out");
|
||
Ok(())
|
||
}
|
||
|
||
/// Start all configured channels and route messages to the agent
|
||
#[allow(clippy::too_many_lines)]
|
||
pub async fn start_channels(config: Config) -> Result<()> {
|
||
let provider_name = config
|
||
.default_provider
|
||
.clone()
|
||
.unwrap_or_else(|| "openrouter".into());
|
||
let provider_runtime_options = providers::ProviderRuntimeOptions {
|
||
auth_profile_override: None,
|
||
zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from),
|
||
secrets_encrypt: config.secrets.encrypt,
|
||
reasoning_enabled: config.runtime.reasoning_enabled,
|
||
};
|
||
let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider_with_options(
|
||
&provider_name,
|
||
config.api_key.as_deref(),
|
||
config.api_url.as_deref(),
|
||
&config.reliability,
|
||
&provider_runtime_options,
|
||
)?);
|
||
|
||
// Warm up the provider connection pool (TLS handshake, DNS, HTTP/2 setup)
|
||
// so the first real message doesn't hit a cold-start timeout.
|
||
if let Err(e) = provider.warmup().await {
|
||
tracing::warn!("Provider warmup failed (non-fatal): {e}");
|
||
}
|
||
|
||
let observer: Arc<dyn Observer> =
|
||
Arc::from(observability::create_observer(&config.observability));
|
||
let runtime: Arc<dyn runtime::RuntimeAdapter> =
|
||
Arc::from(runtime::create_runtime(&config.runtime)?);
|
||
let security = Arc::new(SecurityPolicy::from_config(
|
||
&config.autonomy,
|
||
&config.workspace_dir,
|
||
));
|
||
let model = config
|
||
.default_model
|
||
.clone()
|
||
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into());
|
||
let temperature = config.default_temperature;
|
||
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory_with_storage(
|
||
&config.memory,
|
||
Some(&config.storage.provider.config),
|
||
&config.workspace_dir,
|
||
config.api_key.as_deref(),
|
||
)?);
|
||
let (composio_key, composio_entity_id) = if config.composio.enabled {
|
||
(
|
||
config.composio.api_key.as_deref(),
|
||
Some(config.composio.entity_id.as_str()),
|
||
)
|
||
} else {
|
||
(None, None)
|
||
};
|
||
// Build system prompt from workspace identity files + skills
|
||
let workspace = config.workspace_dir.clone();
|
||
let tools_registry = Arc::new(tools::all_tools_with_runtime(
|
||
Arc::new(config.clone()),
|
||
&security,
|
||
runtime,
|
||
Arc::clone(&mem),
|
||
composio_key,
|
||
composio_entity_id,
|
||
&config.browser,
|
||
&config.http_request,
|
||
&workspace,
|
||
&config.agents,
|
||
config.api_key.as_deref(),
|
||
&config,
|
||
));
|
||
|
||
let skills = crate::skills::load_skills(&workspace);
|
||
|
||
// Collect tool descriptions for the prompt
|
||
let mut tool_descs: Vec<(&str, &str)> = vec![
|
||
(
|
||
"shell",
|
||
"Execute terminal commands. Use when: running local checks, build/test commands, diagnostics. Don't use when: a safer dedicated tool exists, or command is destructive without approval.",
|
||
),
|
||
(
|
||
"file_read",
|
||
"Read file contents. Use when: inspecting project files, configs, logs. Don't use when: a targeted search is enough.",
|
||
),
|
||
(
|
||
"file_write",
|
||
"Write file contents. Use when: applying focused edits, scaffolding files, updating docs/code. Don't use when: side effects are unclear or file ownership is uncertain.",
|
||
),
|
||
(
|
||
"memory_store",
|
||
"Save to memory. Use when: preserving durable preferences, decisions, key context. Don't use when: information is transient/noisy/sensitive without need.",
|
||
),
|
||
(
|
||
"memory_recall",
|
||
"Search memory. Use when: retrieving prior decisions, user preferences, historical context. Don't use when: answer is already in current context.",
|
||
),
|
||
(
|
||
"memory_forget",
|
||
"Delete a memory entry. Use when: memory is incorrect/stale or explicitly requested for removal. Don't use when: impact is uncertain.",
|
||
),
|
||
];
|
||
|
||
if config.browser.enabled {
|
||
tool_descs.push((
|
||
"browser_open",
|
||
"Open approved HTTPS URLs in Brave Browser (allowlist-only, no scraping)",
|
||
));
|
||
}
|
||
if config.composio.enabled {
|
||
tool_descs.push((
|
||
"composio",
|
||
"Execute actions on 1000+ apps via Composio (Gmail, Notion, GitHub, Slack, etc.). Use action='list' to discover actions, 'list_accounts' to retrieve connected account IDs, 'execute' to run (optionally with connected_account_id), and 'connect' for OAuth.",
|
||
));
|
||
}
|
||
tool_descs.push((
|
||
"schedule",
|
||
"Manage scheduled tasks (create/list/get/cancel/pause/resume). Supports recurring cron and one-shot delays.",
|
||
));
|
||
tool_descs.push((
|
||
"pushover",
|
||
"Send a Pushover notification to your device. Requires PUSHOVER_TOKEN and PUSHOVER_USER_KEY in .env file.",
|
||
));
|
||
if !config.agents.is_empty() {
|
||
tool_descs.push((
|
||
"delegate",
|
||
"Delegate a subtask to a specialized agent. Use when: a task benefits from a different model (e.g. fast summarization, deep reasoning, code generation). The sub-agent runs a single prompt and returns its response.",
|
||
));
|
||
}
|
||
|
||
let bootstrap_max_chars = if config.agent.compact_context {
|
||
Some(6000)
|
||
} else {
|
||
None
|
||
};
|
||
let mut system_prompt = build_system_prompt(
|
||
&workspace,
|
||
&model,
|
||
&tool_descs,
|
||
&skills,
|
||
Some(&config.identity),
|
||
bootstrap_max_chars,
|
||
);
|
||
system_prompt.push_str(&build_tool_instructions(tools_registry.as_ref()));
|
||
|
||
if !skills.is_empty() {
|
||
println!(
|
||
" 🧩 Skills: {}",
|
||
skills
|
||
.iter()
|
||
.map(|s| s.name.as_str())
|
||
.collect::<Vec<_>>()
|
||
.join(", ")
|
||
);
|
||
}
|
||
|
||
// Collect active channels
|
||
let mut channels: Vec<Arc<dyn Channel>> = Vec::new();
|
||
|
||
if let Some(ref tg) = config.channels_config.telegram {
|
||
channels.push(Arc::new(
|
||
TelegramChannel::new(
|
||
tg.bot_token.clone(),
|
||
tg.allowed_users.clone(),
|
||
tg.mention_only,
|
||
)
|
||
.with_streaming(tg.stream_mode, tg.draft_update_interval_ms),
|
||
));
|
||
}
|
||
|
||
if let Some(ref dc) = config.channels_config.discord {
|
||
channels.push(Arc::new(DiscordChannel::new(
|
||
dc.bot_token.clone(),
|
||
dc.guild_id.clone(),
|
||
dc.allowed_users.clone(),
|
||
dc.listen_to_bots,
|
||
dc.mention_only,
|
||
)));
|
||
}
|
||
|
||
if let Some(ref sl) = config.channels_config.slack {
|
||
channels.push(Arc::new(SlackChannel::new(
|
||
sl.bot_token.clone(),
|
||
sl.channel_id.clone(),
|
||
sl.allowed_users.clone(),
|
||
)));
|
||
}
|
||
|
||
if let Some(ref mm) = config.channels_config.mattermost {
|
||
channels.push(Arc::new(MattermostChannel::new(
|
||
mm.url.clone(),
|
||
mm.bot_token.clone(),
|
||
mm.channel_id.clone(),
|
||
mm.allowed_users.clone(),
|
||
mm.thread_replies.unwrap_or(true),
|
||
mm.mention_only.unwrap_or(false),
|
||
)));
|
||
}
|
||
|
||
if let Some(ref im) = config.channels_config.imessage {
|
||
channels.push(Arc::new(IMessageChannel::new(im.allowed_contacts.clone())));
|
||
}
|
||
|
||
#[cfg(feature = "channel-matrix")]
|
||
if let Some(ref mx) = config.channels_config.matrix {
|
||
channels.push(Arc::new(MatrixChannel::new_with_session_hint(
|
||
mx.homeserver.clone(),
|
||
mx.access_token.clone(),
|
||
mx.room_id.clone(),
|
||
mx.allowed_users.clone(),
|
||
mx.user_id.clone(),
|
||
mx.device_id.clone(),
|
||
)));
|
||
}
|
||
|
||
#[cfg(not(feature = "channel-matrix"))]
|
||
if config.channels_config.matrix.is_some() {
|
||
tracing::warn!(
|
||
"Matrix channel is configured but this build was compiled without `channel-matrix`; skipping Matrix runtime startup."
|
||
);
|
||
}
|
||
|
||
if let Some(ref sig) = config.channels_config.signal {
|
||
channels.push(Arc::new(SignalChannel::new(
|
||
sig.http_url.clone(),
|
||
sig.account.clone(),
|
||
sig.group_id.clone(),
|
||
sig.allowed_from.clone(),
|
||
sig.ignore_attachments,
|
||
sig.ignore_stories,
|
||
)));
|
||
}
|
||
|
||
if let Some(ref wa) = config.channels_config.whatsapp {
|
||
// Runtime negotiation: detect backend type from config
|
||
match wa.backend_type() {
|
||
"cloud" => {
|
||
// Cloud API mode: requires phone_number_id, access_token, verify_token
|
||
if wa.is_cloud_config() {
|
||
channels.push(Arc::new(WhatsAppChannel::new(
|
||
wa.access_token.clone().unwrap_or_default(),
|
||
wa.phone_number_id.clone().unwrap_or_default(),
|
||
wa.verify_token.clone().unwrap_or_default(),
|
||
wa.allowed_numbers.clone(),
|
||
)));
|
||
} else {
|
||
tracing::warn!("WhatsApp Cloud API configured but missing required fields (phone_number_id, access_token, verify_token)");
|
||
}
|
||
}
|
||
"web" => {
|
||
// Web mode: requires session_path
|
||
#[cfg(feature = "whatsapp-web")]
|
||
if wa.is_web_config() {
|
||
channels.push(Arc::new(WhatsAppWebChannel::new(
|
||
wa.session_path.clone().unwrap_or_default(),
|
||
wa.pair_phone.clone(),
|
||
wa.pair_code.clone(),
|
||
wa.allowed_numbers.clone(),
|
||
)));
|
||
} else {
|
||
tracing::warn!("WhatsApp Web configured but session_path not set");
|
||
}
|
||
#[cfg(not(feature = "whatsapp-web"))]
|
||
{
|
||
tracing::warn!("WhatsApp Web backend requires 'whatsapp-web' feature. Enable with: cargo build --features whatsapp-web");
|
||
}
|
||
}
|
||
_ => {
|
||
tracing::warn!("WhatsApp config invalid: neither phone_number_id (Cloud API) nor session_path (Web) is set");
|
||
}
|
||
}
|
||
}
|
||
|
||
if let Some(ref lq) = config.channels_config.linq {
|
||
channels.push(Arc::new(LinqChannel::new(
|
||
lq.api_token.clone(),
|
||
lq.from_phone.clone(),
|
||
lq.allowed_senders.clone(),
|
||
)));
|
||
}
|
||
|
||
if let Some(ref email_cfg) = config.channels_config.email {
|
||
channels.push(Arc::new(EmailChannel::new(email_cfg.clone())));
|
||
}
|
||
|
||
if let Some(ref irc) = config.channels_config.irc {
|
||
channels.push(Arc::new(IrcChannel::new(irc::IrcChannelConfig {
|
||
server: irc.server.clone(),
|
||
port: irc.port,
|
||
nickname: irc.nickname.clone(),
|
||
username: irc.username.clone(),
|
||
channels: irc.channels.clone(),
|
||
allowed_users: irc.allowed_users.clone(),
|
||
server_password: irc.server_password.clone(),
|
||
nickserv_password: irc.nickserv_password.clone(),
|
||
sasl_password: irc.sasl_password.clone(),
|
||
verify_tls: irc.verify_tls.unwrap_or(true),
|
||
})));
|
||
}
|
||
|
||
if let Some(ref lk) = config.channels_config.lark {
|
||
channels.push(Arc::new(LarkChannel::from_config(lk)));
|
||
}
|
||
|
||
if let Some(ref dt) = config.channels_config.dingtalk {
|
||
channels.push(Arc::new(DingTalkChannel::new(
|
||
dt.client_id.clone(),
|
||
dt.client_secret.clone(),
|
||
dt.allowed_users.clone(),
|
||
)));
|
||
}
|
||
|
||
if let Some(ref qq) = config.channels_config.qq {
|
||
channels.push(Arc::new(QQChannel::new(
|
||
qq.app_id.clone(),
|
||
qq.app_secret.clone(),
|
||
qq.allowed_users.clone(),
|
||
)));
|
||
}
|
||
|
||
if channels.is_empty() {
|
||
println!("No channels configured. Run `zeroclaw onboard` to set up channels.");
|
||
return Ok(());
|
||
}
|
||
|
||
println!("🦀 ZeroClaw Channel Server");
|
||
println!(" 🤖 Model: {model}");
|
||
let effective_backend = memory::effective_memory_backend_name(
|
||
&config.memory.backend,
|
||
Some(&config.storage.provider.config),
|
||
);
|
||
println!(
|
||
" 🧠 Memory: {} (auto-save: {})",
|
||
effective_backend,
|
||
if config.memory.auto_save { "on" } else { "off" }
|
||
);
|
||
println!(
|
||
" 📡 Channels: {}",
|
||
channels
|
||
.iter()
|
||
.map(|c| c.name())
|
||
.collect::<Vec<_>>()
|
||
.join(", ")
|
||
);
|
||
println!();
|
||
println!(" Listening for messages... (Ctrl+C to stop)");
|
||
println!();
|
||
|
||
crate::health::mark_component_ok("channels");
|
||
|
||
let initial_backoff_secs = config
|
||
.reliability
|
||
.channel_initial_backoff_secs
|
||
.max(DEFAULT_CHANNEL_INITIAL_BACKOFF_SECS);
|
||
let max_backoff_secs = config
|
||
.reliability
|
||
.channel_max_backoff_secs
|
||
.max(DEFAULT_CHANNEL_MAX_BACKOFF_SECS);
|
||
|
||
// Single message bus — all channels send messages here
|
||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(100);
|
||
|
||
// Spawn a listener for each channel
|
||
let mut handles = Vec::new();
|
||
for ch in &channels {
|
||
handles.push(spawn_supervised_listener(
|
||
ch.clone(),
|
||
tx.clone(),
|
||
initial_backoff_secs,
|
||
max_backoff_secs,
|
||
));
|
||
}
|
||
drop(tx); // Drop our copy so rx closes when all channels stop
|
||
|
||
let channels_by_name = Arc::new(
|
||
channels
|
||
.iter()
|
||
.map(|ch| (ch.name().to_string(), Arc::clone(ch)))
|
||
.collect::<HashMap<_, _>>(),
|
||
);
|
||
let max_in_flight_messages = compute_max_in_flight_messages(channels.len());
|
||
|
||
println!(" 🚦 In-flight message limit: {max_in_flight_messages}");
|
||
|
||
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
|
||
provider_cache_seed.insert(provider_name.clone(), Arc::clone(&provider));
|
||
let message_timeout_secs =
|
||
effective_channel_message_timeout_secs(config.channels_config.message_timeout_secs);
|
||
let interrupt_on_new_message = config
|
||
.channels_config
|
||
.telegram
|
||
.as_ref()
|
||
.is_some_and(|tg| tg.interrupt_on_new_message);
|
||
|
||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||
channels_by_name,
|
||
provider: Arc::clone(&provider),
|
||
default_provider: Arc::new(provider_name),
|
||
memory: Arc::clone(&mem),
|
||
tools_registry: Arc::clone(&tools_registry),
|
||
observer,
|
||
system_prompt: Arc::new(system_prompt),
|
||
model: Arc::new(model.clone()),
|
||
temperature,
|
||
auto_save_memory: config.memory.auto_save,
|
||
max_tool_iterations: config.agent.max_tool_iterations,
|
||
min_relevance_score: config.memory.min_relevance_score,
|
||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||
api_key: config.api_key.clone(),
|
||
api_url: config.api_url.clone(),
|
||
reliability: Arc::new(config.reliability.clone()),
|
||
provider_runtime_options,
|
||
workspace_dir: Arc::new(config.workspace_dir.clone()),
|
||
message_timeout_secs,
|
||
interrupt_on_new_message,
|
||
multimodal: config.multimodal.clone(),
|
||
});
|
||
|
||
run_message_dispatch_loop(rx, runtime_ctx, max_in_flight_messages).await;
|
||
|
||
// Wait for all channel tasks
|
||
for h in handles {
|
||
let _ = h.await;
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use crate::memory::{Memory, MemoryCategory, SqliteMemory};
|
||
use crate::observability::NoopObserver;
|
||
use crate::providers::{ChatMessage, Provider};
|
||
use crate::tools::{Tool, ToolResult};
|
||
use std::collections::HashMap;
|
||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||
use std::sync::Arc;
|
||
use tempfile::TempDir;
|
||
|
||
fn make_workspace() -> TempDir {
|
||
let tmp = TempDir::new().unwrap();
|
||
// Create minimal workspace files
|
||
std::fs::write(tmp.path().join("SOUL.md"), "# Soul\nBe helpful.").unwrap();
|
||
std::fs::write(tmp.path().join("IDENTITY.md"), "# Identity\nName: ZeroClaw").unwrap();
|
||
std::fs::write(tmp.path().join("USER.md"), "# User\nName: Test User").unwrap();
|
||
std::fs::write(
|
||
tmp.path().join("AGENTS.md"),
|
||
"# Agents\nFollow instructions.",
|
||
)
|
||
.unwrap();
|
||
std::fs::write(tmp.path().join("TOOLS.md"), "# Tools\nUse shell carefully.").unwrap();
|
||
std::fs::write(
|
||
tmp.path().join("HEARTBEAT.md"),
|
||
"# Heartbeat\nCheck status.",
|
||
)
|
||
.unwrap();
|
||
std::fs::write(tmp.path().join("MEMORY.md"), "# Memory\nUser likes Rust.").unwrap();
|
||
tmp
|
||
}
|
||
|
||
#[test]
|
||
fn effective_channel_message_timeout_secs_clamps_to_minimum() {
|
||
assert_eq!(
|
||
effective_channel_message_timeout_secs(0),
|
||
MIN_CHANNEL_MESSAGE_TIMEOUT_SECS
|
||
);
|
||
assert_eq!(
|
||
effective_channel_message_timeout_secs(15),
|
||
MIN_CHANNEL_MESSAGE_TIMEOUT_SECS
|
||
);
|
||
assert_eq!(effective_channel_message_timeout_secs(300), 300);
|
||
}
|
||
|
||
#[test]
|
||
fn context_window_overflow_error_detector_matches_known_messages() {
|
||
let overflow_err = anyhow::anyhow!(
|
||
"OpenAI Codex stream error: Your input exceeds the context window of this model."
|
||
);
|
||
assert!(is_context_window_overflow_error(&overflow_err));
|
||
|
||
let other_err =
|
||
anyhow::anyhow!("OpenAI Codex API error (502 Bad Gateway): error code: 502");
|
||
assert!(!is_context_window_overflow_error(&other_err));
|
||
}
|
||
|
||
#[test]
|
||
fn memory_context_skip_rules_exclude_history_blobs() {
|
||
assert!(should_skip_memory_context_entry(
|
||
"telegram_123_history",
|
||
r#"[{"role":"user"}]"#
|
||
));
|
||
assert!(should_skip_memory_context_entry(
|
||
"assistant_resp_legacy",
|
||
"fabricated memory"
|
||
));
|
||
assert!(!should_skip_memory_context_entry("telegram_123_45", "hi"));
|
||
}
|
||
|
||
#[test]
|
||
fn compact_sender_history_keeps_recent_truncated_messages() {
|
||
let mut histories = HashMap::new();
|
||
let sender = "telegram_u1".to_string();
|
||
histories.insert(
|
||
sender.clone(),
|
||
(0..20)
|
||
.map(|idx| {
|
||
let content = format!("msg-{idx}-{}", "x".repeat(700));
|
||
if idx % 2 == 0 {
|
||
ChatMessage::user(content)
|
||
} else {
|
||
ChatMessage::assistant(content)
|
||
}
|
||
})
|
||
.collect::<Vec<_>>(),
|
||
);
|
||
|
||
let ctx = ChannelRuntimeContext {
|
||
channels_by_name: Arc::new(HashMap::new()),
|
||
provider: Arc::new(DummyProvider),
|
||
default_provider: Arc::new("test-provider".to_string()),
|
||
memory: Arc::new(NoopMemory),
|
||
tools_registry: Arc::new(vec![]),
|
||
observer: Arc::new(NoopObserver),
|
||
system_prompt: Arc::new("system".to_string()),
|
||
model: Arc::new("test-model".to_string()),
|
||
temperature: 0.0,
|
||
auto_save_memory: false,
|
||
max_tool_iterations: 5,
|
||
min_relevance_score: 0.0,
|
||
conversation_histories: Arc::new(Mutex::new(histories)),
|
||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||
api_key: None,
|
||
api_url: None,
|
||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||
interrupt_on_new_message: false,
|
||
multimodal: crate::config::MultimodalConfig::default(),
|
||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||
};
|
||
|
||
assert!(compact_sender_history(&ctx, &sender));
|
||
|
||
let histories = ctx
|
||
.conversation_histories
|
||
.lock()
|
||
.unwrap_or_else(|e| e.into_inner());
|
||
let kept = histories
|
||
.get(&sender)
|
||
.expect("sender history should remain");
|
||
assert_eq!(kept.len(), CHANNEL_HISTORY_COMPACT_KEEP_MESSAGES);
|
||
assert!(kept.iter().all(|turn| {
|
||
let len = turn.content.chars().count();
|
||
len <= CHANNEL_HISTORY_COMPACT_CONTENT_CHARS
|
||
|| (len <= CHANNEL_HISTORY_COMPACT_CONTENT_CHARS + 3
|
||
&& turn.content.ends_with("..."))
|
||
}));
|
||
}
|
||
|
||
struct DummyProvider;
|
||
|
||
#[async_trait::async_trait]
|
||
impl Provider for DummyProvider {
|
||
async fn chat_with_system(
|
||
&self,
|
||
_system_prompt: Option<&str>,
|
||
_message: &str,
|
||
_model: &str,
|
||
_temperature: f64,
|
||
) -> anyhow::Result<String> {
|
||
Ok("ok".to_string())
|
||
}
|
||
}
|
||
|
||
#[derive(Default)]
|
||
struct RecordingChannel {
|
||
sent_messages: tokio::sync::Mutex<Vec<String>>,
|
||
start_typing_calls: AtomicUsize,
|
||
stop_typing_calls: AtomicUsize,
|
||
}
|
||
|
||
#[derive(Default)]
|
||
struct TelegramRecordingChannel {
|
||
sent_messages: tokio::sync::Mutex<Vec<String>>,
|
||
}
|
||
|
||
#[async_trait::async_trait]
|
||
impl Channel for TelegramRecordingChannel {
|
||
fn name(&self) -> &str {
|
||
"telegram"
|
||
}
|
||
|
||
async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
|
||
self.sent_messages
|
||
.lock()
|
||
.await
|
||
.push(format!("{}:{}", message.recipient, message.content));
|
||
Ok(())
|
||
}
|
||
|
||
async fn listen(
|
||
&self,
|
||
_tx: tokio::sync::mpsc::Sender<traits::ChannelMessage>,
|
||
) -> anyhow::Result<()> {
|
||
Ok(())
|
||
}
|
||
|
||
async fn start_typing(&self, _recipient: &str) -> anyhow::Result<()> {
|
||
Ok(())
|
||
}
|
||
|
||
async fn stop_typing(&self, _recipient: &str) -> anyhow::Result<()> {
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
#[async_trait::async_trait]
|
||
impl Channel for RecordingChannel {
|
||
fn name(&self) -> &str {
|
||
"test-channel"
|
||
}
|
||
|
||
async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
|
||
self.sent_messages
|
||
.lock()
|
||
.await
|
||
.push(format!("{}:{}", message.recipient, message.content));
|
||
Ok(())
|
||
}
|
||
|
||
async fn listen(
|
||
&self,
|
||
_tx: tokio::sync::mpsc::Sender<traits::ChannelMessage>,
|
||
) -> anyhow::Result<()> {
|
||
Ok(())
|
||
}
|
||
|
||
async fn start_typing(&self, _recipient: &str) -> anyhow::Result<()> {
|
||
self.start_typing_calls.fetch_add(1, Ordering::SeqCst);
|
||
Ok(())
|
||
}
|
||
|
||
async fn stop_typing(&self, _recipient: &str) -> anyhow::Result<()> {
|
||
self.stop_typing_calls.fetch_add(1, Ordering::SeqCst);
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
struct SlowProvider {
|
||
delay: Duration,
|
||
}
|
||
|
||
#[async_trait::async_trait]
|
||
impl Provider for SlowProvider {
|
||
async fn chat_with_system(
|
||
&self,
|
||
_system_prompt: Option<&str>,
|
||
message: &str,
|
||
_model: &str,
|
||
_temperature: f64,
|
||
) -> anyhow::Result<String> {
|
||
tokio::time::sleep(self.delay).await;
|
||
Ok(format!("echo: {message}"))
|
||
}
|
||
}
|
||
|
||
struct ToolCallingProvider;
|
||
|
||
fn tool_call_payload() -> String {
|
||
r#"<tool_call>
|
||
{"name":"mock_price","arguments":{"symbol":"BTC"}}
|
||
</tool_call>"#
|
||
.to_string()
|
||
}
|
||
|
||
fn tool_call_payload_with_alias_tag() -> String {
|
||
r#"<toolcall>
|
||
{"name":"mock_price","arguments":{"symbol":"BTC"}}
|
||
</toolcall>"#
|
||
.to_string()
|
||
}
|
||
|
||
#[async_trait::async_trait]
|
||
impl Provider for ToolCallingProvider {
|
||
async fn chat_with_system(
|
||
&self,
|
||
_system_prompt: Option<&str>,
|
||
_message: &str,
|
||
_model: &str,
|
||
_temperature: f64,
|
||
) -> anyhow::Result<String> {
|
||
Ok(tool_call_payload())
|
||
}
|
||
|
||
async fn chat_with_history(
|
||
&self,
|
||
messages: &[ChatMessage],
|
||
_model: &str,
|
||
_temperature: f64,
|
||
) -> anyhow::Result<String> {
|
||
let has_tool_results = messages
|
||
.iter()
|
||
.any(|msg| msg.role == "user" && msg.content.contains("[Tool results]"));
|
||
if has_tool_results {
|
||
Ok("BTC is currently around $65,000 based on latest tool output.".to_string())
|
||
} else {
|
||
Ok(tool_call_payload())
|
||
}
|
||
}
|
||
}
|
||
|
||
struct ToolCallingAliasProvider;
|
||
|
||
#[async_trait::async_trait]
|
||
impl Provider for ToolCallingAliasProvider {
|
||
async fn chat_with_system(
|
||
&self,
|
||
_system_prompt: Option<&str>,
|
||
_message: &str,
|
||
_model: &str,
|
||
_temperature: f64,
|
||
) -> anyhow::Result<String> {
|
||
Ok(tool_call_payload_with_alias_tag())
|
||
}
|
||
|
||
async fn chat_with_history(
|
||
&self,
|
||
messages: &[ChatMessage],
|
||
_model: &str,
|
||
_temperature: f64,
|
||
) -> anyhow::Result<String> {
|
||
let has_tool_results = messages
|
||
.iter()
|
||
.any(|msg| msg.role == "user" && msg.content.contains("[Tool results]"));
|
||
if has_tool_results {
|
||
Ok("BTC alias-tag flow resolved to final text output.".to_string())
|
||
} else {
|
||
Ok(tool_call_payload_with_alias_tag())
|
||
}
|
||
}
|
||
}
|
||
|
||
struct IterativeToolProvider {
|
||
required_tool_iterations: usize,
|
||
}
|
||
|
||
impl IterativeToolProvider {
|
||
fn completed_tool_iterations(messages: &[ChatMessage]) -> usize {
|
||
messages
|
||
.iter()
|
||
.filter(|msg| msg.role == "user" && msg.content.contains("[Tool results]"))
|
||
.count()
|
||
}
|
||
}
|
||
|
||
#[async_trait::async_trait]
|
||
impl Provider for IterativeToolProvider {
|
||
async fn chat_with_system(
|
||
&self,
|
||
_system_prompt: Option<&str>,
|
||
_message: &str,
|
||
_model: &str,
|
||
_temperature: f64,
|
||
) -> anyhow::Result<String> {
|
||
Ok(tool_call_payload())
|
||
}
|
||
|
||
async fn chat_with_history(
|
||
&self,
|
||
messages: &[ChatMessage],
|
||
_model: &str,
|
||
_temperature: f64,
|
||
) -> anyhow::Result<String> {
|
||
let completed_iterations = Self::completed_tool_iterations(messages);
|
||
if completed_iterations >= self.required_tool_iterations {
|
||
Ok(format!(
|
||
"Completed after {completed_iterations} tool iterations."
|
||
))
|
||
} else {
|
||
Ok(tool_call_payload())
|
||
}
|
||
}
|
||
}
|
||
|
||
#[derive(Default)]
|
||
struct HistoryCaptureProvider {
|
||
calls: std::sync::Mutex<Vec<Vec<(String, String)>>>,
|
||
}
|
||
|
||
#[async_trait::async_trait]
|
||
impl Provider for HistoryCaptureProvider {
|
||
async fn chat_with_system(
|
||
&self,
|
||
_system_prompt: Option<&str>,
|
||
_message: &str,
|
||
_model: &str,
|
||
_temperature: f64,
|
||
) -> anyhow::Result<String> {
|
||
Ok("fallback".to_string())
|
||
}
|
||
|
||
async fn chat_with_history(
|
||
&self,
|
||
messages: &[ChatMessage],
|
||
_model: &str,
|
||
_temperature: f64,
|
||
) -> anyhow::Result<String> {
|
||
let snapshot = messages
|
||
.iter()
|
||
.map(|m| (m.role.clone(), m.content.clone()))
|
||
.collect::<Vec<_>>();
|
||
let mut calls = self.calls.lock().unwrap_or_else(|e| e.into_inner());
|
||
calls.push(snapshot);
|
||
Ok(format!("response-{}", calls.len()))
|
||
}
|
||
}
|
||
|
||
struct DelayedHistoryCaptureProvider {
|
||
delay: Duration,
|
||
calls: std::sync::Mutex<Vec<Vec<(String, String)>>>,
|
||
}
|
||
|
||
#[async_trait::async_trait]
|
||
impl Provider for DelayedHistoryCaptureProvider {
|
||
async fn chat_with_system(
|
||
&self,
|
||
_system_prompt: Option<&str>,
|
||
_message: &str,
|
||
_model: &str,
|
||
_temperature: f64,
|
||
) -> anyhow::Result<String> {
|
||
Ok("fallback".to_string())
|
||
}
|
||
|
||
async fn chat_with_history(
|
||
&self,
|
||
messages: &[ChatMessage],
|
||
_model: &str,
|
||
_temperature: f64,
|
||
) -> anyhow::Result<String> {
|
||
let snapshot = messages
|
||
.iter()
|
||
.map(|m| (m.role.clone(), m.content.clone()))
|
||
.collect::<Vec<_>>();
|
||
let call_index = {
|
||
let mut calls = self.calls.lock().unwrap_or_else(|e| e.into_inner());
|
||
calls.push(snapshot);
|
||
calls.len()
|
||
};
|
||
tokio::time::sleep(self.delay).await;
|
||
Ok(format!("response-{call_index}"))
|
||
}
|
||
}
|
||
|
||
struct MockPriceTool;
|
||
|
||
#[derive(Default)]
|
||
struct ModelCaptureProvider {
|
||
call_count: AtomicUsize,
|
||
models: std::sync::Mutex<Vec<String>>,
|
||
}
|
||
|
||
#[async_trait::async_trait]
|
||
impl Provider for ModelCaptureProvider {
|
||
async fn chat_with_system(
|
||
&self,
|
||
_system_prompt: Option<&str>,
|
||
_message: &str,
|
||
_model: &str,
|
||
_temperature: f64,
|
||
) -> anyhow::Result<String> {
|
||
Ok("fallback".to_string())
|
||
}
|
||
|
||
async fn chat_with_history(
|
||
&self,
|
||
_messages: &[ChatMessage],
|
||
model: &str,
|
||
_temperature: f64,
|
||
) -> anyhow::Result<String> {
|
||
self.call_count.fetch_add(1, Ordering::SeqCst);
|
||
self.models
|
||
.lock()
|
||
.unwrap_or_else(|e| e.into_inner())
|
||
.push(model.to_string());
|
||
Ok("ok".to_string())
|
||
}
|
||
}
|
||
|
||
#[async_trait::async_trait]
|
||
impl Tool for MockPriceTool {
|
||
fn name(&self) -> &str {
|
||
"mock_price"
|
||
}
|
||
|
||
fn description(&self) -> &str {
|
||
"Return a mocked BTC price"
|
||
}
|
||
|
||
fn parameters_schema(&self) -> serde_json::Value {
|
||
serde_json::json!({
|
||
"type": "object",
|
||
"properties": {
|
||
"symbol": { "type": "string" }
|
||
},
|
||
"required": ["symbol"]
|
||
})
|
||
}
|
||
|
||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||
let symbol = args.get("symbol").and_then(serde_json::Value::as_str);
|
||
if symbol != Some("BTC") {
|
||
return Ok(ToolResult {
|
||
success: false,
|
||
output: String::new(),
|
||
error: Some("unexpected symbol".to_string()),
|
||
});
|
||
}
|
||
|
||
Ok(ToolResult {
|
||
success: true,
|
||
output: r#"{"symbol":"BTC","price_usd":65000}"#.to_string(),
|
||
error: None,
|
||
})
|
||
}
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn process_channel_message_executes_tool_calls_instead_of_sending_raw_json() {
|
||
let channel_impl = Arc::new(RecordingChannel::default());
|
||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||
|
||
let mut channels_by_name = HashMap::new();
|
||
channels_by_name.insert(channel.name().to_string(), channel);
|
||
|
||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||
channels_by_name: Arc::new(channels_by_name),
|
||
provider: Arc::new(ToolCallingProvider),
|
||
default_provider: Arc::new("test-provider".to_string()),
|
||
memory: Arc::new(NoopMemory),
|
||
tools_registry: Arc::new(vec![Box::new(MockPriceTool)]),
|
||
observer: Arc::new(NoopObserver),
|
||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||
model: Arc::new("test-model".to_string()),
|
||
temperature: 0.0,
|
||
auto_save_memory: false,
|
||
max_tool_iterations: 10,
|
||
min_relevance_score: 0.0,
|
||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||
api_key: None,
|
||
api_url: None,
|
||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||
interrupt_on_new_message: false,
|
||
multimodal: crate::config::MultimodalConfig::default(),
|
||
});
|
||
|
||
process_channel_message(
|
||
runtime_ctx,
|
||
traits::ChannelMessage {
|
||
id: "msg-1".to_string(),
|
||
sender: "alice".to_string(),
|
||
reply_target: "chat-42".to_string(),
|
||
content: "What is the BTC price now?".to_string(),
|
||
channel: "test-channel".to_string(),
|
||
timestamp: 1,
|
||
thread_ts: None,
|
||
},
|
||
CancellationToken::new(),
|
||
)
|
||
.await;
|
||
|
||
let sent_messages = channel_impl.sent_messages.lock().await;
|
||
assert_eq!(sent_messages.len(), 1);
|
||
assert!(sent_messages[0].starts_with("chat-42:"));
|
||
assert!(sent_messages[0].contains("BTC is currently around"));
|
||
assert!(!sent_messages[0].contains("\"tool_calls\""));
|
||
assert!(!sent_messages[0].contains("mock_price"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn process_channel_message_executes_tool_calls_with_alias_tags() {
|
||
let channel_impl = Arc::new(RecordingChannel::default());
|
||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||
|
||
let mut channels_by_name = HashMap::new();
|
||
channels_by_name.insert(channel.name().to_string(), channel);
|
||
|
||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||
channels_by_name: Arc::new(channels_by_name),
|
||
provider: Arc::new(ToolCallingAliasProvider),
|
||
default_provider: Arc::new("test-provider".to_string()),
|
||
memory: Arc::new(NoopMemory),
|
||
tools_registry: Arc::new(vec![Box::new(MockPriceTool)]),
|
||
observer: Arc::new(NoopObserver),
|
||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||
model: Arc::new("test-model".to_string()),
|
||
temperature: 0.0,
|
||
auto_save_memory: false,
|
||
max_tool_iterations: 10,
|
||
min_relevance_score: 0.0,
|
||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||
api_key: None,
|
||
api_url: None,
|
||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||
interrupt_on_new_message: false,
|
||
multimodal: crate::config::MultimodalConfig::default(),
|
||
});
|
||
|
||
process_channel_message(
|
||
runtime_ctx,
|
||
traits::ChannelMessage {
|
||
id: "msg-2".to_string(),
|
||
sender: "bob".to_string(),
|
||
reply_target: "chat-84".to_string(),
|
||
content: "What is the BTC price now?".to_string(),
|
||
channel: "test-channel".to_string(),
|
||
timestamp: 2,
|
||
thread_ts: None,
|
||
},
|
||
CancellationToken::new(),
|
||
)
|
||
.await;
|
||
|
||
let sent_messages = channel_impl.sent_messages.lock().await;
|
||
assert_eq!(sent_messages.len(), 1);
|
||
assert!(sent_messages[0].starts_with("chat-84:"));
|
||
assert!(sent_messages[0].contains("alias-tag flow resolved"));
|
||
assert!(!sent_messages[0].contains("<toolcall>"));
|
||
assert!(!sent_messages[0].contains("mock_price"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn process_channel_message_handles_models_command_without_llm_call() {
|
||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||
|
||
let mut channels_by_name = HashMap::new();
|
||
channels_by_name.insert(channel.name().to_string(), channel);
|
||
|
||
let default_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||
let default_provider: Arc<dyn Provider> = default_provider_impl.clone();
|
||
let fallback_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||
let fallback_provider: Arc<dyn Provider> = fallback_provider_impl.clone();
|
||
|
||
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
|
||
provider_cache_seed.insert("test-provider".to_string(), Arc::clone(&default_provider));
|
||
provider_cache_seed.insert("openrouter".to_string(), fallback_provider);
|
||
|
||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||
channels_by_name: Arc::new(channels_by_name),
|
||
provider: Arc::clone(&default_provider),
|
||
default_provider: Arc::new("test-provider".to_string()),
|
||
memory: Arc::new(NoopMemory),
|
||
tools_registry: Arc::new(vec![]),
|
||
observer: Arc::new(NoopObserver),
|
||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||
model: Arc::new("default-model".to_string()),
|
||
temperature: 0.0,
|
||
auto_save_memory: false,
|
||
max_tool_iterations: 5,
|
||
min_relevance_score: 0.0,
|
||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||
api_key: None,
|
||
api_url: None,
|
||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||
interrupt_on_new_message: false,
|
||
multimodal: crate::config::MultimodalConfig::default(),
|
||
});
|
||
|
||
process_channel_message(
|
||
runtime_ctx.clone(),
|
||
traits::ChannelMessage {
|
||
id: "msg-cmd-1".to_string(),
|
||
sender: "alice".to_string(),
|
||
reply_target: "chat-1".to_string(),
|
||
content: "/models openrouter".to_string(),
|
||
channel: "telegram".to_string(),
|
||
timestamp: 1,
|
||
thread_ts: None,
|
||
},
|
||
CancellationToken::new(),
|
||
)
|
||
.await;
|
||
|
||
let sent = channel_impl.sent_messages.lock().await;
|
||
assert_eq!(sent.len(), 1);
|
||
assert!(sent[0].contains("Provider switched to `openrouter`"));
|
||
|
||
let route_key = "telegram_alice";
|
||
let route = runtime_ctx
|
||
.route_overrides
|
||
.lock()
|
||
.unwrap_or_else(|e| e.into_inner())
|
||
.get(route_key)
|
||
.cloned()
|
||
.expect("route should be stored for sender");
|
||
assert_eq!(route.provider, "openrouter");
|
||
assert_eq!(route.model, "default-model");
|
||
|
||
assert_eq!(default_provider_impl.call_count.load(Ordering::SeqCst), 0);
|
||
assert_eq!(fallback_provider_impl.call_count.load(Ordering::SeqCst), 0);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn process_channel_message_uses_route_override_provider_and_model() {
|
||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||
|
||
let mut channels_by_name = HashMap::new();
|
||
channels_by_name.insert(channel.name().to_string(), channel);
|
||
|
||
let default_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||
let default_provider: Arc<dyn Provider> = default_provider_impl.clone();
|
||
let routed_provider_impl = Arc::new(ModelCaptureProvider::default());
|
||
let routed_provider: Arc<dyn Provider> = routed_provider_impl.clone();
|
||
|
||
let mut provider_cache_seed: HashMap<String, Arc<dyn Provider>> = HashMap::new();
|
||
provider_cache_seed.insert("test-provider".to_string(), Arc::clone(&default_provider));
|
||
provider_cache_seed.insert("openrouter".to_string(), routed_provider);
|
||
|
||
let route_key = "telegram_alice".to_string();
|
||
let mut route_overrides = HashMap::new();
|
||
route_overrides.insert(
|
||
route_key,
|
||
ChannelRouteSelection {
|
||
provider: "openrouter".to_string(),
|
||
model: "route-model".to_string(),
|
||
},
|
||
);
|
||
|
||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||
channels_by_name: Arc::new(channels_by_name),
|
||
provider: Arc::clone(&default_provider),
|
||
default_provider: Arc::new("test-provider".to_string()),
|
||
memory: Arc::new(NoopMemory),
|
||
tools_registry: Arc::new(vec![]),
|
||
observer: Arc::new(NoopObserver),
|
||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||
model: Arc::new("default-model".to_string()),
|
||
temperature: 0.0,
|
||
auto_save_memory: false,
|
||
max_tool_iterations: 5,
|
||
min_relevance_score: 0.0,
|
||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||
provider_cache: Arc::new(Mutex::new(provider_cache_seed)),
|
||
route_overrides: Arc::new(Mutex::new(route_overrides)),
|
||
api_key: None,
|
||
api_url: None,
|
||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||
interrupt_on_new_message: false,
|
||
multimodal: crate::config::MultimodalConfig::default(),
|
||
});
|
||
|
||
process_channel_message(
|
||
runtime_ctx,
|
||
traits::ChannelMessage {
|
||
id: "msg-routed-1".to_string(),
|
||
sender: "alice".to_string(),
|
||
reply_target: "chat-1".to_string(),
|
||
content: "hello routed provider".to_string(),
|
||
channel: "telegram".to_string(),
|
||
timestamp: 2,
|
||
thread_ts: None,
|
||
},
|
||
CancellationToken::new(),
|
||
)
|
||
.await;
|
||
|
||
assert_eq!(default_provider_impl.call_count.load(Ordering::SeqCst), 0);
|
||
assert_eq!(routed_provider_impl.call_count.load(Ordering::SeqCst), 1);
|
||
assert_eq!(
|
||
routed_provider_impl
|
||
.models
|
||
.lock()
|
||
.unwrap_or_else(|e| e.into_inner())
|
||
.as_slice(),
|
||
&["route-model".to_string()]
|
||
);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn process_channel_message_respects_configured_max_tool_iterations_above_default() {
|
||
let channel_impl = Arc::new(RecordingChannel::default());
|
||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||
|
||
let mut channels_by_name = HashMap::new();
|
||
channels_by_name.insert(channel.name().to_string(), channel);
|
||
|
||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||
channels_by_name: Arc::new(channels_by_name),
|
||
provider: Arc::new(IterativeToolProvider {
|
||
required_tool_iterations: 11,
|
||
}),
|
||
default_provider: Arc::new("test-provider".to_string()),
|
||
memory: Arc::new(NoopMemory),
|
||
tools_registry: Arc::new(vec![Box::new(MockPriceTool)]),
|
||
observer: Arc::new(NoopObserver),
|
||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||
model: Arc::new("test-model".to_string()),
|
||
temperature: 0.0,
|
||
auto_save_memory: false,
|
||
max_tool_iterations: 12,
|
||
min_relevance_score: 0.0,
|
||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||
api_key: None,
|
||
api_url: None,
|
||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||
interrupt_on_new_message: false,
|
||
multimodal: crate::config::MultimodalConfig::default(),
|
||
});
|
||
|
||
process_channel_message(
|
||
runtime_ctx,
|
||
traits::ChannelMessage {
|
||
id: "msg-iter-success".to_string(),
|
||
sender: "alice".to_string(),
|
||
reply_target: "chat-iter-success".to_string(),
|
||
content: "Loop until done".to_string(),
|
||
channel: "test-channel".to_string(),
|
||
timestamp: 1,
|
||
thread_ts: None,
|
||
},
|
||
CancellationToken::new(),
|
||
)
|
||
.await;
|
||
|
||
let sent_messages = channel_impl.sent_messages.lock().await;
|
||
assert_eq!(sent_messages.len(), 1);
|
||
assert!(sent_messages[0].starts_with("chat-iter-success:"));
|
||
assert!(sent_messages[0].contains("Completed after 11 tool iterations."));
|
||
assert!(!sent_messages[0].contains("⚠️ Error:"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn process_channel_message_reports_configured_max_tool_iterations_limit() {
|
||
let channel_impl = Arc::new(RecordingChannel::default());
|
||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||
|
||
let mut channels_by_name = HashMap::new();
|
||
channels_by_name.insert(channel.name().to_string(), channel);
|
||
|
||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||
channels_by_name: Arc::new(channels_by_name),
|
||
provider: Arc::new(IterativeToolProvider {
|
||
required_tool_iterations: 20,
|
||
}),
|
||
default_provider: Arc::new("test-provider".to_string()),
|
||
memory: Arc::new(NoopMemory),
|
||
tools_registry: Arc::new(vec![Box::new(MockPriceTool)]),
|
||
observer: Arc::new(NoopObserver),
|
||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||
model: Arc::new("test-model".to_string()),
|
||
temperature: 0.0,
|
||
auto_save_memory: false,
|
||
max_tool_iterations: 3,
|
||
min_relevance_score: 0.0,
|
||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||
api_key: None,
|
||
api_url: None,
|
||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||
interrupt_on_new_message: false,
|
||
multimodal: crate::config::MultimodalConfig::default(),
|
||
});
|
||
|
||
process_channel_message(
|
||
runtime_ctx,
|
||
traits::ChannelMessage {
|
||
id: "msg-iter-fail".to_string(),
|
||
sender: "bob".to_string(),
|
||
reply_target: "chat-iter-fail".to_string(),
|
||
content: "Loop forever".to_string(),
|
||
channel: "test-channel".to_string(),
|
||
timestamp: 2,
|
||
thread_ts: None,
|
||
},
|
||
CancellationToken::new(),
|
||
)
|
||
.await;
|
||
|
||
let sent_messages = channel_impl.sent_messages.lock().await;
|
||
assert_eq!(sent_messages.len(), 1);
|
||
assert!(sent_messages[0].starts_with("chat-iter-fail:"));
|
||
assert!(sent_messages[0].contains("⚠️ Error: Agent exceeded maximum tool iterations (3)"));
|
||
}
|
||
|
||
struct NoopMemory;
|
||
|
||
#[async_trait::async_trait]
|
||
impl Memory for NoopMemory {
|
||
fn name(&self) -> &str {
|
||
"noop"
|
||
}
|
||
|
||
async fn store(
|
||
&self,
|
||
_key: &str,
|
||
_content: &str,
|
||
_category: crate::memory::MemoryCategory,
|
||
_session_id: Option<&str>,
|
||
) -> anyhow::Result<()> {
|
||
Ok(())
|
||
}
|
||
|
||
async fn recall(
|
||
&self,
|
||
_query: &str,
|
||
_limit: usize,
|
||
_session_id: Option<&str>,
|
||
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
|
||
Ok(Vec::new())
|
||
}
|
||
|
||
async fn get(&self, _key: &str) -> anyhow::Result<Option<crate::memory::MemoryEntry>> {
|
||
Ok(None)
|
||
}
|
||
|
||
async fn list(
|
||
&self,
|
||
_category: Option<&crate::memory::MemoryCategory>,
|
||
_session_id: Option<&str>,
|
||
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
|
||
Ok(Vec::new())
|
||
}
|
||
|
||
async fn forget(&self, _key: &str) -> anyhow::Result<bool> {
|
||
Ok(false)
|
||
}
|
||
|
||
async fn count(&self) -> anyhow::Result<usize> {
|
||
Ok(0)
|
||
}
|
||
|
||
async fn health_check(&self) -> bool {
|
||
true
|
||
}
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn message_dispatch_processes_messages_in_parallel() {
|
||
let channel_impl = Arc::new(RecordingChannel::default());
|
||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||
|
||
let mut channels_by_name = HashMap::new();
|
||
channels_by_name.insert(channel.name().to_string(), channel);
|
||
|
||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||
channels_by_name: Arc::new(channels_by_name),
|
||
provider: Arc::new(SlowProvider {
|
||
delay: Duration::from_millis(250),
|
||
}),
|
||
default_provider: Arc::new("test-provider".to_string()),
|
||
memory: Arc::new(NoopMemory),
|
||
tools_registry: Arc::new(vec![]),
|
||
observer: Arc::new(NoopObserver),
|
||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||
model: Arc::new("test-model".to_string()),
|
||
temperature: 0.0,
|
||
auto_save_memory: false,
|
||
max_tool_iterations: 10,
|
||
min_relevance_score: 0.0,
|
||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||
api_key: None,
|
||
api_url: None,
|
||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||
interrupt_on_new_message: false,
|
||
multimodal: crate::config::MultimodalConfig::default(),
|
||
});
|
||
|
||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(4);
|
||
tx.send(traits::ChannelMessage {
|
||
id: "1".to_string(),
|
||
sender: "alice".to_string(),
|
||
reply_target: "alice".to_string(),
|
||
content: "hello".to_string(),
|
||
channel: "test-channel".to_string(),
|
||
timestamp: 1,
|
||
thread_ts: None,
|
||
})
|
||
.await
|
||
.unwrap();
|
||
tx.send(traits::ChannelMessage {
|
||
id: "2".to_string(),
|
||
sender: "bob".to_string(),
|
||
reply_target: "bob".to_string(),
|
||
content: "world".to_string(),
|
||
channel: "test-channel".to_string(),
|
||
timestamp: 2,
|
||
thread_ts: None,
|
||
})
|
||
.await
|
||
.unwrap();
|
||
drop(tx);
|
||
|
||
let started = Instant::now();
|
||
run_message_dispatch_loop(rx, runtime_ctx, 2).await;
|
||
let elapsed = started.elapsed();
|
||
|
||
assert!(
|
||
elapsed < Duration::from_millis(430),
|
||
"expected parallel dispatch (<430ms), got {:?}",
|
||
elapsed
|
||
);
|
||
|
||
let sent_messages = channel_impl.sent_messages.lock().await;
|
||
assert_eq!(sent_messages.len(), 2);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn message_dispatch_interrupts_in_flight_telegram_request_and_preserves_context() {
|
||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||
|
||
let mut channels_by_name = HashMap::new();
|
||
channels_by_name.insert(channel.name().to_string(), channel);
|
||
|
||
let provider_impl = Arc::new(DelayedHistoryCaptureProvider {
|
||
delay: Duration::from_millis(250),
|
||
calls: std::sync::Mutex::new(Vec::new()),
|
||
});
|
||
|
||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||
channels_by_name: Arc::new(channels_by_name),
|
||
provider: provider_impl.clone(),
|
||
default_provider: Arc::new("test-provider".to_string()),
|
||
memory: Arc::new(NoopMemory),
|
||
tools_registry: Arc::new(vec![]),
|
||
observer: Arc::new(NoopObserver),
|
||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||
model: Arc::new("test-model".to_string()),
|
||
temperature: 0.0,
|
||
auto_save_memory: false,
|
||
max_tool_iterations: 10,
|
||
min_relevance_score: 0.0,
|
||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||
api_key: None,
|
||
api_url: None,
|
||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||
interrupt_on_new_message: true,
|
||
multimodal: crate::config::MultimodalConfig::default(),
|
||
});
|
||
|
||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||
let send_task = tokio::spawn(async move {
|
||
tx.send(traits::ChannelMessage {
|
||
id: "msg-1".to_string(),
|
||
sender: "alice".to_string(),
|
||
reply_target: "chat-1".to_string(),
|
||
content: "forwarded content".to_string(),
|
||
channel: "telegram".to_string(),
|
||
timestamp: 1,
|
||
thread_ts: None,
|
||
})
|
||
.await
|
||
.unwrap();
|
||
tokio::time::sleep(Duration::from_millis(40)).await;
|
||
tx.send(traits::ChannelMessage {
|
||
id: "msg-2".to_string(),
|
||
sender: "alice".to_string(),
|
||
reply_target: "chat-1".to_string(),
|
||
content: "summarize this".to_string(),
|
||
channel: "telegram".to_string(),
|
||
timestamp: 2,
|
||
thread_ts: None,
|
||
})
|
||
.await
|
||
.unwrap();
|
||
});
|
||
|
||
run_message_dispatch_loop(rx, runtime_ctx, 4).await;
|
||
send_task.await.unwrap();
|
||
|
||
let sent_messages = channel_impl.sent_messages.lock().await;
|
||
assert_eq!(sent_messages.len(), 1);
|
||
assert!(sent_messages[0].starts_with("chat-1:"));
|
||
assert!(sent_messages[0].contains("response-2"));
|
||
drop(sent_messages);
|
||
|
||
let calls = provider_impl
|
||
.calls
|
||
.lock()
|
||
.unwrap_or_else(|e| e.into_inner());
|
||
assert_eq!(calls.len(), 2);
|
||
let second_call = &calls[1];
|
||
assert!(second_call
|
||
.iter()
|
||
.any(|(role, content)| { role == "user" && content.contains("forwarded content") }));
|
||
assert!(second_call
|
||
.iter()
|
||
.any(|(role, content)| { role == "user" && content.contains("summarize this") }));
|
||
assert!(
|
||
!second_call.iter().any(|(role, _)| role == "assistant"),
|
||
"cancelled turn should not persist an assistant response"
|
||
);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn message_dispatch_interrupt_scope_is_same_sender_same_chat() {
|
||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||
|
||
let mut channels_by_name = HashMap::new();
|
||
channels_by_name.insert(channel.name().to_string(), channel);
|
||
|
||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||
channels_by_name: Arc::new(channels_by_name),
|
||
provider: Arc::new(SlowProvider {
|
||
delay: Duration::from_millis(180),
|
||
}),
|
||
default_provider: Arc::new("test-provider".to_string()),
|
||
memory: Arc::new(NoopMemory),
|
||
tools_registry: Arc::new(vec![]),
|
||
observer: Arc::new(NoopObserver),
|
||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||
model: Arc::new("test-model".to_string()),
|
||
temperature: 0.0,
|
||
auto_save_memory: false,
|
||
max_tool_iterations: 10,
|
||
min_relevance_score: 0.0,
|
||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||
api_key: None,
|
||
api_url: None,
|
||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||
interrupt_on_new_message: true,
|
||
multimodal: crate::config::MultimodalConfig::default(),
|
||
});
|
||
|
||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||
let send_task = tokio::spawn(async move {
|
||
tx.send(traits::ChannelMessage {
|
||
id: "msg-a".to_string(),
|
||
sender: "alice".to_string(),
|
||
reply_target: "chat-1".to_string(),
|
||
content: "first chat".to_string(),
|
||
channel: "telegram".to_string(),
|
||
timestamp: 1,
|
||
thread_ts: None,
|
||
})
|
||
.await
|
||
.unwrap();
|
||
tokio::time::sleep(Duration::from_millis(30)).await;
|
||
tx.send(traits::ChannelMessage {
|
||
id: "msg-b".to_string(),
|
||
sender: "alice".to_string(),
|
||
reply_target: "chat-2".to_string(),
|
||
content: "second chat".to_string(),
|
||
channel: "telegram".to_string(),
|
||
timestamp: 2,
|
||
thread_ts: None,
|
||
})
|
||
.await
|
||
.unwrap();
|
||
});
|
||
|
||
run_message_dispatch_loop(rx, runtime_ctx, 4).await;
|
||
send_task.await.unwrap();
|
||
|
||
let sent_messages = channel_impl.sent_messages.lock().await;
|
||
assert_eq!(sent_messages.len(), 2);
|
||
assert!(sent_messages.iter().any(|msg| msg.starts_with("chat-1:")));
|
||
assert!(sent_messages.iter().any(|msg| msg.starts_with("chat-2:")));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn process_channel_message_cancels_scoped_typing_task() {
|
||
let channel_impl = Arc::new(RecordingChannel::default());
|
||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||
|
||
let mut channels_by_name = HashMap::new();
|
||
channels_by_name.insert(channel.name().to_string(), channel);
|
||
|
||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||
channels_by_name: Arc::new(channels_by_name),
|
||
provider: Arc::new(SlowProvider {
|
||
delay: Duration::from_millis(20),
|
||
}),
|
||
default_provider: Arc::new("test-provider".to_string()),
|
||
memory: Arc::new(NoopMemory),
|
||
tools_registry: Arc::new(vec![]),
|
||
observer: Arc::new(NoopObserver),
|
||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||
model: Arc::new("test-model".to_string()),
|
||
temperature: 0.0,
|
||
auto_save_memory: false,
|
||
max_tool_iterations: 10,
|
||
min_relevance_score: 0.0,
|
||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||
api_key: None,
|
||
api_url: None,
|
||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||
interrupt_on_new_message: false,
|
||
multimodal: crate::config::MultimodalConfig::default(),
|
||
});
|
||
|
||
process_channel_message(
|
||
runtime_ctx,
|
||
traits::ChannelMessage {
|
||
id: "typing-msg".to_string(),
|
||
sender: "alice".to_string(),
|
||
reply_target: "chat-typing".to_string(),
|
||
content: "hello".to_string(),
|
||
channel: "test-channel".to_string(),
|
||
timestamp: 1,
|
||
thread_ts: None,
|
||
},
|
||
CancellationToken::new(),
|
||
)
|
||
.await;
|
||
|
||
let starts = channel_impl.start_typing_calls.load(Ordering::SeqCst);
|
||
let stops = channel_impl.stop_typing_calls.load(Ordering::SeqCst);
|
||
assert_eq!(starts, 1, "start_typing should be called once");
|
||
assert_eq!(stops, 1, "stop_typing should be called once");
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_contains_all_sections() {
|
||
let ws = make_workspace();
|
||
let tools = vec![("shell", "Run commands"), ("file_read", "Read files")];
|
||
let prompt = build_system_prompt(ws.path(), "test-model", &tools, &[], None, None);
|
||
|
||
// Section headers
|
||
assert!(prompt.contains("## Tools"), "missing Tools section");
|
||
assert!(prompt.contains("## Safety"), "missing Safety section");
|
||
assert!(prompt.contains("## Workspace"), "missing Workspace section");
|
||
assert!(
|
||
prompt.contains("## Project Context"),
|
||
"missing Project Context"
|
||
);
|
||
assert!(
|
||
prompt.contains("## Current Date & Time"),
|
||
"missing Date/Time"
|
||
);
|
||
assert!(prompt.contains("## Runtime"), "missing Runtime section");
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_injects_tools() {
|
||
let ws = make_workspace();
|
||
let tools = vec![
|
||
("shell", "Run commands"),
|
||
("memory_recall", "Search memory"),
|
||
];
|
||
let prompt = build_system_prompt(ws.path(), "gpt-4o", &tools, &[], None, None);
|
||
|
||
assert!(prompt.contains("**shell**"));
|
||
assert!(prompt.contains("Run commands"));
|
||
assert!(prompt.contains("**memory_recall**"));
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_includes_single_tool_protocol_block_after_append() {
|
||
let ws = make_workspace();
|
||
let tools = vec![("shell", "Run commands")];
|
||
let mut prompt = build_system_prompt(ws.path(), "gpt-4o", &tools, &[], None, None);
|
||
|
||
assert!(
|
||
!prompt.contains("## Tool Use Protocol"),
|
||
"build_system_prompt should not emit protocol block directly"
|
||
);
|
||
|
||
prompt.push_str(&build_tool_instructions(&[]));
|
||
|
||
assert_eq!(
|
||
prompt.matches("## Tool Use Protocol").count(),
|
||
1,
|
||
"protocol block should appear exactly once in the final prompt"
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_injects_safety() {
|
||
let ws = make_workspace();
|
||
let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None);
|
||
|
||
assert!(prompt.contains("Do not exfiltrate private data"));
|
||
assert!(prompt.contains("Do not run destructive commands"));
|
||
assert!(prompt.contains("Prefer `trash` over `rm`"));
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_injects_workspace_files() {
|
||
let ws = make_workspace();
|
||
let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None);
|
||
|
||
assert!(prompt.contains("### SOUL.md"), "missing SOUL.md header");
|
||
assert!(prompt.contains("Be helpful"), "missing SOUL content");
|
||
assert!(prompt.contains("### IDENTITY.md"), "missing IDENTITY.md");
|
||
assert!(
|
||
prompt.contains("Name: ZeroClaw"),
|
||
"missing IDENTITY content"
|
||
);
|
||
assert!(prompt.contains("### USER.md"), "missing USER.md");
|
||
assert!(prompt.contains("### AGENTS.md"), "missing AGENTS.md");
|
||
assert!(prompt.contains("### TOOLS.md"), "missing TOOLS.md");
|
||
// HEARTBEAT.md is intentionally excluded from channel prompts — it's only
|
||
// relevant to the heartbeat worker and causes LLMs to emit spurious
|
||
// "HEARTBEAT_OK" acknowledgments in channel conversations.
|
||
assert!(
|
||
!prompt.contains("### HEARTBEAT.md"),
|
||
"HEARTBEAT.md should not be in channel prompt"
|
||
);
|
||
assert!(prompt.contains("### MEMORY.md"), "missing MEMORY.md");
|
||
assert!(prompt.contains("User likes Rust"), "missing MEMORY content");
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_missing_file_markers() {
|
||
let tmp = TempDir::new().unwrap();
|
||
// Empty workspace — no files at all
|
||
let prompt = build_system_prompt(tmp.path(), "model", &[], &[], None, None);
|
||
|
||
assert!(prompt.contains("[File not found: SOUL.md]"));
|
||
assert!(prompt.contains("[File not found: AGENTS.md]"));
|
||
assert!(prompt.contains("[File not found: IDENTITY.md]"));
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_bootstrap_only_if_exists() {
|
||
let ws = make_workspace();
|
||
// No BOOTSTRAP.md — should not appear
|
||
let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None);
|
||
assert!(
|
||
!prompt.contains("### BOOTSTRAP.md"),
|
||
"BOOTSTRAP.md should not appear when missing"
|
||
);
|
||
|
||
// Create BOOTSTRAP.md — should appear
|
||
std::fs::write(ws.path().join("BOOTSTRAP.md"), "# Bootstrap\nFirst run.").unwrap();
|
||
let prompt2 = build_system_prompt(ws.path(), "model", &[], &[], None, None);
|
||
assert!(
|
||
prompt2.contains("### BOOTSTRAP.md"),
|
||
"BOOTSTRAP.md should appear when present"
|
||
);
|
||
assert!(prompt2.contains("First run"));
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_no_daily_memory_injection() {
|
||
let ws = make_workspace();
|
||
let memory_dir = ws.path().join("memory");
|
||
std::fs::create_dir_all(&memory_dir).unwrap();
|
||
let today = chrono::Local::now().format("%Y-%m-%d").to_string();
|
||
std::fs::write(
|
||
memory_dir.join(format!("{today}.md")),
|
||
"# Daily\nSome note.",
|
||
)
|
||
.unwrap();
|
||
|
||
let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None);
|
||
|
||
// Daily notes should NOT be in the system prompt (on-demand via tools)
|
||
assert!(
|
||
!prompt.contains("Daily Notes"),
|
||
"daily notes should not be auto-injected"
|
||
);
|
||
assert!(
|
||
!prompt.contains("Some note"),
|
||
"daily content should not be in prompt"
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_runtime_metadata() {
|
||
let ws = make_workspace();
|
||
let prompt = build_system_prompt(ws.path(), "claude-sonnet-4", &[], &[], None, None);
|
||
|
||
assert!(prompt.contains("Model: claude-sonnet-4"));
|
||
assert!(prompt.contains(&format!("OS: {}", std::env::consts::OS)));
|
||
assert!(prompt.contains("Host:"));
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_skills_include_instructions_and_tools() {
|
||
let ws = make_workspace();
|
||
let skills = vec![crate::skills::Skill {
|
||
name: "code-review".into(),
|
||
description: "Review code for bugs".into(),
|
||
version: "1.0.0".into(),
|
||
author: None,
|
||
tags: vec![],
|
||
tools: vec![crate::skills::SkillTool {
|
||
name: "lint".into(),
|
||
description: "Run static checks".into(),
|
||
kind: "shell".into(),
|
||
command: "cargo clippy".into(),
|
||
args: HashMap::new(),
|
||
}],
|
||
prompts: vec!["Always run cargo test before final response.".into()],
|
||
location: None,
|
||
}];
|
||
|
||
let prompt = build_system_prompt(ws.path(), "model", &[], &skills, None, None);
|
||
|
||
assert!(prompt.contains("<available_skills>"), "missing skills XML");
|
||
assert!(prompt.contains("<name>code-review</name>"));
|
||
assert!(prompt.contains("<description>Review code for bugs</description>"));
|
||
assert!(prompt.contains("SKILL.md</location>"));
|
||
assert!(prompt.contains("<instructions>"));
|
||
assert!(prompt
|
||
.contains("<instruction>Always run cargo test before final response.</instruction>"));
|
||
assert!(prompt.contains("<tools>"));
|
||
assert!(prompt.contains("<name>lint</name>"));
|
||
assert!(prompt.contains("<kind>shell</kind>"));
|
||
assert!(!prompt.contains("loaded on demand"));
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_skills_escape_reserved_xml_chars() {
|
||
let ws = make_workspace();
|
||
let skills = vec![crate::skills::Skill {
|
||
name: "code<review>&".into(),
|
||
description: "Review \"unsafe\" and 'risky' bits".into(),
|
||
version: "1.0.0".into(),
|
||
author: None,
|
||
tags: vec![],
|
||
tools: vec![crate::skills::SkillTool {
|
||
name: "run\"linter\"".into(),
|
||
description: "Run <lint> & report".into(),
|
||
kind: "shell&exec".into(),
|
||
command: "cargo clippy".into(),
|
||
args: HashMap::new(),
|
||
}],
|
||
prompts: vec!["Use <tool_call> and & keep output \"safe\"".into()],
|
||
location: None,
|
||
}];
|
||
|
||
let prompt = build_system_prompt(ws.path(), "model", &[], &skills, None, None);
|
||
|
||
assert!(prompt.contains("<name>code<review>&</name>"));
|
||
assert!(prompt.contains(
|
||
"<description>Review "unsafe" and 'risky' bits</description>"
|
||
));
|
||
assert!(prompt.contains("<name>run"linter"</name>"));
|
||
assert!(prompt.contains("<description>Run <lint> & report</description>"));
|
||
assert!(prompt.contains("<kind>shell&exec</kind>"));
|
||
assert!(prompt.contains(
|
||
"<instruction>Use <tool_call> and & keep output "safe"</instruction>"
|
||
));
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_truncation() {
|
||
let ws = make_workspace();
|
||
// Write a file larger than BOOTSTRAP_MAX_CHARS
|
||
let big_content = "x".repeat(BOOTSTRAP_MAX_CHARS + 1000);
|
||
std::fs::write(ws.path().join("AGENTS.md"), &big_content).unwrap();
|
||
|
||
let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None);
|
||
|
||
assert!(
|
||
prompt.contains("truncated at"),
|
||
"large files should be truncated"
|
||
);
|
||
assert!(
|
||
!prompt.contains(&big_content),
|
||
"full content should not appear"
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_empty_files_skipped() {
|
||
let ws = make_workspace();
|
||
std::fs::write(ws.path().join("TOOLS.md"), "").unwrap();
|
||
|
||
let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None);
|
||
|
||
// Empty file should not produce a header
|
||
assert!(
|
||
!prompt.contains("### TOOLS.md"),
|
||
"empty files should be skipped"
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn channel_log_truncation_is_utf8_safe_for_multibyte_text() {
|
||
let msg = "Hello from ZeroClaw 🌍. Current status is healthy, and café-style UTF-8 text stays safe in logs.";
|
||
|
||
// Reproduces the production crash path where channel logs truncate at 80 chars.
|
||
let result = std::panic::catch_unwind(|| crate::util::truncate_with_ellipsis(msg, 80));
|
||
assert!(
|
||
result.is_ok(),
|
||
"truncate_with_ellipsis should never panic on UTF-8"
|
||
);
|
||
|
||
let truncated = result.unwrap();
|
||
assert!(!truncated.is_empty());
|
||
assert!(truncated.is_char_boundary(truncated.len()));
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_contains_channel_capabilities() {
|
||
let ws = make_workspace();
|
||
let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None);
|
||
|
||
assert!(
|
||
prompt.contains("## Channel Capabilities"),
|
||
"missing Channel Capabilities section"
|
||
);
|
||
assert!(
|
||
prompt.contains("running as a messaging bot"),
|
||
"missing channel context"
|
||
);
|
||
assert!(
|
||
prompt.contains("NEVER repeat, describe, or echo credentials"),
|
||
"missing security instruction"
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn prompt_workspace_path() {
|
||
let ws = make_workspace();
|
||
let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None);
|
||
|
||
assert!(prompt.contains(&format!("Working directory: `{}`", ws.path().display())));
|
||
}
|
||
|
||
#[test]
|
||
fn conversation_memory_key_uses_message_id() {
|
||
let msg = traits::ChannelMessage {
|
||
id: "msg_abc123".into(),
|
||
sender: "U123".into(),
|
||
reply_target: "C456".into(),
|
||
content: "hello".into(),
|
||
channel: "slack".into(),
|
||
timestamp: 1,
|
||
thread_ts: None,
|
||
};
|
||
|
||
assert_eq!(conversation_memory_key(&msg), "slack_U123_msg_abc123");
|
||
}
|
||
|
||
#[test]
|
||
fn conversation_memory_key_is_unique_per_message() {
|
||
let msg1 = traits::ChannelMessage {
|
||
id: "msg_1".into(),
|
||
sender: "U123".into(),
|
||
reply_target: "C456".into(),
|
||
content: "first".into(),
|
||
channel: "slack".into(),
|
||
timestamp: 1,
|
||
thread_ts: None,
|
||
};
|
||
let msg2 = traits::ChannelMessage {
|
||
id: "msg_2".into(),
|
||
sender: "U123".into(),
|
||
reply_target: "C456".into(),
|
||
content: "second".into(),
|
||
channel: "slack".into(),
|
||
timestamp: 2,
|
||
thread_ts: None,
|
||
};
|
||
|
||
assert_ne!(
|
||
conversation_memory_key(&msg1),
|
||
conversation_memory_key(&msg2)
|
||
);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn autosave_keys_preserve_multiple_conversation_facts() {
|
||
let tmp = TempDir::new().unwrap();
|
||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||
|
||
let msg1 = traits::ChannelMessage {
|
||
id: "msg_1".into(),
|
||
sender: "U123".into(),
|
||
reply_target: "C456".into(),
|
||
content: "I'm Paul".into(),
|
||
channel: "slack".into(),
|
||
timestamp: 1,
|
||
thread_ts: None,
|
||
};
|
||
let msg2 = traits::ChannelMessage {
|
||
id: "msg_2".into(),
|
||
sender: "U123".into(),
|
||
reply_target: "C456".into(),
|
||
content: "I'm 45".into(),
|
||
channel: "slack".into(),
|
||
timestamp: 2,
|
||
thread_ts: None,
|
||
};
|
||
|
||
mem.store(
|
||
&conversation_memory_key(&msg1),
|
||
&msg1.content,
|
||
MemoryCategory::Conversation,
|
||
None,
|
||
)
|
||
.await
|
||
.unwrap();
|
||
mem.store(
|
||
&conversation_memory_key(&msg2),
|
||
&msg2.content,
|
||
MemoryCategory::Conversation,
|
||
None,
|
||
)
|
||
.await
|
||
.unwrap();
|
||
|
||
assert_eq!(mem.count().await.unwrap(), 2);
|
||
|
||
let recalled = mem.recall("45", 5, None).await.unwrap();
|
||
assert!(recalled.iter().any(|entry| entry.content.contains("45")));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn build_memory_context_includes_recalled_entries() {
|
||
let tmp = TempDir::new().unwrap();
|
||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||
mem.store("age_fact", "Age is 45", MemoryCategory::Conversation, None)
|
||
.await
|
||
.unwrap();
|
||
|
||
let context = build_memory_context(&mem, "age", 0.0).await;
|
||
assert!(context.contains("[Memory context]"));
|
||
assert!(context.contains("Age is 45"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn process_channel_message_restores_per_sender_history_on_follow_ups() {
|
||
let channel_impl = Arc::new(RecordingChannel::default());
|
||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||
|
||
let mut channels_by_name = HashMap::new();
|
||
channels_by_name.insert(channel.name().to_string(), channel);
|
||
|
||
let provider_impl = Arc::new(HistoryCaptureProvider::default());
|
||
|
||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||
channels_by_name: Arc::new(channels_by_name),
|
||
provider: provider_impl.clone(),
|
||
default_provider: Arc::new("test-provider".to_string()),
|
||
memory: Arc::new(NoopMemory),
|
||
tools_registry: Arc::new(vec![]),
|
||
observer: Arc::new(NoopObserver),
|
||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||
model: Arc::new("test-model".to_string()),
|
||
temperature: 0.0,
|
||
auto_save_memory: false,
|
||
max_tool_iterations: 5,
|
||
min_relevance_score: 0.0,
|
||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||
api_key: None,
|
||
api_url: None,
|
||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||
interrupt_on_new_message: false,
|
||
multimodal: crate::config::MultimodalConfig::default(),
|
||
});
|
||
|
||
process_channel_message(
|
||
runtime_ctx.clone(),
|
||
traits::ChannelMessage {
|
||
id: "msg-a".to_string(),
|
||
sender: "alice".to_string(),
|
||
reply_target: "chat-1".to_string(),
|
||
content: "hello".to_string(),
|
||
channel: "test-channel".to_string(),
|
||
timestamp: 1,
|
||
thread_ts: None,
|
||
},
|
||
CancellationToken::new(),
|
||
)
|
||
.await;
|
||
|
||
process_channel_message(
|
||
runtime_ctx,
|
||
traits::ChannelMessage {
|
||
id: "msg-b".to_string(),
|
||
sender: "alice".to_string(),
|
||
reply_target: "chat-1".to_string(),
|
||
content: "follow up".to_string(),
|
||
channel: "test-channel".to_string(),
|
||
timestamp: 2,
|
||
thread_ts: None,
|
||
},
|
||
CancellationToken::new(),
|
||
)
|
||
.await;
|
||
|
||
let calls = provider_impl
|
||
.calls
|
||
.lock()
|
||
.unwrap_or_else(|e| e.into_inner());
|
||
assert_eq!(calls.len(), 2);
|
||
assert_eq!(calls[0].len(), 2);
|
||
assert_eq!(calls[0][0].0, "system");
|
||
assert_eq!(calls[0][1].0, "user");
|
||
assert_eq!(calls[1].len(), 4);
|
||
assert_eq!(calls[1][0].0, "system");
|
||
assert_eq!(calls[1][1].0, "user");
|
||
assert_eq!(calls[1][2].0, "assistant");
|
||
assert_eq!(calls[1][3].0, "user");
|
||
assert!(calls[1][1].1.contains("hello"));
|
||
assert!(calls[1][2].1.contains("response-1"));
|
||
assert!(calls[1][3].1.contains("follow up"));
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn process_channel_message_telegram_keeps_system_instruction_at_top_only() {
|
||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||
|
||
let mut channels_by_name = HashMap::new();
|
||
channels_by_name.insert(channel.name().to_string(), channel);
|
||
|
||
let provider_impl = Arc::new(HistoryCaptureProvider::default());
|
||
let mut histories = HashMap::new();
|
||
histories.insert(
|
||
"telegram_alice".to_string(),
|
||
vec![
|
||
ChatMessage::assistant("stale assistant"),
|
||
ChatMessage::user("earlier user question"),
|
||
ChatMessage::assistant("earlier assistant reply"),
|
||
],
|
||
);
|
||
|
||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||
channels_by_name: Arc::new(channels_by_name),
|
||
provider: provider_impl.clone(),
|
||
default_provider: Arc::new("test-provider".to_string()),
|
||
memory: Arc::new(NoopMemory),
|
||
tools_registry: Arc::new(vec![]),
|
||
observer: Arc::new(NoopObserver),
|
||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||
model: Arc::new("test-model".to_string()),
|
||
temperature: 0.0,
|
||
auto_save_memory: false,
|
||
max_tool_iterations: 5,
|
||
min_relevance_score: 0.0,
|
||
conversation_histories: Arc::new(Mutex::new(histories)),
|
||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||
api_key: None,
|
||
api_url: None,
|
||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||
interrupt_on_new_message: false,
|
||
multimodal: crate::config::MultimodalConfig::default(),
|
||
});
|
||
|
||
process_channel_message(
|
||
runtime_ctx.clone(),
|
||
traits::ChannelMessage {
|
||
id: "tg-msg-1".to_string(),
|
||
sender: "alice".to_string(),
|
||
reply_target: "chat-telegram".to_string(),
|
||
content: "hello".to_string(),
|
||
channel: "telegram".to_string(),
|
||
timestamp: 1,
|
||
thread_ts: None,
|
||
},
|
||
CancellationToken::new(),
|
||
)
|
||
.await;
|
||
|
||
let calls = provider_impl
|
||
.calls
|
||
.lock()
|
||
.unwrap_or_else(|e| e.into_inner());
|
||
assert_eq!(calls.len(), 1);
|
||
assert_eq!(calls[0].len(), 4);
|
||
|
||
let roles = calls[0]
|
||
.iter()
|
||
.map(|(role, _)| role.as_str())
|
||
.collect::<Vec<_>>();
|
||
assert_eq!(roles, vec!["system", "user", "assistant", "user"]);
|
||
assert!(
|
||
calls[0][0]
|
||
.1
|
||
.contains("When responding on Telegram, include media markers"),
|
||
"telegram delivery instruction should live in the system prompt"
|
||
);
|
||
assert!(!calls[0].iter().skip(1).any(|(role, _)| role == "system"));
|
||
}
|
||
|
||
// ── AIEOS Identity Tests (Issue #168) ─────────────────────────
|
||
|
||
#[test]
|
||
fn aieos_identity_from_file() {
|
||
use crate::config::IdentityConfig;
|
||
use tempfile::TempDir;
|
||
|
||
let tmp = TempDir::new().unwrap();
|
||
let identity_path = tmp.path().join("aieos_identity.json");
|
||
|
||
// Write AIEOS identity file
|
||
let aieos_json = r#"{
|
||
"identity": {
|
||
"names": {"first": "Nova", "nickname": "Nov"},
|
||
"bio": "A helpful AI assistant.",
|
||
"origin": "Silicon Valley"
|
||
},
|
||
"psychology": {
|
||
"mbti": "INTJ",
|
||
"moral_compass": ["Be helpful", "Do no harm"]
|
||
},
|
||
"linguistics": {
|
||
"style": "concise",
|
||
"formality": "casual"
|
||
}
|
||
}"#;
|
||
std::fs::write(&identity_path, aieos_json).unwrap();
|
||
|
||
// Create identity config pointing to the file
|
||
let config = IdentityConfig {
|
||
format: "aieos".into(),
|
||
aieos_path: Some("aieos_identity.json".into()),
|
||
aieos_inline: None,
|
||
};
|
||
|
||
let prompt = build_system_prompt(tmp.path(), "model", &[], &[], Some(&config), None);
|
||
|
||
// Should contain AIEOS sections
|
||
assert!(prompt.contains("## Identity"));
|
||
assert!(prompt.contains("**Name:** Nova"));
|
||
assert!(prompt.contains("**Nickname:** Nov"));
|
||
assert!(prompt.contains("**Bio:** A helpful AI assistant."));
|
||
assert!(prompt.contains("**Origin:** Silicon Valley"));
|
||
|
||
assert!(prompt.contains("## Personality"));
|
||
assert!(prompt.contains("**MBTI:** INTJ"));
|
||
assert!(prompt.contains("**Moral Compass:**"));
|
||
assert!(prompt.contains("- Be helpful"));
|
||
|
||
assert!(prompt.contains("## Communication Style"));
|
||
assert!(prompt.contains("**Style:** concise"));
|
||
assert!(prompt.contains("**Formality Level:** casual"));
|
||
|
||
// Should NOT contain OpenClaw bootstrap file headers
|
||
assert!(!prompt.contains("### SOUL.md"));
|
||
assert!(!prompt.contains("### IDENTITY.md"));
|
||
assert!(!prompt.contains("[File not found"));
|
||
}
|
||
|
||
#[test]
|
||
fn aieos_identity_from_inline() {
|
||
use crate::config::IdentityConfig;
|
||
|
||
let config = IdentityConfig {
|
||
format: "aieos".into(),
|
||
aieos_path: None,
|
||
aieos_inline: Some(r#"{"identity":{"names":{"first":"Claw"}}}"#.into()),
|
||
};
|
||
|
||
let prompt = build_system_prompt(
|
||
std::env::temp_dir().as_path(),
|
||
"model",
|
||
&[],
|
||
&[],
|
||
Some(&config),
|
||
None,
|
||
);
|
||
|
||
assert!(prompt.contains("**Name:** Claw"));
|
||
assert!(prompt.contains("## Identity"));
|
||
}
|
||
|
||
#[test]
|
||
fn aieos_fallback_to_openclaw_on_parse_error() {
|
||
use crate::config::IdentityConfig;
|
||
|
||
let config = IdentityConfig {
|
||
format: "aieos".into(),
|
||
aieos_path: Some("nonexistent.json".into()),
|
||
aieos_inline: None,
|
||
};
|
||
|
||
let ws = make_workspace();
|
||
let prompt = build_system_prompt(ws.path(), "model", &[], &[], Some(&config), None);
|
||
|
||
// Should fall back to OpenClaw format when AIEOS file is not found
|
||
// (Error is logged to stderr with filename, not included in prompt)
|
||
assert!(prompt.contains("### SOUL.md"));
|
||
}
|
||
|
||
#[test]
|
||
fn aieos_empty_uses_openclaw() {
|
||
use crate::config::IdentityConfig;
|
||
|
||
// Format is "aieos" but neither path nor inline is set
|
||
let config = IdentityConfig {
|
||
format: "aieos".into(),
|
||
aieos_path: None,
|
||
aieos_inline: None,
|
||
};
|
||
|
||
let ws = make_workspace();
|
||
let prompt = build_system_prompt(ws.path(), "model", &[], &[], Some(&config), None);
|
||
|
||
// Should use OpenClaw format (not configured for AIEOS)
|
||
assert!(prompt.contains("### SOUL.md"));
|
||
assert!(prompt.contains("Be helpful"));
|
||
}
|
||
|
||
#[test]
|
||
fn openclaw_format_uses_bootstrap_files() {
|
||
use crate::config::IdentityConfig;
|
||
|
||
let config = IdentityConfig {
|
||
format: "openclaw".into(),
|
||
aieos_path: Some("identity.json".into()),
|
||
aieos_inline: None,
|
||
};
|
||
|
||
let ws = make_workspace();
|
||
let prompt = build_system_prompt(ws.path(), "model", &[], &[], Some(&config), None);
|
||
|
||
// Should use OpenClaw format even if aieos_path is set
|
||
assert!(prompt.contains("### SOUL.md"));
|
||
assert!(prompt.contains("Be helpful"));
|
||
assert!(!prompt.contains("## Identity"));
|
||
}
|
||
|
||
#[test]
|
||
fn none_identity_config_uses_openclaw() {
|
||
let ws = make_workspace();
|
||
// Pass None for identity config
|
||
let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None);
|
||
|
||
// Should use OpenClaw format
|
||
assert!(prompt.contains("### SOUL.md"));
|
||
assert!(prompt.contains("Be helpful"));
|
||
}
|
||
|
||
#[test]
|
||
fn classify_health_ok_true() {
|
||
let state = classify_health_result(&Ok(true));
|
||
assert_eq!(state, ChannelHealthState::Healthy);
|
||
}
|
||
|
||
#[test]
|
||
fn classify_health_ok_false() {
|
||
let state = classify_health_result(&Ok(false));
|
||
assert_eq!(state, ChannelHealthState::Unhealthy);
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn classify_health_timeout() {
|
||
let result = tokio::time::timeout(Duration::from_millis(1), async {
|
||
tokio::time::sleep(Duration::from_millis(20)).await;
|
||
true
|
||
})
|
||
.await;
|
||
let state = classify_health_result(&result);
|
||
assert_eq!(state, ChannelHealthState::Timeout);
|
||
}
|
||
|
||
struct AlwaysFailChannel {
|
||
name: &'static str,
|
||
calls: Arc<AtomicUsize>,
|
||
}
|
||
|
||
#[async_trait::async_trait]
|
||
impl Channel for AlwaysFailChannel {
|
||
fn name(&self) -> &str {
|
||
self.name
|
||
}
|
||
|
||
async fn send(&self, _message: &SendMessage) -> anyhow::Result<()> {
|
||
Ok(())
|
||
}
|
||
|
||
async fn listen(
|
||
&self,
|
||
_tx: tokio::sync::mpsc::Sender<traits::ChannelMessage>,
|
||
) -> anyhow::Result<()> {
|
||
self.calls.fetch_add(1, Ordering::SeqCst);
|
||
anyhow::bail!("listen boom")
|
||
}
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn supervised_listener_marks_error_and_restarts_on_failures() {
|
||
let calls = Arc::new(AtomicUsize::new(0));
|
||
let channel: Arc<dyn Channel> = Arc::new(AlwaysFailChannel {
|
||
name: "test-supervised-fail",
|
||
calls: Arc::clone(&calls),
|
||
});
|
||
|
||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(1);
|
||
let handle = spawn_supervised_listener(channel, tx, 1, 1);
|
||
|
||
tokio::time::sleep(Duration::from_millis(80)).await;
|
||
drop(rx);
|
||
handle.abort();
|
||
let _ = handle.await;
|
||
|
||
let snapshot = crate::health::snapshot_json();
|
||
let component = &snapshot["components"]["channel:test-supervised-fail"];
|
||
assert_eq!(component["status"], "error");
|
||
assert!(component["restart_count"].as_u64().unwrap_or(0) >= 1);
|
||
assert!(component["last_error"]
|
||
.as_str()
|
||
.unwrap_or("")
|
||
.contains("listen boom"));
|
||
assert!(calls.load(Ordering::SeqCst) >= 1);
|
||
}
|
||
}
|