fix: Keep typing status on telegram while message is being processed

# Conflicts:
#	src/channels/mod.rs
This commit is contained in:
Jayson Reis 2026-02-18 09:10:37 +00:00 committed by Chummy
parent 1bfd50bce9
commit 12c5473083
3 changed files with 103 additions and 12 deletions

1
Cargo.lock generated
View file

@ -5226,6 +5226,7 @@ dependencies = [
"tokio-rustls",
"tokio-serial",
"tokio-tungstenite 0.24.0",
"tokio-util",
"toml 1.0.2+spec-1.1.0",
"tower",
"tower-http",

View file

@ -20,6 +20,7 @@ clap = { version = "4.5", features = ["derive"] }
# Async runtime - feature-optimized for size
tokio = { version = "1.42", default-features = false, features = ["rt-multi-thread", "macros", "time", "net", "io-util", "sync", "process", "io-std", "fs", "signal"] }
tokio-util = { version = "0.7", default-features = false }
# HTTP client - minimal features
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls", "blocking", "multipart", "stream"] }

View file

@ -47,6 +47,7 @@ use std::path::PathBuf;
use std::process::Command;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
/// Per-sender conversation history for channel messages.
type ConversationHistoryMap = Arc<Mutex<HashMap<String, Vec<ChatMessage>>>>;
@ -64,6 +65,7 @@ const CHANNEL_MESSAGE_TIMEOUT_SECS: u64 = 300;
const CHANNEL_PARALLELISM_PER_CHANNEL: usize = 4;
const CHANNEL_MIN_IN_FLIGHT_MESSAGES: usize = 8;
const CHANNEL_MAX_IN_FLIGHT_MESSAGES: usize = 64;
const CHANNEL_TYPING_REFRESH_INTERVAL_SECS: u64 = 4;
#[derive(Clone)]
struct ChannelRuntimeContext {
@ -177,6 +179,36 @@ fn log_worker_join_result(result: Result<(), tokio::task::JoinError>) {
}
}
fn spawn_scoped_typing_task(
channel: Arc<dyn Channel>,
recipient: String,
cancellation_token: CancellationToken,
) -> tokio::task::JoinHandle<()> {
let stop_signal = cancellation_token;
let refresh_interval = Duration::from_secs(CHANNEL_TYPING_REFRESH_INTERVAL_SECS);
let handle = tokio::spawn(async move {
let mut interval = tokio::time::interval(refresh_interval);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
() = stop_signal.cancelled() => break,
_ = interval.tick() => {
if let Err(e) = channel.start_typing(&recipient).await {
tracing::debug!("Failed to start typing on {}: {e}", channel.name());
}
}
}
}
if let Err(e) = channel.stop_typing(&recipient).await {
tracing::debug!("Failed to stop typing on {}: {e}", channel.name());
}
});
handle
}
async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::ChannelMessage) {
println!(
" 💬 [{}] from {}: {}",
@ -209,12 +241,6 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
let target_channel = ctx.channels_by_name.get(&msg.channel).cloned();
if let Some(channel) = target_channel.as_ref() {
if let Err(e) = channel.start_typing(&msg.reply_target).await {
tracing::debug!("Failed to start typing on {}: {e}", channel.name());
}
}
println!(" ⏳ Processing message...");
let started_at = Instant::now();
@ -294,6 +320,20 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
None
};
let typing_cancellation = target_channel.as_ref().map(|_| {
let token = CancellationToken::new();
let guard = token.clone().drop_guard();
(token, guard)
});
let _typing_task = match (target_channel.as_ref(), typing_cancellation.as_ref()) {
(Some(channel), Some((token, _guard))) => Some(spawn_scoped_typing_task(
Arc::clone(channel),
msg.reply_target.clone(),
token.clone(),
)),
_ => None,
};
let llm_result = tokio::time::timeout(
Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS),
run_tool_call_loop(
@ -318,12 +358,6 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
let _ = handle.await;
}
if let Some(channel) = target_channel.as_ref() {
if let Err(e) = channel.stop_typing(&msg.reply_target).await {
tracing::debug!("Failed to stop typing on {}: {e}", channel.name());
}
}
match llm_result {
Ok(Ok(response)) => {
// Save user + assistant turn to per-sender history
@ -1458,6 +1492,8 @@ mod tests {
#[derive(Default)]
struct RecordingChannel {
sent_messages: tokio::sync::Mutex<Vec<String>>,
start_typing_calls: AtomicUsize,
stop_typing_calls: AtomicUsize,
}
#[async_trait::async_trait]
@ -1480,6 +1516,16 @@ mod tests {
) -> anyhow::Result<()> {
Ok(())
}
async fn start_typing(&self, _recipient: &str) -> anyhow::Result<()> {
self.start_typing_calls.fetch_add(1, Ordering::SeqCst);
Ok(())
}
async fn stop_typing(&self, _recipient: &str) -> anyhow::Result<()> {
self.stop_typing_calls.fetch_add(1, Ordering::SeqCst);
Ok(())
}
}
struct SlowProvider {
@ -1851,6 +1897,49 @@ mod tests {
assert_eq!(sent_messages.len(), 2);
}
#[tokio::test]
async fn process_channel_message_cancels_scoped_typing_task() {
let channel_impl = Arc::new(RecordingChannel::default());
let channel: Arc<dyn Channel> = channel_impl.clone();
let mut channels_by_name = HashMap::new();
channels_by_name.insert(channel.name().to_string(), channel);
let runtime_ctx = Arc::new(ChannelRuntimeContext {
channels_by_name: Arc::new(channels_by_name),
provider: Arc::new(SlowProvider {
delay: Duration::from_millis(20),
}),
memory: Arc::new(NoopMemory),
tools_registry: Arc::new(vec![]),
observer: Arc::new(NoopObserver),
system_prompt: Arc::new("test-system-prompt".to_string()),
model: Arc::new("test-model".to_string()),
temperature: 0.0,
auto_save_memory: false,
max_tool_iterations: 10,
min_relevance_score: 0.0,
});
process_channel_message(
runtime_ctx,
traits::ChannelMessage {
id: "typing-msg".to_string(),
sender: "alice".to_string(),
reply_target: "chat-typing".to_string(),
content: "hello".to_string(),
channel: "test-channel".to_string(),
timestamp: 1,
},
)
.await;
let starts = channel_impl.start_typing_calls.load(Ordering::SeqCst);
let stops = channel_impl.stop_typing_calls.load(Ordering::SeqCst);
assert_eq!(starts, 1, "start_typing should be called once");
assert_eq!(stops, 1, "stop_typing should be called once");
}
#[test]
fn prompt_contains_all_sections() {
let ws = make_workspace();