fix(agent): retry malformed tool_call payloads in tool loop
This commit is contained in:
parent
4b89e91a5a
commit
3522d51f98
1 changed files with 136 additions and 1 deletions
|
|
@ -660,13 +660,26 @@ pub(crate) async fn run_tool_call_loop(
|
|||
}
|
||||
};
|
||||
|
||||
let display_text = if parsed_text.is_empty() {
|
||||
let parsed_text_is_empty = parsed_text.trim().is_empty();
|
||||
let display_text = if parsed_text_is_empty {
|
||||
response_text.clone()
|
||||
} else {
|
||||
parsed_text
|
||||
};
|
||||
let has_tool_call_markup =
|
||||
response_text.contains("<tool_call>") && response_text.contains("</tool_call>");
|
||||
|
||||
if tool_calls.is_empty() {
|
||||
// Recovery path: the model attempted tool use but emitted malformed JSON.
|
||||
// Ask it to re-send valid tool-call payload instead of leaking raw markup to users.
|
||||
if has_tool_call_markup && parsed_text_is_empty {
|
||||
history.push(ChatMessage::assistant(response_text.clone()));
|
||||
history.push(ChatMessage::user(
|
||||
"[Tool parser error]\nYour previous <tool_call> payload was invalid JSON and was NOT executed. Re-send the same tool call using strict valid JSON only. Escape inner double quotes inside string values.",
|
||||
));
|
||||
continue;
|
||||
}
|
||||
|
||||
// No tool calls — this is the final response
|
||||
history.push(ChatMessage::assistant(response_text.clone()));
|
||||
return Ok(display_text);
|
||||
|
|
@ -1382,6 +1395,12 @@ mod tests {
|
|||
assert!(scrubbed.contains("public"));
|
||||
}
|
||||
use crate::memory::{Memory, MemoryCategory, SqliteMemory};
|
||||
use crate::observability::NoopObserver;
|
||||
use crate::providers::Provider;
|
||||
use crate::tools::{Tool, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
|
|
@ -1923,4 +1942,120 @@ Done."#;
|
|||
let result = parse_tool_calls_from_json_value(&value);
|
||||
assert_eq!(result.len(), 2);
|
||||
}
|
||||
|
||||
struct MalformedThenValidToolProvider;
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for MalformedThenValidToolProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
_system_prompt: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
anyhow::bail!("chat_with_system should not be called in this test");
|
||||
}
|
||||
|
||||
async fn chat_with_history(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
if messages
|
||||
.iter()
|
||||
.any(|m| m.role == "user" && m.content.contains("[Tool results]"))
|
||||
{
|
||||
return Ok("Top memory users parsed successfully.".to_string());
|
||||
}
|
||||
|
||||
if messages
|
||||
.iter()
|
||||
.any(|m| m.role == "user" && m.content.contains("[Tool parser error]"))
|
||||
{
|
||||
return Ok(
|
||||
r#"<tool_call>
|
||||
{"name":"shell","arguments":{"command":"echo fixed"}}
|
||||
</tool_call>"#
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(
|
||||
r#"<tool_call>
|
||||
{"name":"shell","arguments":{"command":"echo "$rss $name ($pid)""}}
|
||||
</tool_call>"#
|
||||
.to_string(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
struct CountingShellTool {
|
||||
runs: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for CountingShellTool {
|
||||
fn name(&self) -> &str {
|
||||
"shell"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Count shell executions"
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": { "type": "string" }
|
||||
},
|
||||
"required": ["command"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
self.runs.fetch_add(1, Ordering::SeqCst);
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: args
|
||||
.get("command")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.unwrap_or_default()
|
||||
.to_string(),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_retries_invalid_tool_call_markup() {
|
||||
let runs = Arc::new(AtomicUsize::new(0));
|
||||
let tools_registry: Vec<Box<dyn Tool>> = vec![Box::new(CountingShellTool {
|
||||
runs: Arc::clone(&runs),
|
||||
})];
|
||||
|
||||
let mut history = vec![ChatMessage::system("sys"), ChatMessage::user("check memory")];
|
||||
|
||||
let response = run_tool_call_loop(
|
||||
&MalformedThenValidToolProvider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&NoopObserver,
|
||||
"test-provider",
|
||||
"test-model",
|
||||
0.0,
|
||||
true,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response, "Top memory users parsed successfully.");
|
||||
assert_eq!(runs.load(Ordering::SeqCst), 1);
|
||||
assert!(!response.contains("<tool_call>"));
|
||||
assert!(history
|
||||
.iter()
|
||||
.any(|m| m.role == "user" && m.content.contains("[Tool parser error]")));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue