diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index e21e760..941e4d0 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -346,7 +346,7 @@ fn parse_tool_calls_from_json_value(value: &serde_json::Value) -> Vec", "", ""]; +const TOOL_CALL_OPEN_TAGS: [&str; 4] = ["", "", "", ""]; fn find_first_tag<'a>(haystack: &str, tags: &'a [&'a str]) -> Option<(usize, &'a str)> { tags.iter() @@ -359,10 +359,47 @@ fn matching_tool_call_close_tag(open_tag: &str) -> Option<&'static str> { "" => Some(""), "" => Some(""), "" => Some(""), + "" => Some(""), _ => None, } } +fn extract_first_json_value_with_end(input: &str) -> Option<(serde_json::Value, usize)> { + let trimmed = input.trim_start(); + let trim_offset = input.len().saturating_sub(trimmed.len()); + + for (byte_idx, ch) in trimmed.char_indices() { + if ch != '{' && ch != '[' { + continue; + } + + let slice = &trimmed[byte_idx..]; + let mut stream = serde_json::Deserializer::from_str(slice).into_iter::(); + if let Some(Ok(value)) = stream.next() { + let consumed = stream.byte_offset(); + if consumed > 0 { + return Some((value, trim_offset + byte_idx + consumed)); + } + } + } + + None +} + +fn strip_leading_close_tags(mut input: &str) -> &str { + loop { + let trimmed = input.trim_start(); + if !trimmed.starts_with("') else { + return ""; + }; + input = &trimmed[close_end + 1..]; + } +} + /// Extract JSON values from a string. /// /// # Security Warning @@ -607,20 +644,27 @@ fn parse_tool_calls(response: &str) -> (String, Vec) { remaining = &after_open[close_idx + close_tag.len()..]; } else { if let Some(json_end) = find_json_end(after_open) { - if let Ok(value) = - serde_json::from_str::(&after_open[..json_end]) + if let Ok(value) = serde_json::from_str::(&after_open[..json_end]) { let parsed_calls = parse_tool_calls_from_json_value(&value); if !parsed_calls.is_empty() { calls.extend(parsed_calls); - let after_json = &after_open[json_end..]; - if !after_json.trim().is_empty() { - text_parts.push(after_json.trim().to_string()); - } - remaining = ""; + remaining = strip_leading_close_tags(&after_open[json_end..]); + continue; } } } + + if let Some((value, consumed_end)) = extract_first_json_value_with_end(after_open) { + let parsed_calls = parse_tool_calls_from_json_value(&value); + if !parsed_calls.is_empty() { + calls.extend(parsed_calls); + remaining = strip_leading_close_tags(&after_open[consumed_end..]); + continue; + } + } + + remaining = &remaining[start..]; break; } } @@ -630,8 +674,10 @@ fn parse_tool_calls(response: &str) -> (String, Vec) { // ```tool_call ... instead of structured API calls or XML tags. if calls.is_empty() { static MD_TOOL_CALL_RE: LazyLock = LazyLock::new(|| { - Regex::new(r"(?s)```tool[_-]?call\s*\n(.*?)(?:```||)") - .unwrap() + Regex::new( + r"(?s)```(?:tool[_-]?call|invoke)\s*\n(.*?)(?:```|||)", + ) + .unwrap() }); let mut md_text_parts: Vec = Vec::new(); let mut last_end = 0; @@ -1932,6 +1978,25 @@ Tail"#; assert!(!text.contains("```tool-call")); } + #[test] + fn parse_tool_calls_handles_markdown_invoke_fence() { + let response = r#"Checking. +```invoke +{"name": "shell", "arguments": {"command": "date"}} +``` +Done."#; + + let (text, calls) = parse_tool_calls(response); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "shell"); + assert_eq!( + calls[0].arguments.get("command").unwrap().as_str().unwrap(), + "date" + ); + assert!(text.contains("Checking.")); + assert!(text.contains("Done.")); + } + #[test] fn parse_tool_calls_handles_toolcall_tag_alias() { let response = r#" @@ -1965,15 +2030,63 @@ Tail"#; } #[test] - fn parse_tool_calls_does_not_cross_match_alias_tags() { + fn parse_tool_calls_handles_invoke_tag_alias() { + let response = r#" +{"name": "shell", "arguments": {"command": "uptime"}} +"#; + + let (text, calls) = parse_tool_calls(response); + assert!(text.is_empty()); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "shell"); + assert_eq!( + calls[0].arguments.get("command").unwrap().as_str().unwrap(), + "uptime" + ); + } + + #[test] + fn parse_tool_calls_recovers_unclosed_tool_call_with_json() { + let response = r#"I will call the tool now. + +{"name": "shell", "arguments": {"command": "uptime -p"}}"#; + + let (text, calls) = parse_tool_calls(response); + assert!(text.contains("I will call the tool now.")); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "shell"); + assert_eq!( + calls[0].arguments.get("command").unwrap().as_str().unwrap(), + "uptime -p" + ); + } + + #[test] + fn parse_tool_calls_recovers_mismatched_close_tag() { + let response = r#" +{"name": "shell", "arguments": {"command": "uptime"}} +"#; + + let (text, calls) = parse_tool_calls(response); + assert!(text.is_empty()); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "shell"); + assert_eq!( + calls[0].arguments.get("command").unwrap().as_str().unwrap(), + "uptime" + ); + } + + #[test] + fn parse_tool_calls_recovers_cross_alias_closing_tags() { let response = r#" {"name": "shell", "arguments": {"command": "date"}} "#; let (text, calls) = parse_tool_calls(response); - assert!(calls.is_empty()); - assert!(text.contains("")); - assert!(text.contains("")); + assert!(text.is_empty()); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "shell"); } #[test] diff --git a/src/channels/telegram.rs b/src/channels/telegram.rs index 82430a8..6cd71ee 100644 --- a/src/channels/telegram.rs +++ b/src/channels/telegram.rs @@ -144,38 +144,103 @@ fn parse_path_only_attachment(message: &str) -> Option { /// These tags are used internally but must not be sent to Telegram as raw markup, /// since Telegram's Markdown parser will reject them (causing status 400 errors). fn strip_tool_call_tags(message: &str) -> String { - let mut result = message.to_string(); + const TOOL_CALL_OPEN_TAGS: [&str; 5] = [ + "", + "", + "", + "", + "", + ]; - // Strip ... - while let Some(start) = result.find("") { - if let Some(end) = result[start..].find("") { - let end = start + end + "".len(); - result = format!("{}{}", &result[..start], &result[end..]); - } else { - break; + fn find_first_tag<'a>(haystack: &str, tags: &'a [&'a str]) -> Option<(usize, &'a str)> { + tags.iter() + .filter_map(|tag| haystack.find(tag).map(|idx| (idx, *tag))) + .min_by_key(|(idx, _)| *idx) + } + + fn matching_close_tag(open_tag: &str) -> Option<&'static str> { + match open_tag { + "" => Some(""), + "" => Some(""), + "" => Some(""), + "" => Some(""), + "" => Some(""), + _ => None, } } - // Strip ... - while let Some(start) = result.find("") { - if let Some(end) = result[start..].find("") { - let end = start + end + "".len(); - result = format!("{}{}", &result[..start], &result[end..]); - } else { - break; + fn extract_first_json_end(input: &str) -> Option { + let trimmed = input.trim_start(); + let trim_offset = input.len().saturating_sub(trimmed.len()); + + for (byte_idx, ch) in trimmed.char_indices() { + if ch != '{' && ch != '[' { + continue; + } + + let slice = &trimmed[byte_idx..]; + let mut stream = + serde_json::Deserializer::from_str(slice).into_iter::(); + if let Some(Ok(_value)) = stream.next() { + let consumed = stream.byte_offset(); + if consumed > 0 { + return Some(trim_offset + byte_idx + consumed); + } + } + } + + None + } + + fn strip_leading_close_tags(mut input: &str) -> &str { + loop { + let trimmed = input.trim_start(); + if !trimmed.starts_with("') else { + return ""; + }; + input = &trimmed[close_end + 1..]; } } - // Strip ... - while let Some(start) = result.find("") { - if let Some(end) = result[start..].find("") { - let end = start + end + "".len(); - result = format!("{}{}", &result[..start], &result[end..]); - } else { - break; + let mut kept_segments = Vec::new(); + let mut remaining = message; + + while let Some((start, open_tag)) = find_first_tag(remaining, &TOOL_CALL_OPEN_TAGS) { + let before = &remaining[..start]; + if !before.is_empty() { + kept_segments.push(before.to_string()); } + + let Some(close_tag) = matching_close_tag(open_tag) else { + break; + }; + let after_open = &remaining[start + open_tag.len()..]; + + if let Some(close_idx) = after_open.find(close_tag) { + remaining = &after_open[close_idx + close_tag.len()..]; + continue; + } + + if let Some(consumed_end) = extract_first_json_end(after_open) { + remaining = strip_leading_close_tags(&after_open[consumed_end..]); + continue; + } + + kept_segments.push(remaining[start..].to_string()); + remaining = ""; + break; } + if !remaining.is_empty() { + kept_segments.push(remaining.to_string()); + } + + let mut result = kept_segments.concat(); + // Clean up any resulting blank lines (but preserve paragraphs) while result.contains("\n\n\n") { result = result.replace("\n\n\n", "\n\n"); @@ -2373,6 +2438,20 @@ mod tests { assert_eq!(result, "Hello world"); } + #[test] + fn strip_tool_call_tags_removes_tool_call_tags() { + let input = "Hello {\"name\":\"shell\",\"arguments\":{\"command\":\"ls\"}} world"; + let result = strip_tool_call_tags(input); + assert_eq!(result, "Hello world"); + } + + #[test] + fn strip_tool_call_tags_removes_invoke_tags() { + let input = "Hello {\"name\":\"shell\",\"arguments\":{\"command\":\"date\"}} world"; + let result = strip_tool_call_tags(input); + assert_eq!(result, "Hello world"); + } + #[test] fn strip_tool_call_tags_handles_multiple_tags() { let input = "Start a middle b end"; @@ -2401,6 +2480,22 @@ mod tests { assert_eq!(result, "Hello world"); } + #[test] + fn strip_tool_call_tags_handles_unclosed_tool_call_with_json() { + let input = + "Status:\n\n{\"name\":\"shell\",\"arguments\":{\"command\":\"uptime\"}}"; + let result = strip_tool_call_tags(input); + assert_eq!(result, "Status:"); + } + + #[test] + fn strip_tool_call_tags_handles_mismatched_close_tag() { + let input = + "{\"name\":\"shell\",\"arguments\":{\"command\":\"uptime\"}}"; + let result = strip_tool_call_tags(input); + assert_eq!(result, ""); + } + #[test] fn strip_tool_call_tags_cleans_extra_newlines() { let input = "Hello\n\n\ntest\n\n\n\nworld"; diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index d6d16e2..45f9734 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -1028,8 +1028,9 @@ mod tests { auto_save: false, webhook_secret_hash: None, pairing: Arc::new(PairingGuard::new(false, &[])), - rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), - idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), + trust_forwarded_headers: false, + rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100, 100)), + idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300), 1000)), whatsapp: None, whatsapp_app_secret: None, observer: Arc::new(crate::observability::NoopObserver), @@ -1068,8 +1069,9 @@ mod tests { auto_save: false, webhook_secret_hash: None, pairing: Arc::new(PairingGuard::new(false, &[])), - rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), - idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), + trust_forwarded_headers: false, + rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100, 100)), + idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300), 1000)), whatsapp: None, whatsapp_app_secret: None, observer, diff --git a/src/tools/memory_store.rs b/src/tools/memory_store.rs index 1095f04..5d7d043 100644 --- a/src/tools/memory_store.rs +++ b/src/tools/memory_store.rs @@ -150,7 +150,7 @@ mod tests { #[tokio::test] async fn store_with_custom_category() { let (_tmp, mem) = test_mem(); - let tool = MemoryStoreTool::new(mem.clone()); + let tool = MemoryStoreTool::new(mem.clone(), test_security()); let result = tool .execute( json!({"key": "proj_note", "content": "Uses async runtime", "category": "project"}),