fix(channels): interrupt in-flight telegram requests on newer sender messages
This commit is contained in:
parent
d9a94fc763
commit
ef82c7dbcd
17 changed files with 669 additions and 115 deletions
|
|
@ -138,8 +138,17 @@ Field names differ by channel:
|
||||||
[channels_config.telegram]
|
[channels_config.telegram]
|
||||||
bot_token = "123456:telegram-token"
|
bot_token = "123456:telegram-token"
|
||||||
allowed_users = ["*"]
|
allowed_users = ["*"]
|
||||||
|
stream_mode = "off" # optional: off | partial
|
||||||
|
draft_update_interval_ms = 1000 # optional: edit throttle for partial streaming
|
||||||
|
mention_only = false # optional: require @mention in groups
|
||||||
|
interrupt_on_new_message = false # optional: cancel in-flight same-sender same-chat request
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Telegram notes:
|
||||||
|
|
||||||
|
- `interrupt_on_new_message = true` preserves interrupted user turns in conversation history, then restarts generation on the newest message.
|
||||||
|
- Interruption scope is strict: same sender in the same chat. Messages from different chats are processed independently.
|
||||||
|
|
||||||
### 4.2 Discord
|
### 4.2 Discord
|
||||||
|
|
||||||
```toml
|
```toml
|
||||||
|
|
|
||||||
|
|
@ -188,6 +188,8 @@ Notes:
|
||||||
- If using cloud APIs (OpenAI, Anthropic, etc.), you can reduce this to `60` or lower.
|
- If using cloud APIs (OpenAI, Anthropic, etc.), you can reduce this to `60` or lower.
|
||||||
- Values below `30` are clamped to `30` to avoid immediate timeout churn.
|
- Values below `30` are clamped to `30` to avoid immediate timeout churn.
|
||||||
- When a timeout occurs, users receive: `⚠️ Request timed out while waiting for the model. Please try again.`
|
- When a timeout occurs, users receive: `⚠️ Request timed out while waiting for the model. Please try again.`
|
||||||
|
- Telegram-only interruption behavior is controlled with `channels_config.telegram.interrupt_on_new_message` (default `false`).
|
||||||
|
When enabled, a newer message from the same sender in the same chat cancels the in-flight request and preserves interrupted user context.
|
||||||
|
|
||||||
See detailed channel matrix and allowlist behavior in [channels-reference.md](channels-reference.md).
|
See detailed channel matrix and allowlist behavior in [channels-reference.md](channels-reference.md).
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ use std::fmt::Write;
|
||||||
use std::io::Write as _;
|
use std::io::Write as _;
|
||||||
use std::sync::{Arc, LazyLock};
|
use std::sync::{Arc, LazyLock};
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
/// Minimum characters per chunk when relaying LLM text to a streaming draft.
|
/// Minimum characters per chunk when relaying LLM text to a streaming draft.
|
||||||
|
|
@ -823,6 +824,21 @@ struct ParsedToolCall {
|
||||||
arguments: serde_json::Value,
|
arguments: serde_json::Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct ToolLoopCancelled;
|
||||||
|
|
||||||
|
impl std::fmt::Display for ToolLoopCancelled {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.write_str("tool loop cancelled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for ToolLoopCancelled {}
|
||||||
|
|
||||||
|
pub(crate) fn is_tool_loop_cancelled(err: &anyhow::Error) -> bool {
|
||||||
|
err.chain().any(|source| source.is::<ToolLoopCancelled>())
|
||||||
|
}
|
||||||
|
|
||||||
/// Execute a single turn of the agent loop: send messages, parse tool calls,
|
/// Execute a single turn of the agent loop: send messages, parse tool calls,
|
||||||
/// execute tools, and loop until the LLM produces a final text response.
|
/// execute tools, and loop until the LLM produces a final text response.
|
||||||
/// When `silent` is true, suppresses stdout (for channel use).
|
/// When `silent` is true, suppresses stdout (for channel use).
|
||||||
|
|
@ -853,6 +869,7 @@ pub(crate) async fn agent_turn(
|
||||||
multimodal_config,
|
multimodal_config,
|
||||||
max_tool_iterations,
|
max_tool_iterations,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
@ -873,6 +890,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||||
channel_name: &str,
|
channel_name: &str,
|
||||||
multimodal_config: &crate::config::MultimodalConfig,
|
multimodal_config: &crate::config::MultimodalConfig,
|
||||||
max_tool_iterations: usize,
|
max_tool_iterations: usize,
|
||||||
|
cancellation_token: Option<CancellationToken>,
|
||||||
on_delta: Option<tokio::sync::mpsc::Sender<String>>,
|
on_delta: Option<tokio::sync::mpsc::Sender<String>>,
|
||||||
) -> Result<String> {
|
) -> Result<String> {
|
||||||
let max_iterations = if max_tool_iterations == 0 {
|
let max_iterations = if max_tool_iterations == 0 {
|
||||||
|
|
@ -886,6 +904,13 @@ pub(crate) async fn run_tool_call_loop(
|
||||||
let use_native_tools = provider.supports_native_tools() && !tool_specs.is_empty();
|
let use_native_tools = provider.supports_native_tools() && !tool_specs.is_empty();
|
||||||
|
|
||||||
for _iteration in 0..max_iterations {
|
for _iteration in 0..max_iterations {
|
||||||
|
if cancellation_token
|
||||||
|
.as_ref()
|
||||||
|
.is_some_and(CancellationToken::is_cancelled)
|
||||||
|
{
|
||||||
|
return Err(ToolLoopCancelled.into());
|
||||||
|
}
|
||||||
|
|
||||||
let image_marker_count = multimodal::count_image_markers(history);
|
let image_marker_count = multimodal::count_image_markers(history);
|
||||||
if image_marker_count > 0 && !provider.supports_vision() {
|
if image_marker_count > 0 && !provider.supports_vision() {
|
||||||
return Err(ProviderCapabilityError {
|
return Err(ProviderCapabilityError {
|
||||||
|
|
@ -917,18 +942,26 @@ pub(crate) async fn run_tool_call_loop(
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
let (response_text, parsed_text, tool_calls, assistant_history_content, native_tool_calls) =
|
let chat_future = provider.chat(
|
||||||
match provider
|
|
||||||
.chat(
|
|
||||||
ChatRequest {
|
ChatRequest {
|
||||||
messages: &prepared_messages.messages,
|
messages: &prepared_messages.messages,
|
||||||
tools: request_tools,
|
tools: request_tools,
|
||||||
},
|
},
|
||||||
model,
|
model,
|
||||||
temperature,
|
temperature,
|
||||||
)
|
);
|
||||||
.await
|
|
||||||
{
|
let chat_result = if let Some(token) = cancellation_token.as_ref() {
|
||||||
|
tokio::select! {
|
||||||
|
() = token.cancelled() => return Err(ToolLoopCancelled.into()),
|
||||||
|
result = chat_future => result,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
chat_future.await
|
||||||
|
};
|
||||||
|
|
||||||
|
let (response_text, parsed_text, tool_calls, assistant_history_content, native_tool_calls) =
|
||||||
|
match chat_result {
|
||||||
Ok(resp) => {
|
Ok(resp) => {
|
||||||
observer.record_event(&ObserverEvent::LlmResponse {
|
observer.record_event(&ObserverEvent::LlmResponse {
|
||||||
provider: provider_name.to_string(),
|
provider: provider_name.to_string(),
|
||||||
|
|
@ -994,6 +1027,12 @@ pub(crate) async fn run_tool_call_loop(
|
||||||
// STREAM_CHUNK_MIN_CHARS characters for progressive draft updates.
|
// STREAM_CHUNK_MIN_CHARS characters for progressive draft updates.
|
||||||
let mut chunk = String::new();
|
let mut chunk = String::new();
|
||||||
for word in display_text.split_inclusive(char::is_whitespace) {
|
for word in display_text.split_inclusive(char::is_whitespace) {
|
||||||
|
if cancellation_token
|
||||||
|
.as_ref()
|
||||||
|
.is_some_and(CancellationToken::is_cancelled)
|
||||||
|
{
|
||||||
|
return Err(ToolLoopCancelled.into());
|
||||||
|
}
|
||||||
chunk.push_str(word);
|
chunk.push_str(word);
|
||||||
if chunk.len() >= STREAM_CHUNK_MIN_CHARS
|
if chunk.len() >= STREAM_CHUNK_MIN_CHARS
|
||||||
&& tx.send(std::mem::take(&mut chunk)).await.is_err()
|
&& tx.send(std::mem::take(&mut chunk)).await.is_err()
|
||||||
|
|
@ -1056,7 +1095,17 @@ pub(crate) async fn run_tool_call_loop(
|
||||||
});
|
});
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
let result = if let Some(tool) = find_tool(tools_registry, &call.name) {
|
let result = if let Some(tool) = find_tool(tools_registry, &call.name) {
|
||||||
match tool.execute(call.arguments.clone()).await {
|
let tool_future = tool.execute(call.arguments.clone());
|
||||||
|
let tool_result = if let Some(token) = cancellation_token.as_ref() {
|
||||||
|
tokio::select! {
|
||||||
|
() = token.cancelled() => return Err(ToolLoopCancelled.into()),
|
||||||
|
result = tool_future => result,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tool_future.await
|
||||||
|
};
|
||||||
|
|
||||||
|
match tool_result {
|
||||||
Ok(r) => {
|
Ok(r) => {
|
||||||
observer.record_event(&ObserverEvent::ToolCall {
|
observer.record_event(&ObserverEvent::ToolCall {
|
||||||
tool: call.name.clone(),
|
tool: call.name.clone(),
|
||||||
|
|
@ -1435,6 +1484,7 @@ pub async fn run(
|
||||||
&config.multimodal,
|
&config.multimodal,
|
||||||
config.agent.max_tool_iterations,
|
config.agent.max_tool_iterations,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
final_output = response.clone();
|
final_output = response.clone();
|
||||||
|
|
@ -1553,6 +1603,7 @@ pub async fn run(
|
||||||
&config.multimodal,
|
&config.multimodal,
|
||||||
config.agent.max_tool_iterations,
|
config.agent.max_tool_iterations,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
|
|
@ -1900,6 +1951,7 @@ mod tests {
|
||||||
&crate::config::MultimodalConfig::default(),
|
&crate::config::MultimodalConfig::default(),
|
||||||
3,
|
3,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.expect_err("provider without vision support should fail");
|
.expect_err("provider without vision support should fail");
|
||||||
|
|
@ -1943,6 +1995,7 @@ mod tests {
|
||||||
&multimodal,
|
&multimodal,
|
||||||
3,
|
3,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.expect_err("oversized payload must fail");
|
.expect_err("oversized payload must fail");
|
||||||
|
|
@ -1980,6 +2033,7 @@ mod tests {
|
||||||
&crate::config::MultimodalConfig::default(),
|
&crate::config::MultimodalConfig::default(),
|
||||||
3,
|
3,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.expect("valid multimodal payload should pass");
|
.expect("valid multimodal payload should pass");
|
||||||
|
|
@ -2809,7 +2863,10 @@ browser_open/url>https://example.com"#;
|
||||||
fn parse_tool_calls_closing_tag_only_returns_text() {
|
fn parse_tool_calls_closing_tag_only_returns_text() {
|
||||||
let response = "Some text </tool_call> more text";
|
let response = "Some text </tool_call> more text";
|
||||||
let (text, calls) = parse_tool_calls(response);
|
let (text, calls) = parse_tool_calls(response);
|
||||||
assert!(calls.is_empty(), "closing tag only should not produce calls");
|
assert!(
|
||||||
|
calls.is_empty(),
|
||||||
|
"closing tag only should not produce calls"
|
||||||
|
);
|
||||||
assert!(
|
assert!(
|
||||||
!text.is_empty(),
|
!text.is_empty(),
|
||||||
"text around orphaned closing tag should be preserved"
|
"text around orphaned closing tag should be preserved"
|
||||||
|
|
@ -2858,7 +2915,11 @@ browser_open/url>https://example.com"#;
|
||||||
|
|
||||||
Let me check the result."#;
|
Let me check the result."#;
|
||||||
let (text, calls) = parse_tool_calls(response);
|
let (text, calls) = parse_tool_calls(response);
|
||||||
assert_eq!(calls.len(), 1, "should extract one tool call from mixed content");
|
assert_eq!(
|
||||||
|
calls.len(),
|
||||||
|
1,
|
||||||
|
"should extract one tool call from mixed content"
|
||||||
|
);
|
||||||
assert_eq!(calls[0].name, "shell");
|
assert_eq!(calls[0].name, "shell");
|
||||||
assert!(
|
assert!(
|
||||||
text.contains("help you"),
|
text.contains("help you"),
|
||||||
|
|
@ -2880,7 +2941,10 @@ Let me check the result."#;
|
||||||
fn scrub_credentials_no_sensitive_data() {
|
fn scrub_credentials_no_sensitive_data() {
|
||||||
let input = "normal text without any secrets";
|
let input = "normal text without any secrets";
|
||||||
let result = scrub_credentials(input);
|
let result = scrub_credentials(input);
|
||||||
assert_eq!(result, input, "non-sensitive text should pass through unchanged");
|
assert_eq!(
|
||||||
|
result, input,
|
||||||
|
"non-sensitive text should pass through unchanged"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
|
|
@ -857,7 +857,10 @@ mod tests {
|
||||||
msg.push_str(&"x".repeat(1990));
|
msg.push_str(&"x".repeat(1990));
|
||||||
msg.push_str("\n```\nMore text after code block");
|
msg.push_str("\n```\nMore text after code block");
|
||||||
let parts = split_message_for_discord(&msg);
|
let parts = split_message_for_discord(&msg);
|
||||||
assert!(parts.len() >= 2, "code block spanning boundary should split");
|
assert!(
|
||||||
|
parts.len() >= 2,
|
||||||
|
"code block spanning boundary should split"
|
||||||
|
);
|
||||||
for part in &parts {
|
for part in &parts {
|
||||||
assert!(
|
assert!(
|
||||||
part.len() <= DISCORD_MAX_MESSAGE_LENGTH,
|
part.len() <= DISCORD_MAX_MESSAGE_LENGTH,
|
||||||
|
|
|
||||||
|
|
@ -56,6 +56,7 @@ use std::collections::HashMap;
|
||||||
use std::fmt::Write;
|
use std::fmt::Write;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use std::process::Command;
|
use std::process::Command;
|
||||||
|
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
@ -141,9 +142,43 @@ struct ChannelRuntimeContext {
|
||||||
provider_runtime_options: providers::ProviderRuntimeOptions,
|
provider_runtime_options: providers::ProviderRuntimeOptions,
|
||||||
workspace_dir: Arc<PathBuf>,
|
workspace_dir: Arc<PathBuf>,
|
||||||
message_timeout_secs: u64,
|
message_timeout_secs: u64,
|
||||||
|
interrupt_on_new_message: bool,
|
||||||
multimodal: crate::config::MultimodalConfig,
|
multimodal: crate::config::MultimodalConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct InFlightSenderTaskState {
|
||||||
|
task_id: u64,
|
||||||
|
cancellation: CancellationToken,
|
||||||
|
completion: Arc<InFlightTaskCompletion>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct InFlightTaskCompletion {
|
||||||
|
done: AtomicBool,
|
||||||
|
notify: tokio::sync::Notify,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl InFlightTaskCompletion {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
done: AtomicBool::new(false),
|
||||||
|
notify: tokio::sync::Notify::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mark_done(&self) {
|
||||||
|
self.done.store(true, Ordering::Release);
|
||||||
|
self.notify.notify_waiters();
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn wait(&self) {
|
||||||
|
if self.done.load(Ordering::Acquire) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
self.notify.notified().await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn conversation_memory_key(msg: &traits::ChannelMessage) -> String {
|
fn conversation_memory_key(msg: &traits::ChannelMessage) -> String {
|
||||||
format!("{}_{}_{}", msg.channel, msg.sender, msg.id)
|
format!("{}_{}_{}", msg.channel, msg.sender, msg.id)
|
||||||
}
|
}
|
||||||
|
|
@ -152,6 +187,10 @@ fn conversation_history_key(msg: &traits::ChannelMessage) -> String {
|
||||||
format!("{}_{}", msg.channel, msg.sender)
|
format!("{}_{}", msg.channel, msg.sender)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn interruption_scope_key(msg: &traits::ChannelMessage) -> String {
|
||||||
|
format!("{}_{}_{}", msg.channel, msg.reply_target, msg.sender)
|
||||||
|
}
|
||||||
|
|
||||||
fn channel_delivery_instructions(channel_name: &str) -> Option<&'static str> {
|
fn channel_delivery_instructions(channel_name: &str) -> Option<&'static str> {
|
||||||
match channel_name {
|
match channel_name {
|
||||||
"telegram" => Some(
|
"telegram" => Some(
|
||||||
|
|
@ -292,6 +331,18 @@ fn compact_sender_history(ctx: &ChannelRuntimeContext, sender_key: &str) -> bool
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn append_sender_turn(ctx: &ChannelRuntimeContext, sender_key: &str, turn: ChatMessage) {
|
||||||
|
let mut histories = ctx
|
||||||
|
.conversation_histories
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(|e| e.into_inner());
|
||||||
|
let turns = histories.entry(sender_key.to_string()).or_default();
|
||||||
|
turns.push(turn);
|
||||||
|
while turns.len() > MAX_CHANNEL_HISTORY {
|
||||||
|
turns.remove(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn should_skip_memory_context_entry(key: &str, content: &str) -> bool {
|
fn should_skip_memory_context_entry(key: &str, content: &str) -> bool {
|
||||||
if memory::is_assistant_autosave_key(key) {
|
if memory::is_assistant_autosave_key(key) {
|
||||||
return true;
|
return true;
|
||||||
|
|
@ -657,7 +708,15 @@ fn spawn_scoped_typing_task(
|
||||||
handle
|
handle
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::ChannelMessage) {
|
async fn process_channel_message(
|
||||||
|
ctx: Arc<ChannelRuntimeContext>,
|
||||||
|
msg: traits::ChannelMessage,
|
||||||
|
cancellation_token: CancellationToken,
|
||||||
|
) {
|
||||||
|
if cancellation_token.is_cancelled() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
" 💬 [{}] from {}: {}",
|
" 💬 [{}] from {}: {}",
|
||||||
msg.channel,
|
msg.channel,
|
||||||
|
|
@ -717,7 +776,13 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
||||||
println!(" ⏳ Processing message...");
|
println!(" ⏳ Processing message...");
|
||||||
let started_at = Instant::now();
|
let started_at = Instant::now();
|
||||||
|
|
||||||
// Build history from per-sender conversation cache
|
// Preserve user turn before the LLM call so interrupted requests keep context.
|
||||||
|
append_sender_turn(
|
||||||
|
ctx.as_ref(),
|
||||||
|
&history_key,
|
||||||
|
ChatMessage::user(&enriched_message),
|
||||||
|
);
|
||||||
|
|
||||||
let mut prior_turns = ctx
|
let mut prior_turns = ctx
|
||||||
.conversation_histories
|
.conversation_histories
|
||||||
.lock()
|
.lock()
|
||||||
|
|
@ -728,18 +793,15 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
||||||
|
|
||||||
let mut history = vec![ChatMessage::system(ctx.system_prompt.as_str())];
|
let mut history = vec![ChatMessage::system(ctx.system_prompt.as_str())];
|
||||||
history.append(&mut prior_turns);
|
history.append(&mut prior_turns);
|
||||||
history.push(ChatMessage::user(&enriched_message));
|
|
||||||
|
|
||||||
if let Some(instructions) = channel_delivery_instructions(&msg.channel) {
|
if let Some(instructions) = channel_delivery_instructions(&msg.channel) {
|
||||||
history.push(ChatMessage::system(instructions));
|
history.push(ChatMessage::system(instructions));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine if this channel supports streaming draft updates
|
|
||||||
let use_streaming = target_channel
|
let use_streaming = target_channel
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map_or(false, |ch| ch.supports_draft_updates());
|
.is_some_and(|ch| ch.supports_draft_updates());
|
||||||
|
|
||||||
// Set up streaming channel if supported
|
|
||||||
let (delta_tx, delta_rx) = if use_streaming {
|
let (delta_tx, delta_rx) = if use_streaming {
|
||||||
let (tx, rx) = tokio::sync::mpsc::channel::<String>(64);
|
let (tx, rx) = tokio::sync::mpsc::channel::<String>(64);
|
||||||
(Some(tx), Some(rx))
|
(Some(tx), Some(rx))
|
||||||
|
|
@ -747,7 +809,6 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
||||||
(None, None)
|
(None, None)
|
||||||
};
|
};
|
||||||
|
|
||||||
// Send initial draft message if streaming
|
|
||||||
let draft_message_id = if use_streaming {
|
let draft_message_id = if use_streaming {
|
||||||
if let Some(channel) = target_channel.as_ref() {
|
if let Some(channel) = target_channel.as_ref() {
|
||||||
match channel
|
match channel
|
||||||
|
|
@ -769,7 +830,6 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
// Spawn a task to forward streaming deltas to draft updates
|
|
||||||
let draft_updater = if let (Some(mut rx), Some(draft_id_ref), Some(channel_ref)) = (
|
let draft_updater = if let (Some(mut rx), Some(draft_id_ref), Some(channel_ref)) = (
|
||||||
delta_rx,
|
delta_rx,
|
||||||
draft_message_id.as_deref(),
|
draft_message_id.as_deref(),
|
||||||
|
|
@ -804,7 +864,14 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
||||||
_ => None,
|
_ => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let llm_result = tokio::time::timeout(
|
enum LlmExecutionResult {
|
||||||
|
Completed(Result<Result<String, anyhow::Error>, tokio::time::error::Elapsed>),
|
||||||
|
Cancelled,
|
||||||
|
}
|
||||||
|
|
||||||
|
let llm_result = tokio::select! {
|
||||||
|
() = cancellation_token.cancelled() => LlmExecutionResult::Cancelled,
|
||||||
|
result = tokio::time::timeout(
|
||||||
Duration::from_secs(ctx.message_timeout_secs),
|
Duration::from_secs(ctx.message_timeout_secs),
|
||||||
run_tool_call_loop(
|
run_tool_call_loop(
|
||||||
active_provider.as_ref(),
|
active_provider.as_ref(),
|
||||||
|
|
@ -819,12 +886,12 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
||||||
msg.channel.as_str(),
|
msg.channel.as_str(),
|
||||||
&ctx.multimodal,
|
&ctx.multimodal,
|
||||||
ctx.max_tool_iterations,
|
ctx.max_tool_iterations,
|
||||||
|
Some(cancellation_token.clone()),
|
||||||
delta_tx,
|
delta_tx,
|
||||||
),
|
),
|
||||||
)
|
) => LlmExecutionResult::Completed(result),
|
||||||
.await;
|
};
|
||||||
|
|
||||||
// Wait for draft updater to finish
|
|
||||||
if let Some(handle) = draft_updater {
|
if let Some(handle) = draft_updater {
|
||||||
let _ = handle.await;
|
let _ = handle.await;
|
||||||
}
|
}
|
||||||
|
|
@ -837,21 +904,26 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
||||||
}
|
}
|
||||||
|
|
||||||
match llm_result {
|
match llm_result {
|
||||||
Ok(Ok(response)) => {
|
LlmExecutionResult::Cancelled => {
|
||||||
// Save user + assistant turn to per-sender history
|
tracing::info!(
|
||||||
|
channel = %msg.channel,
|
||||||
|
sender = %msg.sender,
|
||||||
|
"Cancelled in-flight channel request due to newer message"
|
||||||
|
);
|
||||||
|
if let (Some(channel), Some(draft_id)) =
|
||||||
|
(target_channel.as_ref(), draft_message_id.as_deref())
|
||||||
{
|
{
|
||||||
let mut histories = ctx
|
if let Err(err) = channel.cancel_draft(&msg.reply_target, draft_id).await {
|
||||||
.conversation_histories
|
tracing::debug!("Failed to cancel draft on {}: {err}", channel.name());
|
||||||
.lock()
|
|
||||||
.unwrap_or_else(|e| e.into_inner());
|
|
||||||
let turns = histories.entry(history_key).or_default();
|
|
||||||
turns.push(ChatMessage::user(&enriched_message));
|
|
||||||
turns.push(ChatMessage::assistant(&response));
|
|
||||||
// Trim to MAX_CHANNEL_HISTORY (keep recent turns)
|
|
||||||
while turns.len() > MAX_CHANNEL_HISTORY {
|
|
||||||
turns.remove(0);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
LlmExecutionResult::Completed(Ok(Ok(response))) => {
|
||||||
|
append_sender_turn(
|
||||||
|
ctx.as_ref(),
|
||||||
|
&history_key,
|
||||||
|
ChatMessage::assistant(&response),
|
||||||
|
);
|
||||||
println!(
|
println!(
|
||||||
" 🤖 Reply ({}ms): {}",
|
" 🤖 Reply ({}ms): {}",
|
||||||
started_at.elapsed().as_millis(),
|
started_at.elapsed().as_millis(),
|
||||||
|
|
@ -882,7 +954,24 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(Err(e)) => {
|
LlmExecutionResult::Completed(Ok(Err(e))) => {
|
||||||
|
if crate::agent::loop_::is_tool_loop_cancelled(&e) || cancellation_token.is_cancelled()
|
||||||
|
{
|
||||||
|
tracing::info!(
|
||||||
|
channel = %msg.channel,
|
||||||
|
sender = %msg.sender,
|
||||||
|
"Cancelled in-flight channel request due to newer message"
|
||||||
|
);
|
||||||
|
if let (Some(channel), Some(draft_id)) =
|
||||||
|
(target_channel.as_ref(), draft_message_id.as_deref())
|
||||||
|
{
|
||||||
|
if let Err(err) = channel.cancel_draft(&msg.reply_target, draft_id).await {
|
||||||
|
tracing::debug!("Failed to cancel draft on {}: {err}", channel.name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if is_context_window_overflow_error(&e) {
|
if is_context_window_overflow_error(&e) {
|
||||||
let compacted = compact_sender_history(ctx.as_ref(), &history_key);
|
let compacted = compact_sender_history(ctx.as_ref(), &history_key);
|
||||||
let error_text = if compacted {
|
let error_text = if compacted {
|
||||||
|
|
@ -931,7 +1020,7 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(_) => {
|
LlmExecutionResult::Completed(Err(_)) => {
|
||||||
let timeout_msg = format!("LLM response timed out after {}s", ctx.message_timeout_secs);
|
let timeout_msg = format!("LLM response timed out after {}s", ctx.message_timeout_secs);
|
||||||
eprintln!(
|
eprintln!(
|
||||||
" ❌ {} (elapsed: {}ms)",
|
" ❌ {} (elapsed: {}ms)",
|
||||||
|
|
@ -965,6 +1054,11 @@ async fn run_message_dispatch_loop(
|
||||||
) {
|
) {
|
||||||
let semaphore = Arc::new(tokio::sync::Semaphore::new(max_in_flight_messages));
|
let semaphore = Arc::new(tokio::sync::Semaphore::new(max_in_flight_messages));
|
||||||
let mut workers = tokio::task::JoinSet::new();
|
let mut workers = tokio::task::JoinSet::new();
|
||||||
|
let in_flight_by_sender = Arc::new(tokio::sync::Mutex::new(HashMap::<
|
||||||
|
String,
|
||||||
|
InFlightSenderTaskState,
|
||||||
|
>::new()));
|
||||||
|
let task_sequence = Arc::new(AtomicU64::new(1));
|
||||||
|
|
||||||
while let Some(msg) = rx.recv().await {
|
while let Some(msg) = rx.recv().await {
|
||||||
let permit = match Arc::clone(&semaphore).acquire_owned().await {
|
let permit = match Arc::clone(&semaphore).acquire_owned().await {
|
||||||
|
|
@ -973,9 +1067,54 @@ async fn run_message_dispatch_loop(
|
||||||
};
|
};
|
||||||
|
|
||||||
let worker_ctx = Arc::clone(&ctx);
|
let worker_ctx = Arc::clone(&ctx);
|
||||||
|
let in_flight = Arc::clone(&in_flight_by_sender);
|
||||||
|
let task_sequence = Arc::clone(&task_sequence);
|
||||||
workers.spawn(async move {
|
workers.spawn(async move {
|
||||||
let _permit = permit;
|
let _permit = permit;
|
||||||
process_channel_message(worker_ctx, msg).await;
|
let interrupt_enabled =
|
||||||
|
worker_ctx.interrupt_on_new_message && msg.channel == "telegram";
|
||||||
|
let sender_scope_key = interruption_scope_key(&msg);
|
||||||
|
let cancellation_token = CancellationToken::new();
|
||||||
|
let completion = Arc::new(InFlightTaskCompletion::new());
|
||||||
|
let task_id = task_sequence.fetch_add(1, Ordering::Relaxed);
|
||||||
|
|
||||||
|
if interrupt_enabled {
|
||||||
|
let previous = {
|
||||||
|
let mut active = in_flight.lock().await;
|
||||||
|
active.insert(
|
||||||
|
sender_scope_key.clone(),
|
||||||
|
InFlightSenderTaskState {
|
||||||
|
task_id,
|
||||||
|
cancellation: cancellation_token.clone(),
|
||||||
|
completion: Arc::clone(&completion),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(previous) = previous {
|
||||||
|
tracing::info!(
|
||||||
|
channel = %msg.channel,
|
||||||
|
sender = %msg.sender,
|
||||||
|
"Interrupting previous in-flight request for sender"
|
||||||
|
);
|
||||||
|
previous.cancellation.cancel();
|
||||||
|
previous.completion.wait().await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
process_channel_message(worker_ctx, msg, cancellation_token).await;
|
||||||
|
|
||||||
|
if interrupt_enabled {
|
||||||
|
let mut active = in_flight.lock().await;
|
||||||
|
if active
|
||||||
|
.get(&sender_scope_key)
|
||||||
|
.is_some_and(|state| state.task_id == task_id)
|
||||||
|
{
|
||||||
|
active.remove(&sender_scope_key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
completion.mark_done();
|
||||||
});
|
});
|
||||||
|
|
||||||
while let Some(result) = workers.try_join_next() {
|
while let Some(result) = workers.try_join_next() {
|
||||||
|
|
@ -2101,6 +2240,11 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||||
provider_cache_seed.insert(provider_name.clone(), Arc::clone(&provider));
|
provider_cache_seed.insert(provider_name.clone(), Arc::clone(&provider));
|
||||||
let message_timeout_secs =
|
let message_timeout_secs =
|
||||||
effective_channel_message_timeout_secs(config.channels_config.message_timeout_secs);
|
effective_channel_message_timeout_secs(config.channels_config.message_timeout_secs);
|
||||||
|
let interrupt_on_new_message = config
|
||||||
|
.channels_config
|
||||||
|
.telegram
|
||||||
|
.as_ref()
|
||||||
|
.is_some_and(|tg| tg.interrupt_on_new_message);
|
||||||
|
|
||||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||||
channels_by_name,
|
channels_by_name,
|
||||||
|
|
@ -2124,6 +2268,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||||
provider_runtime_options,
|
provider_runtime_options,
|
||||||
workspace_dir: Arc::new(config.workspace_dir.clone()),
|
workspace_dir: Arc::new(config.workspace_dir.clone()),
|
||||||
message_timeout_secs,
|
message_timeout_secs,
|
||||||
|
interrupt_on_new_message,
|
||||||
multimodal: config.multimodal.clone(),
|
multimodal: config.multimodal.clone(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -2245,6 +2390,7 @@ mod tests {
|
||||||
api_key: None,
|
api_key: None,
|
||||||
api_url: None,
|
api_url: None,
|
||||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||||
|
interrupt_on_new_message: false,
|
||||||
multimodal: crate::config::MultimodalConfig::default(),
|
multimodal: crate::config::MultimodalConfig::default(),
|
||||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||||
|
|
@ -2527,6 +2673,43 @@ mod tests {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct DelayedHistoryCaptureProvider {
|
||||||
|
delay: Duration,
|
||||||
|
calls: std::sync::Mutex<Vec<Vec<(String, String)>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl Provider for DelayedHistoryCaptureProvider {
|
||||||
|
async fn chat_with_system(
|
||||||
|
&self,
|
||||||
|
_system_prompt: Option<&str>,
|
||||||
|
_message: &str,
|
||||||
|
_model: &str,
|
||||||
|
_temperature: f64,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
Ok("fallback".to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn chat_with_history(
|
||||||
|
&self,
|
||||||
|
messages: &[ChatMessage],
|
||||||
|
_model: &str,
|
||||||
|
_temperature: f64,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
let snapshot = messages
|
||||||
|
.iter()
|
||||||
|
.map(|m| (m.role.clone(), m.content.clone()))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let call_index = {
|
||||||
|
let mut calls = self.calls.lock().unwrap_or_else(|e| e.into_inner());
|
||||||
|
calls.push(snapshot);
|
||||||
|
calls.len()
|
||||||
|
};
|
||||||
|
tokio::time::sleep(self.delay).await;
|
||||||
|
Ok(format!("response-{call_index}"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct MockPriceTool;
|
struct MockPriceTool;
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
|
|
@ -2630,6 +2813,7 @@ mod tests {
|
||||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||||
|
interrupt_on_new_message: false,
|
||||||
multimodal: crate::config::MultimodalConfig::default(),
|
multimodal: crate::config::MultimodalConfig::default(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -2644,6 +2828,7 @@ mod tests {
|
||||||
timestamp: 1,
|
timestamp: 1,
|
||||||
thread_ts: None,
|
thread_ts: None,
|
||||||
},
|
},
|
||||||
|
CancellationToken::new(),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
|
@ -2685,6 +2870,7 @@ mod tests {
|
||||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||||
|
interrupt_on_new_message: false,
|
||||||
multimodal: crate::config::MultimodalConfig::default(),
|
multimodal: crate::config::MultimodalConfig::default(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -2699,6 +2885,7 @@ mod tests {
|
||||||
timestamp: 2,
|
timestamp: 2,
|
||||||
thread_ts: None,
|
thread_ts: None,
|
||||||
},
|
},
|
||||||
|
CancellationToken::new(),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
|
@ -2749,6 +2936,7 @@ mod tests {
|
||||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||||
|
interrupt_on_new_message: false,
|
||||||
multimodal: crate::config::MultimodalConfig::default(),
|
multimodal: crate::config::MultimodalConfig::default(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -2763,6 +2951,7 @@ mod tests {
|
||||||
timestamp: 1,
|
timestamp: 1,
|
||||||
thread_ts: None,
|
thread_ts: None,
|
||||||
},
|
},
|
||||||
|
CancellationToken::new(),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
|
@ -2834,6 +3023,7 @@ mod tests {
|
||||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||||
|
interrupt_on_new_message: false,
|
||||||
multimodal: crate::config::MultimodalConfig::default(),
|
multimodal: crate::config::MultimodalConfig::default(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -2848,6 +3038,7 @@ mod tests {
|
||||||
timestamp: 2,
|
timestamp: 2,
|
||||||
thread_ts: None,
|
thread_ts: None,
|
||||||
},
|
},
|
||||||
|
CancellationToken::new(),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
|
@ -2895,6 +3086,7 @@ mod tests {
|
||||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||||
|
interrupt_on_new_message: false,
|
||||||
multimodal: crate::config::MultimodalConfig::default(),
|
multimodal: crate::config::MultimodalConfig::default(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -2909,6 +3101,7 @@ mod tests {
|
||||||
timestamp: 1,
|
timestamp: 1,
|
||||||
thread_ts: None,
|
thread_ts: None,
|
||||||
},
|
},
|
||||||
|
CancellationToken::new(),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
|
@ -2951,6 +3144,7 @@ mod tests {
|
||||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||||
|
interrupt_on_new_message: false,
|
||||||
multimodal: crate::config::MultimodalConfig::default(),
|
multimodal: crate::config::MultimodalConfig::default(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -2965,6 +3159,7 @@ mod tests {
|
||||||
timestamp: 2,
|
timestamp: 2,
|
||||||
thread_ts: None,
|
thread_ts: None,
|
||||||
},
|
},
|
||||||
|
CancellationToken::new(),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
|
@ -3058,6 +3253,7 @@ mod tests {
|
||||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||||
|
interrupt_on_new_message: false,
|
||||||
multimodal: crate::config::MultimodalConfig::default(),
|
multimodal: crate::config::MultimodalConfig::default(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -3100,6 +3296,171 @@ mod tests {
|
||||||
assert_eq!(sent_messages.len(), 2);
|
assert_eq!(sent_messages.len(), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn message_dispatch_interrupts_in_flight_telegram_request_and_preserves_context() {
|
||||||
|
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||||||
|
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||||
|
|
||||||
|
let mut channels_by_name = HashMap::new();
|
||||||
|
channels_by_name.insert(channel.name().to_string(), channel);
|
||||||
|
|
||||||
|
let provider_impl = Arc::new(DelayedHistoryCaptureProvider {
|
||||||
|
delay: Duration::from_millis(250),
|
||||||
|
calls: std::sync::Mutex::new(Vec::new()),
|
||||||
|
});
|
||||||
|
|
||||||
|
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||||
|
channels_by_name: Arc::new(channels_by_name),
|
||||||
|
provider: provider_impl.clone(),
|
||||||
|
default_provider: Arc::new("test-provider".to_string()),
|
||||||
|
memory: Arc::new(NoopMemory),
|
||||||
|
tools_registry: Arc::new(vec![]),
|
||||||
|
observer: Arc::new(NoopObserver),
|
||||||
|
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||||
|
model: Arc::new("test-model".to_string()),
|
||||||
|
temperature: 0.0,
|
||||||
|
auto_save_memory: false,
|
||||||
|
max_tool_iterations: 10,
|
||||||
|
min_relevance_score: 0.0,
|
||||||
|
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||||
|
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||||
|
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||||
|
api_key: None,
|
||||||
|
api_url: None,
|
||||||
|
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||||
|
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||||
|
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||||
|
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||||
|
interrupt_on_new_message: true,
|
||||||
|
multimodal: crate::config::MultimodalConfig::default(),
|
||||||
|
});
|
||||||
|
|
||||||
|
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||||
|
let send_task = tokio::spawn(async move {
|
||||||
|
tx.send(traits::ChannelMessage {
|
||||||
|
id: "msg-1".to_string(),
|
||||||
|
sender: "alice".to_string(),
|
||||||
|
reply_target: "chat-1".to_string(),
|
||||||
|
content: "forwarded content".to_string(),
|
||||||
|
channel: "telegram".to_string(),
|
||||||
|
timestamp: 1,
|
||||||
|
thread_ts: None,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
tokio::time::sleep(Duration::from_millis(40)).await;
|
||||||
|
tx.send(traits::ChannelMessage {
|
||||||
|
id: "msg-2".to_string(),
|
||||||
|
sender: "alice".to_string(),
|
||||||
|
reply_target: "chat-1".to_string(),
|
||||||
|
content: "summarize this".to_string(),
|
||||||
|
channel: "telegram".to_string(),
|
||||||
|
timestamp: 2,
|
||||||
|
thread_ts: None,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
run_message_dispatch_loop(rx, runtime_ctx, 4).await;
|
||||||
|
send_task.await.unwrap();
|
||||||
|
|
||||||
|
let sent_messages = channel_impl.sent_messages.lock().await;
|
||||||
|
assert_eq!(sent_messages.len(), 1);
|
||||||
|
assert!(sent_messages[0].starts_with("chat-1:"));
|
||||||
|
assert!(sent_messages[0].contains("response-2"));
|
||||||
|
drop(sent_messages);
|
||||||
|
|
||||||
|
let calls = provider_impl
|
||||||
|
.calls
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(|e| e.into_inner());
|
||||||
|
assert_eq!(calls.len(), 2);
|
||||||
|
let second_call = &calls[1];
|
||||||
|
assert!(second_call
|
||||||
|
.iter()
|
||||||
|
.any(|(role, content)| { role == "user" && content.contains("forwarded content") }));
|
||||||
|
assert!(second_call
|
||||||
|
.iter()
|
||||||
|
.any(|(role, content)| { role == "user" && content.contains("summarize this") }));
|
||||||
|
assert!(
|
||||||
|
!second_call.iter().any(|(role, _)| role == "assistant"),
|
||||||
|
"cancelled turn should not persist an assistant response"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn message_dispatch_interrupt_scope_is_same_sender_same_chat() {
|
||||||
|
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||||||
|
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||||
|
|
||||||
|
let mut channels_by_name = HashMap::new();
|
||||||
|
channels_by_name.insert(channel.name().to_string(), channel);
|
||||||
|
|
||||||
|
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||||
|
channels_by_name: Arc::new(channels_by_name),
|
||||||
|
provider: Arc::new(SlowProvider {
|
||||||
|
delay: Duration::from_millis(180),
|
||||||
|
}),
|
||||||
|
default_provider: Arc::new("test-provider".to_string()),
|
||||||
|
memory: Arc::new(NoopMemory),
|
||||||
|
tools_registry: Arc::new(vec![]),
|
||||||
|
observer: Arc::new(NoopObserver),
|
||||||
|
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||||
|
model: Arc::new("test-model".to_string()),
|
||||||
|
temperature: 0.0,
|
||||||
|
auto_save_memory: false,
|
||||||
|
max_tool_iterations: 10,
|
||||||
|
min_relevance_score: 0.0,
|
||||||
|
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||||
|
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||||
|
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||||
|
api_key: None,
|
||||||
|
api_url: None,
|
||||||
|
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||||
|
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||||
|
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||||
|
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||||
|
interrupt_on_new_message: true,
|
||||||
|
multimodal: crate::config::MultimodalConfig::default(),
|
||||||
|
});
|
||||||
|
|
||||||
|
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||||
|
let send_task = tokio::spawn(async move {
|
||||||
|
tx.send(traits::ChannelMessage {
|
||||||
|
id: "msg-a".to_string(),
|
||||||
|
sender: "alice".to_string(),
|
||||||
|
reply_target: "chat-1".to_string(),
|
||||||
|
content: "first chat".to_string(),
|
||||||
|
channel: "telegram".to_string(),
|
||||||
|
timestamp: 1,
|
||||||
|
thread_ts: None,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
tokio::time::sleep(Duration::from_millis(30)).await;
|
||||||
|
tx.send(traits::ChannelMessage {
|
||||||
|
id: "msg-b".to_string(),
|
||||||
|
sender: "alice".to_string(),
|
||||||
|
reply_target: "chat-2".to_string(),
|
||||||
|
content: "second chat".to_string(),
|
||||||
|
channel: "telegram".to_string(),
|
||||||
|
timestamp: 2,
|
||||||
|
thread_ts: None,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
run_message_dispatch_loop(rx, runtime_ctx, 4).await;
|
||||||
|
send_task.await.unwrap();
|
||||||
|
|
||||||
|
let sent_messages = channel_impl.sent_messages.lock().await;
|
||||||
|
assert_eq!(sent_messages.len(), 2);
|
||||||
|
assert!(sent_messages.iter().any(|msg| msg.starts_with("chat-1:")));
|
||||||
|
assert!(sent_messages.iter().any(|msg| msg.starts_with("chat-2:")));
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn process_channel_message_cancels_scoped_typing_task() {
|
async fn process_channel_message_cancels_scoped_typing_task() {
|
||||||
let channel_impl = Arc::new(RecordingChannel::default());
|
let channel_impl = Arc::new(RecordingChannel::default());
|
||||||
|
|
@ -3132,6 +3493,7 @@ mod tests {
|
||||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||||
|
interrupt_on_new_message: false,
|
||||||
multimodal: crate::config::MultimodalConfig::default(),
|
multimodal: crate::config::MultimodalConfig::default(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -3146,6 +3508,7 @@ mod tests {
|
||||||
timestamp: 1,
|
timestamp: 1,
|
||||||
thread_ts: None,
|
thread_ts: None,
|
||||||
},
|
},
|
||||||
|
CancellationToken::new(),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
|
@ -3579,6 +3942,7 @@ mod tests {
|
||||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||||
|
interrupt_on_new_message: false,
|
||||||
multimodal: crate::config::MultimodalConfig::default(),
|
multimodal: crate::config::MultimodalConfig::default(),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -3593,6 +3957,7 @@ mod tests {
|
||||||
timestamp: 1,
|
timestamp: 1,
|
||||||
thread_ts: None,
|
thread_ts: None,
|
||||||
},
|
},
|
||||||
|
CancellationToken::new(),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
|
@ -3607,6 +3972,7 @@ mod tests {
|
||||||
timestamp: 2,
|
timestamp: 2,
|
||||||
thread_ts: None,
|
thread_ts: None,
|
||||||
},
|
},
|
||||||
|
CancellationToken::new(),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -45,10 +45,7 @@ fn split_message_for_telegram(message: &str) -> Vec<String> {
|
||||||
pos + 1
|
pos + 1
|
||||||
} else {
|
} else {
|
||||||
// Try space as fallback
|
// Try space as fallback
|
||||||
search_area
|
search_area.rfind(' ').unwrap_or(hard_split) + 1
|
||||||
.rfind(' ')
|
|
||||||
.unwrap_or(hard_split)
|
|
||||||
+ 1
|
|
||||||
}
|
}
|
||||||
} else if let Some(pos) = search_area.rfind(' ') {
|
} else if let Some(pos) = search_area.rfind(' ') {
|
||||||
pos + 1
|
pos + 1
|
||||||
|
|
@ -1632,6 +1629,37 @@ impl Channel for TelegramChannel {
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn cancel_draft(&self, recipient: &str, message_id: &str) -> anyhow::Result<()> {
|
||||||
|
let (chat_id, _) = Self::parse_reply_target(recipient);
|
||||||
|
self.last_draft_edit.lock().remove(&chat_id);
|
||||||
|
|
||||||
|
let message_id = match message_id.parse::<i64>() {
|
||||||
|
Ok(id) => id,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::debug!("Invalid Telegram draft message_id '{message_id}': {e}");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.post(self.api_url("deleteMessage"))
|
||||||
|
.json(&serde_json::json!({
|
||||||
|
"chat_id": chat_id,
|
||||||
|
"message_id": message_id,
|
||||||
|
}))
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
let status = response.status();
|
||||||
|
let body = response.text().await.unwrap_or_default();
|
||||||
|
tracing::debug!("Telegram deleteMessage failed ({status}): {body}");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
|
async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
|
||||||
// Strip tool_call tags before processing to prevent Markdown parsing failures
|
// Strip tool_call tags before processing to prevent Markdown parsing failures
|
||||||
let content = strip_tool_call_tags(&message.content);
|
let content = strip_tool_call_tags(&message.content);
|
||||||
|
|
@ -2844,7 +2872,10 @@ mod tests {
|
||||||
msg.push_str(&"x".repeat(4085));
|
msg.push_str(&"x".repeat(4085));
|
||||||
msg.push_str("\n```\nMore text after code block");
|
msg.push_str("\n```\nMore text after code block");
|
||||||
let parts = split_message_for_telegram(&msg);
|
let parts = split_message_for_telegram(&msg);
|
||||||
assert!(parts.len() >= 2, "code block spanning boundary should split");
|
assert!(
|
||||||
|
parts.len() >= 2,
|
||||||
|
"code block spanning boundary should split"
|
||||||
|
);
|
||||||
for part in &parts {
|
for part in &parts {
|
||||||
assert!(
|
assert!(
|
||||||
part.len() <= TELEGRAM_MAX_MESSAGE_LENGTH,
|
part.len() <= TELEGRAM_MAX_MESSAGE_LENGTH,
|
||||||
|
|
|
||||||
|
|
@ -113,6 +113,11 @@ pub trait Channel: Send + Sync {
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Cancel and remove a previously sent draft message if the channel supports it.
|
||||||
|
async fn cancel_draft(&self, _recipient: &str, _message_id: &str) -> anyhow::Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
@ -198,6 +203,7 @@ mod tests {
|
||||||
.finalize_draft("bob", "msg_1", "final text")
|
.finalize_draft("bob", "msg_1", "final text")
|
||||||
.await
|
.await
|
||||||
.is_ok());
|
.is_ok());
|
||||||
|
assert!(channel.cancel_draft("bob", "msg_1").await.is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,7 @@ mod tests {
|
||||||
allowed_users: vec!["alice".into()],
|
allowed_users: vec!["alice".into()],
|
||||||
stream_mode: StreamMode::default(),
|
stream_mode: StreamMode::default(),
|
||||||
draft_update_interval_ms: 1000,
|
draft_update_interval_ms: 1000,
|
||||||
|
interrupt_on_new_message: false,
|
||||||
mention_only: false,
|
mention_only: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2125,6 +2125,10 @@ pub struct TelegramConfig {
|
||||||
/// Minimum interval (ms) between draft message edits to avoid rate limits.
|
/// Minimum interval (ms) between draft message edits to avoid rate limits.
|
||||||
#[serde(default = "default_draft_update_interval_ms")]
|
#[serde(default = "default_draft_update_interval_ms")]
|
||||||
pub draft_update_interval_ms: u64,
|
pub draft_update_interval_ms: u64,
|
||||||
|
/// When true, a newer Telegram message from the same sender in the same chat
|
||||||
|
/// cancels the in-flight request and starts a fresh response with preserved history.
|
||||||
|
#[serde(default)]
|
||||||
|
pub interrupt_on_new_message: bool,
|
||||||
/// When true, only respond to messages that @-mention the bot in groups.
|
/// When true, only respond to messages that @-mention the bot in groups.
|
||||||
/// Direct messages are always processed.
|
/// Direct messages are always processed.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
|
@ -3520,6 +3524,7 @@ default_temperature = 0.7
|
||||||
allowed_users: vec!["user1".into()],
|
allowed_users: vec!["user1".into()],
|
||||||
stream_mode: StreamMode::default(),
|
stream_mode: StreamMode::default(),
|
||||||
draft_update_interval_ms: default_draft_update_interval_ms(),
|
draft_update_interval_ms: default_draft_update_interval_ms(),
|
||||||
|
interrupt_on_new_message: false,
|
||||||
mention_only: false,
|
mention_only: false,
|
||||||
}),
|
}),
|
||||||
discord: None,
|
discord: None,
|
||||||
|
|
@ -3852,6 +3857,7 @@ tool_dispatcher = "xml"
|
||||||
allowed_users: vec!["alice".into(), "bob".into()],
|
allowed_users: vec!["alice".into(), "bob".into()],
|
||||||
stream_mode: StreamMode::Partial,
|
stream_mode: StreamMode::Partial,
|
||||||
draft_update_interval_ms: 500,
|
draft_update_interval_ms: 500,
|
||||||
|
interrupt_on_new_message: true,
|
||||||
mention_only: false,
|
mention_only: false,
|
||||||
};
|
};
|
||||||
let json = serde_json::to_string(&tc).unwrap();
|
let json = serde_json::to_string(&tc).unwrap();
|
||||||
|
|
@ -3860,6 +3866,7 @@ tool_dispatcher = "xml"
|
||||||
assert_eq!(parsed.allowed_users.len(), 2);
|
assert_eq!(parsed.allowed_users.len(), 2);
|
||||||
assert_eq!(parsed.stream_mode, StreamMode::Partial);
|
assert_eq!(parsed.stream_mode, StreamMode::Partial);
|
||||||
assert_eq!(parsed.draft_update_interval_ms, 500);
|
assert_eq!(parsed.draft_update_interval_ms, 500);
|
||||||
|
assert!(parsed.interrupt_on_new_message);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -3868,6 +3875,7 @@ tool_dispatcher = "xml"
|
||||||
let parsed: TelegramConfig = serde_json::from_str(json).unwrap();
|
let parsed: TelegramConfig = serde_json::from_str(json).unwrap();
|
||||||
assert_eq!(parsed.stream_mode, StreamMode::Off);
|
assert_eq!(parsed.stream_mode, StreamMode::Off);
|
||||||
assert_eq!(parsed.draft_update_interval_ms, 1000);
|
assert_eq!(parsed.draft_update_interval_ms, 1000);
|
||||||
|
assert!(!parsed.interrupt_on_new_message);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
|
|
@ -321,6 +321,7 @@ mod tests {
|
||||||
allowed_users: vec![],
|
allowed_users: vec![],
|
||||||
stream_mode: crate::config::StreamMode::default(),
|
stream_mode: crate::config::StreamMode::default(),
|
||||||
draft_update_interval_ms: 1000,
|
draft_update_interval_ms: 1000,
|
||||||
|
interrupt_on_new_message: false,
|
||||||
mention_only: false,
|
mention_only: false,
|
||||||
});
|
});
|
||||||
assert!(has_supervised_channels(&config));
|
assert!(has_supervised_channels(&config));
|
||||||
|
|
|
||||||
|
|
@ -790,6 +790,7 @@ mod tests {
|
||||||
allowed_users: vec!["user".into()],
|
allowed_users: vec!["user".into()],
|
||||||
stream_mode: StreamMode::default(),
|
stream_mode: StreamMode::default(),
|
||||||
draft_update_interval_ms: 1000,
|
draft_update_interval_ms: 1000,
|
||||||
|
interrupt_on_new_message: false,
|
||||||
mention_only: false,
|
mention_only: false,
|
||||||
});
|
});
|
||||||
let entries = all_integrations();
|
let entries = all_integrations();
|
||||||
|
|
|
||||||
|
|
@ -2793,6 +2793,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||||
allowed_users,
|
allowed_users,
|
||||||
stream_mode: StreamMode::default(),
|
stream_mode: StreamMode::default(),
|
||||||
draft_update_interval_ms: 1000,
|
draft_update_interval_ms: 1000,
|
||||||
|
interrupt_on_new_message: false,
|
||||||
mention_only: false,
|
mention_only: false,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -128,7 +128,12 @@ struct CountingTool {
|
||||||
impl CountingTool {
|
impl CountingTool {
|
||||||
fn new() -> (Self, Arc<Mutex<usize>>) {
|
fn new() -> (Self, Arc<Mutex<usize>>) {
|
||||||
let count = Arc::new(Mutex::new(0));
|
let count = Arc::new(Mutex::new(0));
|
||||||
(Self { count: count.clone() }, count)
|
(
|
||||||
|
Self {
|
||||||
|
count: count.clone(),
|
||||||
|
},
|
||||||
|
count,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -295,10 +300,7 @@ async fn agent_handles_mixed_tool_success_and_failure() {
|
||||||
text_response("Mixed results processed"),
|
text_response("Mixed results processed"),
|
||||||
]));
|
]));
|
||||||
|
|
||||||
let mut agent = build_agent(
|
let mut agent = build_agent(provider, vec![Box::new(EchoTool), Box::new(FailingTool)]);
|
||||||
provider,
|
|
||||||
vec![Box::new(EchoTool), Box::new(FailingTool)],
|
|
||||||
);
|
|
||||||
let response = agent.turn("mixed tools").await.unwrap();
|
let response = agent.turn("mixed tools").await.unwrap();
|
||||||
assert!(!response.is_empty());
|
assert!(!response.is_empty());
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ fn channel_message_sender_field_holds_platform_user_id() {
|
||||||
content: "test message".into(),
|
content: "test message".into(),
|
||||||
channel: "telegram".into(),
|
channel: "telegram".into(),
|
||||||
timestamp: 1700000000,
|
timestamp: 1700000000,
|
||||||
|
thread_ts: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
assert_eq!(msg.sender, "123456789");
|
assert_eq!(msg.sender, "123456789");
|
||||||
|
|
@ -45,6 +46,7 @@ fn channel_message_reply_target_distinct_from_sender() {
|
||||||
content: "test message".into(),
|
content: "test message".into(),
|
||||||
channel: "discord".into(),
|
channel: "discord".into(),
|
||||||
timestamp: 1700000000,
|
timestamp: 1700000000,
|
||||||
|
thread_ts: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
assert_ne!(
|
assert_ne!(
|
||||||
|
|
@ -64,9 +66,13 @@ fn channel_message_fields_not_swapped() {
|
||||||
content: "payload".into(),
|
content: "payload".into(),
|
||||||
channel: "test".into(),
|
channel: "test".into(),
|
||||||
timestamp: 1700000000,
|
timestamp: 1700000000,
|
||||||
|
thread_ts: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
assert_eq!(msg.sender, "sender_value", "sender field should not be swapped");
|
assert_eq!(
|
||||||
|
msg.sender, "sender_value",
|
||||||
|
"sender field should not be swapped"
|
||||||
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
msg.reply_target, "target_value",
|
msg.reply_target, "target_value",
|
||||||
"reply_target field should not be swapped"
|
"reply_target field should not be swapped"
|
||||||
|
|
@ -86,6 +92,7 @@ fn channel_message_preserves_all_fields_on_clone() {
|
||||||
content: "cloned content".into(),
|
content: "cloned content".into(),
|
||||||
channel: "test_channel".into(),
|
channel: "test_channel".into(),
|
||||||
timestamp: 1700000001,
|
timestamp: 1700000001,
|
||||||
|
thread_ts: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let cloned = original.clone();
|
let cloned = original.clone();
|
||||||
|
|
@ -170,10 +177,7 @@ impl Channel for CapturingChannel {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn listen(
|
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
|
||||||
&self,
|
|
||||||
tx: tokio::sync::mpsc::Sender<ChannelMessage>,
|
|
||||||
) -> anyhow::Result<()> {
|
|
||||||
tx.send(ChannelMessage {
|
tx.send(ChannelMessage {
|
||||||
id: "listen_1".into(),
|
id: "listen_1".into(),
|
||||||
sender: "test_sender".into(),
|
sender: "test_sender".into(),
|
||||||
|
|
@ -181,6 +185,7 @@ impl Channel for CapturingChannel {
|
||||||
content: "incoming".into(),
|
content: "incoming".into(),
|
||||||
channel: "capturing".into(),
|
channel: "capturing".into(),
|
||||||
timestamp: 1700000000,
|
timestamp: 1700000000,
|
||||||
|
thread_ts: None,
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.map_err(|e| anyhow::anyhow!(e.to_string()))
|
.map_err(|e| anyhow::anyhow!(e.to_string()))
|
||||||
|
|
@ -266,7 +271,10 @@ async fn channel_draft_defaults() {
|
||||||
.send_draft(&SendMessage::new("draft", "target"))
|
.send_draft(&SendMessage::new("draft", "target"))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert!(draft_result.is_none(), "default send_draft should return None");
|
assert!(
|
||||||
|
draft_result.is_none(),
|
||||||
|
"default send_draft should return None"
|
||||||
|
);
|
||||||
|
|
||||||
assert!(channel
|
assert!(channel
|
||||||
.update_draft("target", "msg_1", "updated")
|
.update_draft("target", "msg_1", "updated")
|
||||||
|
|
|
||||||
|
|
@ -232,7 +232,10 @@ fn workspace_dir_creation_in_tempdir() {
|
||||||
|
|
||||||
fs::create_dir_all(&workspace_dir).expect("workspace dir creation should succeed");
|
fs::create_dir_all(&workspace_dir).expect("workspace dir creation should succeed");
|
||||||
assert!(workspace_dir.exists(), "workspace dir should exist");
|
assert!(workspace_dir.exists(), "workspace dir should exist");
|
||||||
assert!(workspace_dir.is_dir(), "workspace path should be a directory");
|
assert!(
|
||||||
|
workspace_dir.is_dir(),
|
||||||
|
"workspace path should be a directory"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
|
|
@ -29,10 +29,17 @@ async fn sqlite_memory_store_same_key_deduplicates() {
|
||||||
|
|
||||||
// Should have exactly 1 entry, not 2
|
// Should have exactly 1 entry, not 2
|
||||||
let count = mem.count().await.unwrap();
|
let count = mem.count().await.unwrap();
|
||||||
assert_eq!(count, 1, "storing same key twice should not create duplicates");
|
assert_eq!(
|
||||||
|
count, 1,
|
||||||
|
"storing same key twice should not create duplicates"
|
||||||
|
);
|
||||||
|
|
||||||
// Content should be the latest version
|
// Content should be the latest version
|
||||||
let entry = mem.get("greeting").await.unwrap().expect("entry should exist");
|
let entry = mem
|
||||||
|
.get("greeting")
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.expect("entry should exist");
|
||||||
assert_eq!(entry.content, "hello updated");
|
assert_eq!(entry.content, "hello updated");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -63,7 +70,12 @@ async fn sqlite_memory_persists_across_reinitialization() {
|
||||||
// First "session": store data
|
// First "session": store data
|
||||||
{
|
{
|
||||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
mem.store("persistent_fact", "Rust is great", MemoryCategory::Core, None)
|
mem.store(
|
||||||
|
"persistent_fact",
|
||||||
|
"Rust is great",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
@ -158,7 +170,12 @@ async fn sqlite_memory_global_recall_includes_all_sessions() {
|
||||||
let tmp = tempfile::TempDir::new().unwrap();
|
let tmp = tempfile::TempDir::new().unwrap();
|
||||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
|
|
||||||
mem.store("global_a", "alpha content", MemoryCategory::Core, Some("s1"))
|
mem.store(
|
||||||
|
"global_a",
|
||||||
|
"alpha content",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
Some("s1"),
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("global_b", "beta content", MemoryCategory::Core, Some("s2"))
|
mem.store("global_b", "beta content", MemoryCategory::Core, Some("s2"))
|
||||||
|
|
@ -167,7 +184,10 @@ async fn sqlite_memory_global_recall_includes_all_sessions() {
|
||||||
|
|
||||||
// Global count should include all
|
// Global count should include all
|
||||||
let count = mem.count().await.unwrap();
|
let count = mem.count().await.unwrap();
|
||||||
assert_eq!(count, 2, "global count should include entries from all sessions");
|
assert_eq!(
|
||||||
|
count, 2,
|
||||||
|
"global count should include entries from all sessions"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ─────────────────────────────────────────────────────────────────────────────
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
@ -179,10 +199,20 @@ async fn sqlite_memory_recall_returns_relevant_results() {
|
||||||
let tmp = tempfile::TempDir::new().unwrap();
|
let tmp = tempfile::TempDir::new().unwrap();
|
||||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
|
|
||||||
mem.store("lang_pref", "User prefers Rust programming", MemoryCategory::Core, None)
|
mem.store(
|
||||||
|
"lang_pref",
|
||||||
|
"User prefers Rust programming",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("food_pref", "User likes sushi for lunch", MemoryCategory::Core, None)
|
mem.store(
|
||||||
|
"food_pref",
|
||||||
|
"User likes sushi for lunch",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
@ -229,10 +259,7 @@ async fn sqlite_memory_recall_empty_query_returns_empty() {
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let results = mem.recall("", 10, None).await.unwrap();
|
let results = mem.recall("", 10, None).await.unwrap();
|
||||||
assert!(
|
assert!(results.is_empty(), "empty query should return no results");
|
||||||
results.is_empty(),
|
|
||||||
"empty query should return no results"
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ─────────────────────────────────────────────────────────────────────────────
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
@ -322,7 +349,12 @@ async fn sqlite_memory_list_by_category() {
|
||||||
mem.store("daily_note", "daily note", MemoryCategory::Daily, None)
|
mem.store("daily_note", "daily note", MemoryCategory::Daily, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("conv_msg", "conversation msg", MemoryCategory::Conversation, None)
|
mem.store(
|
||||||
|
"conv_msg",
|
||||||
|
"conversation msg",
|
||||||
|
MemoryCategory::Conversation,
|
||||||
|
None,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -80,7 +80,10 @@ fn tool_call_has_required_fields() {
|
||||||
|
|
||||||
let json = serde_json::to_value(&tc).unwrap();
|
let json = serde_json::to_value(&tc).unwrap();
|
||||||
assert!(json.get("id").is_some(), "ToolCall must have 'id' field");
|
assert!(json.get("id").is_some(), "ToolCall must have 'id' field");
|
||||||
assert!(json.get("name").is_some(), "ToolCall must have 'name' field");
|
assert!(
|
||||||
|
json.get("name").is_some(),
|
||||||
|
"ToolCall must have 'name' field"
|
||||||
|
);
|
||||||
assert!(
|
assert!(
|
||||||
json.get("arguments").is_some(),
|
json.get("arguments").is_some(),
|
||||||
"ToolCall must have 'arguments' field"
|
"ToolCall must have 'arguments' field"
|
||||||
|
|
@ -98,7 +101,10 @@ fn tool_call_id_preserved_in_serialization() {
|
||||||
let json_str = serde_json::to_string(&tc).unwrap();
|
let json_str = serde_json::to_string(&tc).unwrap();
|
||||||
let parsed: ToolCall = serde_json::from_str(&json_str).unwrap();
|
let parsed: ToolCall = serde_json::from_str(&json_str).unwrap();
|
||||||
|
|
||||||
assert_eq!(parsed.id, "call_deepseek_42", "tool_call_id must survive roundtrip");
|
assert_eq!(
|
||||||
|
parsed.id, "call_deepseek_42",
|
||||||
|
"tool_call_id must survive roundtrip"
|
||||||
|
);
|
||||||
assert_eq!(parsed.name, "shell");
|
assert_eq!(parsed.name, "shell");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -111,8 +117,8 @@ fn tool_call_arguments_contain_valid_json() {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Arguments should parse as valid JSON
|
// Arguments should parse as valid JSON
|
||||||
let args: serde_json::Value = serde_json::from_str(&tc.arguments)
|
let args: serde_json::Value =
|
||||||
.expect("tool call arguments should be valid JSON");
|
serde_json::from_str(&tc.arguments).expect("tool call arguments should be valid JSON");
|
||||||
assert!(args.get("path").is_some());
|
assert!(args.get("path").is_some());
|
||||||
assert!(args.get("content").is_some());
|
assert!(args.get("content").is_some());
|
||||||
}
|
}
|
||||||
|
|
@ -125,9 +131,8 @@ fn tool_call_arguments_contain_valid_json() {
|
||||||
fn tool_response_message_can_embed_tool_call_id() {
|
fn tool_response_message_can_embed_tool_call_id() {
|
||||||
// DeepSeek requires tool_call_id in tool response messages.
|
// DeepSeek requires tool_call_id in tool response messages.
|
||||||
// The tool message content can embed the tool_call_id as JSON.
|
// The tool message content can embed the tool_call_id as JSON.
|
||||||
let tool_response = ChatMessage::tool(
|
let tool_response =
|
||||||
r#"{"tool_call_id": "call_abc123", "content": "search results here"}"#,
|
ChatMessage::tool(r#"{"tool_call_id": "call_abc123", "content": "search results here"}"#);
|
||||||
);
|
|
||||||
|
|
||||||
let parsed: serde_json::Value = serde_json::from_str(&tool_response.content)
|
let parsed: serde_json::Value = serde_json::from_str(&tool_response.content)
|
||||||
.expect("tool response content should be valid JSON");
|
.expect("tool response content should be valid JSON");
|
||||||
|
|
@ -245,21 +250,32 @@ fn provider_construction_with_different_names() {
|
||||||
Some("test-key"),
|
Some("test-key"),
|
||||||
AuthStyle::Bearer,
|
AuthStyle::Bearer,
|
||||||
);
|
);
|
||||||
let _p2 = OpenAiCompatibleProvider::new(
|
let _p2 =
|
||||||
"deepseek",
|
OpenAiCompatibleProvider::new("deepseek", "https://api.test.com", None, AuthStyle::Bearer);
|
||||||
"https://api.test.com",
|
|
||||||
None,
|
|
||||||
AuthStyle::Bearer,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn provider_construction_with_different_auth_styles() {
|
fn provider_construction_with_different_auth_styles() {
|
||||||
use zeroclaw::providers::compatible::OpenAiCompatibleProvider;
|
use zeroclaw::providers::compatible::OpenAiCompatibleProvider;
|
||||||
|
|
||||||
let _bearer = OpenAiCompatibleProvider::new("Test", "https://api.test.com", Some("key"), AuthStyle::Bearer);
|
let _bearer = OpenAiCompatibleProvider::new(
|
||||||
let _xapi = OpenAiCompatibleProvider::new("Test", "https://api.test.com", Some("key"), AuthStyle::XApiKey);
|
"Test",
|
||||||
let _custom = OpenAiCompatibleProvider::new("Test", "https://api.test.com", Some("key"), AuthStyle::Custom("X-My-Auth".into()));
|
"https://api.test.com",
|
||||||
|
Some("key"),
|
||||||
|
AuthStyle::Bearer,
|
||||||
|
);
|
||||||
|
let _xapi = OpenAiCompatibleProvider::new(
|
||||||
|
"Test",
|
||||||
|
"https://api.test.com",
|
||||||
|
Some("key"),
|
||||||
|
AuthStyle::XApiKey,
|
||||||
|
);
|
||||||
|
let _custom = OpenAiCompatibleProvider::new(
|
||||||
|
"Test",
|
||||||
|
"https://api.test.com",
|
||||||
|
Some("key"),
|
||||||
|
AuthStyle::Custom("X-My-Auth".into()),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ─────────────────────────────────────────────────────────────────────────────
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue