From c481f5298a0e3b032be7826cb07afb9631ecfa69 Mon Sep 17 00:00:00 2001 From: Chummy Date: Mon, 16 Feb 2026 14:58:01 +0800 Subject: [PATCH] fix(channels): process inbound messages concurrently (#267) Fixes #235 --- src/channels/mod.rs | 426 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 330 insertions(+), 96 deletions(-) diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 92b5526..a828f53 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -26,6 +26,7 @@ use crate::memory::{self, Memory}; use crate::providers::{self, Provider}; use crate::util::truncate_with_ellipsis; use anyhow::Result; +use std::collections::HashMap; use std::fmt::Write; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -36,6 +37,20 @@ const BOOTSTRAP_MAX_CHARS: usize = 20_000; const DEFAULT_CHANNEL_INITIAL_BACKOFF_SECS: u64 = 2; const DEFAULT_CHANNEL_MAX_BACKOFF_SECS: u64 = 60; const CHANNEL_MESSAGE_TIMEOUT_SECS: u64 = 90; +const CHANNEL_PARALLELISM_PER_CHANNEL: usize = 4; +const CHANNEL_MIN_IN_FLIGHT_MESSAGES: usize = 8; +const CHANNEL_MAX_IN_FLIGHT_MESSAGES: usize = 64; + +#[derive(Clone)] +struct ChannelRuntimeContext { + channels_by_name: Arc>>, + provider: Arc, + memory: Arc, + system_prompt: Arc, + model: Arc, + temperature: f64, + auto_save_memory: bool, +} fn conversation_memory_key(msg: &traits::ChannelMessage) -> String { format!("{}_{}_{}", msg.channel, msg.sender, msg.id) @@ -97,6 +112,151 @@ fn spawn_supervised_listener( }) } +fn compute_max_in_flight_messages(channel_count: usize) -> usize { + channel_count + .saturating_mul(CHANNEL_PARALLELISM_PER_CHANNEL) + .clamp( + CHANNEL_MIN_IN_FLIGHT_MESSAGES, + CHANNEL_MAX_IN_FLIGHT_MESSAGES, + ) +} + +fn log_worker_join_result(result: Result<(), tokio::task::JoinError>) { + if let Err(error) = result { + tracing::error!("Channel message worker crashed: {error}"); + } +} + +async fn process_channel_message(ctx: Arc, msg: traits::ChannelMessage) { + println!( + " 💬 [{}] from {}: {}", + msg.channel, + msg.sender, + truncate_with_ellipsis(&msg.content, 80) + ); + + let memory_context = build_memory_context(ctx.memory.as_ref(), &msg.content).await; + + if ctx.auto_save_memory { + let autosave_key = conversation_memory_key(&msg); + let _ = ctx + .memory + .store( + &autosave_key, + &msg.content, + crate::memory::MemoryCategory::Conversation, + ) + .await; + } + + let enriched_message = if memory_context.is_empty() { + msg.content.clone() + } else { + format!("{memory_context}{}", msg.content) + }; + + let target_channel = ctx.channels_by_name.get(&msg.channel).cloned(); + + if let Some(channel) = target_channel.as_ref() { + if let Err(e) = channel.start_typing(&msg.sender).await { + tracing::debug!("Failed to start typing on {}: {e}", channel.name()); + } + } + + println!(" ⏳ Processing message..."); + let started_at = Instant::now(); + + let llm_result = tokio::time::timeout( + Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS), + ctx.provider.chat_with_system( + Some(ctx.system_prompt.as_str()), + &enriched_message, + ctx.model.as_str(), + ctx.temperature, + ), + ) + .await; + + if let Some(channel) = target_channel.as_ref() { + if let Err(e) = channel.stop_typing(&msg.sender).await { + tracing::debug!("Failed to stop typing on {}: {e}", channel.name()); + } + } + + match llm_result { + Ok(Ok(response)) => { + println!( + " 🤖 Reply ({}ms): {}", + started_at.elapsed().as_millis(), + truncate_with_ellipsis(&response, 80) + ); + if let Some(channel) = target_channel.as_ref() { + if let Err(e) = channel.send(&response, &msg.sender).await { + eprintln!(" ❌ Failed to reply on {}: {e}", channel.name()); + } + } + } + Ok(Err(e)) => { + eprintln!( + " ❌ LLM error after {}ms: {e}", + started_at.elapsed().as_millis() + ); + if let Some(channel) = target_channel.as_ref() { + let _ = channel.send(&format!("⚠️ Error: {e}"), &msg.sender).await; + } + } + Err(_) => { + let timeout_msg = format!( + "LLM response timed out after {}s", + CHANNEL_MESSAGE_TIMEOUT_SECS + ); + eprintln!( + " ❌ {} (elapsed: {}ms)", + timeout_msg, + started_at.elapsed().as_millis() + ); + if let Some(channel) = target_channel.as_ref() { + let _ = channel + .send( + "⚠️ Request timed out while waiting for the model. Please try again.", + &msg.sender, + ) + .await; + } + } + } +} + +async fn run_message_dispatch_loop( + mut rx: tokio::sync::mpsc::Receiver, + ctx: Arc, + max_in_flight_messages: usize, +) { + let semaphore = Arc::new(tokio::sync::Semaphore::new(max_in_flight_messages)); + let mut workers = tokio::task::JoinSet::new(); + + while let Some(msg) = rx.recv().await { + let permit = match Arc::clone(&semaphore).acquire_owned().await { + Ok(permit) => permit, + Err(_) => break, + }; + + let worker_ctx = Arc::clone(&ctx); + workers.spawn(async move { + let _permit = permit; + process_channel_message(worker_ctx, msg).await; + }); + + while let Some(result) = workers.try_join_next() { + log_worker_join_result(result); + } + } + + while let Some(result) = workers.join_next().await { + log_worker_join_result(result); + } +} + /// Load OpenClaw format bootstrap files into the prompt. fn load_openclaw_bootstrap_files(prompt: &mut String, workspace_dir: &std::path::Path) { prompt @@ -680,7 +840,7 @@ pub async fn start_channels(config: Config) -> Result<()> { .max(DEFAULT_CHANNEL_MAX_BACKOFF_SECS); // Single message bus — all channels send messages here - let (tx, mut rx) = tokio::sync::mpsc::channel::(100); + let (tx, rx) = tokio::sync::mpsc::channel::(100); // Spawn a listener for each channel let mut handles = Vec::new(); @@ -694,104 +854,27 @@ pub async fn start_channels(config: Config) -> Result<()> { } drop(tx); // Drop our copy so rx closes when all channels stop - // Process incoming messages — call the LLM and reply - while let Some(msg) = rx.recv().await { - println!( - " 💬 [{}] from {}: {}", - msg.channel, - msg.sender, - truncate_with_ellipsis(&msg.content, 80) - ); + let channels_by_name = Arc::new( + channels + .iter() + .map(|ch| (ch.name().to_string(), Arc::clone(ch))) + .collect::>(), + ); + let max_in_flight_messages = compute_max_in_flight_messages(channels.len()); - let memory_context = build_memory_context(mem.as_ref(), &msg.content).await; + println!(" 🚦 In-flight message limit: {max_in_flight_messages}"); - // Auto-save to memory - if config.memory.auto_save { - let autosave_key = conversation_memory_key(&msg); - let _ = mem - .store( - &autosave_key, - &msg.content, - crate::memory::MemoryCategory::Conversation, - ) - .await; - } + let runtime_ctx = Arc::new(ChannelRuntimeContext { + channels_by_name, + provider: Arc::clone(&provider), + memory: Arc::clone(&mem), + system_prompt: Arc::new(system_prompt), + model: Arc::new(model.clone()), + temperature, + auto_save_memory: config.memory.auto_save, + }); - let enriched_message = if memory_context.is_empty() { - msg.content.clone() - } else { - format!("{memory_context}{}", msg.content) - }; - - let target_channel = channels.iter().find(|ch| ch.name() == msg.channel); - - // Show typing indicator while processing - if let Some(ch) = target_channel { - if let Err(e) = ch.start_typing(&msg.sender).await { - tracing::debug!("Failed to start typing on {}: {e}", ch.name()); - } - } - - // Call the LLM with system prompt (identity + soul + tools) - println!(" ⏳ Processing message..."); - let started_at = Instant::now(); - - let llm_result = tokio::time::timeout( - Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS), - provider.chat_with_system(Some(&system_prompt), &enriched_message, &model, temperature), - ) - .await; - - // Stop typing before sending the response - if let Some(ch) = target_channel { - if let Err(e) = ch.stop_typing(&msg.sender).await { - tracing::debug!("Failed to stop typing on {}: {e}", ch.name()); - } - } - - match llm_result { - Ok(Ok(response)) => { - println!( - " 🤖 Reply ({}ms): {}", - started_at.elapsed().as_millis(), - truncate_with_ellipsis(&response, 80) - ); - if let Some(ch) = target_channel { - if let Err(e) = ch.send(&response, &msg.sender).await { - eprintln!(" ❌ Failed to reply on {}: {e}", ch.name()); - } - } - } - Ok(Err(e)) => { - eprintln!( - " ❌ LLM error after {}ms: {e}", - started_at.elapsed().as_millis() - ); - if let Some(ch) = target_channel { - let _ = ch.send(&format!("⚠️ Error: {e}"), &msg.sender).await; - } - } - Err(_) => { - let timeout_msg = format!( - "LLM response timed out after {}s", - CHANNEL_MESSAGE_TIMEOUT_SECS - ); - eprintln!( - " ❌ {} (elapsed: {}ms)", - timeout_msg, - started_at.elapsed().as_millis() - ); - if let Some(ch) = target_channel { - let _ = ch - .send( - "⚠️ Request timed out while waiting for the model. Please try again.", - &msg.sender, - ) - .await; - } - } - } - } + run_message_dispatch_loop(rx, runtime_ctx, max_in_flight_messages).await; // Wait for all channel tasks for h in handles { @@ -805,6 +888,8 @@ pub async fn start_channels(config: Config) -> Result<()> { mod tests { use super::*; use crate::memory::{Memory, MemoryCategory, SqliteMemory}; + use crate::providers::Provider; + use std::collections::HashMap; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use tempfile::TempDir; @@ -830,6 +915,155 @@ mod tests { tmp } + #[derive(Default)] + struct RecordingChannel { + sent_messages: tokio::sync::Mutex>, + } + + #[async_trait::async_trait] + impl Channel for RecordingChannel { + fn name(&self) -> &str { + "test-channel" + } + + async fn send(&self, message: &str, recipient: &str) -> anyhow::Result<()> { + self.sent_messages + .lock() + .await + .push(format!("{recipient}:{message}")); + Ok(()) + } + + async fn listen( + &self, + _tx: tokio::sync::mpsc::Sender, + ) -> anyhow::Result<()> { + Ok(()) + } + } + + struct SlowProvider { + delay: Duration, + } + + #[async_trait::async_trait] + impl Provider for SlowProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + tokio::time::sleep(self.delay).await; + Ok(format!("echo: {message}")) + } + } + + struct NoopMemory; + + #[async_trait::async_trait] + impl Memory for NoopMemory { + fn name(&self) -> &str { + "noop" + } + + async fn store( + &self, + _key: &str, + _content: &str, + _category: crate::memory::MemoryCategory, + ) -> anyhow::Result<()> { + Ok(()) + } + + async fn recall( + &self, + _query: &str, + _limit: usize, + ) -> anyhow::Result> { + Ok(Vec::new()) + } + + async fn get(&self, _key: &str) -> anyhow::Result> { + Ok(None) + } + + async fn list( + &self, + _category: Option<&crate::memory::MemoryCategory>, + ) -> anyhow::Result> { + Ok(Vec::new()) + } + + async fn forget(&self, _key: &str) -> anyhow::Result { + Ok(false) + } + + async fn count(&self) -> anyhow::Result { + Ok(0) + } + + async fn health_check(&self) -> bool { + true + } + } + + #[tokio::test] + async fn message_dispatch_processes_messages_in_parallel() { + let channel_impl = Arc::new(RecordingChannel::default()); + let channel: Arc = channel_impl.clone(); + + let mut channels_by_name = HashMap::new(); + channels_by_name.insert(channel.name().to_string(), channel); + + let runtime_ctx = Arc::new(ChannelRuntimeContext { + channels_by_name: Arc::new(channels_by_name), + provider: Arc::new(SlowProvider { + delay: Duration::from_millis(250), + }), + memory: Arc::new(NoopMemory), + system_prompt: Arc::new("test-system-prompt".to_string()), + model: Arc::new("test-model".to_string()), + temperature: 0.0, + auto_save_memory: false, + }); + + let (tx, rx) = tokio::sync::mpsc::channel::(4); + tx.send(traits::ChannelMessage { + id: "1".to_string(), + sender: "alice".to_string(), + content: "hello".to_string(), + channel: "test-channel".to_string(), + timestamp: 1, + }) + .await + .unwrap(); + tx.send(traits::ChannelMessage { + id: "2".to_string(), + sender: "bob".to_string(), + content: "world".to_string(), + channel: "test-channel".to_string(), + timestamp: 2, + }) + .await + .unwrap(); + drop(tx); + + let started = Instant::now(); + run_message_dispatch_loop(rx, runtime_ctx, 2).await; + let elapsed = started.elapsed(); + + assert!( + elapsed < Duration::from_millis(430), + "expected parallel dispatch (<430ms), got {:?}", + elapsed + ); + + let sent_messages = channel_impl.sent_messages.lock().await; + assert_eq!(sent_messages.len(), 2); + } + #[test] fn prompt_contains_all_sections() { let ws = make_workspace();