fix: Keep typing status on telegram while message is being processed
# Conflicts: # src/channels/mod.rs
This commit is contained in:
parent
1bfd50bce9
commit
12c5473083
3 changed files with 103 additions and 12 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
|
@ -5226,6 +5226,7 @@ dependencies = [
|
||||||
"tokio-rustls",
|
"tokio-rustls",
|
||||||
"tokio-serial",
|
"tokio-serial",
|
||||||
"tokio-tungstenite 0.24.0",
|
"tokio-tungstenite 0.24.0",
|
||||||
|
"tokio-util",
|
||||||
"toml 1.0.2+spec-1.1.0",
|
"toml 1.0.2+spec-1.1.0",
|
||||||
"tower",
|
"tower",
|
||||||
"tower-http",
|
"tower-http",
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ clap = { version = "4.5", features = ["derive"] }
|
||||||
|
|
||||||
# Async runtime - feature-optimized for size
|
# Async runtime - feature-optimized for size
|
||||||
tokio = { version = "1.42", default-features = false, features = ["rt-multi-thread", "macros", "time", "net", "io-util", "sync", "process", "io-std", "fs", "signal"] }
|
tokio = { version = "1.42", default-features = false, features = ["rt-multi-thread", "macros", "time", "net", "io-util", "sync", "process", "io-std", "fs", "signal"] }
|
||||||
|
tokio-util = { version = "0.7", default-features = false }
|
||||||
|
|
||||||
# HTTP client - minimal features
|
# HTTP client - minimal features
|
||||||
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls", "blocking", "multipart", "stream"] }
|
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls", "blocking", "multipart", "stream"] }
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,7 @@ use std::path::PathBuf;
|
||||||
use std::process::Command;
|
use std::process::Command;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
/// Per-sender conversation history for channel messages.
|
/// Per-sender conversation history for channel messages.
|
||||||
type ConversationHistoryMap = Arc<Mutex<HashMap<String, Vec<ChatMessage>>>>;
|
type ConversationHistoryMap = Arc<Mutex<HashMap<String, Vec<ChatMessage>>>>;
|
||||||
|
|
@ -64,6 +65,7 @@ const CHANNEL_MESSAGE_TIMEOUT_SECS: u64 = 300;
|
||||||
const CHANNEL_PARALLELISM_PER_CHANNEL: usize = 4;
|
const CHANNEL_PARALLELISM_PER_CHANNEL: usize = 4;
|
||||||
const CHANNEL_MIN_IN_FLIGHT_MESSAGES: usize = 8;
|
const CHANNEL_MIN_IN_FLIGHT_MESSAGES: usize = 8;
|
||||||
const CHANNEL_MAX_IN_FLIGHT_MESSAGES: usize = 64;
|
const CHANNEL_MAX_IN_FLIGHT_MESSAGES: usize = 64;
|
||||||
|
const CHANNEL_TYPING_REFRESH_INTERVAL_SECS: u64 = 4;
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct ChannelRuntimeContext {
|
struct ChannelRuntimeContext {
|
||||||
|
|
@ -177,6 +179,36 @@ fn log_worker_join_result(result: Result<(), tokio::task::JoinError>) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn spawn_scoped_typing_task(
|
||||||
|
channel: Arc<dyn Channel>,
|
||||||
|
recipient: String,
|
||||||
|
cancellation_token: CancellationToken,
|
||||||
|
) -> tokio::task::JoinHandle<()> {
|
||||||
|
let stop_signal = cancellation_token;
|
||||||
|
let refresh_interval = Duration::from_secs(CHANNEL_TYPING_REFRESH_INTERVAL_SECS);
|
||||||
|
let handle = tokio::spawn(async move {
|
||||||
|
let mut interval = tokio::time::interval(refresh_interval);
|
||||||
|
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||||
|
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
() = stop_signal.cancelled() => break,
|
||||||
|
_ = interval.tick() => {
|
||||||
|
if let Err(e) = channel.start_typing(&recipient).await {
|
||||||
|
tracing::debug!("Failed to start typing on {}: {e}", channel.name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(e) = channel.stop_typing(&recipient).await {
|
||||||
|
tracing::debug!("Failed to stop typing on {}: {e}", channel.name());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
handle
|
||||||
|
}
|
||||||
|
|
||||||
async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::ChannelMessage) {
|
async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::ChannelMessage) {
|
||||||
println!(
|
println!(
|
||||||
" 💬 [{}] from {}: {}",
|
" 💬 [{}] from {}: {}",
|
||||||
|
|
@ -209,12 +241,6 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
||||||
|
|
||||||
let target_channel = ctx.channels_by_name.get(&msg.channel).cloned();
|
let target_channel = ctx.channels_by_name.get(&msg.channel).cloned();
|
||||||
|
|
||||||
if let Some(channel) = target_channel.as_ref() {
|
|
||||||
if let Err(e) = channel.start_typing(&msg.reply_target).await {
|
|
||||||
tracing::debug!("Failed to start typing on {}: {e}", channel.name());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
println!(" ⏳ Processing message...");
|
println!(" ⏳ Processing message...");
|
||||||
let started_at = Instant::now();
|
let started_at = Instant::now();
|
||||||
|
|
||||||
|
|
@ -294,6 +320,20 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let typing_cancellation = target_channel.as_ref().map(|_| {
|
||||||
|
let token = CancellationToken::new();
|
||||||
|
let guard = token.clone().drop_guard();
|
||||||
|
(token, guard)
|
||||||
|
});
|
||||||
|
let _typing_task = match (target_channel.as_ref(), typing_cancellation.as_ref()) {
|
||||||
|
(Some(channel), Some((token, _guard))) => Some(spawn_scoped_typing_task(
|
||||||
|
Arc::clone(channel),
|
||||||
|
msg.reply_target.clone(),
|
||||||
|
token.clone(),
|
||||||
|
)),
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
|
||||||
let llm_result = tokio::time::timeout(
|
let llm_result = tokio::time::timeout(
|
||||||
Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS),
|
Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS),
|
||||||
run_tool_call_loop(
|
run_tool_call_loop(
|
||||||
|
|
@ -318,12 +358,6 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
||||||
let _ = handle.await;
|
let _ = handle.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(channel) = target_channel.as_ref() {
|
|
||||||
if let Err(e) = channel.stop_typing(&msg.reply_target).await {
|
|
||||||
tracing::debug!("Failed to stop typing on {}: {e}", channel.name());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
match llm_result {
|
match llm_result {
|
||||||
Ok(Ok(response)) => {
|
Ok(Ok(response)) => {
|
||||||
// Save user + assistant turn to per-sender history
|
// Save user + assistant turn to per-sender history
|
||||||
|
|
@ -1458,6 +1492,8 @@ mod tests {
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
struct RecordingChannel {
|
struct RecordingChannel {
|
||||||
sent_messages: tokio::sync::Mutex<Vec<String>>,
|
sent_messages: tokio::sync::Mutex<Vec<String>>,
|
||||||
|
start_typing_calls: AtomicUsize,
|
||||||
|
stop_typing_calls: AtomicUsize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
|
|
@ -1480,6 +1516,16 @@ mod tests {
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn start_typing(&self, _recipient: &str) -> anyhow::Result<()> {
|
||||||
|
self.start_typing_calls.fetch_add(1, Ordering::SeqCst);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn stop_typing(&self, _recipient: &str) -> anyhow::Result<()> {
|
||||||
|
self.stop_typing_calls.fetch_add(1, Ordering::SeqCst);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SlowProvider {
|
struct SlowProvider {
|
||||||
|
|
@ -1851,6 +1897,49 @@ mod tests {
|
||||||
assert_eq!(sent_messages.len(), 2);
|
assert_eq!(sent_messages.len(), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn process_channel_message_cancels_scoped_typing_task() {
|
||||||
|
let channel_impl = Arc::new(RecordingChannel::default());
|
||||||
|
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||||
|
|
||||||
|
let mut channels_by_name = HashMap::new();
|
||||||
|
channels_by_name.insert(channel.name().to_string(), channel);
|
||||||
|
|
||||||
|
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||||
|
channels_by_name: Arc::new(channels_by_name),
|
||||||
|
provider: Arc::new(SlowProvider {
|
||||||
|
delay: Duration::from_millis(20),
|
||||||
|
}),
|
||||||
|
memory: Arc::new(NoopMemory),
|
||||||
|
tools_registry: Arc::new(vec![]),
|
||||||
|
observer: Arc::new(NoopObserver),
|
||||||
|
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||||
|
model: Arc::new("test-model".to_string()),
|
||||||
|
temperature: 0.0,
|
||||||
|
auto_save_memory: false,
|
||||||
|
max_tool_iterations: 10,
|
||||||
|
min_relevance_score: 0.0,
|
||||||
|
});
|
||||||
|
|
||||||
|
process_channel_message(
|
||||||
|
runtime_ctx,
|
||||||
|
traits::ChannelMessage {
|
||||||
|
id: "typing-msg".to_string(),
|
||||||
|
sender: "alice".to_string(),
|
||||||
|
reply_target: "chat-typing".to_string(),
|
||||||
|
content: "hello".to_string(),
|
||||||
|
channel: "test-channel".to_string(),
|
||||||
|
timestamp: 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let starts = channel_impl.start_typing_calls.load(Ordering::SeqCst);
|
||||||
|
let stops = channel_impl.stop_typing_calls.load(Ordering::SeqCst);
|
||||||
|
assert_eq!(starts, 1, "start_typing should be called once");
|
||||||
|
assert_eq!(stops, 1, "stop_typing should be called once");
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn prompt_contains_all_sections() {
|
fn prompt_contains_all_sections() {
|
||||||
let ws = make_workspace();
|
let ws = make_workspace();
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue