fix(agent): retry malformed tool_call payloads in tool loop
This commit is contained in:
parent
4b89e91a5a
commit
3522d51f98
1 changed files with 136 additions and 1 deletions
|
|
@ -660,13 +660,26 @@ pub(crate) async fn run_tool_call_loop(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let display_text = if parsed_text.is_empty() {
|
let parsed_text_is_empty = parsed_text.trim().is_empty();
|
||||||
|
let display_text = if parsed_text_is_empty {
|
||||||
response_text.clone()
|
response_text.clone()
|
||||||
} else {
|
} else {
|
||||||
parsed_text
|
parsed_text
|
||||||
};
|
};
|
||||||
|
let has_tool_call_markup =
|
||||||
|
response_text.contains("<tool_call>") && response_text.contains("</tool_call>");
|
||||||
|
|
||||||
if tool_calls.is_empty() {
|
if tool_calls.is_empty() {
|
||||||
|
// Recovery path: the model attempted tool use but emitted malformed JSON.
|
||||||
|
// Ask it to re-send valid tool-call payload instead of leaking raw markup to users.
|
||||||
|
if has_tool_call_markup && parsed_text_is_empty {
|
||||||
|
history.push(ChatMessage::assistant(response_text.clone()));
|
||||||
|
history.push(ChatMessage::user(
|
||||||
|
"[Tool parser error]\nYour previous <tool_call> payload was invalid JSON and was NOT executed. Re-send the same tool call using strict valid JSON only. Escape inner double quotes inside string values.",
|
||||||
|
));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// No tool calls — this is the final response
|
// No tool calls — this is the final response
|
||||||
history.push(ChatMessage::assistant(response_text.clone()));
|
history.push(ChatMessage::assistant(response_text.clone()));
|
||||||
return Ok(display_text);
|
return Ok(display_text);
|
||||||
|
|
@ -1382,6 +1395,12 @@ mod tests {
|
||||||
assert!(scrubbed.contains("public"));
|
assert!(scrubbed.contains("public"));
|
||||||
}
|
}
|
||||||
use crate::memory::{Memory, MemoryCategory, SqliteMemory};
|
use crate::memory::{Memory, MemoryCategory, SqliteMemory};
|
||||||
|
use crate::observability::NoopObserver;
|
||||||
|
use crate::providers::Provider;
|
||||||
|
use crate::tools::{Tool, ToolResult};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
use tempfile::TempDir;
|
use tempfile::TempDir;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -1923,4 +1942,120 @@ Done."#;
|
||||||
let result = parse_tool_calls_from_json_value(&value);
|
let result = parse_tool_calls_from_json_value(&value);
|
||||||
assert_eq!(result.len(), 2);
|
assert_eq!(result.len(), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct MalformedThenValidToolProvider;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Provider for MalformedThenValidToolProvider {
|
||||||
|
async fn chat_with_system(
|
||||||
|
&self,
|
||||||
|
_system_prompt: Option<&str>,
|
||||||
|
_message: &str,
|
||||||
|
_model: &str,
|
||||||
|
_temperature: f64,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
anyhow::bail!("chat_with_system should not be called in this test");
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn chat_with_history(
|
||||||
|
&self,
|
||||||
|
messages: &[ChatMessage],
|
||||||
|
_model: &str,
|
||||||
|
_temperature: f64,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
if messages
|
||||||
|
.iter()
|
||||||
|
.any(|m| m.role == "user" && m.content.contains("[Tool results]"))
|
||||||
|
{
|
||||||
|
return Ok("Top memory users parsed successfully.".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
if messages
|
||||||
|
.iter()
|
||||||
|
.any(|m| m.role == "user" && m.content.contains("[Tool parser error]"))
|
||||||
|
{
|
||||||
|
return Ok(
|
||||||
|
r#"<tool_call>
|
||||||
|
{"name":"shell","arguments":{"command":"echo fixed"}}
|
||||||
|
</tool_call>"#
|
||||||
|
.to_string(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(
|
||||||
|
r#"<tool_call>
|
||||||
|
{"name":"shell","arguments":{"command":"echo "$rss $name ($pid)""}}
|
||||||
|
</tool_call>"#
|
||||||
|
.to_string(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct CountingShellTool {
|
||||||
|
runs: Arc<AtomicUsize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for CountingShellTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"shell"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Count shell executions"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"command": { "type": "string" }
|
||||||
|
},
|
||||||
|
"required": ["command"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||||
|
self.runs.fetch_add(1, Ordering::SeqCst);
|
||||||
|
Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: args
|
||||||
|
.get("command")
|
||||||
|
.and_then(serde_json::Value::as_str)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.to_string(),
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn run_tool_call_loop_retries_invalid_tool_call_markup() {
|
||||||
|
let runs = Arc::new(AtomicUsize::new(0));
|
||||||
|
let tools_registry: Vec<Box<dyn Tool>> = vec![Box::new(CountingShellTool {
|
||||||
|
runs: Arc::clone(&runs),
|
||||||
|
})];
|
||||||
|
|
||||||
|
let mut history = vec![ChatMessage::system("sys"), ChatMessage::user("check memory")];
|
||||||
|
|
||||||
|
let response = run_tool_call_loop(
|
||||||
|
&MalformedThenValidToolProvider,
|
||||||
|
&mut history,
|
||||||
|
&tools_registry,
|
||||||
|
&NoopObserver,
|
||||||
|
"test-provider",
|
||||||
|
"test-model",
|
||||||
|
0.0,
|
||||||
|
true,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(response, "Top memory users parsed successfully.");
|
||||||
|
assert_eq!(runs.load(Ordering::SeqCst), 1);
|
||||||
|
assert!(!response.contains("<tool_call>"));
|
||||||
|
assert!(history
|
||||||
|
.iter()
|
||||||
|
.any(|m| m.role == "user" && m.content.contains("[Tool parser error]")));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue