fix(channels): interrupt in-flight telegram requests on newer sender messages
This commit is contained in:
parent
d9a94fc763
commit
ef82c7dbcd
17 changed files with 669 additions and 115 deletions
|
|
@ -56,6 +56,7 @@ use std::collections::HashMap;
|
|||
use std::fmt::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::Command;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
|
@ -141,9 +142,43 @@ struct ChannelRuntimeContext {
|
|||
provider_runtime_options: providers::ProviderRuntimeOptions,
|
||||
workspace_dir: Arc<PathBuf>,
|
||||
message_timeout_secs: u64,
|
||||
interrupt_on_new_message: bool,
|
||||
multimodal: crate::config::MultimodalConfig,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct InFlightSenderTaskState {
|
||||
task_id: u64,
|
||||
cancellation: CancellationToken,
|
||||
completion: Arc<InFlightTaskCompletion>,
|
||||
}
|
||||
|
||||
struct InFlightTaskCompletion {
|
||||
done: AtomicBool,
|
||||
notify: tokio::sync::Notify,
|
||||
}
|
||||
|
||||
impl InFlightTaskCompletion {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
done: AtomicBool::new(false),
|
||||
notify: tokio::sync::Notify::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn mark_done(&self) {
|
||||
self.done.store(true, Ordering::Release);
|
||||
self.notify.notify_waiters();
|
||||
}
|
||||
|
||||
async fn wait(&self) {
|
||||
if self.done.load(Ordering::Acquire) {
|
||||
return;
|
||||
}
|
||||
self.notify.notified().await;
|
||||
}
|
||||
}
|
||||
|
||||
fn conversation_memory_key(msg: &traits::ChannelMessage) -> String {
|
||||
format!("{}_{}_{}", msg.channel, msg.sender, msg.id)
|
||||
}
|
||||
|
|
@ -152,6 +187,10 @@ fn conversation_history_key(msg: &traits::ChannelMessage) -> String {
|
|||
format!("{}_{}", msg.channel, msg.sender)
|
||||
}
|
||||
|
||||
fn interruption_scope_key(msg: &traits::ChannelMessage) -> String {
|
||||
format!("{}_{}_{}", msg.channel, msg.reply_target, msg.sender)
|
||||
}
|
||||
|
||||
fn channel_delivery_instructions(channel_name: &str) -> Option<&'static str> {
|
||||
match channel_name {
|
||||
"telegram" => Some(
|
||||
|
|
@ -292,6 +331,18 @@ fn compact_sender_history(ctx: &ChannelRuntimeContext, sender_key: &str) -> bool
|
|||
true
|
||||
}
|
||||
|
||||
fn append_sender_turn(ctx: &ChannelRuntimeContext, sender_key: &str, turn: ChatMessage) {
|
||||
let mut histories = ctx
|
||||
.conversation_histories
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
let turns = histories.entry(sender_key.to_string()).or_default();
|
||||
turns.push(turn);
|
||||
while turns.len() > MAX_CHANNEL_HISTORY {
|
||||
turns.remove(0);
|
||||
}
|
||||
}
|
||||
|
||||
fn should_skip_memory_context_entry(key: &str, content: &str) -> bool {
|
||||
if memory::is_assistant_autosave_key(key) {
|
||||
return true;
|
||||
|
|
@ -657,7 +708,15 @@ fn spawn_scoped_typing_task(
|
|||
handle
|
||||
}
|
||||
|
||||
async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::ChannelMessage) {
|
||||
async fn process_channel_message(
|
||||
ctx: Arc<ChannelRuntimeContext>,
|
||||
msg: traits::ChannelMessage,
|
||||
cancellation_token: CancellationToken,
|
||||
) {
|
||||
if cancellation_token.is_cancelled() {
|
||||
return;
|
||||
}
|
||||
|
||||
println!(
|
||||
" 💬 [{}] from {}: {}",
|
||||
msg.channel,
|
||||
|
|
@ -717,7 +776,13 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
|||
println!(" ⏳ Processing message...");
|
||||
let started_at = Instant::now();
|
||||
|
||||
// Build history from per-sender conversation cache
|
||||
// Preserve user turn before the LLM call so interrupted requests keep context.
|
||||
append_sender_turn(
|
||||
ctx.as_ref(),
|
||||
&history_key,
|
||||
ChatMessage::user(&enriched_message),
|
||||
);
|
||||
|
||||
let mut prior_turns = ctx
|
||||
.conversation_histories
|
||||
.lock()
|
||||
|
|
@ -728,18 +793,15 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
|||
|
||||
let mut history = vec![ChatMessage::system(ctx.system_prompt.as_str())];
|
||||
history.append(&mut prior_turns);
|
||||
history.push(ChatMessage::user(&enriched_message));
|
||||
|
||||
if let Some(instructions) = channel_delivery_instructions(&msg.channel) {
|
||||
history.push(ChatMessage::system(instructions));
|
||||
}
|
||||
|
||||
// Determine if this channel supports streaming draft updates
|
||||
let use_streaming = target_channel
|
||||
.as_ref()
|
||||
.map_or(false, |ch| ch.supports_draft_updates());
|
||||
.is_some_and(|ch| ch.supports_draft_updates());
|
||||
|
||||
// Set up streaming channel if supported
|
||||
let (delta_tx, delta_rx) = if use_streaming {
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<String>(64);
|
||||
(Some(tx), Some(rx))
|
||||
|
|
@ -747,7 +809,6 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
|||
(None, None)
|
||||
};
|
||||
|
||||
// Send initial draft message if streaming
|
||||
let draft_message_id = if use_streaming {
|
||||
if let Some(channel) = target_channel.as_ref() {
|
||||
match channel
|
||||
|
|
@ -769,7 +830,6 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
|||
None
|
||||
};
|
||||
|
||||
// Spawn a task to forward streaming deltas to draft updates
|
||||
let draft_updater = if let (Some(mut rx), Some(draft_id_ref), Some(channel_ref)) = (
|
||||
delta_rx,
|
||||
draft_message_id.as_deref(),
|
||||
|
|
@ -804,27 +864,34 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
|||
_ => None,
|
||||
};
|
||||
|
||||
let llm_result = tokio::time::timeout(
|
||||
Duration::from_secs(ctx.message_timeout_secs),
|
||||
run_tool_call_loop(
|
||||
active_provider.as_ref(),
|
||||
&mut history,
|
||||
ctx.tools_registry.as_ref(),
|
||||
ctx.observer.as_ref(),
|
||||
route.provider.as_str(),
|
||||
route.model.as_str(),
|
||||
ctx.temperature,
|
||||
true,
|
||||
None,
|
||||
msg.channel.as_str(),
|
||||
&ctx.multimodal,
|
||||
ctx.max_tool_iterations,
|
||||
delta_tx,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
enum LlmExecutionResult {
|
||||
Completed(Result<Result<String, anyhow::Error>, tokio::time::error::Elapsed>),
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
let llm_result = tokio::select! {
|
||||
() = cancellation_token.cancelled() => LlmExecutionResult::Cancelled,
|
||||
result = tokio::time::timeout(
|
||||
Duration::from_secs(ctx.message_timeout_secs),
|
||||
run_tool_call_loop(
|
||||
active_provider.as_ref(),
|
||||
&mut history,
|
||||
ctx.tools_registry.as_ref(),
|
||||
ctx.observer.as_ref(),
|
||||
route.provider.as_str(),
|
||||
route.model.as_str(),
|
||||
ctx.temperature,
|
||||
true,
|
||||
None,
|
||||
msg.channel.as_str(),
|
||||
&ctx.multimodal,
|
||||
ctx.max_tool_iterations,
|
||||
Some(cancellation_token.clone()),
|
||||
delta_tx,
|
||||
),
|
||||
) => LlmExecutionResult::Completed(result),
|
||||
};
|
||||
|
||||
// Wait for draft updater to finish
|
||||
if let Some(handle) = draft_updater {
|
||||
let _ = handle.await;
|
||||
}
|
||||
|
|
@ -837,21 +904,26 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
|||
}
|
||||
|
||||
match llm_result {
|
||||
Ok(Ok(response)) => {
|
||||
// Save user + assistant turn to per-sender history
|
||||
LlmExecutionResult::Cancelled => {
|
||||
tracing::info!(
|
||||
channel = %msg.channel,
|
||||
sender = %msg.sender,
|
||||
"Cancelled in-flight channel request due to newer message"
|
||||
);
|
||||
if let (Some(channel), Some(draft_id)) =
|
||||
(target_channel.as_ref(), draft_message_id.as_deref())
|
||||
{
|
||||
let mut histories = ctx
|
||||
.conversation_histories
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
let turns = histories.entry(history_key).or_default();
|
||||
turns.push(ChatMessage::user(&enriched_message));
|
||||
turns.push(ChatMessage::assistant(&response));
|
||||
// Trim to MAX_CHANNEL_HISTORY (keep recent turns)
|
||||
while turns.len() > MAX_CHANNEL_HISTORY {
|
||||
turns.remove(0);
|
||||
if let Err(err) = channel.cancel_draft(&msg.reply_target, draft_id).await {
|
||||
tracing::debug!("Failed to cancel draft on {}: {err}", channel.name());
|
||||
}
|
||||
}
|
||||
}
|
||||
LlmExecutionResult::Completed(Ok(Ok(response))) => {
|
||||
append_sender_turn(
|
||||
ctx.as_ref(),
|
||||
&history_key,
|
||||
ChatMessage::assistant(&response),
|
||||
);
|
||||
println!(
|
||||
" 🤖 Reply ({}ms): {}",
|
||||
started_at.elapsed().as_millis(),
|
||||
|
|
@ -882,7 +954,24 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
|||
}
|
||||
}
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
LlmExecutionResult::Completed(Ok(Err(e))) => {
|
||||
if crate::agent::loop_::is_tool_loop_cancelled(&e) || cancellation_token.is_cancelled()
|
||||
{
|
||||
tracing::info!(
|
||||
channel = %msg.channel,
|
||||
sender = %msg.sender,
|
||||
"Cancelled in-flight channel request due to newer message"
|
||||
);
|
||||
if let (Some(channel), Some(draft_id)) =
|
||||
(target_channel.as_ref(), draft_message_id.as_deref())
|
||||
{
|
||||
if let Err(err) = channel.cancel_draft(&msg.reply_target, draft_id).await {
|
||||
tracing::debug!("Failed to cancel draft on {}: {err}", channel.name());
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if is_context_window_overflow_error(&e) {
|
||||
let compacted = compact_sender_history(ctx.as_ref(), &history_key);
|
||||
let error_text = if compacted {
|
||||
|
|
@ -931,7 +1020,7 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
|||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
LlmExecutionResult::Completed(Err(_)) => {
|
||||
let timeout_msg = format!("LLM response timed out after {}s", ctx.message_timeout_secs);
|
||||
eprintln!(
|
||||
" ❌ {} (elapsed: {}ms)",
|
||||
|
|
@ -965,6 +1054,11 @@ async fn run_message_dispatch_loop(
|
|||
) {
|
||||
let semaphore = Arc::new(tokio::sync::Semaphore::new(max_in_flight_messages));
|
||||
let mut workers = tokio::task::JoinSet::new();
|
||||
let in_flight_by_sender = Arc::new(tokio::sync::Mutex::new(HashMap::<
|
||||
String,
|
||||
InFlightSenderTaskState,
|
||||
>::new()));
|
||||
let task_sequence = Arc::new(AtomicU64::new(1));
|
||||
|
||||
while let Some(msg) = rx.recv().await {
|
||||
let permit = match Arc::clone(&semaphore).acquire_owned().await {
|
||||
|
|
@ -973,9 +1067,54 @@ async fn run_message_dispatch_loop(
|
|||
};
|
||||
|
||||
let worker_ctx = Arc::clone(&ctx);
|
||||
let in_flight = Arc::clone(&in_flight_by_sender);
|
||||
let task_sequence = Arc::clone(&task_sequence);
|
||||
workers.spawn(async move {
|
||||
let _permit = permit;
|
||||
process_channel_message(worker_ctx, msg).await;
|
||||
let interrupt_enabled =
|
||||
worker_ctx.interrupt_on_new_message && msg.channel == "telegram";
|
||||
let sender_scope_key = interruption_scope_key(&msg);
|
||||
let cancellation_token = CancellationToken::new();
|
||||
let completion = Arc::new(InFlightTaskCompletion::new());
|
||||
let task_id = task_sequence.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
if interrupt_enabled {
|
||||
let previous = {
|
||||
let mut active = in_flight.lock().await;
|
||||
active.insert(
|
||||
sender_scope_key.clone(),
|
||||
InFlightSenderTaskState {
|
||||
task_id,
|
||||
cancellation: cancellation_token.clone(),
|
||||
completion: Arc::clone(&completion),
|
||||
},
|
||||
)
|
||||
};
|
||||
|
||||
if let Some(previous) = previous {
|
||||
tracing::info!(
|
||||
channel = %msg.channel,
|
||||
sender = %msg.sender,
|
||||
"Interrupting previous in-flight request for sender"
|
||||
);
|
||||
previous.cancellation.cancel();
|
||||
previous.completion.wait().await;
|
||||
}
|
||||
}
|
||||
|
||||
process_channel_message(worker_ctx, msg, cancellation_token).await;
|
||||
|
||||
if interrupt_enabled {
|
||||
let mut active = in_flight.lock().await;
|
||||
if active
|
||||
.get(&sender_scope_key)
|
||||
.is_some_and(|state| state.task_id == task_id)
|
||||
{
|
||||
active.remove(&sender_scope_key);
|
||||
}
|
||||
}
|
||||
|
||||
completion.mark_done();
|
||||
});
|
||||
|
||||
while let Some(result) = workers.try_join_next() {
|
||||
|
|
@ -2101,6 +2240,11 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
provider_cache_seed.insert(provider_name.clone(), Arc::clone(&provider));
|
||||
let message_timeout_secs =
|
||||
effective_channel_message_timeout_secs(config.channels_config.message_timeout_secs);
|
||||
let interrupt_on_new_message = config
|
||||
.channels_config
|
||||
.telegram
|
||||
.as_ref()
|
||||
.is_some_and(|tg| tg.interrupt_on_new_message);
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name,
|
||||
|
|
@ -2124,6 +2268,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
provider_runtime_options,
|
||||
workspace_dir: Arc::new(config.workspace_dir.clone()),
|
||||
message_timeout_secs,
|
||||
interrupt_on_new_message,
|
||||
multimodal: config.multimodal.clone(),
|
||||
});
|
||||
|
||||
|
|
@ -2245,6 +2390,7 @@ mod tests {
|
|||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||
interrupt_on_new_message: false,
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
|
|
@ -2527,6 +2673,43 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
struct DelayedHistoryCaptureProvider {
|
||||
delay: Duration,
|
||||
calls: std::sync::Mutex<Vec<Vec<(String, String)>>>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Provider for DelayedHistoryCaptureProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
_system_prompt: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Ok("fallback".to_string())
|
||||
}
|
||||
|
||||
async fn chat_with_history(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let snapshot = messages
|
||||
.iter()
|
||||
.map(|m| (m.role.clone(), m.content.clone()))
|
||||
.collect::<Vec<_>>();
|
||||
let call_index = {
|
||||
let mut calls = self.calls.lock().unwrap_or_else(|e| e.into_inner());
|
||||
calls.push(snapshot);
|
||||
calls.len()
|
||||
};
|
||||
tokio::time::sleep(self.delay).await;
|
||||
Ok(format!("response-{call_index}"))
|
||||
}
|
||||
}
|
||||
|
||||
struct MockPriceTool;
|
||||
|
||||
#[derive(Default)]
|
||||
|
|
@ -2630,6 +2813,7 @@ mod tests {
|
|||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: false,
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
});
|
||||
|
||||
|
|
@ -2644,6 +2828,7 @@ mod tests {
|
|||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
|
|
@ -2685,6 +2870,7 @@ mod tests {
|
|||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: false,
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
});
|
||||
|
||||
|
|
@ -2699,6 +2885,7 @@ mod tests {
|
|||
timestamp: 2,
|
||||
thread_ts: None,
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
|
|
@ -2749,6 +2936,7 @@ mod tests {
|
|||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: false,
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
});
|
||||
|
||||
|
|
@ -2763,6 +2951,7 @@ mod tests {
|
|||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
|
|
@ -2834,6 +3023,7 @@ mod tests {
|
|||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: false,
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
});
|
||||
|
||||
|
|
@ -2848,6 +3038,7 @@ mod tests {
|
|||
timestamp: 2,
|
||||
thread_ts: None,
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
|
|
@ -2895,6 +3086,7 @@ mod tests {
|
|||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: false,
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
});
|
||||
|
||||
|
|
@ -2909,6 +3101,7 @@ mod tests {
|
|||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
|
|
@ -2951,6 +3144,7 @@ mod tests {
|
|||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: false,
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
});
|
||||
|
||||
|
|
@ -2965,6 +3159,7 @@ mod tests {
|
|||
timestamp: 2,
|
||||
thread_ts: None,
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
|
|
@ -3058,6 +3253,7 @@ mod tests {
|
|||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: false,
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
});
|
||||
|
||||
|
|
@ -3100,6 +3296,171 @@ mod tests {
|
|||
assert_eq!(sent_messages.len(), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn message_dispatch_interrupts_in_flight_telegram_request_and_preserves_context() {
|
||||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let provider_impl = Arc::new(DelayedHistoryCaptureProvider {
|
||||
delay: Duration::from_millis(250),
|
||||
calls: std::sync::Mutex::new(Vec::new()),
|
||||
});
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: provider_impl.clone(),
|
||||
default_provider: Arc::new("test-provider".to_string()),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||
model: Arc::new("test-model".to_string()),
|
||||
temperature: 0.0,
|
||||
auto_save_memory: false,
|
||||
max_tool_iterations: 10,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: true,
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
let send_task = tokio::spawn(async move {
|
||||
tx.send(traits::ChannelMessage {
|
||||
id: "msg-1".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "chat-1".to_string(),
|
||||
content: "forwarded content".to_string(),
|
||||
channel: "telegram".to_string(),
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
tokio::time::sleep(Duration::from_millis(40)).await;
|
||||
tx.send(traits::ChannelMessage {
|
||||
id: "msg-2".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "chat-1".to_string(),
|
||||
content: "summarize this".to_string(),
|
||||
channel: "telegram".to_string(),
|
||||
timestamp: 2,
|
||||
thread_ts: None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
run_message_dispatch_loop(rx, runtime_ctx, 4).await;
|
||||
send_task.await.unwrap();
|
||||
|
||||
let sent_messages = channel_impl.sent_messages.lock().await;
|
||||
assert_eq!(sent_messages.len(), 1);
|
||||
assert!(sent_messages[0].starts_with("chat-1:"));
|
||||
assert!(sent_messages[0].contains("response-2"));
|
||||
drop(sent_messages);
|
||||
|
||||
let calls = provider_impl
|
||||
.calls
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
assert_eq!(calls.len(), 2);
|
||||
let second_call = &calls[1];
|
||||
assert!(second_call
|
||||
.iter()
|
||||
.any(|(role, content)| { role == "user" && content.contains("forwarded content") }));
|
||||
assert!(second_call
|
||||
.iter()
|
||||
.any(|(role, content)| { role == "user" && content.contains("summarize this") }));
|
||||
assert!(
|
||||
!second_call.iter().any(|(role, _)| role == "assistant"),
|
||||
"cancelled turn should not persist an assistant response"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn message_dispatch_interrupt_scope_is_same_sender_same_chat() {
|
||||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: Arc::new(SlowProvider {
|
||||
delay: Duration::from_millis(180),
|
||||
}),
|
||||
default_provider: Arc::new("test-provider".to_string()),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||
model: Arc::new("test-model".to_string()),
|
||||
temperature: 0.0,
|
||||
auto_save_memory: false,
|
||||
max_tool_iterations: 10,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: true,
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
let send_task = tokio::spawn(async move {
|
||||
tx.send(traits::ChannelMessage {
|
||||
id: "msg-a".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "chat-1".to_string(),
|
||||
content: "first chat".to_string(),
|
||||
channel: "telegram".to_string(),
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
tokio::time::sleep(Duration::from_millis(30)).await;
|
||||
tx.send(traits::ChannelMessage {
|
||||
id: "msg-b".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "chat-2".to_string(),
|
||||
content: "second chat".to_string(),
|
||||
channel: "telegram".to_string(),
|
||||
timestamp: 2,
|
||||
thread_ts: None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
run_message_dispatch_loop(rx, runtime_ctx, 4).await;
|
||||
send_task.await.unwrap();
|
||||
|
||||
let sent_messages = channel_impl.sent_messages.lock().await;
|
||||
assert_eq!(sent_messages.len(), 2);
|
||||
assert!(sent_messages.iter().any(|msg| msg.starts_with("chat-1:")));
|
||||
assert!(sent_messages.iter().any(|msg| msg.starts_with("chat-2:")));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_cancels_scoped_typing_task() {
|
||||
let channel_impl = Arc::new(RecordingChannel::default());
|
||||
|
|
@ -3132,6 +3493,7 @@ mod tests {
|
|||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: false,
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
});
|
||||
|
||||
|
|
@ -3146,6 +3508,7 @@ mod tests {
|
|||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
|
|
@ -3579,6 +3942,7 @@ mod tests {
|
|||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: false,
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
});
|
||||
|
||||
|
|
@ -3593,6 +3957,7 @@ mod tests {
|
|||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
|
|
@ -3607,6 +3972,7 @@ mod tests {
|
|||
timestamp: 2,
|
||||
thread_ts: None,
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue