fix(agent): recover malformed tool_call blocks with leading text
This commit is contained in:
parent
59f74e8f39
commit
af5d1f3066
1 changed files with 148 additions and 22 deletions
|
|
@ -409,10 +409,11 @@ fn extract_json_values(input: &str) -> Vec<serde_json::Value> {
|
|||
/// compatibility.
|
||||
///
|
||||
/// Also supports JSON with `tool_calls` array from OpenAI-format responses.
|
||||
fn parse_tool_calls(response: &str) -> (String, Vec<ParsedToolCall>) {
|
||||
fn parse_tool_calls(response: &str) -> (String, Vec<ParsedToolCall>, bool) {
|
||||
let mut text_parts = Vec::new();
|
||||
let mut calls = Vec::new();
|
||||
let mut remaining = response;
|
||||
let mut malformed_markup = false;
|
||||
|
||||
// First, try to parse as OpenAI-style JSON response with tool_calls array
|
||||
// This handles providers like Minimax that return tool_calls in native JSON format
|
||||
|
|
@ -425,7 +426,7 @@ fn parse_tool_calls(response: &str) -> (String, Vec<ParsedToolCall>) {
|
|||
text_parts.push(content.trim().to_string());
|
||||
}
|
||||
}
|
||||
return (text_parts.join("\n"), calls);
|
||||
return (text_parts.join("\n"), calls, false);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -456,10 +457,12 @@ fn parse_tool_calls(response: &str) -> (String, Vec<ParsedToolCall>) {
|
|||
|
||||
if !parsed_any {
|
||||
tracing::warn!("Malformed <tool_call> JSON: expected tool-call object in tag body");
|
||||
malformed_markup = true;
|
||||
}
|
||||
|
||||
remaining = &after_open[close_idx + close_tag.len()..];
|
||||
} else {
|
||||
malformed_markup = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
@ -477,7 +480,7 @@ fn parse_tool_calls(response: &str) -> (String, Vec<ParsedToolCall>) {
|
|||
text_parts.push(remaining.trim().to_string());
|
||||
}
|
||||
|
||||
(text_parts.join("\n"), calls)
|
||||
(text_parts.join("\n"), calls, malformed_markup)
|
||||
}
|
||||
|
||||
fn parse_structured_tool_calls(tool_calls: &[ToolCall]) -> Vec<ParsedToolCall> {
|
||||
|
|
@ -593,7 +596,7 @@ pub(crate) async fn run_tool_call_loop(
|
|||
let llm_started_at = Instant::now();
|
||||
|
||||
// Choose between native tool-call API and prompt-based tool use.
|
||||
let (response_text, parsed_text, tool_calls, assistant_history_content) =
|
||||
let (response_text, parsed_text, tool_calls, assistant_history_content, malformed_markup) =
|
||||
if use_native_tools {
|
||||
match provider
|
||||
.chat_with_tools(history, &tool_definitions, model, temperature)
|
||||
|
|
@ -610,13 +613,16 @@ pub(crate) async fn run_tool_call_loop(
|
|||
let response_text = resp.text_or_empty().to_string();
|
||||
let mut calls = parse_structured_tool_calls(&resp.tool_calls);
|
||||
let mut parsed_text = String::new();
|
||||
let mut malformed_markup = false;
|
||||
|
||||
if calls.is_empty() {
|
||||
let (fallback_text, fallback_calls) = parse_tool_calls(&response_text);
|
||||
let (fallback_text, fallback_calls, fallback_malformed_markup) =
|
||||
parse_tool_calls(&response_text);
|
||||
if !fallback_text.is_empty() {
|
||||
parsed_text = fallback_text;
|
||||
}
|
||||
calls = fallback_calls;
|
||||
malformed_markup = fallback_malformed_markup;
|
||||
}
|
||||
|
||||
let assistant_history_content = if resp.tool_calls.is_empty() {
|
||||
|
|
@ -628,7 +634,13 @@ pub(crate) async fn run_tool_call_loop(
|
|||
)
|
||||
};
|
||||
|
||||
(response_text, parsed_text, calls, assistant_history_content)
|
||||
(
|
||||
response_text,
|
||||
parsed_text,
|
||||
calls,
|
||||
assistant_history_content,
|
||||
malformed_markup,
|
||||
)
|
||||
}
|
||||
Err(e) => {
|
||||
observer.record_event(&ObserverEvent::LlmResponse {
|
||||
|
|
@ -658,8 +670,15 @@ pub(crate) async fn run_tool_call_loop(
|
|||
});
|
||||
let response_text = resp;
|
||||
let assistant_history_content = response_text.clone();
|
||||
let (parsed_text, calls) = parse_tool_calls(&response_text);
|
||||
(response_text, parsed_text, calls, assistant_history_content)
|
||||
let (parsed_text, calls, malformed_markup) =
|
||||
parse_tool_calls(&response_text);
|
||||
(
|
||||
response_text,
|
||||
parsed_text,
|
||||
calls,
|
||||
assistant_history_content,
|
||||
malformed_markup,
|
||||
)
|
||||
}
|
||||
Err(e) => {
|
||||
observer.record_event(&ObserverEvent::LlmResponse {
|
||||
|
|
@ -684,7 +703,8 @@ pub(crate) async fn run_tool_call_loop(
|
|||
};
|
||||
let has_tool_call_markup =
|
||||
response_text.contains("<tool_call>") && response_text.contains("</tool_call>");
|
||||
let malformed_tool_call_markup = looks_like_malformed_tool_call_markup(&response_text);
|
||||
let malformed_tool_call_markup =
|
||||
malformed_markup || looks_like_malformed_tool_call_markup(&response_text);
|
||||
|
||||
if tool_calls.is_empty() {
|
||||
// Recovery path: the model attempted tool use but emitted malformed JSON.
|
||||
|
|
@ -1427,10 +1447,11 @@ mod tests {
|
|||
{"name": "shell", "arguments": {"command": "ls -la"}}
|
||||
</tool_call>"#;
|
||||
|
||||
let (text, calls) = parse_tool_calls(response);
|
||||
let (text, calls, malformed) = parse_tool_calls(response);
|
||||
assert_eq!(text, "Let me check that.");
|
||||
assert_eq!(calls.len(), 1);
|
||||
assert_eq!(calls[0].name, "shell");
|
||||
assert!(!malformed);
|
||||
assert_eq!(
|
||||
calls[0].arguments.get("command").unwrap().as_str().unwrap(),
|
||||
"ls -la"
|
||||
|
|
@ -1446,18 +1467,20 @@ mod tests {
|
|||
{"name": "file_read", "arguments": {"path": "b.txt"}}
|
||||
</tool_call>"#;
|
||||
|
||||
let (_, calls) = parse_tool_calls(response);
|
||||
let (_, calls, malformed) = parse_tool_calls(response);
|
||||
assert_eq!(calls.len(), 2);
|
||||
assert_eq!(calls[0].name, "file_read");
|
||||
assert_eq!(calls[1].name, "file_read");
|
||||
assert!(!malformed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tool_calls_returns_text_only_when_no_calls() {
|
||||
let response = "Just a normal response with no tools.";
|
||||
let (text, calls) = parse_tool_calls(response);
|
||||
let (text, calls, malformed) = parse_tool_calls(response);
|
||||
assert_eq!(text, "Just a normal response with no tools.");
|
||||
assert!(calls.is_empty());
|
||||
assert!(!malformed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -1467,9 +1490,23 @@ not valid json
|
|||
</tool_call>
|
||||
Some text after."#;
|
||||
|
||||
let (text, calls) = parse_tool_calls(response);
|
||||
let (text, calls, malformed) = parse_tool_calls(response);
|
||||
assert!(calls.is_empty());
|
||||
assert!(text.contains("Some text after."));
|
||||
assert!(malformed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tool_calls_marks_malformed_when_text_precedes_invalid_tool_call() {
|
||||
let response = r#"I will schedule a 3AM update task. First, I will inspect existing tasks:
|
||||
<tool_call>
|
||||
{"action":"create","command":"nova update","expression":"0 3 * * *","id":"nova-self-update"}
|
||||
</tool_call>"#;
|
||||
|
||||
let (text, calls, malformed) = parse_tool_calls(response);
|
||||
assert!(calls.is_empty());
|
||||
assert!(text.contains("I will schedule a 3AM update task"));
|
||||
assert!(malformed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -1480,10 +1517,11 @@ Some text after."#;
|
|||
</tool_call>
|
||||
After text."#;
|
||||
|
||||
let (text, calls) = parse_tool_calls(response);
|
||||
let (text, calls, malformed) = parse_tool_calls(response);
|
||||
assert!(text.contains("Before text."));
|
||||
assert!(text.contains("After text."));
|
||||
assert_eq!(calls.len(), 1);
|
||||
assert!(!malformed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -1491,7 +1529,7 @@ After text."#;
|
|||
// OpenAI-style response with tool_calls array
|
||||
let response = r#"{"content": "Let me check that for you.", "tool_calls": [{"type": "function", "function": {"name": "shell", "arguments": "{\"command\": \"ls -la\"}"}}]}"#;
|
||||
|
||||
let (text, calls) = parse_tool_calls(response);
|
||||
let (text, calls, malformed) = parse_tool_calls(response);
|
||||
assert_eq!(text, "Let me check that for you.");
|
||||
assert_eq!(calls.len(), 1);
|
||||
assert_eq!(calls[0].name, "shell");
|
||||
|
|
@ -1499,16 +1537,18 @@ After text."#;
|
|||
calls[0].arguments.get("command").unwrap().as_str().unwrap(),
|
||||
"ls -la"
|
||||
);
|
||||
assert!(!malformed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tool_calls_handles_openai_format_multiple_calls() {
|
||||
let response = r#"{"tool_calls": [{"type": "function", "function": {"name": "file_read", "arguments": "{\"path\": \"a.txt\"}"}}, {"type": "function", "function": {"name": "file_read", "arguments": "{\"path\": \"b.txt\"}"}}]}"#;
|
||||
|
||||
let (_, calls) = parse_tool_calls(response);
|
||||
let (_, calls, malformed) = parse_tool_calls(response);
|
||||
assert_eq!(calls.len(), 2);
|
||||
assert_eq!(calls[0].name, "file_read");
|
||||
assert_eq!(calls[1].name, "file_read");
|
||||
assert!(!malformed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -1516,10 +1556,11 @@ After text."#;
|
|||
// Some providers don't include content field with tool_calls
|
||||
let response = r#"{"tool_calls": [{"type": "function", "function": {"name": "memory_recall", "arguments": "{}"}}]}"#;
|
||||
|
||||
let (text, calls) = parse_tool_calls(response);
|
||||
let (text, calls, malformed) = parse_tool_calls(response);
|
||||
assert!(text.is_empty()); // No content field
|
||||
assert_eq!(calls.len(), 1);
|
||||
assert_eq!(calls[0].name, "memory_recall");
|
||||
assert!(!malformed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -1530,7 +1571,7 @@ After text."#;
|
|||
```
|
||||
</tool_call>"#;
|
||||
|
||||
let (text, calls) = parse_tool_calls(response);
|
||||
let (text, calls, malformed) = parse_tool_calls(response);
|
||||
assert!(text.is_empty());
|
||||
assert_eq!(calls.len(), 1);
|
||||
assert_eq!(calls[0].name, "file_write");
|
||||
|
|
@ -1538,6 +1579,7 @@ After text."#;
|
|||
calls[0].arguments.get("path").unwrap().as_str().unwrap(),
|
||||
"test.py"
|
||||
);
|
||||
assert!(!malformed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -1547,7 +1589,7 @@ I will now call the tool with this payload:
|
|||
{"name": "shell", "arguments": {"command": "pwd"}}
|
||||
</tool_call>"#;
|
||||
|
||||
let (text, calls) = parse_tool_calls(response);
|
||||
let (text, calls, malformed) = parse_tool_calls(response);
|
||||
assert!(text.is_empty());
|
||||
assert_eq!(calls.len(), 1);
|
||||
assert_eq!(calls[0].name, "shell");
|
||||
|
|
@ -1555,6 +1597,7 @@ I will now call the tool with this payload:
|
|||
calls[0].arguments.get("command").unwrap().as_str().unwrap(),
|
||||
"pwd"
|
||||
);
|
||||
assert!(!malformed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -1609,13 +1652,14 @@ I will now call the tool with this payload:
|
|||
let response = r#"Sure, creating the file now.
|
||||
{"name": "file_write", "arguments": {"path": "hello.py", "content": "print('hello')"}}"#;
|
||||
|
||||
let (text, calls) = parse_tool_calls(response);
|
||||
let (text, calls, malformed) = parse_tool_calls(response);
|
||||
assert!(text.contains("Sure, creating the file now."));
|
||||
assert_eq!(
|
||||
calls.len(),
|
||||
0,
|
||||
"Raw JSON without wrappers should not be parsed"
|
||||
);
|
||||
assert!(!malformed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -1776,9 +1820,10 @@ I will now call the tool with this payload:
|
|||
|
||||
</tool_result>
|
||||
Done."#;
|
||||
let (text, calls) = parse_tool_calls(response);
|
||||
let (text, calls, malformed) = parse_tool_calls(response);
|
||||
assert!(text.contains("Done."));
|
||||
assert!(calls.is_empty());
|
||||
assert!(!malformed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -1793,10 +1838,11 @@ Done."#;
|
|||
fn parse_tool_calls_handles_empty_tool_calls_array() {
|
||||
// Recovery: Empty tool_calls array returns original response (no tool parsing)
|
||||
let response = r#"{"content": "Hello", "tool_calls": []}"#;
|
||||
let (text, calls) = parse_tool_calls(response);
|
||||
let (text, calls, malformed) = parse_tool_calls(response);
|
||||
// When tool_calls is empty, the entire JSON is returned as text
|
||||
assert!(text.contains("Hello"));
|
||||
assert!(calls.is_empty());
|
||||
assert!(!malformed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -2086,6 +2132,86 @@ Done."#;
|
|||
.any(|m| m.role == "user" && m.content.contains("[Tool parser error]")));
|
||||
}
|
||||
|
||||
struct TextPrefixedMalformedThenValidToolProvider;
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for TextPrefixedMalformedThenValidToolProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
_system_prompt: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
anyhow::bail!("chat_with_system should not be called in this test");
|
||||
}
|
||||
|
||||
async fn chat_with_history(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
if messages
|
||||
.iter()
|
||||
.any(|m| m.role == "user" && m.content.contains("[Tool results]"))
|
||||
{
|
||||
return Ok("Scheduled successfully.".to_string());
|
||||
}
|
||||
|
||||
if messages
|
||||
.iter()
|
||||
.any(|m| m.role == "user" && m.content.contains("[Tool parser error]"))
|
||||
{
|
||||
return Ok(r#"<tool_call>
|
||||
{"name":"shell","arguments":{"command":"echo fixed"}}
|
||||
</tool_call>"#
|
||||
.to_string());
|
||||
}
|
||||
|
||||
Ok(
|
||||
r#"I will schedule a 3AM update task. First, I will inspect existing tasks:
|
||||
<tool_call>
|
||||
{"action":"create","command":"nova update","expression":"0 3 * * *","id":"nova-self-update"}
|
||||
</tool_call>"#
|
||||
.to_string(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_retries_text_prefixed_invalid_tool_call_markup() {
|
||||
let runs = Arc::new(AtomicUsize::new(0));
|
||||
let tools_registry: Vec<Box<dyn Tool>> = vec![Box::new(CountingShellTool {
|
||||
runs: Arc::clone(&runs),
|
||||
})];
|
||||
|
||||
let mut history = vec![
|
||||
ChatMessage::system("sys"),
|
||||
ChatMessage::user("set schedule"),
|
||||
];
|
||||
|
||||
let response = run_tool_call_loop(
|
||||
&TextPrefixedMalformedThenValidToolProvider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&NoopObserver,
|
||||
"test-provider",
|
||||
"test-model",
|
||||
0.0,
|
||||
true,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response, "Scheduled successfully.");
|
||||
assert_eq!(runs.load(Ordering::SeqCst), 1);
|
||||
assert!(!response.contains("<tool_call>"));
|
||||
assert!(history
|
||||
.iter()
|
||||
.any(|m| m.role == "user" && m.content.contains("[Tool parser error]")));
|
||||
}
|
||||
|
||||
struct PrefixMalformedThenValidToolProvider;
|
||||
|
||||
#[async_trait]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue