fix(agent): parse tool-call alias tags in channel runtime
This commit is contained in:
parent
c6d068a371
commit
4243d8ec86
4 changed files with 133 additions and 6 deletions
|
|
@ -329,6 +329,15 @@ fn parse_tool_calls_from_json_value(value: &serde_json::Value) -> Vec<ParsedTool
|
||||||
calls
|
calls
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const TOOL_CALL_OPEN_TAGS: [&str; 3] = ["<tool_call>", "<toolcall>", "<tool-call>"];
|
||||||
|
const TOOL_CALL_CLOSE_TAGS: [&str; 3] = ["</tool_call>", "</toolcall>", "</tool-call>"];
|
||||||
|
|
||||||
|
fn find_first_tag<'a>(haystack: &str, tags: &'a [&'a str]) -> Option<(usize, &'a str)> {
|
||||||
|
tags.iter()
|
||||||
|
.filter_map(|tag| haystack.find(tag).map(|idx| (idx, *tag)))
|
||||||
|
.min_by_key(|(idx, _)| *idx)
|
||||||
|
}
|
||||||
|
|
||||||
/// Extract JSON values from a string.
|
/// Extract JSON values from a string.
|
||||||
///
|
///
|
||||||
/// # Security Warning
|
/// # Security Warning
|
||||||
|
|
@ -385,6 +394,9 @@ fn extract_json_values(input: &str) -> Vec<serde_json::Value> {
|
||||||
/// </tool_call>
|
/// </tool_call>
|
||||||
/// ```
|
/// ```
|
||||||
///
|
///
|
||||||
|
/// Also accepts common tag variants (`<toolcall>`, `<tool-call>`) for model
|
||||||
|
/// compatibility.
|
||||||
|
///
|
||||||
/// Also supports JSON with `tool_calls` array from OpenAI-format responses.
|
/// Also supports JSON with `tool_calls` array from OpenAI-format responses.
|
||||||
fn parse_tool_calls(response: &str) -> (String, Vec<ParsedToolCall>) {
|
fn parse_tool_calls(response: &str) -> (String, Vec<ParsedToolCall>) {
|
||||||
let mut text_parts = Vec::new();
|
let mut text_parts = Vec::new();
|
||||||
|
|
@ -406,16 +418,17 @@ fn parse_tool_calls(response: &str) -> (String, Vec<ParsedToolCall>) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fall back to XML-style <invoke> tag parsing (ZeroClaw's original format)
|
// Fall back to XML-style tool-call tag parsing.
|
||||||
while let Some(start) = remaining.find("<tool_call>") {
|
while let Some((start, open_tag)) = find_first_tag(remaining, &TOOL_CALL_OPEN_TAGS) {
|
||||||
// Everything before the tag is text
|
// Everything before the tag is text
|
||||||
let before = &remaining[..start];
|
let before = &remaining[..start];
|
||||||
if !before.trim().is_empty() {
|
if !before.trim().is_empty() {
|
||||||
text_parts.push(before.trim().to_string());
|
text_parts.push(before.trim().to_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(end) = remaining[start..].find("</tool_call>") {
|
let after_open = &remaining[start + open_tag.len()..];
|
||||||
let inner = &remaining[start + 11..start + end];
|
if let Some((close_idx, close_tag)) = find_first_tag(after_open, &TOOL_CALL_CLOSE_TAGS) {
|
||||||
|
let inner = &after_open[..close_idx];
|
||||||
let mut parsed_any = false;
|
let mut parsed_any = false;
|
||||||
let json_values = extract_json_values(inner);
|
let json_values = extract_json_values(inner);
|
||||||
for value in json_values {
|
for value in json_values {
|
||||||
|
|
@ -430,7 +443,7 @@ fn parse_tool_calls(response: &str) -> (String, Vec<ParsedToolCall>) {
|
||||||
tracing::warn!("Malformed <tool_call> JSON: expected tool-call object in tag body");
|
tracing::warn!("Malformed <tool_call> JSON: expected tool-call object in tag body");
|
||||||
}
|
}
|
||||||
|
|
||||||
remaining = &remaining[start + end + 12..];
|
remaining = &after_open[close_idx + close_tag.len()..];
|
||||||
} else {
|
} else {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
@ -1496,6 +1509,38 @@ I will now call the tool with this payload:
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tool_calls_handles_toolcall_tag_alias() {
|
||||||
|
let response = r#"<toolcall>
|
||||||
|
{"name": "shell", "arguments": {"command": "date"}}
|
||||||
|
</toolcall>"#;
|
||||||
|
|
||||||
|
let (text, calls) = parse_tool_calls(response);
|
||||||
|
assert!(text.is_empty());
|
||||||
|
assert_eq!(calls.len(), 1);
|
||||||
|
assert_eq!(calls[0].name, "shell");
|
||||||
|
assert_eq!(
|
||||||
|
calls[0].arguments.get("command").unwrap().as_str().unwrap(),
|
||||||
|
"date"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tool_calls_handles_tool_dash_call_tag_alias() {
|
||||||
|
let response = r#"<tool-call>
|
||||||
|
{"name": "shell", "arguments": {"command": "whoami"}}
|
||||||
|
</tool-call>"#;
|
||||||
|
|
||||||
|
let (text, calls) = parse_tool_calls(response);
|
||||||
|
assert!(text.is_empty());
|
||||||
|
assert_eq!(calls.len(), 1);
|
||||||
|
assert_eq!(calls[0].name, "shell");
|
||||||
|
assert_eq!(
|
||||||
|
calls[0].arguments.get("command").unwrap().as_str().unwrap(),
|
||||||
|
"whoami"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn parse_tool_calls_rejects_raw_tool_json_without_tags() {
|
fn parse_tool_calls_rejects_raw_tool_json_without_tags() {
|
||||||
// SECURITY: Raw JSON without explicit wrappers should NOT be parsed
|
// SECURITY: Raw JSON without explicit wrappers should NOT be parsed
|
||||||
|
|
|
||||||
|
|
@ -1370,6 +1370,13 @@ mod tests {
|
||||||
.to_string()
|
.to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn tool_call_payload_with_alias_tag() -> String {
|
||||||
|
r#"<toolcall>
|
||||||
|
{"name":"mock_price","arguments":{"symbol":"BTC"}}
|
||||||
|
</toolcall>"#
|
||||||
|
.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl Provider for ToolCallingProvider {
|
impl Provider for ToolCallingProvider {
|
||||||
async fn chat_with_system(
|
async fn chat_with_system(
|
||||||
|
|
@ -1399,6 +1406,37 @@ mod tests {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ToolCallingAliasProvider;
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl Provider for ToolCallingAliasProvider {
|
||||||
|
async fn chat_with_system(
|
||||||
|
&self,
|
||||||
|
_system_prompt: Option<&str>,
|
||||||
|
_message: &str,
|
||||||
|
_model: &str,
|
||||||
|
_temperature: f64,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
Ok(tool_call_payload_with_alias_tag())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn chat_with_history(
|
||||||
|
&self,
|
||||||
|
messages: &[ChatMessage],
|
||||||
|
_model: &str,
|
||||||
|
_temperature: f64,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
let has_tool_results = messages
|
||||||
|
.iter()
|
||||||
|
.any(|msg| msg.role == "user" && msg.content.contains("[Tool results]"));
|
||||||
|
if has_tool_results {
|
||||||
|
Ok("BTC alias-tag flow resolved to final text output.".to_string())
|
||||||
|
} else {
|
||||||
|
Ok(tool_call_payload_with_alias_tag())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct MockPriceTool;
|
struct MockPriceTool;
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
|
|
@ -1480,6 +1518,47 @@ mod tests {
|
||||||
assert!(!sent_messages[0].contains("mock_price"));
|
assert!(!sent_messages[0].contains("mock_price"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn process_channel_message_executes_tool_calls_with_alias_tags() {
|
||||||
|
let channel_impl = Arc::new(RecordingChannel::default());
|
||||||
|
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||||
|
|
||||||
|
let mut channels_by_name = HashMap::new();
|
||||||
|
channels_by_name.insert(channel.name().to_string(), channel);
|
||||||
|
|
||||||
|
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||||
|
channels_by_name: Arc::new(channels_by_name),
|
||||||
|
provider: Arc::new(ToolCallingAliasProvider),
|
||||||
|
memory: Arc::new(NoopMemory),
|
||||||
|
tools_registry: Arc::new(vec![Box::new(MockPriceTool)]),
|
||||||
|
observer: Arc::new(NoopObserver),
|
||||||
|
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||||
|
model: Arc::new("test-model".to_string()),
|
||||||
|
temperature: 0.0,
|
||||||
|
auto_save_memory: false,
|
||||||
|
});
|
||||||
|
|
||||||
|
process_channel_message(
|
||||||
|
runtime_ctx,
|
||||||
|
traits::ChannelMessage {
|
||||||
|
id: "msg-2".to_string(),
|
||||||
|
sender: "bob".to_string(),
|
||||||
|
reply_target: "chat-84".to_string(),
|
||||||
|
content: "What is the BTC price now?".to_string(),
|
||||||
|
channel: "test-channel".to_string(),
|
||||||
|
timestamp: 2,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let sent_messages = channel_impl.sent_messages.lock().await;
|
||||||
|
assert_eq!(sent_messages.len(), 1);
|
||||||
|
assert!(sent_messages[0].starts_with("chat-84:"));
|
||||||
|
assert!(sent_messages[0].contains("alias-tag flow resolved"));
|
||||||
|
assert!(!sent_messages[0].contains("<toolcall>"));
|
||||||
|
assert!(!sent_messages[0].contains("mock_price"));
|
||||||
|
}
|
||||||
|
|
||||||
struct NoopMemory;
|
struct NoopMemory;
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
|
|
|
||||||
|
|
@ -96,6 +96,7 @@ pub fn config_from_wizard_choice(choice: usize, devices: &[DiscoveredDevice]) ->
|
||||||
pub fn handle_command(cmd: crate::HardwareCommands, _config: &Config) -> Result<()> {
|
pub fn handle_command(cmd: crate::HardwareCommands, _config: &Config) -> Result<()> {
|
||||||
#[cfg(not(feature = "hardware"))]
|
#[cfg(not(feature = "hardware"))]
|
||||||
{
|
{
|
||||||
|
let _ = &cmd;
|
||||||
println!("Hardware discovery requires the 'hardware' feature.");
|
println!("Hardware discovery requires the 'hardware' feature.");
|
||||||
println!("Build with: cargo build --features hardware");
|
println!("Build with: cargo build --features hardware");
|
||||||
return Ok(());
|
return Ok(());
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,9 @@ pub mod rpi;
|
||||||
pub use traits::Peripheral;
|
pub use traits::Peripheral;
|
||||||
|
|
||||||
use crate::config::{Config, PeripheralBoardConfig, PeripheralsConfig};
|
use crate::config::{Config, PeripheralBoardConfig, PeripheralsConfig};
|
||||||
use crate::tools::{HardwareMemoryMapTool, Tool};
|
#[cfg(feature = "hardware")]
|
||||||
|
use crate::tools::HardwareMemoryMapTool;
|
||||||
|
use crate::tools::Tool;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
|
||||||
/// List configured boards from config (no connection yet).
|
/// List configured boards from config (no connection yet).
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue