fix(agent): retry malformed tool_call payloads in tool loop

This commit is contained in:
JamesYin 2026-02-17 21:51:00 +08:00 committed by Chummy
parent 4b89e91a5a
commit 3522d51f98

View file

@ -660,13 +660,26 @@ pub(crate) async fn run_tool_call_loop(
} }
}; };
let display_text = if parsed_text.is_empty() { let parsed_text_is_empty = parsed_text.trim().is_empty();
let display_text = if parsed_text_is_empty {
response_text.clone() response_text.clone()
} else { } else {
parsed_text parsed_text
}; };
let has_tool_call_markup =
response_text.contains("<tool_call>") && response_text.contains("</tool_call>");
if tool_calls.is_empty() { if tool_calls.is_empty() {
// Recovery path: the model attempted tool use but emitted malformed JSON.
// Ask it to re-send valid tool-call payload instead of leaking raw markup to users.
if has_tool_call_markup && parsed_text_is_empty {
history.push(ChatMessage::assistant(response_text.clone()));
history.push(ChatMessage::user(
"[Tool parser error]\nYour previous <tool_call> payload was invalid JSON and was NOT executed. Re-send the same tool call using strict valid JSON only. Escape inner double quotes inside string values.",
));
continue;
}
// No tool calls — this is the final response // No tool calls — this is the final response
history.push(ChatMessage::assistant(response_text.clone())); history.push(ChatMessage::assistant(response_text.clone()));
return Ok(display_text); return Ok(display_text);
@ -1382,6 +1395,12 @@ mod tests {
assert!(scrubbed.contains("public")); assert!(scrubbed.contains("public"));
} }
use crate::memory::{Memory, MemoryCategory, SqliteMemory}; use crate::memory::{Memory, MemoryCategory, SqliteMemory};
use crate::observability::NoopObserver;
use crate::providers::Provider;
use crate::tools::{Tool, ToolResult};
use async_trait::async_trait;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use tempfile::TempDir; use tempfile::TempDir;
#[test] #[test]
@ -1923,4 +1942,120 @@ Done."#;
let result = parse_tool_calls_from_json_value(&value); let result = parse_tool_calls_from_json_value(&value);
assert_eq!(result.len(), 2); assert_eq!(result.len(), 2);
} }
struct MalformedThenValidToolProvider;
#[async_trait]
impl Provider for MalformedThenValidToolProvider {
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
anyhow::bail!("chat_with_system should not be called in this test");
}
async fn chat_with_history(
&self,
messages: &[ChatMessage],
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
if messages
.iter()
.any(|m| m.role == "user" && m.content.contains("[Tool results]"))
{
return Ok("Top memory users parsed successfully.".to_string());
}
if messages
.iter()
.any(|m| m.role == "user" && m.content.contains("[Tool parser error]"))
{
return Ok(
r#"<tool_call>
{"name":"shell","arguments":{"command":"echo fixed"}}
</tool_call>"#
.to_string(),
);
}
Ok(
r#"<tool_call>
{"name":"shell","arguments":{"command":"echo "$rss $name ($pid)""}}
</tool_call>"#
.to_string(),
)
}
}
struct CountingShellTool {
runs: Arc<AtomicUsize>,
}
#[async_trait]
impl Tool for CountingShellTool {
fn name(&self) -> &str {
"shell"
}
fn description(&self) -> &str {
"Count shell executions"
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"command": { "type": "string" }
},
"required": ["command"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
self.runs.fetch_add(1, Ordering::SeqCst);
Ok(ToolResult {
success: true,
output: args
.get("command")
.and_then(serde_json::Value::as_str)
.unwrap_or_default()
.to_string(),
error: None,
})
}
}
#[tokio::test]
async fn run_tool_call_loop_retries_invalid_tool_call_markup() {
let runs = Arc::new(AtomicUsize::new(0));
let tools_registry: Vec<Box<dyn Tool>> = vec![Box::new(CountingShellTool {
runs: Arc::clone(&runs),
})];
let mut history = vec![ChatMessage::system("sys"), ChatMessage::user("check memory")];
let response = run_tool_call_loop(
&MalformedThenValidToolProvider,
&mut history,
&tools_registry,
&NoopObserver,
"test-provider",
"test-model",
0.0,
true,
)
.await
.unwrap();
assert_eq!(response, "Top memory users parsed successfully.");
assert_eq!(runs.load(Ordering::SeqCst), 1);
assert!(!response.contains("<tool_call>"));
assert!(history
.iter()
.any(|m| m.role == "user" && m.content.contains("[Tool parser error]")));
}
} }