fix(channels): interrupt in-flight telegram requests on newer sender messages
This commit is contained in:
parent
d9a94fc763
commit
ef82c7dbcd
17 changed files with 669 additions and 115 deletions
|
|
@ -16,6 +16,7 @@ use std::fmt::Write;
|
|||
use std::io::Write as _;
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use std::time::Instant;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Minimum characters per chunk when relaying LLM text to a streaming draft.
|
||||
|
|
@ -823,6 +824,21 @@ struct ParsedToolCall {
|
|||
arguments: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct ToolLoopCancelled;
|
||||
|
||||
impl std::fmt::Display for ToolLoopCancelled {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str("tool loop cancelled")
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for ToolLoopCancelled {}
|
||||
|
||||
pub(crate) fn is_tool_loop_cancelled(err: &anyhow::Error) -> bool {
|
||||
err.chain().any(|source| source.is::<ToolLoopCancelled>())
|
||||
}
|
||||
|
||||
/// Execute a single turn of the agent loop: send messages, parse tool calls,
|
||||
/// execute tools, and loop until the LLM produces a final text response.
|
||||
/// When `silent` is true, suppresses stdout (for channel use).
|
||||
|
|
@ -853,6 +869,7 @@ pub(crate) async fn agent_turn(
|
|||
multimodal_config,
|
||||
max_tool_iterations,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
|
@ -873,6 +890,7 @@ pub(crate) async fn run_tool_call_loop(
|
|||
channel_name: &str,
|
||||
multimodal_config: &crate::config::MultimodalConfig,
|
||||
max_tool_iterations: usize,
|
||||
cancellation_token: Option<CancellationToken>,
|
||||
on_delta: Option<tokio::sync::mpsc::Sender<String>>,
|
||||
) -> Result<String> {
|
||||
let max_iterations = if max_tool_iterations == 0 {
|
||||
|
|
@ -886,6 +904,13 @@ pub(crate) async fn run_tool_call_loop(
|
|||
let use_native_tools = provider.supports_native_tools() && !tool_specs.is_empty();
|
||||
|
||||
for _iteration in 0..max_iterations {
|
||||
if cancellation_token
|
||||
.as_ref()
|
||||
.is_some_and(CancellationToken::is_cancelled)
|
||||
{
|
||||
return Err(ToolLoopCancelled.into());
|
||||
}
|
||||
|
||||
let image_marker_count = multimodal::count_image_markers(history);
|
||||
if image_marker_count > 0 && !provider.supports_vision() {
|
||||
return Err(ProviderCapabilityError {
|
||||
|
|
@ -917,18 +942,26 @@ pub(crate) async fn run_tool_call_loop(
|
|||
None
|
||||
};
|
||||
|
||||
let chat_future = provider.chat(
|
||||
ChatRequest {
|
||||
messages: &prepared_messages.messages,
|
||||
tools: request_tools,
|
||||
},
|
||||
model,
|
||||
temperature,
|
||||
);
|
||||
|
||||
let chat_result = if let Some(token) = cancellation_token.as_ref() {
|
||||
tokio::select! {
|
||||
() = token.cancelled() => return Err(ToolLoopCancelled.into()),
|
||||
result = chat_future => result,
|
||||
}
|
||||
} else {
|
||||
chat_future.await
|
||||
};
|
||||
|
||||
let (response_text, parsed_text, tool_calls, assistant_history_content, native_tool_calls) =
|
||||
match provider
|
||||
.chat(
|
||||
ChatRequest {
|
||||
messages: &prepared_messages.messages,
|
||||
tools: request_tools,
|
||||
},
|
||||
model,
|
||||
temperature,
|
||||
)
|
||||
.await
|
||||
{
|
||||
match chat_result {
|
||||
Ok(resp) => {
|
||||
observer.record_event(&ObserverEvent::LlmResponse {
|
||||
provider: provider_name.to_string(),
|
||||
|
|
@ -994,6 +1027,12 @@ pub(crate) async fn run_tool_call_loop(
|
|||
// STREAM_CHUNK_MIN_CHARS characters for progressive draft updates.
|
||||
let mut chunk = String::new();
|
||||
for word in display_text.split_inclusive(char::is_whitespace) {
|
||||
if cancellation_token
|
||||
.as_ref()
|
||||
.is_some_and(CancellationToken::is_cancelled)
|
||||
{
|
||||
return Err(ToolLoopCancelled.into());
|
||||
}
|
||||
chunk.push_str(word);
|
||||
if chunk.len() >= STREAM_CHUNK_MIN_CHARS
|
||||
&& tx.send(std::mem::take(&mut chunk)).await.is_err()
|
||||
|
|
@ -1056,7 +1095,17 @@ pub(crate) async fn run_tool_call_loop(
|
|||
});
|
||||
let start = Instant::now();
|
||||
let result = if let Some(tool) = find_tool(tools_registry, &call.name) {
|
||||
match tool.execute(call.arguments.clone()).await {
|
||||
let tool_future = tool.execute(call.arguments.clone());
|
||||
let tool_result = if let Some(token) = cancellation_token.as_ref() {
|
||||
tokio::select! {
|
||||
() = token.cancelled() => return Err(ToolLoopCancelled.into()),
|
||||
result = tool_future => result,
|
||||
}
|
||||
} else {
|
||||
tool_future.await
|
||||
};
|
||||
|
||||
match tool_result {
|
||||
Ok(r) => {
|
||||
observer.record_event(&ObserverEvent::ToolCall {
|
||||
tool: call.name.clone(),
|
||||
|
|
@ -1435,6 +1484,7 @@ pub async fn run(
|
|||
&config.multimodal,
|
||||
config.agent.max_tool_iterations,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
final_output = response.clone();
|
||||
|
|
@ -1553,6 +1603,7 @@ pub async fn run(
|
|||
&config.multimodal,
|
||||
config.agent.max_tool_iterations,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
{
|
||||
|
|
@ -1900,6 +1951,7 @@ mod tests {
|
|||
&crate::config::MultimodalConfig::default(),
|
||||
3,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect_err("provider without vision support should fail");
|
||||
|
|
@ -1943,6 +1995,7 @@ mod tests {
|
|||
&multimodal,
|
||||
3,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect_err("oversized payload must fail");
|
||||
|
|
@ -1980,6 +2033,7 @@ mod tests {
|
|||
&crate::config::MultimodalConfig::default(),
|
||||
3,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("valid multimodal payload should pass");
|
||||
|
|
@ -2809,7 +2863,10 @@ browser_open/url>https://example.com"#;
|
|||
fn parse_tool_calls_closing_tag_only_returns_text() {
|
||||
let response = "Some text </tool_call> more text";
|
||||
let (text, calls) = parse_tool_calls(response);
|
||||
assert!(calls.is_empty(), "closing tag only should not produce calls");
|
||||
assert!(
|
||||
calls.is_empty(),
|
||||
"closing tag only should not produce calls"
|
||||
);
|
||||
assert!(
|
||||
!text.is_empty(),
|
||||
"text around orphaned closing tag should be preserved"
|
||||
|
|
@ -2858,7 +2915,11 @@ browser_open/url>https://example.com"#;
|
|||
|
||||
Let me check the result."#;
|
||||
let (text, calls) = parse_tool_calls(response);
|
||||
assert_eq!(calls.len(), 1, "should extract one tool call from mixed content");
|
||||
assert_eq!(
|
||||
calls.len(),
|
||||
1,
|
||||
"should extract one tool call from mixed content"
|
||||
);
|
||||
assert_eq!(calls[0].name, "shell");
|
||||
assert!(
|
||||
text.contains("help you"),
|
||||
|
|
@ -2880,7 +2941,10 @@ Let me check the result."#;
|
|||
fn scrub_credentials_no_sensitive_data() {
|
||||
let input = "normal text without any secrets";
|
||||
let result = scrub_credentials(input);
|
||||
assert_eq!(result, input, "non-sensitive text should pass through unchanged");
|
||||
assert_eq!(
|
||||
result, input,
|
||||
"non-sensitive text should pass through unchanged"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue