From 12c54730832f0ce69d14e2d7b86e8387cdbac7c7 Mon Sep 17 00:00:00 2001 From: Jayson Reis Date: Wed, 18 Feb 2026 09:10:37 +0000 Subject: [PATCH] fix: Keep typing status on telegram while message is being processed # Conflicts: # src/channels/mod.rs --- Cargo.lock | 1 + Cargo.toml | 1 + src/channels/mod.rs | 113 +++++++++++++++++++++++++++++++++++++++----- 3 files changed, 103 insertions(+), 12 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index fce1823..0a7b284 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5226,6 +5226,7 @@ dependencies = [ "tokio-rustls", "tokio-serial", "tokio-tungstenite 0.24.0", + "tokio-util", "toml 1.0.2+spec-1.1.0", "tower", "tower-http", diff --git a/Cargo.toml b/Cargo.toml index fc67ff4..6c22173 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ clap = { version = "4.5", features = ["derive"] } # Async runtime - feature-optimized for size tokio = { version = "1.42", default-features = false, features = ["rt-multi-thread", "macros", "time", "net", "io-util", "sync", "process", "io-std", "fs", "signal"] } +tokio-util = { version = "0.7", default-features = false } # HTTP client - minimal features reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls", "blocking", "multipart", "stream"] } diff --git a/src/channels/mod.rs b/src/channels/mod.rs index b0fba77..43b3f05 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -47,6 +47,7 @@ use std::path::PathBuf; use std::process::Command; use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; +use tokio_util::sync::CancellationToken; /// Per-sender conversation history for channel messages. type ConversationHistoryMap = Arc>>>; @@ -64,6 +65,7 @@ const CHANNEL_MESSAGE_TIMEOUT_SECS: u64 = 300; const CHANNEL_PARALLELISM_PER_CHANNEL: usize = 4; const CHANNEL_MIN_IN_FLIGHT_MESSAGES: usize = 8; const CHANNEL_MAX_IN_FLIGHT_MESSAGES: usize = 64; +const CHANNEL_TYPING_REFRESH_INTERVAL_SECS: u64 = 4; #[derive(Clone)] struct ChannelRuntimeContext { @@ -177,6 +179,36 @@ fn log_worker_join_result(result: Result<(), tokio::task::JoinError>) { } } +fn spawn_scoped_typing_task( + channel: Arc, + recipient: String, + cancellation_token: CancellationToken, +) -> tokio::task::JoinHandle<()> { + let stop_signal = cancellation_token; + let refresh_interval = Duration::from_secs(CHANNEL_TYPING_REFRESH_INTERVAL_SECS); + let handle = tokio::spawn(async move { + let mut interval = tokio::time::interval(refresh_interval); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + + loop { + tokio::select! { + () = stop_signal.cancelled() => break, + _ = interval.tick() => { + if let Err(e) = channel.start_typing(&recipient).await { + tracing::debug!("Failed to start typing on {}: {e}", channel.name()); + } + } + } + } + + if let Err(e) = channel.stop_typing(&recipient).await { + tracing::debug!("Failed to stop typing on {}: {e}", channel.name()); + } + }); + + handle +} + async fn process_channel_message(ctx: Arc, msg: traits::ChannelMessage) { println!( " 💬 [{}] from {}: {}", @@ -209,12 +241,6 @@ async fn process_channel_message(ctx: Arc, msg: traits::C let target_channel = ctx.channels_by_name.get(&msg.channel).cloned(); - if let Some(channel) = target_channel.as_ref() { - if let Err(e) = channel.start_typing(&msg.reply_target).await { - tracing::debug!("Failed to start typing on {}: {e}", channel.name()); - } - } - println!(" ⏳ Processing message..."); let started_at = Instant::now(); @@ -294,6 +320,20 @@ async fn process_channel_message(ctx: Arc, msg: traits::C None }; + let typing_cancellation = target_channel.as_ref().map(|_| { + let token = CancellationToken::new(); + let guard = token.clone().drop_guard(); + (token, guard) + }); + let _typing_task = match (target_channel.as_ref(), typing_cancellation.as_ref()) { + (Some(channel), Some((token, _guard))) => Some(spawn_scoped_typing_task( + Arc::clone(channel), + msg.reply_target.clone(), + token.clone(), + )), + _ => None, + }; + let llm_result = tokio::time::timeout( Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS), run_tool_call_loop( @@ -318,12 +358,6 @@ async fn process_channel_message(ctx: Arc, msg: traits::C let _ = handle.await; } - if let Some(channel) = target_channel.as_ref() { - if let Err(e) = channel.stop_typing(&msg.reply_target).await { - tracing::debug!("Failed to stop typing on {}: {e}", channel.name()); - } - } - match llm_result { Ok(Ok(response)) => { // Save user + assistant turn to per-sender history @@ -1458,6 +1492,8 @@ mod tests { #[derive(Default)] struct RecordingChannel { sent_messages: tokio::sync::Mutex>, + start_typing_calls: AtomicUsize, + stop_typing_calls: AtomicUsize, } #[async_trait::async_trait] @@ -1480,6 +1516,16 @@ mod tests { ) -> anyhow::Result<()> { Ok(()) } + + async fn start_typing(&self, _recipient: &str) -> anyhow::Result<()> { + self.start_typing_calls.fetch_add(1, Ordering::SeqCst); + Ok(()) + } + + async fn stop_typing(&self, _recipient: &str) -> anyhow::Result<()> { + self.stop_typing_calls.fetch_add(1, Ordering::SeqCst); + Ok(()) + } } struct SlowProvider { @@ -1851,6 +1897,49 @@ mod tests { assert_eq!(sent_messages.len(), 2); } + #[tokio::test] + async fn process_channel_message_cancels_scoped_typing_task() { + let channel_impl = Arc::new(RecordingChannel::default()); + let channel: Arc = channel_impl.clone(); + + let mut channels_by_name = HashMap::new(); + channels_by_name.insert(channel.name().to_string(), channel); + + let runtime_ctx = Arc::new(ChannelRuntimeContext { + channels_by_name: Arc::new(channels_by_name), + provider: Arc::new(SlowProvider { + delay: Duration::from_millis(20), + }), + memory: Arc::new(NoopMemory), + tools_registry: Arc::new(vec![]), + observer: Arc::new(NoopObserver), + system_prompt: Arc::new("test-system-prompt".to_string()), + model: Arc::new("test-model".to_string()), + temperature: 0.0, + auto_save_memory: false, + max_tool_iterations: 10, + min_relevance_score: 0.0, + }); + + process_channel_message( + runtime_ctx, + traits::ChannelMessage { + id: "typing-msg".to_string(), + sender: "alice".to_string(), + reply_target: "chat-typing".to_string(), + content: "hello".to_string(), + channel: "test-channel".to_string(), + timestamp: 1, + }, + ) + .await; + + let starts = channel_impl.start_typing_calls.load(Ordering::SeqCst); + let stops = channel_impl.stop_typing_calls.load(Ordering::SeqCst); + assert_eq!(starts, 1, "start_typing should be called once"); + assert_eq!(stops, 1, "stop_typing should be called once"); + } + #[test] fn prompt_contains_all_sections() { let ws = make_workspace();