fix(channels): recover malformed invoke/tool_call output in daemon mode
This commit is contained in:
parent
75a9eb383c
commit
219764d4d8
4 changed files with 251 additions and 41 deletions
|
|
@ -346,7 +346,7 @@ fn parse_tool_calls_from_json_value(value: &serde_json::Value) -> Vec<ParsedTool
|
||||||
calls
|
calls
|
||||||
}
|
}
|
||||||
|
|
||||||
const TOOL_CALL_OPEN_TAGS: [&str; 3] = ["<tool_call>", "<toolcall>", "<tool-call>"];
|
const TOOL_CALL_OPEN_TAGS: [&str; 4] = ["<tool_call>", "<toolcall>", "<tool-call>", "<invoke>"];
|
||||||
|
|
||||||
fn find_first_tag<'a>(haystack: &str, tags: &'a [&'a str]) -> Option<(usize, &'a str)> {
|
fn find_first_tag<'a>(haystack: &str, tags: &'a [&'a str]) -> Option<(usize, &'a str)> {
|
||||||
tags.iter()
|
tags.iter()
|
||||||
|
|
@ -359,10 +359,47 @@ fn matching_tool_call_close_tag(open_tag: &str) -> Option<&'static str> {
|
||||||
"<tool_call>" => Some("</tool_call>"),
|
"<tool_call>" => Some("</tool_call>"),
|
||||||
"<toolcall>" => Some("</toolcall>"),
|
"<toolcall>" => Some("</toolcall>"),
|
||||||
"<tool-call>" => Some("</tool-call>"),
|
"<tool-call>" => Some("</tool-call>"),
|
||||||
|
"<invoke>" => Some("</invoke>"),
|
||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn extract_first_json_value_with_end(input: &str) -> Option<(serde_json::Value, usize)> {
|
||||||
|
let trimmed = input.trim_start();
|
||||||
|
let trim_offset = input.len().saturating_sub(trimmed.len());
|
||||||
|
|
||||||
|
for (byte_idx, ch) in trimmed.char_indices() {
|
||||||
|
if ch != '{' && ch != '[' {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let slice = &trimmed[byte_idx..];
|
||||||
|
let mut stream = serde_json::Deserializer::from_str(slice).into_iter::<serde_json::Value>();
|
||||||
|
if let Some(Ok(value)) = stream.next() {
|
||||||
|
let consumed = stream.byte_offset();
|
||||||
|
if consumed > 0 {
|
||||||
|
return Some((value, trim_offset + byte_idx + consumed));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn strip_leading_close_tags(mut input: &str) -> &str {
|
||||||
|
loop {
|
||||||
|
let trimmed = input.trim_start();
|
||||||
|
if !trimmed.starts_with("</") {
|
||||||
|
return trimmed;
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(close_end) = trimmed.find('>') else {
|
||||||
|
return "";
|
||||||
|
};
|
||||||
|
input = &trimmed[close_end + 1..];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Extract JSON values from a string.
|
/// Extract JSON values from a string.
|
||||||
///
|
///
|
||||||
/// # Security Warning
|
/// # Security Warning
|
||||||
|
|
@ -607,20 +644,27 @@ fn parse_tool_calls(response: &str) -> (String, Vec<ParsedToolCall>) {
|
||||||
remaining = &after_open[close_idx + close_tag.len()..];
|
remaining = &after_open[close_idx + close_tag.len()..];
|
||||||
} else {
|
} else {
|
||||||
if let Some(json_end) = find_json_end(after_open) {
|
if let Some(json_end) = find_json_end(after_open) {
|
||||||
if let Ok(value) =
|
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&after_open[..json_end])
|
||||||
serde_json::from_str::<serde_json::Value>(&after_open[..json_end])
|
|
||||||
{
|
{
|
||||||
let parsed_calls = parse_tool_calls_from_json_value(&value);
|
let parsed_calls = parse_tool_calls_from_json_value(&value);
|
||||||
if !parsed_calls.is_empty() {
|
if !parsed_calls.is_empty() {
|
||||||
calls.extend(parsed_calls);
|
calls.extend(parsed_calls);
|
||||||
let after_json = &after_open[json_end..];
|
remaining = strip_leading_close_tags(&after_open[json_end..]);
|
||||||
if !after_json.trim().is_empty() {
|
continue;
|
||||||
text_parts.push(after_json.trim().to_string());
|
|
||||||
}
|
|
||||||
remaining = "";
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some((value, consumed_end)) = extract_first_json_value_with_end(after_open) {
|
||||||
|
let parsed_calls = parse_tool_calls_from_json_value(&value);
|
||||||
|
if !parsed_calls.is_empty() {
|
||||||
|
calls.extend(parsed_calls);
|
||||||
|
remaining = strip_leading_close_tags(&after_open[consumed_end..]);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
remaining = &remaining[start..];
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -630,8 +674,10 @@ fn parse_tool_calls(response: &str) -> (String, Vec<ParsedToolCall>) {
|
||||||
// ```tool_call ... </tool_call> instead of structured API calls or XML tags.
|
// ```tool_call ... </tool_call> instead of structured API calls or XML tags.
|
||||||
if calls.is_empty() {
|
if calls.is_empty() {
|
||||||
static MD_TOOL_CALL_RE: LazyLock<Regex> = LazyLock::new(|| {
|
static MD_TOOL_CALL_RE: LazyLock<Regex> = LazyLock::new(|| {
|
||||||
Regex::new(r"(?s)```tool[_-]?call\s*\n(.*?)(?:```|</tool[_-]?call>|</toolcall>)")
|
Regex::new(
|
||||||
.unwrap()
|
r"(?s)```(?:tool[_-]?call|invoke)\s*\n(.*?)(?:```|</tool[_-]?call>|</toolcall>|</invoke>)",
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
});
|
});
|
||||||
let mut md_text_parts: Vec<String> = Vec::new();
|
let mut md_text_parts: Vec<String> = Vec::new();
|
||||||
let mut last_end = 0;
|
let mut last_end = 0;
|
||||||
|
|
@ -1932,6 +1978,25 @@ Tail"#;
|
||||||
assert!(!text.contains("```tool-call"));
|
assert!(!text.contains("```tool-call"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tool_calls_handles_markdown_invoke_fence() {
|
||||||
|
let response = r#"Checking.
|
||||||
|
```invoke
|
||||||
|
{"name": "shell", "arguments": {"command": "date"}}
|
||||||
|
```
|
||||||
|
Done."#;
|
||||||
|
|
||||||
|
let (text, calls) = parse_tool_calls(response);
|
||||||
|
assert_eq!(calls.len(), 1);
|
||||||
|
assert_eq!(calls[0].name, "shell");
|
||||||
|
assert_eq!(
|
||||||
|
calls[0].arguments.get("command").unwrap().as_str().unwrap(),
|
||||||
|
"date"
|
||||||
|
);
|
||||||
|
assert!(text.contains("Checking."));
|
||||||
|
assert!(text.contains("Done."));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn parse_tool_calls_handles_toolcall_tag_alias() {
|
fn parse_tool_calls_handles_toolcall_tag_alias() {
|
||||||
let response = r#"<toolcall>
|
let response = r#"<toolcall>
|
||||||
|
|
@ -1965,15 +2030,63 @@ Tail"#;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn parse_tool_calls_does_not_cross_match_alias_tags() {
|
fn parse_tool_calls_handles_invoke_tag_alias() {
|
||||||
|
let response = r#"<invoke>
|
||||||
|
{"name": "shell", "arguments": {"command": "uptime"}}
|
||||||
|
</invoke>"#;
|
||||||
|
|
||||||
|
let (text, calls) = parse_tool_calls(response);
|
||||||
|
assert!(text.is_empty());
|
||||||
|
assert_eq!(calls.len(), 1);
|
||||||
|
assert_eq!(calls[0].name, "shell");
|
||||||
|
assert_eq!(
|
||||||
|
calls[0].arguments.get("command").unwrap().as_str().unwrap(),
|
||||||
|
"uptime"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tool_calls_recovers_unclosed_tool_call_with_json() {
|
||||||
|
let response = r#"I will call the tool now.
|
||||||
|
<tool_call>
|
||||||
|
{"name": "shell", "arguments": {"command": "uptime -p"}}"#;
|
||||||
|
|
||||||
|
let (text, calls) = parse_tool_calls(response);
|
||||||
|
assert!(text.contains("I will call the tool now."));
|
||||||
|
assert_eq!(calls.len(), 1);
|
||||||
|
assert_eq!(calls[0].name, "shell");
|
||||||
|
assert_eq!(
|
||||||
|
calls[0].arguments.get("command").unwrap().as_str().unwrap(),
|
||||||
|
"uptime -p"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tool_calls_recovers_mismatched_close_tag() {
|
||||||
|
let response = r#"<tool_call>
|
||||||
|
{"name": "shell", "arguments": {"command": "uptime"}}
|
||||||
|
</arg_value>"#;
|
||||||
|
|
||||||
|
let (text, calls) = parse_tool_calls(response);
|
||||||
|
assert!(text.is_empty());
|
||||||
|
assert_eq!(calls.len(), 1);
|
||||||
|
assert_eq!(calls[0].name, "shell");
|
||||||
|
assert_eq!(
|
||||||
|
calls[0].arguments.get("command").unwrap().as_str().unwrap(),
|
||||||
|
"uptime"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tool_calls_recovers_cross_alias_closing_tags() {
|
||||||
let response = r#"<toolcall>
|
let response = r#"<toolcall>
|
||||||
{"name": "shell", "arguments": {"command": "date"}}
|
{"name": "shell", "arguments": {"command": "date"}}
|
||||||
</tool_call>"#;
|
</tool_call>"#;
|
||||||
|
|
||||||
let (text, calls) = parse_tool_calls(response);
|
let (text, calls) = parse_tool_calls(response);
|
||||||
assert!(calls.is_empty());
|
assert!(text.is_empty());
|
||||||
assert!(text.contains("<toolcall>"));
|
assert_eq!(calls.len(), 1);
|
||||||
assert!(text.contains("</tool_call>"));
|
assert_eq!(calls[0].name, "shell");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
|
|
@ -144,38 +144,103 @@ fn parse_path_only_attachment(message: &str) -> Option<TelegramAttachment> {
|
||||||
/// These tags are used internally but must not be sent to Telegram as raw markup,
|
/// These tags are used internally but must not be sent to Telegram as raw markup,
|
||||||
/// since Telegram's Markdown parser will reject them (causing status 400 errors).
|
/// since Telegram's Markdown parser will reject them (causing status 400 errors).
|
||||||
fn strip_tool_call_tags(message: &str) -> String {
|
fn strip_tool_call_tags(message: &str) -> String {
|
||||||
let mut result = message.to_string();
|
const TOOL_CALL_OPEN_TAGS: [&str; 5] = [
|
||||||
|
"<tool_call>",
|
||||||
|
"<toolcall>",
|
||||||
|
"<tool-call>",
|
||||||
|
"<tool>",
|
||||||
|
"<invoke>",
|
||||||
|
];
|
||||||
|
|
||||||
// Strip <tool>...</tool>
|
fn find_first_tag<'a>(haystack: &str, tags: &'a [&'a str]) -> Option<(usize, &'a str)> {
|
||||||
while let Some(start) = result.find("<tool>") {
|
tags.iter()
|
||||||
if let Some(end) = result[start..].find("</tool>") {
|
.filter_map(|tag| haystack.find(tag).map(|idx| (idx, *tag)))
|
||||||
let end = start + end + "</tool>".len();
|
.min_by_key(|(idx, _)| *idx)
|
||||||
result = format!("{}{}", &result[..start], &result[end..]);
|
}
|
||||||
} else {
|
|
||||||
break;
|
fn matching_close_tag(open_tag: &str) -> Option<&'static str> {
|
||||||
|
match open_tag {
|
||||||
|
"<tool_call>" => Some("</tool_call>"),
|
||||||
|
"<toolcall>" => Some("</toolcall>"),
|
||||||
|
"<tool-call>" => Some("</tool-call>"),
|
||||||
|
"<tool>" => Some("</tool>"),
|
||||||
|
"<invoke>" => Some("</invoke>"),
|
||||||
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Strip <toolcall>...</toolcall>
|
fn extract_first_json_end(input: &str) -> Option<usize> {
|
||||||
while let Some(start) = result.find("<toolcall>") {
|
let trimmed = input.trim_start();
|
||||||
if let Some(end) = result[start..].find("</toolcall>") {
|
let trim_offset = input.len().saturating_sub(trimmed.len());
|
||||||
let end = start + end + "</toolcall>".len();
|
|
||||||
result = format!("{}{}", &result[..start], &result[end..]);
|
for (byte_idx, ch) in trimmed.char_indices() {
|
||||||
} else {
|
if ch != '{' && ch != '[' {
|
||||||
break;
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let slice = &trimmed[byte_idx..];
|
||||||
|
let mut stream =
|
||||||
|
serde_json::Deserializer::from_str(slice).into_iter::<serde_json::Value>();
|
||||||
|
if let Some(Ok(_value)) = stream.next() {
|
||||||
|
let consumed = stream.byte_offset();
|
||||||
|
if consumed > 0 {
|
||||||
|
return Some(trim_offset + byte_idx + consumed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn strip_leading_close_tags(mut input: &str) -> &str {
|
||||||
|
loop {
|
||||||
|
let trimmed = input.trim_start();
|
||||||
|
if !trimmed.starts_with("</") {
|
||||||
|
return trimmed;
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(close_end) = trimmed.find('>') else {
|
||||||
|
return "";
|
||||||
|
};
|
||||||
|
input = &trimmed[close_end + 1..];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Strip <tool-call>...</tool-call>
|
let mut kept_segments = Vec::new();
|
||||||
while let Some(start) = result.find("<tool-call>") {
|
let mut remaining = message;
|
||||||
if let Some(end) = result[start..].find("</tool-call>") {
|
|
||||||
let end = start + end + "</tool-call>".len();
|
while let Some((start, open_tag)) = find_first_tag(remaining, &TOOL_CALL_OPEN_TAGS) {
|
||||||
result = format!("{}{}", &result[..start], &result[end..]);
|
let before = &remaining[..start];
|
||||||
} else {
|
if !before.is_empty() {
|
||||||
break;
|
kept_segments.push(before.to_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let Some(close_tag) = matching_close_tag(open_tag) else {
|
||||||
|
break;
|
||||||
|
};
|
||||||
|
let after_open = &remaining[start + open_tag.len()..];
|
||||||
|
|
||||||
|
if let Some(close_idx) = after_open.find(close_tag) {
|
||||||
|
remaining = &after_open[close_idx + close_tag.len()..];
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(consumed_end) = extract_first_json_end(after_open) {
|
||||||
|
remaining = strip_leading_close_tags(&after_open[consumed_end..]);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
kept_segments.push(remaining[start..].to_string());
|
||||||
|
remaining = "";
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !remaining.is_empty() {
|
||||||
|
kept_segments.push(remaining.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut result = kept_segments.concat();
|
||||||
|
|
||||||
// Clean up any resulting blank lines (but preserve paragraphs)
|
// Clean up any resulting blank lines (but preserve paragraphs)
|
||||||
while result.contains("\n\n\n") {
|
while result.contains("\n\n\n") {
|
||||||
result = result.replace("\n\n\n", "\n\n");
|
result = result.replace("\n\n\n", "\n\n");
|
||||||
|
|
@ -2373,6 +2438,20 @@ mod tests {
|
||||||
assert_eq!(result, "Hello world");
|
assert_eq!(result, "Hello world");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn strip_tool_call_tags_removes_tool_call_tags() {
|
||||||
|
let input = "Hello <tool_call>{\"name\":\"shell\",\"arguments\":{\"command\":\"ls\"}}</tool_call> world";
|
||||||
|
let result = strip_tool_call_tags(input);
|
||||||
|
assert_eq!(result, "Hello world");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn strip_tool_call_tags_removes_invoke_tags() {
|
||||||
|
let input = "Hello <invoke>{\"name\":\"shell\",\"arguments\":{\"command\":\"date\"}}</invoke> world";
|
||||||
|
let result = strip_tool_call_tags(input);
|
||||||
|
assert_eq!(result, "Hello world");
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn strip_tool_call_tags_handles_multiple_tags() {
|
fn strip_tool_call_tags_handles_multiple_tags() {
|
||||||
let input = "Start <tool>a</tool> middle <tool>b</tool> end";
|
let input = "Start <tool>a</tool> middle <tool>b</tool> end";
|
||||||
|
|
@ -2401,6 +2480,22 @@ mod tests {
|
||||||
assert_eq!(result, "Hello <tool>world");
|
assert_eq!(result, "Hello <tool>world");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn strip_tool_call_tags_handles_unclosed_tool_call_with_json() {
|
||||||
|
let input =
|
||||||
|
"Status:\n<tool_call>\n{\"name\":\"shell\",\"arguments\":{\"command\":\"uptime\"}}";
|
||||||
|
let result = strip_tool_call_tags(input);
|
||||||
|
assert_eq!(result, "Status:");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn strip_tool_call_tags_handles_mismatched_close_tag() {
|
||||||
|
let input =
|
||||||
|
"<tool_call>{\"name\":\"shell\",\"arguments\":{\"command\":\"uptime\"}}</arg_value>";
|
||||||
|
let result = strip_tool_call_tags(input);
|
||||||
|
assert_eq!(result, "");
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn strip_tool_call_tags_cleans_extra_newlines() {
|
fn strip_tool_call_tags_cleans_extra_newlines() {
|
||||||
let input = "Hello\n\n<tool>\ntest\n</tool>\n\n\nworld";
|
let input = "Hello\n\n<tool>\ntest\n</tool>\n\n\nworld";
|
||||||
|
|
|
||||||
|
|
@ -1028,8 +1028,9 @@ mod tests {
|
||||||
auto_save: false,
|
auto_save: false,
|
||||||
webhook_secret_hash: None,
|
webhook_secret_hash: None,
|
||||||
pairing: Arc::new(PairingGuard::new(false, &[])),
|
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||||
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
trust_forwarded_headers: false,
|
||||||
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100, 100)),
|
||||||
|
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300), 1000)),
|
||||||
whatsapp: None,
|
whatsapp: None,
|
||||||
whatsapp_app_secret: None,
|
whatsapp_app_secret: None,
|
||||||
observer: Arc::new(crate::observability::NoopObserver),
|
observer: Arc::new(crate::observability::NoopObserver),
|
||||||
|
|
@ -1068,8 +1069,9 @@ mod tests {
|
||||||
auto_save: false,
|
auto_save: false,
|
||||||
webhook_secret_hash: None,
|
webhook_secret_hash: None,
|
||||||
pairing: Arc::new(PairingGuard::new(false, &[])),
|
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||||
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
trust_forwarded_headers: false,
|
||||||
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100, 100)),
|
||||||
|
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300), 1000)),
|
||||||
whatsapp: None,
|
whatsapp: None,
|
||||||
whatsapp_app_secret: None,
|
whatsapp_app_secret: None,
|
||||||
observer,
|
observer,
|
||||||
|
|
|
||||||
|
|
@ -150,7 +150,7 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn store_with_custom_category() {
|
async fn store_with_custom_category() {
|
||||||
let (_tmp, mem) = test_mem();
|
let (_tmp, mem) = test_mem();
|
||||||
let tool = MemoryStoreTool::new(mem.clone());
|
let tool = MemoryStoreTool::new(mem.clone(), test_security());
|
||||||
let result = tool
|
let result = tool
|
||||||
.execute(
|
.execute(
|
||||||
json!({"key": "proj_note", "content": "Uses async runtime", "category": "project"}),
|
json!({"key": "proj_note", "content": "Uses async runtime", "category": "project"}),
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue