fix(channels): process inbound messages concurrently (#267)

Fixes #235
This commit is contained in:
Chummy 2026-02-16 14:58:01 +08:00 committed by GitHub
parent 3bdabdc7ec
commit c481f5298a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -26,6 +26,7 @@ use crate::memory::{self, Memory};
use crate::providers::{self, Provider};
use crate::util::truncate_with_ellipsis;
use anyhow::Result;
use std::collections::HashMap;
use std::fmt::Write;
use std::sync::Arc;
use std::time::{Duration, Instant};
@ -36,6 +37,20 @@ const BOOTSTRAP_MAX_CHARS: usize = 20_000;
const DEFAULT_CHANNEL_INITIAL_BACKOFF_SECS: u64 = 2;
const DEFAULT_CHANNEL_MAX_BACKOFF_SECS: u64 = 60;
const CHANNEL_MESSAGE_TIMEOUT_SECS: u64 = 90;
const CHANNEL_PARALLELISM_PER_CHANNEL: usize = 4;
const CHANNEL_MIN_IN_FLIGHT_MESSAGES: usize = 8;
const CHANNEL_MAX_IN_FLIGHT_MESSAGES: usize = 64;
#[derive(Clone)]
struct ChannelRuntimeContext {
channels_by_name: Arc<HashMap<String, Arc<dyn Channel>>>,
provider: Arc<dyn Provider>,
memory: Arc<dyn Memory>,
system_prompt: Arc<String>,
model: Arc<String>,
temperature: f64,
auto_save_memory: bool,
}
fn conversation_memory_key(msg: &traits::ChannelMessage) -> String {
format!("{}_{}_{}", msg.channel, msg.sender, msg.id)
@ -97,6 +112,151 @@ fn spawn_supervised_listener(
})
}
fn compute_max_in_flight_messages(channel_count: usize) -> usize {
channel_count
.saturating_mul(CHANNEL_PARALLELISM_PER_CHANNEL)
.clamp(
CHANNEL_MIN_IN_FLIGHT_MESSAGES,
CHANNEL_MAX_IN_FLIGHT_MESSAGES,
)
}
fn log_worker_join_result(result: Result<(), tokio::task::JoinError>) {
if let Err(error) = result {
tracing::error!("Channel message worker crashed: {error}");
}
}
async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::ChannelMessage) {
println!(
" 💬 [{}] from {}: {}",
msg.channel,
msg.sender,
truncate_with_ellipsis(&msg.content, 80)
);
let memory_context = build_memory_context(ctx.memory.as_ref(), &msg.content).await;
if ctx.auto_save_memory {
let autosave_key = conversation_memory_key(&msg);
let _ = ctx
.memory
.store(
&autosave_key,
&msg.content,
crate::memory::MemoryCategory::Conversation,
)
.await;
}
let enriched_message = if memory_context.is_empty() {
msg.content.clone()
} else {
format!("{memory_context}{}", msg.content)
};
let target_channel = ctx.channels_by_name.get(&msg.channel).cloned();
if let Some(channel) = target_channel.as_ref() {
if let Err(e) = channel.start_typing(&msg.sender).await {
tracing::debug!("Failed to start typing on {}: {e}", channel.name());
}
}
println!(" ⏳ Processing message...");
let started_at = Instant::now();
let llm_result = tokio::time::timeout(
Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS),
ctx.provider.chat_with_system(
Some(ctx.system_prompt.as_str()),
&enriched_message,
ctx.model.as_str(),
ctx.temperature,
),
)
.await;
if let Some(channel) = target_channel.as_ref() {
if let Err(e) = channel.stop_typing(&msg.sender).await {
tracing::debug!("Failed to stop typing on {}: {e}", channel.name());
}
}
match llm_result {
Ok(Ok(response)) => {
println!(
" 🤖 Reply ({}ms): {}",
started_at.elapsed().as_millis(),
truncate_with_ellipsis(&response, 80)
);
if let Some(channel) = target_channel.as_ref() {
if let Err(e) = channel.send(&response, &msg.sender).await {
eprintln!(" ❌ Failed to reply on {}: {e}", channel.name());
}
}
}
Ok(Err(e)) => {
eprintln!(
" ❌ LLM error after {}ms: {e}",
started_at.elapsed().as_millis()
);
if let Some(channel) = target_channel.as_ref() {
let _ = channel.send(&format!("⚠️ Error: {e}"), &msg.sender).await;
}
}
Err(_) => {
let timeout_msg = format!(
"LLM response timed out after {}s",
CHANNEL_MESSAGE_TIMEOUT_SECS
);
eprintln!(
" ❌ {} (elapsed: {}ms)",
timeout_msg,
started_at.elapsed().as_millis()
);
if let Some(channel) = target_channel.as_ref() {
let _ = channel
.send(
"⚠️ Request timed out while waiting for the model. Please try again.",
&msg.sender,
)
.await;
}
}
}
}
async fn run_message_dispatch_loop(
mut rx: tokio::sync::mpsc::Receiver<traits::ChannelMessage>,
ctx: Arc<ChannelRuntimeContext>,
max_in_flight_messages: usize,
) {
let semaphore = Arc::new(tokio::sync::Semaphore::new(max_in_flight_messages));
let mut workers = tokio::task::JoinSet::new();
while let Some(msg) = rx.recv().await {
let permit = match Arc::clone(&semaphore).acquire_owned().await {
Ok(permit) => permit,
Err(_) => break,
};
let worker_ctx = Arc::clone(&ctx);
workers.spawn(async move {
let _permit = permit;
process_channel_message(worker_ctx, msg).await;
});
while let Some(result) = workers.try_join_next() {
log_worker_join_result(result);
}
}
while let Some(result) = workers.join_next().await {
log_worker_join_result(result);
}
}
/// Load OpenClaw format bootstrap files into the prompt.
fn load_openclaw_bootstrap_files(prompt: &mut String, workspace_dir: &std::path::Path) {
prompt
@ -680,7 +840,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
.max(DEFAULT_CHANNEL_MAX_BACKOFF_SECS);
// Single message bus — all channels send messages here
let (tx, mut rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(100);
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(100);
// Spawn a listener for each channel
let mut handles = Vec::new();
@ -694,104 +854,27 @@ pub async fn start_channels(config: Config) -> Result<()> {
}
drop(tx); // Drop our copy so rx closes when all channels stop
// Process incoming messages — call the LLM and reply
while let Some(msg) = rx.recv().await {
println!(
" 💬 [{}] from {}: {}",
msg.channel,
msg.sender,
truncate_with_ellipsis(&msg.content, 80)
let channels_by_name = Arc::new(
channels
.iter()
.map(|ch| (ch.name().to_string(), Arc::clone(ch)))
.collect::<HashMap<_, _>>(),
);
let max_in_flight_messages = compute_max_in_flight_messages(channels.len());
let memory_context = build_memory_context(mem.as_ref(), &msg.content).await;
println!(" 🚦 In-flight message limit: {max_in_flight_messages}");
// Auto-save to memory
if config.memory.auto_save {
let autosave_key = conversation_memory_key(&msg);
let _ = mem
.store(
&autosave_key,
&msg.content,
crate::memory::MemoryCategory::Conversation,
)
.await;
}
let runtime_ctx = Arc::new(ChannelRuntimeContext {
channels_by_name,
provider: Arc::clone(&provider),
memory: Arc::clone(&mem),
system_prompt: Arc::new(system_prompt),
model: Arc::new(model.clone()),
temperature,
auto_save_memory: config.memory.auto_save,
});
let enriched_message = if memory_context.is_empty() {
msg.content.clone()
} else {
format!("{memory_context}{}", msg.content)
};
let target_channel = channels.iter().find(|ch| ch.name() == msg.channel);
// Show typing indicator while processing
if let Some(ch) = target_channel {
if let Err(e) = ch.start_typing(&msg.sender).await {
tracing::debug!("Failed to start typing on {}: {e}", ch.name());
}
}
// Call the LLM with system prompt (identity + soul + tools)
println!(" ⏳ Processing message...");
let started_at = Instant::now();
let llm_result = tokio::time::timeout(
Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS),
provider.chat_with_system(Some(&system_prompt), &enriched_message, &model, temperature),
)
.await;
// Stop typing before sending the response
if let Some(ch) = target_channel {
if let Err(e) = ch.stop_typing(&msg.sender).await {
tracing::debug!("Failed to stop typing on {}: {e}", ch.name());
}
}
match llm_result {
Ok(Ok(response)) => {
println!(
" 🤖 Reply ({}ms): {}",
started_at.elapsed().as_millis(),
truncate_with_ellipsis(&response, 80)
);
if let Some(ch) = target_channel {
if let Err(e) = ch.send(&response, &msg.sender).await {
eprintln!(" ❌ Failed to reply on {}: {e}", ch.name());
}
}
}
Ok(Err(e)) => {
eprintln!(
" ❌ LLM error after {}ms: {e}",
started_at.elapsed().as_millis()
);
if let Some(ch) = target_channel {
let _ = ch.send(&format!("⚠️ Error: {e}"), &msg.sender).await;
}
}
Err(_) => {
let timeout_msg = format!(
"LLM response timed out after {}s",
CHANNEL_MESSAGE_TIMEOUT_SECS
);
eprintln!(
" ❌ {} (elapsed: {}ms)",
timeout_msg,
started_at.elapsed().as_millis()
);
if let Some(ch) = target_channel {
let _ = ch
.send(
"⚠️ Request timed out while waiting for the model. Please try again.",
&msg.sender,
)
.await;
}
}
}
}
run_message_dispatch_loop(rx, runtime_ctx, max_in_flight_messages).await;
// Wait for all channel tasks
for h in handles {
@ -805,6 +888,8 @@ pub async fn start_channels(config: Config) -> Result<()> {
mod tests {
use super::*;
use crate::memory::{Memory, MemoryCategory, SqliteMemory};
use crate::providers::Provider;
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tempfile::TempDir;
@ -830,6 +915,155 @@ mod tests {
tmp
}
#[derive(Default)]
struct RecordingChannel {
sent_messages: tokio::sync::Mutex<Vec<String>>,
}
#[async_trait::async_trait]
impl Channel for RecordingChannel {
fn name(&self) -> &str {
"test-channel"
}
async fn send(&self, message: &str, recipient: &str) -> anyhow::Result<()> {
self.sent_messages
.lock()
.await
.push(format!("{recipient}:{message}"));
Ok(())
}
async fn listen(
&self,
_tx: tokio::sync::mpsc::Sender<traits::ChannelMessage>,
) -> anyhow::Result<()> {
Ok(())
}
}
struct SlowProvider {
delay: Duration,
}
#[async_trait::async_trait]
impl Provider for SlowProvider {
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
tokio::time::sleep(self.delay).await;
Ok(format!("echo: {message}"))
}
}
struct NoopMemory;
#[async_trait::async_trait]
impl Memory for NoopMemory {
fn name(&self) -> &str {
"noop"
}
async fn store(
&self,
_key: &str,
_content: &str,
_category: crate::memory::MemoryCategory,
) -> anyhow::Result<()> {
Ok(())
}
async fn recall(
&self,
_query: &str,
_limit: usize,
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
Ok(Vec::new())
}
async fn get(&self, _key: &str) -> anyhow::Result<Option<crate::memory::MemoryEntry>> {
Ok(None)
}
async fn list(
&self,
_category: Option<&crate::memory::MemoryCategory>,
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
Ok(Vec::new())
}
async fn forget(&self, _key: &str) -> anyhow::Result<bool> {
Ok(false)
}
async fn count(&self) -> anyhow::Result<usize> {
Ok(0)
}
async fn health_check(&self) -> bool {
true
}
}
#[tokio::test]
async fn message_dispatch_processes_messages_in_parallel() {
let channel_impl = Arc::new(RecordingChannel::default());
let channel: Arc<dyn Channel> = channel_impl.clone();
let mut channels_by_name = HashMap::new();
channels_by_name.insert(channel.name().to_string(), channel);
let runtime_ctx = Arc::new(ChannelRuntimeContext {
channels_by_name: Arc::new(channels_by_name),
provider: Arc::new(SlowProvider {
delay: Duration::from_millis(250),
}),
memory: Arc::new(NoopMemory),
system_prompt: Arc::new("test-system-prompt".to_string()),
model: Arc::new("test-model".to_string()),
temperature: 0.0,
auto_save_memory: false,
});
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(4);
tx.send(traits::ChannelMessage {
id: "1".to_string(),
sender: "alice".to_string(),
content: "hello".to_string(),
channel: "test-channel".to_string(),
timestamp: 1,
})
.await
.unwrap();
tx.send(traits::ChannelMessage {
id: "2".to_string(),
sender: "bob".to_string(),
content: "world".to_string(),
channel: "test-channel".to_string(),
timestamp: 2,
})
.await
.unwrap();
drop(tx);
let started = Instant::now();
run_message_dispatch_loop(rx, runtime_ctx, 2).await;
let elapsed = started.elapsed();
assert!(
elapsed < Duration::from_millis(430),
"expected parallel dispatch (<430ms), got {:?}",
elapsed
);
let sent_messages = channel_impl.sent_messages.lock().await;
assert_eq!(sent_messages.len(), 2);
}
#[test]
fn prompt_contains_all_sections() {
let ws = make_workspace();