parent
3bdabdc7ec
commit
c481f5298a
1 changed files with 330 additions and 96 deletions
|
|
@ -26,6 +26,7 @@ use crate::memory::{self, Memory};
|
|||
use crate::providers::{self, Provider};
|
||||
use crate::util::truncate_with_ellipsis;
|
||||
use anyhow::Result;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Write;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
|
@ -36,6 +37,20 @@ const BOOTSTRAP_MAX_CHARS: usize = 20_000;
|
|||
const DEFAULT_CHANNEL_INITIAL_BACKOFF_SECS: u64 = 2;
|
||||
const DEFAULT_CHANNEL_MAX_BACKOFF_SECS: u64 = 60;
|
||||
const CHANNEL_MESSAGE_TIMEOUT_SECS: u64 = 90;
|
||||
const CHANNEL_PARALLELISM_PER_CHANNEL: usize = 4;
|
||||
const CHANNEL_MIN_IN_FLIGHT_MESSAGES: usize = 8;
|
||||
const CHANNEL_MAX_IN_FLIGHT_MESSAGES: usize = 64;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ChannelRuntimeContext {
|
||||
channels_by_name: Arc<HashMap<String, Arc<dyn Channel>>>,
|
||||
provider: Arc<dyn Provider>,
|
||||
memory: Arc<dyn Memory>,
|
||||
system_prompt: Arc<String>,
|
||||
model: Arc<String>,
|
||||
temperature: f64,
|
||||
auto_save_memory: bool,
|
||||
}
|
||||
|
||||
fn conversation_memory_key(msg: &traits::ChannelMessage) -> String {
|
||||
format!("{}_{}_{}", msg.channel, msg.sender, msg.id)
|
||||
|
|
@ -97,6 +112,151 @@ fn spawn_supervised_listener(
|
|||
})
|
||||
}
|
||||
|
||||
fn compute_max_in_flight_messages(channel_count: usize) -> usize {
|
||||
channel_count
|
||||
.saturating_mul(CHANNEL_PARALLELISM_PER_CHANNEL)
|
||||
.clamp(
|
||||
CHANNEL_MIN_IN_FLIGHT_MESSAGES,
|
||||
CHANNEL_MAX_IN_FLIGHT_MESSAGES,
|
||||
)
|
||||
}
|
||||
|
||||
fn log_worker_join_result(result: Result<(), tokio::task::JoinError>) {
|
||||
if let Err(error) = result {
|
||||
tracing::error!("Channel message worker crashed: {error}");
|
||||
}
|
||||
}
|
||||
|
||||
async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::ChannelMessage) {
|
||||
println!(
|
||||
" 💬 [{}] from {}: {}",
|
||||
msg.channel,
|
||||
msg.sender,
|
||||
truncate_with_ellipsis(&msg.content, 80)
|
||||
);
|
||||
|
||||
let memory_context = build_memory_context(ctx.memory.as_ref(), &msg.content).await;
|
||||
|
||||
if ctx.auto_save_memory {
|
||||
let autosave_key = conversation_memory_key(&msg);
|
||||
let _ = ctx
|
||||
.memory
|
||||
.store(
|
||||
&autosave_key,
|
||||
&msg.content,
|
||||
crate::memory::MemoryCategory::Conversation,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
let enriched_message = if memory_context.is_empty() {
|
||||
msg.content.clone()
|
||||
} else {
|
||||
format!("{memory_context}{}", msg.content)
|
||||
};
|
||||
|
||||
let target_channel = ctx.channels_by_name.get(&msg.channel).cloned();
|
||||
|
||||
if let Some(channel) = target_channel.as_ref() {
|
||||
if let Err(e) = channel.start_typing(&msg.sender).await {
|
||||
tracing::debug!("Failed to start typing on {}: {e}", channel.name());
|
||||
}
|
||||
}
|
||||
|
||||
println!(" ⏳ Processing message...");
|
||||
let started_at = Instant::now();
|
||||
|
||||
let llm_result = tokio::time::timeout(
|
||||
Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS),
|
||||
ctx.provider.chat_with_system(
|
||||
Some(ctx.system_prompt.as_str()),
|
||||
&enriched_message,
|
||||
ctx.model.as_str(),
|
||||
ctx.temperature,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Some(channel) = target_channel.as_ref() {
|
||||
if let Err(e) = channel.stop_typing(&msg.sender).await {
|
||||
tracing::debug!("Failed to stop typing on {}: {e}", channel.name());
|
||||
}
|
||||
}
|
||||
|
||||
match llm_result {
|
||||
Ok(Ok(response)) => {
|
||||
println!(
|
||||
" 🤖 Reply ({}ms): {}",
|
||||
started_at.elapsed().as_millis(),
|
||||
truncate_with_ellipsis(&response, 80)
|
||||
);
|
||||
if let Some(channel) = target_channel.as_ref() {
|
||||
if let Err(e) = channel.send(&response, &msg.sender).await {
|
||||
eprintln!(" ❌ Failed to reply on {}: {e}", channel.name());
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
eprintln!(
|
||||
" ❌ LLM error after {}ms: {e}",
|
||||
started_at.elapsed().as_millis()
|
||||
);
|
||||
if let Some(channel) = target_channel.as_ref() {
|
||||
let _ = channel.send(&format!("⚠️ Error: {e}"), &msg.sender).await;
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
let timeout_msg = format!(
|
||||
"LLM response timed out after {}s",
|
||||
CHANNEL_MESSAGE_TIMEOUT_SECS
|
||||
);
|
||||
eprintln!(
|
||||
" ❌ {} (elapsed: {}ms)",
|
||||
timeout_msg,
|
||||
started_at.elapsed().as_millis()
|
||||
);
|
||||
if let Some(channel) = target_channel.as_ref() {
|
||||
let _ = channel
|
||||
.send(
|
||||
"⚠️ Request timed out while waiting for the model. Please try again.",
|
||||
&msg.sender,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_message_dispatch_loop(
|
||||
mut rx: tokio::sync::mpsc::Receiver<traits::ChannelMessage>,
|
||||
ctx: Arc<ChannelRuntimeContext>,
|
||||
max_in_flight_messages: usize,
|
||||
) {
|
||||
let semaphore = Arc::new(tokio::sync::Semaphore::new(max_in_flight_messages));
|
||||
let mut workers = tokio::task::JoinSet::new();
|
||||
|
||||
while let Some(msg) = rx.recv().await {
|
||||
let permit = match Arc::clone(&semaphore).acquire_owned().await {
|
||||
Ok(permit) => permit,
|
||||
Err(_) => break,
|
||||
};
|
||||
|
||||
let worker_ctx = Arc::clone(&ctx);
|
||||
workers.spawn(async move {
|
||||
let _permit = permit;
|
||||
process_channel_message(worker_ctx, msg).await;
|
||||
});
|
||||
|
||||
while let Some(result) = workers.try_join_next() {
|
||||
log_worker_join_result(result);
|
||||
}
|
||||
}
|
||||
|
||||
while let Some(result) = workers.join_next().await {
|
||||
log_worker_join_result(result);
|
||||
}
|
||||
}
|
||||
|
||||
/// Load OpenClaw format bootstrap files into the prompt.
|
||||
fn load_openclaw_bootstrap_files(prompt: &mut String, workspace_dir: &std::path::Path) {
|
||||
prompt
|
||||
|
|
@ -680,7 +840,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
.max(DEFAULT_CHANNEL_MAX_BACKOFF_SECS);
|
||||
|
||||
// Single message bus — all channels send messages here
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(100);
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(100);
|
||||
|
||||
// Spawn a listener for each channel
|
||||
let mut handles = Vec::new();
|
||||
|
|
@ -694,104 +854,27 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
}
|
||||
drop(tx); // Drop our copy so rx closes when all channels stop
|
||||
|
||||
// Process incoming messages — call the LLM and reply
|
||||
while let Some(msg) = rx.recv().await {
|
||||
println!(
|
||||
" 💬 [{}] from {}: {}",
|
||||
msg.channel,
|
||||
msg.sender,
|
||||
truncate_with_ellipsis(&msg.content, 80)
|
||||
let channels_by_name = Arc::new(
|
||||
channels
|
||||
.iter()
|
||||
.map(|ch| (ch.name().to_string(), Arc::clone(ch)))
|
||||
.collect::<HashMap<_, _>>(),
|
||||
);
|
||||
let max_in_flight_messages = compute_max_in_flight_messages(channels.len());
|
||||
|
||||
let memory_context = build_memory_context(mem.as_ref(), &msg.content).await;
|
||||
println!(" 🚦 In-flight message limit: {max_in_flight_messages}");
|
||||
|
||||
// Auto-save to memory
|
||||
if config.memory.auto_save {
|
||||
let autosave_key = conversation_memory_key(&msg);
|
||||
let _ = mem
|
||||
.store(
|
||||
&autosave_key,
|
||||
&msg.content,
|
||||
crate::memory::MemoryCategory::Conversation,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name,
|
||||
provider: Arc::clone(&provider),
|
||||
memory: Arc::clone(&mem),
|
||||
system_prompt: Arc::new(system_prompt),
|
||||
model: Arc::new(model.clone()),
|
||||
temperature,
|
||||
auto_save_memory: config.memory.auto_save,
|
||||
});
|
||||
|
||||
let enriched_message = if memory_context.is_empty() {
|
||||
msg.content.clone()
|
||||
} else {
|
||||
format!("{memory_context}{}", msg.content)
|
||||
};
|
||||
|
||||
let target_channel = channels.iter().find(|ch| ch.name() == msg.channel);
|
||||
|
||||
// Show typing indicator while processing
|
||||
if let Some(ch) = target_channel {
|
||||
if let Err(e) = ch.start_typing(&msg.sender).await {
|
||||
tracing::debug!("Failed to start typing on {}: {e}", ch.name());
|
||||
}
|
||||
}
|
||||
|
||||
// Call the LLM with system prompt (identity + soul + tools)
|
||||
println!(" ⏳ Processing message...");
|
||||
let started_at = Instant::now();
|
||||
|
||||
let llm_result = tokio::time::timeout(
|
||||
Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS),
|
||||
provider.chat_with_system(Some(&system_prompt), &enriched_message, &model, temperature),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Stop typing before sending the response
|
||||
if let Some(ch) = target_channel {
|
||||
if let Err(e) = ch.stop_typing(&msg.sender).await {
|
||||
tracing::debug!("Failed to stop typing on {}: {e}", ch.name());
|
||||
}
|
||||
}
|
||||
|
||||
match llm_result {
|
||||
Ok(Ok(response)) => {
|
||||
println!(
|
||||
" 🤖 Reply ({}ms): {}",
|
||||
started_at.elapsed().as_millis(),
|
||||
truncate_with_ellipsis(&response, 80)
|
||||
);
|
||||
if let Some(ch) = target_channel {
|
||||
if let Err(e) = ch.send(&response, &msg.sender).await {
|
||||
eprintln!(" ❌ Failed to reply on {}: {e}", ch.name());
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
eprintln!(
|
||||
" ❌ LLM error after {}ms: {e}",
|
||||
started_at.elapsed().as_millis()
|
||||
);
|
||||
if let Some(ch) = target_channel {
|
||||
let _ = ch.send(&format!("⚠️ Error: {e}"), &msg.sender).await;
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
let timeout_msg = format!(
|
||||
"LLM response timed out after {}s",
|
||||
CHANNEL_MESSAGE_TIMEOUT_SECS
|
||||
);
|
||||
eprintln!(
|
||||
" ❌ {} (elapsed: {}ms)",
|
||||
timeout_msg,
|
||||
started_at.elapsed().as_millis()
|
||||
);
|
||||
if let Some(ch) = target_channel {
|
||||
let _ = ch
|
||||
.send(
|
||||
"⚠️ Request timed out while waiting for the model. Please try again.",
|
||||
&msg.sender,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
run_message_dispatch_loop(rx, runtime_ctx, max_in_flight_messages).await;
|
||||
|
||||
// Wait for all channel tasks
|
||||
for h in handles {
|
||||
|
|
@ -805,6 +888,8 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::{Memory, MemoryCategory, SqliteMemory};
|
||||
use crate::providers::Provider;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tempfile::TempDir;
|
||||
|
|
@ -830,6 +915,155 @@ mod tests {
|
|||
tmp
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct RecordingChannel {
|
||||
sent_messages: tokio::sync::Mutex<Vec<String>>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Channel for RecordingChannel {
|
||||
fn name(&self) -> &str {
|
||||
"test-channel"
|
||||
}
|
||||
|
||||
async fn send(&self, message: &str, recipient: &str) -> anyhow::Result<()> {
|
||||
self.sent_messages
|
||||
.lock()
|
||||
.await
|
||||
.push(format!("{recipient}:{message}"));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn listen(
|
||||
&self,
|
||||
_tx: tokio::sync::mpsc::Sender<traits::ChannelMessage>,
|
||||
) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
struct SlowProvider {
|
||||
delay: Duration,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Provider for SlowProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
_system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
tokio::time::sleep(self.delay).await;
|
||||
Ok(format!("echo: {message}"))
|
||||
}
|
||||
}
|
||||
|
||||
struct NoopMemory;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Memory for NoopMemory {
|
||||
fn name(&self) -> &str {
|
||||
"noop"
|
||||
}
|
||||
|
||||
async fn store(
|
||||
&self,
|
||||
_key: &str,
|
||||
_content: &str,
|
||||
_category: crate::memory::MemoryCategory,
|
||||
) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recall(
|
||||
&self,
|
||||
_query: &str,
|
||||
_limit: usize,
|
||||
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
async fn get(&self, _key: &str) -> anyhow::Result<Option<crate::memory::MemoryEntry>> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn list(
|
||||
&self,
|
||||
_category: Option<&crate::memory::MemoryCategory>,
|
||||
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
async fn forget(&self, _key: &str) -> anyhow::Result<bool> {
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn count(&self) -> anyhow::Result<usize> {
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn message_dispatch_processes_messages_in_parallel() {
|
||||
let channel_impl = Arc::new(RecordingChannel::default());
|
||||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: Arc::new(SlowProvider {
|
||||
delay: Duration::from_millis(250),
|
||||
}),
|
||||
memory: Arc::new(NoopMemory),
|
||||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||
model: Arc::new("test-model".to_string()),
|
||||
temperature: 0.0,
|
||||
auto_save_memory: false,
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(4);
|
||||
tx.send(traits::ChannelMessage {
|
||||
id: "1".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
content: "hello".to_string(),
|
||||
channel: "test-channel".to_string(),
|
||||
timestamp: 1,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
tx.send(traits::ChannelMessage {
|
||||
id: "2".to_string(),
|
||||
sender: "bob".to_string(),
|
||||
content: "world".to_string(),
|
||||
channel: "test-channel".to_string(),
|
||||
timestamp: 2,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
drop(tx);
|
||||
|
||||
let started = Instant::now();
|
||||
run_message_dispatch_loop(rx, runtime_ctx, 2).await;
|
||||
let elapsed = started.elapsed();
|
||||
|
||||
assert!(
|
||||
elapsed < Duration::from_millis(430),
|
||||
"expected parallel dispatch (<430ms), got {:?}",
|
||||
elapsed
|
||||
);
|
||||
|
||||
let sent_messages = channel_impl.sent_messages.lock().await;
|
||||
assert_eq!(sent_messages.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompt_contains_all_sections() {
|
||||
let ws = make_workspace();
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue