* feat(memory): add session_id isolation to Memory trait Add optional session_id parameter to store(), recall(), and list() methods across the Memory trait and all four backends (sqlite, markdown, lucid, none). This enables per-session memory isolation so different agent sessions cannot cross-read each other's stored memories. Changes: - traits.rs: Add session_id: Option<&str> to store/recall/list - sqlite.rs: Schema migration (ALTER TABLE ADD COLUMN session_id), index, persist/filter by session_id in all query paths - markdown.rs, lucid.rs, none.rs: Updated signatures - All callers pass None for backward compatibility - 5 new tests: session-filtered recall, cross-session isolation, session-filtered list, no-filter returns all, migration idempotency Closes #518 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix(channels): fix discord _channel_id typo and lark missing reply_to Pre-existing compilation errors on main after reply_to was added to ChannelMessage: discord.rs used _channel_id (underscore prefix) but referenced channel_id, and lark.rs was missing the reply_to field. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1721 lines
64 KiB
Rust
1721 lines
64 KiB
Rust
use crate::config::Config;
|
|
use crate::memory::{self, Memory, MemoryCategory};
|
|
use crate::observability::{self, Observer, ObserverEvent};
|
|
use crate::providers::{self, ChatMessage, Provider, ToolCall};
|
|
use crate::runtime;
|
|
use crate::security::SecurityPolicy;
|
|
use crate::tools::{self, Tool};
|
|
use crate::util::truncate_with_ellipsis;
|
|
use anyhow::Result;
|
|
use std::fmt::Write;
|
|
use std::io::Write as _;
|
|
use std::sync::Arc;
|
|
use std::time::Instant;
|
|
use uuid::Uuid;
|
|
/// Maximum agentic tool-use iterations per user message to prevent runaway loops.
|
|
const MAX_TOOL_ITERATIONS: usize = 10;
|
|
|
|
/// Trigger auto-compaction when non-system message count exceeds this threshold.
|
|
const MAX_HISTORY_MESSAGES: usize = 50;
|
|
|
|
/// Keep this many most-recent non-system messages after compaction.
|
|
const COMPACTION_KEEP_RECENT_MESSAGES: usize = 20;
|
|
|
|
/// Safety cap for compaction source transcript passed to the summarizer.
|
|
const COMPACTION_MAX_SOURCE_CHARS: usize = 12_000;
|
|
|
|
/// Max characters retained in stored compaction summary.
|
|
const COMPACTION_MAX_SUMMARY_CHARS: usize = 2_000;
|
|
|
|
/// Convert a tool registry to OpenAI function-calling format for native tool support.
|
|
fn tools_to_openai_format(tools_registry: &[Box<dyn Tool>]) -> Vec<serde_json::Value> {
|
|
tools_registry
|
|
.iter()
|
|
.map(|tool| {
|
|
serde_json::json!({
|
|
"type": "function",
|
|
"function": {
|
|
"name": tool.name(),
|
|
"description": tool.description(),
|
|
"parameters": tool.parameters_schema()
|
|
}
|
|
})
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
fn autosave_memory_key(prefix: &str) -> String {
|
|
format!("{prefix}_{}", Uuid::new_v4())
|
|
}
|
|
|
|
/// Trim conversation history to prevent unbounded growth.
|
|
/// Preserves the system prompt (first message if role=system) and the most recent messages.
|
|
fn trim_history(history: &mut Vec<ChatMessage>) {
|
|
// Nothing to trim if within limit
|
|
let has_system = history.first().map_or(false, |m| m.role == "system");
|
|
let non_system_count = if has_system {
|
|
history.len() - 1
|
|
} else {
|
|
history.len()
|
|
};
|
|
|
|
if non_system_count <= MAX_HISTORY_MESSAGES {
|
|
return;
|
|
}
|
|
|
|
let start = if has_system { 1 } else { 0 };
|
|
let to_remove = non_system_count - MAX_HISTORY_MESSAGES;
|
|
history.drain(start..start + to_remove);
|
|
}
|
|
|
|
fn build_compaction_transcript(messages: &[ChatMessage]) -> String {
|
|
let mut transcript = String::new();
|
|
for msg in messages {
|
|
let role = msg.role.to_uppercase();
|
|
let _ = writeln!(transcript, "{role}: {}", msg.content.trim());
|
|
}
|
|
|
|
if transcript.chars().count() > COMPACTION_MAX_SOURCE_CHARS {
|
|
truncate_with_ellipsis(&transcript, COMPACTION_MAX_SOURCE_CHARS)
|
|
} else {
|
|
transcript
|
|
}
|
|
}
|
|
|
|
fn apply_compaction_summary(
|
|
history: &mut Vec<ChatMessage>,
|
|
start: usize,
|
|
compact_end: usize,
|
|
summary: &str,
|
|
) {
|
|
let summary_msg = ChatMessage::assistant(format!("[Compaction summary]\n{}", summary.trim()));
|
|
history.splice(start..compact_end, std::iter::once(summary_msg));
|
|
}
|
|
|
|
async fn auto_compact_history(
|
|
history: &mut Vec<ChatMessage>,
|
|
provider: &dyn Provider,
|
|
model: &str,
|
|
) -> Result<bool> {
|
|
let has_system = history.first().map_or(false, |m| m.role == "system");
|
|
let non_system_count = if has_system {
|
|
history.len().saturating_sub(1)
|
|
} else {
|
|
history.len()
|
|
};
|
|
|
|
if non_system_count <= MAX_HISTORY_MESSAGES {
|
|
return Ok(false);
|
|
}
|
|
|
|
let start = if has_system { 1 } else { 0 };
|
|
let keep_recent = COMPACTION_KEEP_RECENT_MESSAGES.min(non_system_count);
|
|
let compact_count = non_system_count.saturating_sub(keep_recent);
|
|
if compact_count == 0 {
|
|
return Ok(false);
|
|
}
|
|
|
|
let compact_end = start + compact_count;
|
|
let to_compact: Vec<ChatMessage> = history[start..compact_end].to_vec();
|
|
let transcript = build_compaction_transcript(&to_compact);
|
|
|
|
let summarizer_system = "You are a conversation compaction engine. Summarize older chat history into concise context for future turns. Preserve: user preferences, commitments, decisions, unresolved tasks, key facts. Omit: filler, repeated chit-chat, verbose tool logs. Output plain text bullet points only.";
|
|
|
|
let summarizer_user = format!(
|
|
"Summarize the following conversation history for context preservation. Keep it short (max 12 bullet points).\n\n{}",
|
|
transcript
|
|
);
|
|
|
|
let summary_raw = provider
|
|
.chat_with_system(Some(summarizer_system), &summarizer_user, model, 0.2)
|
|
.await
|
|
.unwrap_or_else(|_| {
|
|
// Fallback to deterministic local truncation when summarization fails.
|
|
truncate_with_ellipsis(&transcript, COMPACTION_MAX_SUMMARY_CHARS)
|
|
});
|
|
|
|
let summary = truncate_with_ellipsis(&summary_raw, COMPACTION_MAX_SUMMARY_CHARS);
|
|
apply_compaction_summary(history, start, compact_end, &summary);
|
|
|
|
Ok(true)
|
|
}
|
|
|
|
/// Build context preamble by searching memory for relevant entries
|
|
async fn build_context(mem: &dyn Memory, user_msg: &str) -> String {
|
|
let mut context = String::new();
|
|
|
|
// Pull relevant memories for this message
|
|
if let Ok(entries) = mem.recall(user_msg, 5, None).await {
|
|
if !entries.is_empty() {
|
|
context.push_str("[Memory context]\n");
|
|
for entry in &entries {
|
|
let _ = writeln!(context, "- {}: {}", entry.key, entry.content);
|
|
}
|
|
context.push('\n');
|
|
}
|
|
}
|
|
|
|
context
|
|
}
|
|
|
|
/// Build hardware datasheet context from RAG when peripherals are enabled.
|
|
/// Includes pin-alias lookup (e.g. "red_led" → 13) when query matches, plus retrieved chunks.
|
|
fn build_hardware_context(
|
|
rag: &crate::rag::HardwareRag,
|
|
user_msg: &str,
|
|
boards: &[String],
|
|
chunk_limit: usize,
|
|
) -> String {
|
|
if rag.is_empty() || boards.is_empty() {
|
|
return String::new();
|
|
}
|
|
|
|
let mut context = String::new();
|
|
|
|
// Pin aliases: when user says "red led", inject "red_led: 13" for matching boards
|
|
let pin_ctx = rag.pin_alias_context(user_msg, boards);
|
|
if !pin_ctx.is_empty() {
|
|
context.push_str(&pin_ctx);
|
|
}
|
|
|
|
let chunks = rag.retrieve(user_msg, boards, chunk_limit);
|
|
if chunks.is_empty() && pin_ctx.is_empty() {
|
|
return String::new();
|
|
}
|
|
|
|
if !chunks.is_empty() {
|
|
context.push_str("[Hardware documentation]\n");
|
|
}
|
|
for chunk in chunks {
|
|
let board_tag = chunk.board.as_deref().unwrap_or("generic");
|
|
let _ = writeln!(
|
|
context,
|
|
"--- {} ({}) ---\n{}\n",
|
|
chunk.source, board_tag, chunk.content
|
|
);
|
|
}
|
|
context.push('\n');
|
|
context
|
|
}
|
|
|
|
/// Find a tool by name in the registry.
|
|
fn find_tool<'a>(tools: &'a [Box<dyn Tool>], name: &str) -> Option<&'a dyn Tool> {
|
|
tools.iter().find(|t| t.name() == name).map(|t| t.as_ref())
|
|
}
|
|
|
|
fn parse_arguments_value(raw: Option<&serde_json::Value>) -> serde_json::Value {
|
|
match raw {
|
|
Some(serde_json::Value::String(s)) => serde_json::from_str::<serde_json::Value>(s)
|
|
.unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new())),
|
|
Some(value) => value.clone(),
|
|
None => serde_json::Value::Object(serde_json::Map::new()),
|
|
}
|
|
}
|
|
|
|
fn parse_tool_call_value(value: &serde_json::Value) -> Option<ParsedToolCall> {
|
|
if let Some(function) = value.get("function") {
|
|
let name = function
|
|
.get("name")
|
|
.and_then(|v| v.as_str())
|
|
.unwrap_or("")
|
|
.trim()
|
|
.to_string();
|
|
if !name.is_empty() {
|
|
let arguments = parse_arguments_value(function.get("arguments"));
|
|
return Some(ParsedToolCall { name, arguments });
|
|
}
|
|
}
|
|
|
|
let name = value
|
|
.get("name")
|
|
.and_then(|v| v.as_str())
|
|
.unwrap_or("")
|
|
.trim()
|
|
.to_string();
|
|
|
|
if name.is_empty() {
|
|
return None;
|
|
}
|
|
|
|
let arguments = parse_arguments_value(value.get("arguments"));
|
|
Some(ParsedToolCall { name, arguments })
|
|
}
|
|
|
|
fn parse_tool_calls_from_json_value(value: &serde_json::Value) -> Vec<ParsedToolCall> {
|
|
let mut calls = Vec::new();
|
|
|
|
if let Some(tool_calls) = value.get("tool_calls").and_then(|v| v.as_array()) {
|
|
for call in tool_calls {
|
|
if let Some(parsed) = parse_tool_call_value(call) {
|
|
calls.push(parsed);
|
|
}
|
|
}
|
|
|
|
if !calls.is_empty() {
|
|
return calls;
|
|
}
|
|
}
|
|
|
|
if let Some(array) = value.as_array() {
|
|
for item in array {
|
|
if let Some(parsed) = parse_tool_call_value(item) {
|
|
calls.push(parsed);
|
|
}
|
|
}
|
|
return calls;
|
|
}
|
|
|
|
if let Some(parsed) = parse_tool_call_value(value) {
|
|
calls.push(parsed);
|
|
}
|
|
|
|
calls
|
|
}
|
|
|
|
/// Extract JSON values from a string.
|
|
///
|
|
/// # Security Warning
|
|
///
|
|
/// This function extracts ANY JSON objects/arrays from the input. It MUST only
|
|
/// be used on content that is already trusted to be from the LLM, such as
|
|
/// content inside `<invoke>` tags where the LLM has explicitly indicated intent
|
|
/// to make a tool call. Do NOT use this on raw user input or content that
|
|
/// could contain prompt injection payloads.
|
|
fn extract_json_values(input: &str) -> Vec<serde_json::Value> {
|
|
let mut values = Vec::new();
|
|
let trimmed = input.trim();
|
|
if trimmed.is_empty() {
|
|
return values;
|
|
}
|
|
|
|
if let Ok(value) = serde_json::from_str::<serde_json::Value>(trimmed) {
|
|
values.push(value);
|
|
return values;
|
|
}
|
|
|
|
let char_positions: Vec<(usize, char)> = trimmed.char_indices().collect();
|
|
let mut idx = 0;
|
|
while idx < char_positions.len() {
|
|
let (byte_idx, ch) = char_positions[idx];
|
|
if ch == '{' || ch == '[' {
|
|
let slice = &trimmed[byte_idx..];
|
|
let mut stream =
|
|
serde_json::Deserializer::from_str(slice).into_iter::<serde_json::Value>();
|
|
if let Some(Ok(value)) = stream.next() {
|
|
let consumed = stream.byte_offset();
|
|
if consumed > 0 {
|
|
values.push(value);
|
|
let next_byte = byte_idx + consumed;
|
|
while idx < char_positions.len() && char_positions[idx].0 < next_byte {
|
|
idx += 1;
|
|
}
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
idx += 1;
|
|
}
|
|
|
|
values
|
|
}
|
|
|
|
/// Parse tool calls from an LLM response that uses XML-style function calling.
|
|
///
|
|
/// Expected format (common with system-prompt-guided tool use):
|
|
/// ```text
|
|
/// <tool_call>
|
|
/// {"name": "shell", "arguments": {"command": "ls"}}
|
|
/// </tool_call>
|
|
/// ```
|
|
///
|
|
/// Also supports JSON with `tool_calls` array from OpenAI-format responses.
|
|
fn parse_tool_calls(response: &str) -> (String, Vec<ParsedToolCall>) {
|
|
let mut text_parts = Vec::new();
|
|
let mut calls = Vec::new();
|
|
let mut remaining = response;
|
|
|
|
// First, try to parse as OpenAI-style JSON response with tool_calls array
|
|
// This handles providers like Minimax that return tool_calls in native JSON format
|
|
if let Ok(json_value) = serde_json::from_str::<serde_json::Value>(response.trim()) {
|
|
calls = parse_tool_calls_from_json_value(&json_value);
|
|
if !calls.is_empty() {
|
|
// If we found tool_calls, extract any content field as text
|
|
if let Some(content) = json_value.get("content").and_then(|v| v.as_str()) {
|
|
if !content.trim().is_empty() {
|
|
text_parts.push(content.trim().to_string());
|
|
}
|
|
}
|
|
return (text_parts.join("\n"), calls);
|
|
}
|
|
}
|
|
|
|
// Fall back to XML-style <invoke> tag parsing (ZeroClaw's original format)
|
|
while let Some(start) = remaining.find("<tool_call>") {
|
|
// Everything before the tag is text
|
|
let before = &remaining[..start];
|
|
if !before.trim().is_empty() {
|
|
text_parts.push(before.trim().to_string());
|
|
}
|
|
|
|
if let Some(end) = remaining[start..].find("</tool_call>") {
|
|
let inner = &remaining[start + 11..start + end];
|
|
let mut parsed_any = false;
|
|
let json_values = extract_json_values(inner);
|
|
for value in json_values {
|
|
let parsed_calls = parse_tool_calls_from_json_value(&value);
|
|
if !parsed_calls.is_empty() {
|
|
parsed_any = true;
|
|
calls.extend(parsed_calls);
|
|
}
|
|
}
|
|
|
|
if !parsed_any {
|
|
tracing::warn!("Malformed <tool_call> JSON: expected tool-call object in tag body");
|
|
}
|
|
|
|
remaining = &remaining[start + end + 12..];
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
|
|
// SECURITY: We do NOT fall back to extracting arbitrary JSON from the response
|
|
// here. That would enable prompt injection attacks where malicious content
|
|
// (e.g., in emails, files, or web pages) could include JSON that mimics a
|
|
// tool call. Tool calls MUST be explicitly wrapped in either:
|
|
// 1. OpenAI-style JSON with a "tool_calls" array
|
|
// 2. ZeroClaw <invoke>...</invoke> tags
|
|
// This ensures only the LLM's intentional tool calls are executed.
|
|
|
|
// Remaining text after last tool call
|
|
if !remaining.trim().is_empty() {
|
|
text_parts.push(remaining.trim().to_string());
|
|
}
|
|
|
|
(text_parts.join("\n"), calls)
|
|
}
|
|
|
|
fn parse_structured_tool_calls(tool_calls: &[ToolCall]) -> Vec<ParsedToolCall> {
|
|
tool_calls
|
|
.iter()
|
|
.map(|call| ParsedToolCall {
|
|
name: call.name.clone(),
|
|
arguments: serde_json::from_str::<serde_json::Value>(&call.arguments)
|
|
.unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new())),
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
fn build_assistant_history_with_tool_calls(text: &str, tool_calls: &[ToolCall]) -> String {
|
|
let mut parts = Vec::new();
|
|
|
|
if !text.trim().is_empty() {
|
|
parts.push(text.trim().to_string());
|
|
}
|
|
|
|
for call in tool_calls {
|
|
let arguments = serde_json::from_str::<serde_json::Value>(&call.arguments)
|
|
.unwrap_or_else(|_| serde_json::Value::String(call.arguments.clone()));
|
|
let payload = serde_json::json!({
|
|
"id": call.id,
|
|
"name": call.name,
|
|
"arguments": arguments,
|
|
});
|
|
parts.push(format!("<tool_call>\n{payload}\n</tool_call>"));
|
|
}
|
|
|
|
parts.join("\n")
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct ParsedToolCall {
|
|
name: String,
|
|
arguments: serde_json::Value,
|
|
}
|
|
|
|
/// Execute a single turn of the agent loop: send messages, parse tool calls,
|
|
/// execute tools, and loop until the LLM produces a final text response.
|
|
/// When `silent` is true, suppresses stdout (for channel use).
|
|
pub(crate) async fn agent_turn(
|
|
provider: &dyn Provider,
|
|
history: &mut Vec<ChatMessage>,
|
|
tools_registry: &[Box<dyn Tool>],
|
|
observer: &dyn Observer,
|
|
provider_name: &str,
|
|
model: &str,
|
|
temperature: f64,
|
|
silent: bool,
|
|
) -> Result<String> {
|
|
run_tool_call_loop(
|
|
provider,
|
|
history,
|
|
tools_registry,
|
|
observer,
|
|
provider_name,
|
|
model,
|
|
temperature,
|
|
silent,
|
|
)
|
|
.await
|
|
}
|
|
|
|
/// Execute a single turn of the agent loop: send messages, parse tool calls,
|
|
/// execute tools, and loop until the LLM produces a final text response.
|
|
pub(crate) async fn run_tool_call_loop(
|
|
provider: &dyn Provider,
|
|
history: &mut Vec<ChatMessage>,
|
|
tools_registry: &[Box<dyn Tool>],
|
|
observer: &dyn Observer,
|
|
provider_name: &str,
|
|
model: &str,
|
|
temperature: f64,
|
|
silent: bool,
|
|
) -> Result<String> {
|
|
// Build native tool definitions once if the provider supports them.
|
|
let use_native_tools = provider.supports_native_tools() && !tools_registry.is_empty();
|
|
let tool_definitions = if use_native_tools {
|
|
tools_to_openai_format(tools_registry)
|
|
} else {
|
|
Vec::new()
|
|
};
|
|
|
|
for _iteration in 0..MAX_TOOL_ITERATIONS {
|
|
observer.record_event(&ObserverEvent::LlmRequest {
|
|
provider: provider_name.to_string(),
|
|
model: model.to_string(),
|
|
messages_count: history.len(),
|
|
});
|
|
|
|
let llm_started_at = Instant::now();
|
|
|
|
// Choose between native tool-call API and prompt-based tool use.
|
|
let (response_text, parsed_text, tool_calls, assistant_history_content) =
|
|
if use_native_tools {
|
|
match provider
|
|
.chat_with_tools(history, &tool_definitions, model, temperature)
|
|
.await
|
|
{
|
|
Ok(resp) => {
|
|
observer.record_event(&ObserverEvent::LlmResponse {
|
|
provider: provider_name.to_string(),
|
|
model: model.to_string(),
|
|
duration: llm_started_at.elapsed(),
|
|
success: true,
|
|
error_message: None,
|
|
});
|
|
let response_text = resp.text_or_empty().to_string();
|
|
let mut calls = parse_structured_tool_calls(&resp.tool_calls);
|
|
let mut parsed_text = String::new();
|
|
|
|
if calls.is_empty() {
|
|
let (fallback_text, fallback_calls) = parse_tool_calls(&response_text);
|
|
if !fallback_text.is_empty() {
|
|
parsed_text = fallback_text;
|
|
}
|
|
calls = fallback_calls;
|
|
}
|
|
|
|
let assistant_history_content = if resp.tool_calls.is_empty() {
|
|
response_text.clone()
|
|
} else {
|
|
build_assistant_history_with_tool_calls(
|
|
&response_text,
|
|
&resp.tool_calls,
|
|
)
|
|
};
|
|
|
|
(response_text, parsed_text, calls, assistant_history_content)
|
|
}
|
|
Err(e) => {
|
|
observer.record_event(&ObserverEvent::LlmResponse {
|
|
provider: provider_name.to_string(),
|
|
model: model.to_string(),
|
|
duration: llm_started_at.elapsed(),
|
|
success: false,
|
|
error_message: Some(crate::providers::sanitize_api_error(
|
|
&e.to_string(),
|
|
)),
|
|
});
|
|
return Err(e);
|
|
}
|
|
}
|
|
} else {
|
|
match provider
|
|
.chat_with_history(history, model, temperature)
|
|
.await
|
|
{
|
|
Ok(resp) => {
|
|
observer.record_event(&ObserverEvent::LlmResponse {
|
|
provider: provider_name.to_string(),
|
|
model: model.to_string(),
|
|
duration: llm_started_at.elapsed(),
|
|
success: true,
|
|
error_message: None,
|
|
});
|
|
let response_text = resp;
|
|
let assistant_history_content = response_text.clone();
|
|
let (parsed_text, calls) = parse_tool_calls(&response_text);
|
|
(response_text, parsed_text, calls, assistant_history_content)
|
|
}
|
|
Err(e) => {
|
|
observer.record_event(&ObserverEvent::LlmResponse {
|
|
provider: provider_name.to_string(),
|
|
model: model.to_string(),
|
|
duration: llm_started_at.elapsed(),
|
|
success: false,
|
|
error_message: Some(crate::providers::sanitize_api_error(
|
|
&e.to_string(),
|
|
)),
|
|
});
|
|
return Err(e);
|
|
}
|
|
}
|
|
};
|
|
|
|
let display_text = if parsed_text.is_empty() {
|
|
response_text.clone()
|
|
} else {
|
|
parsed_text
|
|
};
|
|
|
|
if tool_calls.is_empty() {
|
|
// No tool calls — this is the final response
|
|
history.push(ChatMessage::assistant(response_text.clone()));
|
|
return Ok(display_text);
|
|
}
|
|
|
|
// Print any text the LLM produced alongside tool calls (unless silent)
|
|
if !silent && !display_text.is_empty() {
|
|
print!("{display_text}");
|
|
let _ = std::io::stdout().flush();
|
|
}
|
|
|
|
// Execute each tool call and build results
|
|
let mut tool_results = String::new();
|
|
for call in &tool_calls {
|
|
observer.record_event(&ObserverEvent::ToolCallStart {
|
|
tool: call.name.clone(),
|
|
});
|
|
let start = Instant::now();
|
|
let result = if let Some(tool) = find_tool(tools_registry, &call.name) {
|
|
match tool.execute(call.arguments.clone()).await {
|
|
Ok(r) => {
|
|
observer.record_event(&ObserverEvent::ToolCall {
|
|
tool: call.name.clone(),
|
|
duration: start.elapsed(),
|
|
success: r.success,
|
|
});
|
|
if r.success {
|
|
r.output
|
|
} else {
|
|
format!("Error: {}", r.error.unwrap_or_else(|| r.output))
|
|
}
|
|
}
|
|
Err(e) => {
|
|
observer.record_event(&ObserverEvent::ToolCall {
|
|
tool: call.name.clone(),
|
|
duration: start.elapsed(),
|
|
success: false,
|
|
});
|
|
format!("Error executing {}: {e}", call.name)
|
|
}
|
|
}
|
|
} else {
|
|
format!("Unknown tool: {}", call.name)
|
|
};
|
|
|
|
let _ = writeln!(
|
|
tool_results,
|
|
"<tool_result name=\"{}\">\n{}\n</tool_result>",
|
|
call.name, result
|
|
);
|
|
}
|
|
|
|
// Add assistant message with tool calls + tool results to history
|
|
history.push(ChatMessage::assistant(assistant_history_content));
|
|
history.push(ChatMessage::user(format!("[Tool results]\n{tool_results}")));
|
|
}
|
|
|
|
anyhow::bail!("Agent exceeded maximum tool iterations ({MAX_TOOL_ITERATIONS})")
|
|
}
|
|
|
|
/// Build the tool instruction block for the system prompt so the LLM knows
|
|
/// how to invoke tools.
|
|
pub(crate) fn build_tool_instructions(tools_registry: &[Box<dyn Tool>]) -> String {
|
|
let mut instructions = String::new();
|
|
instructions.push_str("\n## Tool Use Protocol\n\n");
|
|
instructions.push_str("To use a tool, wrap a JSON object in <tool_call></tool_call> tags:\n\n");
|
|
instructions.push_str("```\n<tool_call>\n{\"name\": \"tool_name\", \"arguments\": {\"param\": \"value\"}}\n</tool_call>\n```\n\n");
|
|
instructions.push_str(
|
|
"CRITICAL: Output actual <tool_call> tags—never describe steps or give examples.\n\n",
|
|
);
|
|
instructions.push_str("Example: User says \"what's the date?\". You MUST respond with:\n<tool_call>\n{\"name\":\"shell\",\"arguments\":{\"command\":\"date\"}}\n</tool_call>\n\n");
|
|
instructions.push_str("You may use multiple tool calls in a single response. ");
|
|
instructions.push_str("After tool execution, results appear in <tool_result> tags. ");
|
|
instructions
|
|
.push_str("Continue reasoning with the results until you can give a final answer.\n\n");
|
|
instructions.push_str("### Available Tools\n\n");
|
|
|
|
for tool in tools_registry {
|
|
let _ = writeln!(
|
|
instructions,
|
|
"**{}**: {}\nParameters: `{}`\n",
|
|
tool.name(),
|
|
tool.description(),
|
|
tool.parameters_schema()
|
|
);
|
|
}
|
|
|
|
instructions
|
|
}
|
|
|
|
#[allow(clippy::too_many_lines)]
|
|
pub async fn run(
|
|
config: Config,
|
|
message: Option<String>,
|
|
provider_override: Option<String>,
|
|
model_override: Option<String>,
|
|
temperature: f64,
|
|
peripheral_overrides: Vec<String>,
|
|
) -> Result<String> {
|
|
// ── Wire up agnostic subsystems ──────────────────────────────
|
|
let base_observer = observability::create_observer(&config.observability);
|
|
let observer: Arc<dyn Observer> = Arc::from(base_observer);
|
|
let runtime: Arc<dyn runtime::RuntimeAdapter> =
|
|
Arc::from(runtime::create_runtime(&config.runtime)?);
|
|
let security = Arc::new(SecurityPolicy::from_config(
|
|
&config.autonomy,
|
|
&config.workspace_dir,
|
|
));
|
|
|
|
// ── Memory (the brain) ────────────────────────────────────────
|
|
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory(
|
|
&config.memory,
|
|
&config.workspace_dir,
|
|
config.api_key.as_deref(),
|
|
)?);
|
|
tracing::info!(backend = mem.name(), "Memory initialized");
|
|
|
|
// ── Peripherals (merge peripheral tools into registry) ─
|
|
if !peripheral_overrides.is_empty() {
|
|
tracing::info!(
|
|
peripherals = ?peripheral_overrides,
|
|
"Peripheral overrides from CLI (config boards take precedence)"
|
|
);
|
|
}
|
|
|
|
// ── Tools (including memory tools and peripherals) ────────────
|
|
let (composio_key, composio_entity_id) = if config.composio.enabled {
|
|
(
|
|
config.composio.api_key.as_deref(),
|
|
Some(config.composio.entity_id.as_str()),
|
|
)
|
|
} else {
|
|
(None, None)
|
|
};
|
|
let mut tools_registry = tools::all_tools_with_runtime(
|
|
Arc::new(config.clone()),
|
|
&security,
|
|
runtime,
|
|
mem.clone(),
|
|
composio_key,
|
|
composio_entity_id,
|
|
&config.browser,
|
|
&config.http_request,
|
|
&config.workspace_dir,
|
|
&config.agents,
|
|
config.api_key.as_deref(),
|
|
&config,
|
|
);
|
|
|
|
let peripheral_tools: Vec<Box<dyn Tool>> =
|
|
crate::peripherals::create_peripheral_tools(&config.peripherals).await?;
|
|
if !peripheral_tools.is_empty() {
|
|
tracing::info!(count = peripheral_tools.len(), "Peripheral tools added");
|
|
tools_registry.extend(peripheral_tools);
|
|
}
|
|
|
|
// ── Resolve provider ─────────────────────────────────────────
|
|
let provider_name = provider_override
|
|
.as_deref()
|
|
.or(config.default_provider.as_deref())
|
|
.unwrap_or("openrouter");
|
|
|
|
let model_name = model_override
|
|
.as_deref()
|
|
.or(config.default_model.as_deref())
|
|
.unwrap_or("anthropic/claude-sonnet-4");
|
|
|
|
let provider: Box<dyn Provider> = providers::create_routed_provider(
|
|
provider_name,
|
|
config.api_key.as_deref(),
|
|
config.api_url.as_deref(),
|
|
&config.reliability,
|
|
&config.model_routes,
|
|
model_name,
|
|
)?;
|
|
|
|
observer.record_event(&ObserverEvent::AgentStart {
|
|
provider: provider_name.to_string(),
|
|
model: model_name.to_string(),
|
|
});
|
|
|
|
// ── Hardware RAG (datasheet retrieval when peripherals + datasheet_dir) ──
|
|
let hardware_rag: Option<crate::rag::HardwareRag> = config
|
|
.peripherals
|
|
.datasheet_dir
|
|
.as_ref()
|
|
.filter(|d| !d.trim().is_empty())
|
|
.map(|dir| crate::rag::HardwareRag::load(&config.workspace_dir, dir.trim()))
|
|
.and_then(Result::ok)
|
|
.filter(|r: &crate::rag::HardwareRag| !r.is_empty());
|
|
if let Some(ref rag) = hardware_rag {
|
|
tracing::info!(chunks = rag.len(), "Hardware RAG loaded");
|
|
}
|
|
|
|
let board_names: Vec<String> = config
|
|
.peripherals
|
|
.boards
|
|
.iter()
|
|
.map(|b| b.board.clone())
|
|
.collect();
|
|
|
|
// ── Build system prompt from workspace MD files (OpenClaw framework) ──
|
|
let skills = crate::skills::load_skills(&config.workspace_dir);
|
|
let mut tool_descs: Vec<(&str, &str)> = vec![
|
|
(
|
|
"shell",
|
|
"Execute terminal commands. Use when: running local checks, build/test commands, diagnostics. Don't use when: a safer dedicated tool exists, or command is destructive without approval.",
|
|
),
|
|
(
|
|
"file_read",
|
|
"Read file contents. Use when: inspecting project files, configs, logs. Don't use when: a targeted search is enough.",
|
|
),
|
|
(
|
|
"file_write",
|
|
"Write file contents. Use when: applying focused edits, scaffolding files, updating docs/code. Don't use when: side effects are unclear or file ownership is uncertain.",
|
|
),
|
|
(
|
|
"memory_store",
|
|
"Save to memory. Use when: preserving durable preferences, decisions, key context. Don't use when: information is transient/noisy/sensitive without need.",
|
|
),
|
|
(
|
|
"memory_recall",
|
|
"Search memory. Use when: retrieving prior decisions, user preferences, historical context. Don't use when: answer is already in current context.",
|
|
),
|
|
(
|
|
"memory_forget",
|
|
"Delete a memory entry. Use when: memory is incorrect/stale or explicitly requested for removal. Don't use when: impact is uncertain.",
|
|
),
|
|
];
|
|
tool_descs.push((
|
|
"cron_add",
|
|
"Create a cron job. Supports schedule kinds: cron, at, every; and job types: shell or agent.",
|
|
));
|
|
tool_descs.push((
|
|
"cron_list",
|
|
"List all cron jobs with schedule, status, and metadata.",
|
|
));
|
|
tool_descs.push(("cron_remove", "Remove a cron job by job_id."));
|
|
tool_descs.push((
|
|
"cron_update",
|
|
"Patch a cron job (schedule, enabled, command/prompt, model, delivery, session_target).",
|
|
));
|
|
tool_descs.push((
|
|
"cron_run",
|
|
"Force-run a cron job immediately and record a run history entry.",
|
|
));
|
|
tool_descs.push(("cron_runs", "Show recent run history for a cron job."));
|
|
tool_descs.push((
|
|
"screenshot",
|
|
"Capture a screenshot of the current screen. Returns file path and base64-encoded PNG. Use when: visual verification, UI inspection, debugging displays.",
|
|
));
|
|
tool_descs.push((
|
|
"image_info",
|
|
"Read image file metadata (format, dimensions, size) and optionally base64-encode it. Use when: inspecting images, preparing visual data for analysis.",
|
|
));
|
|
if config.browser.enabled {
|
|
tool_descs.push((
|
|
"browser_open",
|
|
"Open approved HTTPS URLs in Brave Browser (allowlist-only, no scraping)",
|
|
));
|
|
}
|
|
if config.composio.enabled {
|
|
tool_descs.push((
|
|
"composio",
|
|
"Execute actions on 1000+ apps via Composio (Gmail, Notion, GitHub, Slack, etc.). Use action='list' to discover, 'execute' to run (optionally with connected_account_id), 'connect' to OAuth.",
|
|
));
|
|
}
|
|
tool_descs.push((
|
|
"schedule",
|
|
"Manage scheduled tasks (create/list/get/cancel/pause/resume). Supports recurring cron and one-shot delays.",
|
|
));
|
|
if !config.agents.is_empty() {
|
|
tool_descs.push((
|
|
"delegate",
|
|
"Delegate a sub-task to a specialized agent. Use when: task needs different model/capability, or to parallelize work.",
|
|
));
|
|
}
|
|
if config.peripherals.enabled && !config.peripherals.boards.is_empty() {
|
|
tool_descs.push((
|
|
"gpio_read",
|
|
"Read GPIO pin value (0 or 1) on connected hardware (STM32, Arduino). Use when: checking sensor/button state, LED status.",
|
|
));
|
|
tool_descs.push((
|
|
"gpio_write",
|
|
"Set GPIO pin high (1) or low (0) on connected hardware. Use when: turning LED on/off, controlling actuators.",
|
|
));
|
|
tool_descs.push((
|
|
"arduino_upload",
|
|
"Upload agent-generated Arduino sketch. Use when: user asks for 'make a heart', 'blink pattern', or custom LED behavior on Arduino. You write the full .ino code; ZeroClaw compiles and uploads it. Pin 13 = built-in LED on Uno.",
|
|
));
|
|
tool_descs.push((
|
|
"hardware_memory_map",
|
|
"Return flash and RAM address ranges for connected hardware. Use when: user asks for 'upper and lower memory addresses', 'memory map', or 'readable addresses'.",
|
|
));
|
|
tool_descs.push((
|
|
"hardware_board_info",
|
|
"Return full board info (chip, architecture, memory map) for connected hardware. Use when: user asks for 'board info', 'what board do I have', 'connected hardware', 'chip info', or 'what hardware'.",
|
|
));
|
|
tool_descs.push((
|
|
"hardware_memory_read",
|
|
"Read actual memory/register values from Nucleo via USB. Use when: user asks to 'read register values', 'read memory', 'dump lower memory 0-126', 'give address and value'. Params: address (hex, default 0x20000000), length (bytes, default 128).",
|
|
));
|
|
tool_descs.push((
|
|
"hardware_capabilities",
|
|
"Query connected hardware for reported GPIO pins and LED pin. Use when: user asks what pins are available.",
|
|
));
|
|
}
|
|
let bootstrap_max_chars = if config.agent.compact_context {
|
|
Some(6000)
|
|
} else {
|
|
None
|
|
};
|
|
let mut system_prompt = crate::channels::build_system_prompt(
|
|
&config.workspace_dir,
|
|
model_name,
|
|
&tool_descs,
|
|
&skills,
|
|
Some(&config.identity),
|
|
bootstrap_max_chars,
|
|
);
|
|
|
|
// Append structured tool-use instructions with schemas
|
|
system_prompt.push_str(&build_tool_instructions(&tools_registry));
|
|
|
|
// ── Execute ──────────────────────────────────────────────────
|
|
let start = Instant::now();
|
|
|
|
let mut final_output = String::new();
|
|
|
|
if let Some(msg) = message {
|
|
// Auto-save user message to memory
|
|
if config.memory.auto_save {
|
|
let user_key = autosave_memory_key("user_msg");
|
|
let _ = mem
|
|
.store(&user_key, &msg, MemoryCategory::Conversation, None)
|
|
.await;
|
|
}
|
|
|
|
// Inject memory + hardware RAG context into user message
|
|
let mem_context = build_context(mem.as_ref(), &msg).await;
|
|
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
|
|
let hw_context = hardware_rag
|
|
.as_ref()
|
|
.map(|r| build_hardware_context(r, &msg, &board_names, rag_limit))
|
|
.unwrap_or_default();
|
|
let context = format!("{mem_context}{hw_context}");
|
|
let enriched = if context.is_empty() {
|
|
msg.clone()
|
|
} else {
|
|
format!("{context}{msg}")
|
|
};
|
|
|
|
let mut history = vec![
|
|
ChatMessage::system(&system_prompt),
|
|
ChatMessage::user(&enriched),
|
|
];
|
|
|
|
let response = run_tool_call_loop(
|
|
provider.as_ref(),
|
|
&mut history,
|
|
&tools_registry,
|
|
observer.as_ref(),
|
|
provider_name,
|
|
model_name,
|
|
temperature,
|
|
false,
|
|
)
|
|
.await?;
|
|
final_output = response.clone();
|
|
println!("{response}");
|
|
observer.record_event(&ObserverEvent::TurnComplete);
|
|
|
|
// Auto-save assistant response to daily log
|
|
if config.memory.auto_save {
|
|
let summary = truncate_with_ellipsis(&response, 100);
|
|
let response_key = autosave_memory_key("assistant_resp");
|
|
let _ = mem
|
|
.store(&response_key, &summary, MemoryCategory::Daily, None)
|
|
.await;
|
|
}
|
|
} else {
|
|
println!("🦀 ZeroClaw Interactive Mode");
|
|
println!("Type /quit to exit.\n");
|
|
|
|
let (tx, mut rx) = tokio::sync::mpsc::channel(32);
|
|
let cli = crate::channels::CliChannel::new();
|
|
|
|
// Spawn listener
|
|
let listen_handle = tokio::spawn(async move {
|
|
let _ = crate::channels::Channel::listen(&cli, tx).await;
|
|
});
|
|
|
|
// Persistent conversation history across turns
|
|
let mut history = vec![ChatMessage::system(&system_prompt)];
|
|
|
|
while let Some(msg) = rx.recv().await {
|
|
// Auto-save conversation turns
|
|
if config.memory.auto_save {
|
|
let user_key = autosave_memory_key("user_msg");
|
|
let _ = mem
|
|
.store(&user_key, &msg.content, MemoryCategory::Conversation, None)
|
|
.await;
|
|
}
|
|
|
|
// Inject memory + hardware RAG context into user message
|
|
let mem_context = build_context(mem.as_ref(), &msg.content).await;
|
|
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
|
|
let hw_context = hardware_rag
|
|
.as_ref()
|
|
.map(|r| build_hardware_context(r, &msg.content, &board_names, rag_limit))
|
|
.unwrap_or_default();
|
|
let context = format!("{mem_context}{hw_context}");
|
|
let enriched = if context.is_empty() {
|
|
msg.content.clone()
|
|
} else {
|
|
format!("{context}{}", msg.content)
|
|
};
|
|
|
|
history.push(ChatMessage::user(&enriched));
|
|
|
|
let response = match run_tool_call_loop(
|
|
provider.as_ref(),
|
|
&mut history,
|
|
&tools_registry,
|
|
observer.as_ref(),
|
|
provider_name,
|
|
model_name,
|
|
temperature,
|
|
false,
|
|
)
|
|
.await
|
|
{
|
|
Ok(resp) => resp,
|
|
Err(e) => {
|
|
eprintln!("\nError: {e}\n");
|
|
continue;
|
|
}
|
|
};
|
|
final_output = response.clone();
|
|
println!("\n{response}\n");
|
|
observer.record_event(&ObserverEvent::TurnComplete);
|
|
|
|
// Auto-compaction before hard trimming to preserve long-context signal.
|
|
if let Ok(compacted) =
|
|
auto_compact_history(&mut history, provider.as_ref(), model_name).await
|
|
{
|
|
if compacted {
|
|
println!("🧹 Auto-compaction complete");
|
|
}
|
|
}
|
|
|
|
// Hard cap as a safety net.
|
|
trim_history(&mut history);
|
|
|
|
if config.memory.auto_save {
|
|
let summary = truncate_with_ellipsis(&response, 100);
|
|
let response_key = autosave_memory_key("assistant_resp");
|
|
let _ = mem
|
|
.store(&response_key, &summary, MemoryCategory::Daily, None)
|
|
.await;
|
|
}
|
|
}
|
|
|
|
listen_handle.abort();
|
|
}
|
|
|
|
let duration = start.elapsed();
|
|
observer.record_event(&ObserverEvent::AgentEnd {
|
|
duration,
|
|
tokens_used: None,
|
|
cost_usd: None,
|
|
});
|
|
|
|
Ok(final_output)
|
|
}
|
|
|
|
/// Process a single message through the full agent (with tools, peripherals, memory).
|
|
/// Used by channels (Telegram, Discord, etc.) to enable hardware and tool use.
|
|
pub async fn process_message(config: Config, message: &str) -> Result<String> {
|
|
let observer: Arc<dyn Observer> =
|
|
Arc::from(observability::create_observer(&config.observability));
|
|
let runtime: Arc<dyn runtime::RuntimeAdapter> =
|
|
Arc::from(runtime::create_runtime(&config.runtime)?);
|
|
let security = Arc::new(SecurityPolicy::from_config(
|
|
&config.autonomy,
|
|
&config.workspace_dir,
|
|
));
|
|
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory(
|
|
&config.memory,
|
|
&config.workspace_dir,
|
|
config.api_key.as_deref(),
|
|
)?);
|
|
|
|
let (composio_key, composio_entity_id) = if config.composio.enabled {
|
|
(
|
|
config.composio.api_key.as_deref(),
|
|
Some(config.composio.entity_id.as_str()),
|
|
)
|
|
} else {
|
|
(None, None)
|
|
};
|
|
let mut tools_registry = tools::all_tools_with_runtime(
|
|
Arc::new(config.clone()),
|
|
&security,
|
|
runtime,
|
|
mem.clone(),
|
|
composio_key,
|
|
composio_entity_id,
|
|
&config.browser,
|
|
&config.http_request,
|
|
&config.workspace_dir,
|
|
&config.agents,
|
|
config.api_key.as_deref(),
|
|
&config,
|
|
);
|
|
let peripheral_tools: Vec<Box<dyn Tool>> =
|
|
crate::peripherals::create_peripheral_tools(&config.peripherals).await?;
|
|
tools_registry.extend(peripheral_tools);
|
|
|
|
let provider_name = config.default_provider.as_deref().unwrap_or("openrouter");
|
|
let model_name = config
|
|
.default_model
|
|
.clone()
|
|
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into());
|
|
let provider: Box<dyn Provider> = providers::create_routed_provider(
|
|
provider_name,
|
|
config.api_key.as_deref(),
|
|
config.api_url.as_deref(),
|
|
&config.reliability,
|
|
&config.model_routes,
|
|
&model_name,
|
|
)?;
|
|
|
|
let hardware_rag: Option<crate::rag::HardwareRag> = config
|
|
.peripherals
|
|
.datasheet_dir
|
|
.as_ref()
|
|
.filter(|d| !d.trim().is_empty())
|
|
.map(|dir| crate::rag::HardwareRag::load(&config.workspace_dir, dir.trim()))
|
|
.and_then(Result::ok)
|
|
.filter(|r: &crate::rag::HardwareRag| !r.is_empty());
|
|
let board_names: Vec<String> = config
|
|
.peripherals
|
|
.boards
|
|
.iter()
|
|
.map(|b| b.board.clone())
|
|
.collect();
|
|
|
|
let skills = crate::skills::load_skills(&config.workspace_dir);
|
|
let mut tool_descs: Vec<(&str, &str)> = vec![
|
|
("shell", "Execute terminal commands."),
|
|
("file_read", "Read file contents."),
|
|
("file_write", "Write file contents."),
|
|
("memory_store", "Save to memory."),
|
|
("memory_recall", "Search memory."),
|
|
("memory_forget", "Delete a memory entry."),
|
|
("screenshot", "Capture a screenshot."),
|
|
("image_info", "Read image metadata."),
|
|
];
|
|
if config.browser.enabled {
|
|
tool_descs.push(("browser_open", "Open approved URLs in browser."));
|
|
}
|
|
if config.composio.enabled {
|
|
tool_descs.push(("composio", "Execute actions on 1000+ apps via Composio."));
|
|
}
|
|
if config.peripherals.enabled && !config.peripherals.boards.is_empty() {
|
|
tool_descs.push(("gpio_read", "Read GPIO pin value on connected hardware."));
|
|
tool_descs.push((
|
|
"gpio_write",
|
|
"Set GPIO pin high or low on connected hardware.",
|
|
));
|
|
tool_descs.push((
|
|
"arduino_upload",
|
|
"Upload Arduino sketch. Use for 'make a heart', custom patterns. You write full .ino code; ZeroClaw uploads it.",
|
|
));
|
|
tool_descs.push((
|
|
"hardware_memory_map",
|
|
"Return flash and RAM address ranges. Use when user asks for memory addresses or memory map.",
|
|
));
|
|
tool_descs.push((
|
|
"hardware_board_info",
|
|
"Return full board info (chip, architecture, memory map). Use when user asks for board info, what board, connected hardware, or chip info.",
|
|
));
|
|
tool_descs.push((
|
|
"hardware_memory_read",
|
|
"Read actual memory/register values from Nucleo. Use when user asks to read registers, read memory, dump lower memory 0-126, or give address and value.",
|
|
));
|
|
tool_descs.push((
|
|
"hardware_capabilities",
|
|
"Query connected hardware for reported GPIO pins and LED pin. Use when user asks what pins are available.",
|
|
));
|
|
}
|
|
let bootstrap_max_chars = if config.agent.compact_context {
|
|
Some(6000)
|
|
} else {
|
|
None
|
|
};
|
|
let mut system_prompt = crate::channels::build_system_prompt(
|
|
&config.workspace_dir,
|
|
&model_name,
|
|
&tool_descs,
|
|
&skills,
|
|
Some(&config.identity),
|
|
bootstrap_max_chars,
|
|
);
|
|
system_prompt.push_str(&build_tool_instructions(&tools_registry));
|
|
|
|
let mem_context = build_context(mem.as_ref(), message).await;
|
|
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
|
|
let hw_context = hardware_rag
|
|
.as_ref()
|
|
.map(|r| build_hardware_context(r, message, &board_names, rag_limit))
|
|
.unwrap_or_default();
|
|
let context = format!("{mem_context}{hw_context}");
|
|
let enriched = if context.is_empty() {
|
|
message.to_string()
|
|
} else {
|
|
format!("{context}{message}")
|
|
};
|
|
|
|
let mut history = vec![
|
|
ChatMessage::system(&system_prompt),
|
|
ChatMessage::user(&enriched),
|
|
];
|
|
|
|
agent_turn(
|
|
provider.as_ref(),
|
|
&mut history,
|
|
&tools_registry,
|
|
observer.as_ref(),
|
|
provider_name,
|
|
&model_name,
|
|
config.default_temperature,
|
|
true,
|
|
)
|
|
.await
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::memory::{Memory, MemoryCategory, SqliteMemory};
|
|
use tempfile::TempDir;
|
|
|
|
#[test]
|
|
fn parse_tool_calls_extracts_single_call() {
|
|
let response = r#"Let me check that.
|
|
<tool_call>
|
|
{"name": "shell", "arguments": {"command": "ls -la"}}
|
|
</tool_call>"#;
|
|
|
|
let (text, calls) = parse_tool_calls(response);
|
|
assert_eq!(text, "Let me check that.");
|
|
assert_eq!(calls.len(), 1);
|
|
assert_eq!(calls[0].name, "shell");
|
|
assert_eq!(
|
|
calls[0].arguments.get("command").unwrap().as_str().unwrap(),
|
|
"ls -la"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn parse_tool_calls_extracts_multiple_calls() {
|
|
let response = r#"<tool_call>
|
|
{"name": "file_read", "arguments": {"path": "a.txt"}}
|
|
</tool_call>
|
|
<tool_call>
|
|
{"name": "file_read", "arguments": {"path": "b.txt"}}
|
|
</tool_call>"#;
|
|
|
|
let (_, calls) = parse_tool_calls(response);
|
|
assert_eq!(calls.len(), 2);
|
|
assert_eq!(calls[0].name, "file_read");
|
|
assert_eq!(calls[1].name, "file_read");
|
|
}
|
|
|
|
#[test]
|
|
fn parse_tool_calls_returns_text_only_when_no_calls() {
|
|
let response = "Just a normal response with no tools.";
|
|
let (text, calls) = parse_tool_calls(response);
|
|
assert_eq!(text, "Just a normal response with no tools.");
|
|
assert!(calls.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn parse_tool_calls_handles_malformed_json() {
|
|
let response = r#"<tool_call>
|
|
not valid json
|
|
</tool_call>
|
|
Some text after."#;
|
|
|
|
let (text, calls) = parse_tool_calls(response);
|
|
assert!(calls.is_empty());
|
|
assert!(text.contains("Some text after."));
|
|
}
|
|
|
|
#[test]
|
|
fn parse_tool_calls_text_before_and_after() {
|
|
let response = r#"Before text.
|
|
<tool_call>
|
|
{"name": "shell", "arguments": {"command": "echo hi"}}
|
|
</tool_call>
|
|
After text."#;
|
|
|
|
let (text, calls) = parse_tool_calls(response);
|
|
assert!(text.contains("Before text."));
|
|
assert!(text.contains("After text."));
|
|
assert_eq!(calls.len(), 1);
|
|
}
|
|
|
|
#[test]
|
|
fn parse_tool_calls_handles_openai_format() {
|
|
// OpenAI-style response with tool_calls array
|
|
let response = r#"{"content": "Let me check that for you.", "tool_calls": [{"type": "function", "function": {"name": "shell", "arguments": "{\"command\": \"ls -la\"}"}}]}"#;
|
|
|
|
let (text, calls) = parse_tool_calls(response);
|
|
assert_eq!(text, "Let me check that for you.");
|
|
assert_eq!(calls.len(), 1);
|
|
assert_eq!(calls[0].name, "shell");
|
|
assert_eq!(
|
|
calls[0].arguments.get("command").unwrap().as_str().unwrap(),
|
|
"ls -la"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn parse_tool_calls_handles_openai_format_multiple_calls() {
|
|
let response = r#"{"tool_calls": [{"type": "function", "function": {"name": "file_read", "arguments": "{\"path\": \"a.txt\"}"}}, {"type": "function", "function": {"name": "file_read", "arguments": "{\"path\": \"b.txt\"}"}}]}"#;
|
|
|
|
let (_, calls) = parse_tool_calls(response);
|
|
assert_eq!(calls.len(), 2);
|
|
assert_eq!(calls[0].name, "file_read");
|
|
assert_eq!(calls[1].name, "file_read");
|
|
}
|
|
|
|
#[test]
|
|
fn parse_tool_calls_openai_format_without_content() {
|
|
// Some providers don't include content field with tool_calls
|
|
let response = r#"{"tool_calls": [{"type": "function", "function": {"name": "memory_recall", "arguments": "{}"}}]}"#;
|
|
|
|
let (text, calls) = parse_tool_calls(response);
|
|
assert!(text.is_empty()); // No content field
|
|
assert_eq!(calls.len(), 1);
|
|
assert_eq!(calls[0].name, "memory_recall");
|
|
}
|
|
|
|
#[test]
|
|
fn parse_tool_calls_handles_markdown_json_inside_tool_call_tag() {
|
|
let response = r#"<tool_call>
|
|
```json
|
|
{"name": "file_write", "arguments": {"path": "test.py", "content": "print('ok')"}}
|
|
```
|
|
</tool_call>"#;
|
|
|
|
let (text, calls) = parse_tool_calls(response);
|
|
assert!(text.is_empty());
|
|
assert_eq!(calls.len(), 1);
|
|
assert_eq!(calls[0].name, "file_write");
|
|
assert_eq!(
|
|
calls[0].arguments.get("path").unwrap().as_str().unwrap(),
|
|
"test.py"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn parse_tool_calls_handles_noisy_tool_call_tag_body() {
|
|
let response = r#"<tool_call>
|
|
I will now call the tool with this payload:
|
|
{"name": "shell", "arguments": {"command": "pwd"}}
|
|
</tool_call>"#;
|
|
|
|
let (text, calls) = parse_tool_calls(response);
|
|
assert!(text.is_empty());
|
|
assert_eq!(calls.len(), 1);
|
|
assert_eq!(calls[0].name, "shell");
|
|
assert_eq!(
|
|
calls[0].arguments.get("command").unwrap().as_str().unwrap(),
|
|
"pwd"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn parse_tool_calls_rejects_raw_tool_json_without_tags() {
|
|
// SECURITY: Raw JSON without explicit wrappers should NOT be parsed
|
|
// This prevents prompt injection attacks where malicious content
|
|
// could include JSON that mimics a tool call.
|
|
let response = r#"Sure, creating the file now.
|
|
{"name": "file_write", "arguments": {"path": "hello.py", "content": "print('hello')"}}"#;
|
|
|
|
let (text, calls) = parse_tool_calls(response);
|
|
assert!(text.contains("Sure, creating the file now."));
|
|
assert_eq!(
|
|
calls.len(),
|
|
0,
|
|
"Raw JSON without wrappers should not be parsed"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn build_tool_instructions_includes_all_tools() {
|
|
use crate::security::SecurityPolicy;
|
|
let security = Arc::new(SecurityPolicy::from_config(
|
|
&crate::config::AutonomyConfig::default(),
|
|
std::path::Path::new("/tmp"),
|
|
));
|
|
let tools = tools::default_tools(security);
|
|
let instructions = build_tool_instructions(&tools);
|
|
|
|
assert!(instructions.contains("## Tool Use Protocol"));
|
|
assert!(instructions.contains("<tool_call>"));
|
|
assert!(instructions.contains("shell"));
|
|
assert!(instructions.contains("file_read"));
|
|
assert!(instructions.contains("file_write"));
|
|
}
|
|
|
|
#[test]
|
|
fn tools_to_openai_format_produces_valid_schema() {
|
|
use crate::security::SecurityPolicy;
|
|
let security = Arc::new(SecurityPolicy::from_config(
|
|
&crate::config::AutonomyConfig::default(),
|
|
std::path::Path::new("/tmp"),
|
|
));
|
|
let tools = tools::default_tools(security);
|
|
let formatted = tools_to_openai_format(&tools);
|
|
|
|
assert!(!formatted.is_empty());
|
|
for tool_json in &formatted {
|
|
assert_eq!(tool_json["type"], "function");
|
|
assert!(tool_json["function"]["name"].is_string());
|
|
assert!(tool_json["function"]["description"].is_string());
|
|
assert!(!tool_json["function"]["name"].as_str().unwrap().is_empty());
|
|
}
|
|
// Verify known tools are present
|
|
let names: Vec<&str> = formatted
|
|
.iter()
|
|
.filter_map(|t| t["function"]["name"].as_str())
|
|
.collect();
|
|
assert!(names.contains(&"shell"));
|
|
assert!(names.contains(&"file_read"));
|
|
}
|
|
|
|
#[test]
|
|
fn trim_history_preserves_system_prompt() {
|
|
let mut history = vec![ChatMessage::system("system prompt")];
|
|
for i in 0..MAX_HISTORY_MESSAGES + 20 {
|
|
history.push(ChatMessage::user(format!("msg {i}")));
|
|
}
|
|
let original_len = history.len();
|
|
assert!(original_len > MAX_HISTORY_MESSAGES + 1);
|
|
|
|
trim_history(&mut history);
|
|
|
|
// System prompt preserved
|
|
assert_eq!(history[0].role, "system");
|
|
assert_eq!(history[0].content, "system prompt");
|
|
// Trimmed to limit
|
|
assert_eq!(history.len(), MAX_HISTORY_MESSAGES + 1); // +1 for system
|
|
// Most recent messages preserved
|
|
let last = &history[history.len() - 1];
|
|
assert_eq!(last.content, format!("msg {}", MAX_HISTORY_MESSAGES + 19));
|
|
}
|
|
|
|
#[test]
|
|
fn trim_history_noop_when_within_limit() {
|
|
let mut history = vec![
|
|
ChatMessage::system("sys"),
|
|
ChatMessage::user("hello"),
|
|
ChatMessage::assistant("hi"),
|
|
];
|
|
trim_history(&mut history);
|
|
assert_eq!(history.len(), 3);
|
|
}
|
|
|
|
#[test]
|
|
fn build_compaction_transcript_formats_roles() {
|
|
let messages = vec![
|
|
ChatMessage::user("I like dark mode"),
|
|
ChatMessage::assistant("Got it"),
|
|
];
|
|
let transcript = build_compaction_transcript(&messages);
|
|
assert!(transcript.contains("USER: I like dark mode"));
|
|
assert!(transcript.contains("ASSISTANT: Got it"));
|
|
}
|
|
|
|
#[test]
|
|
fn apply_compaction_summary_replaces_old_segment() {
|
|
let mut history = vec![
|
|
ChatMessage::system("sys"),
|
|
ChatMessage::user("old 1"),
|
|
ChatMessage::assistant("old 2"),
|
|
ChatMessage::user("recent 1"),
|
|
ChatMessage::assistant("recent 2"),
|
|
];
|
|
|
|
apply_compaction_summary(&mut history, 1, 3, "- user prefers concise replies");
|
|
|
|
assert_eq!(history.len(), 4);
|
|
assert!(history[1].content.contains("Compaction summary"));
|
|
assert!(history[2].content.contains("recent 1"));
|
|
assert!(history[3].content.contains("recent 2"));
|
|
}
|
|
|
|
#[test]
|
|
fn autosave_memory_key_has_prefix_and_uniqueness() {
|
|
let key1 = autosave_memory_key("user_msg");
|
|
let key2 = autosave_memory_key("user_msg");
|
|
|
|
assert!(key1.starts_with("user_msg_"));
|
|
assert!(key2.starts_with("user_msg_"));
|
|
assert_ne!(key1, key2);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn autosave_memory_keys_preserve_multiple_turns() {
|
|
let tmp = TempDir::new().unwrap();
|
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
|
|
|
let key1 = autosave_memory_key("user_msg");
|
|
let key2 = autosave_memory_key("user_msg");
|
|
|
|
mem.store(&key1, "I'm Paul", MemoryCategory::Conversation, None)
|
|
.await
|
|
.unwrap();
|
|
mem.store(&key2, "I'm 45", MemoryCategory::Conversation, None)
|
|
.await
|
|
.unwrap();
|
|
|
|
assert_eq!(mem.count().await.unwrap(), 2);
|
|
|
|
let recalled = mem.recall("45", 5, None).await.unwrap();
|
|
assert!(recalled.iter().any(|entry| entry.content.contains("45")));
|
|
}
|
|
|
|
// ═══════════════════════════════════════════════════════════════════════
|
|
// Recovery Tests - Tool Call Parsing Edge Cases
|
|
// ═══════════════════════════════════════════════════════════════════════
|
|
|
|
#[test]
|
|
fn parse_tool_calls_handles_empty_tool_result() {
|
|
// Recovery: Empty tool_result tag should be handled gracefully
|
|
let response = r#"I'll run that command.
|
|
<tool_result name="shell">
|
|
|
|
</tool_result>
|
|
Done."#;
|
|
let (text, calls) = parse_tool_calls(response);
|
|
assert!(text.contains("Done."));
|
|
assert!(calls.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn parse_arguments_value_handles_null() {
|
|
// Recovery: null arguments are returned as-is (Value::Null)
|
|
let value = serde_json::json!(null);
|
|
let result = parse_arguments_value(Some(&value));
|
|
assert!(result.is_null());
|
|
}
|
|
|
|
#[test]
|
|
fn parse_tool_calls_handles_empty_tool_calls_array() {
|
|
// Recovery: Empty tool_calls array returns original response (no tool parsing)
|
|
let response = r#"{"content": "Hello", "tool_calls": []}"#;
|
|
let (text, calls) = parse_tool_calls(response);
|
|
// When tool_calls is empty, the entire JSON is returned as text
|
|
assert!(text.contains("Hello"));
|
|
assert!(calls.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn parse_tool_calls_handles_whitespace_only_name() {
|
|
// Recovery: Whitespace-only tool name should return None
|
|
let value = serde_json::json!({"function": {"name": " ", "arguments": {}}});
|
|
let result = parse_tool_call_value(&value);
|
|
assert!(result.is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn parse_tool_calls_handles_empty_string_arguments() {
|
|
// Recovery: Empty string arguments should be handled
|
|
let value = serde_json::json!({"name": "test", "arguments": ""});
|
|
let result = parse_tool_call_value(&value);
|
|
assert!(result.is_some());
|
|
assert_eq!(result.unwrap().name, "test");
|
|
}
|
|
|
|
// ═══════════════════════════════════════════════════════════════════════
|
|
// Recovery Tests - History Management
|
|
// ═══════════════════════════════════════════════════════════════════════
|
|
|
|
#[test]
|
|
fn trim_history_with_no_system_prompt() {
|
|
// Recovery: History without system prompt should trim correctly
|
|
let mut history = vec![];
|
|
for i in 0..MAX_HISTORY_MESSAGES + 20 {
|
|
history.push(ChatMessage::user(format!("msg {i}")));
|
|
}
|
|
trim_history(&mut history);
|
|
assert_eq!(history.len(), MAX_HISTORY_MESSAGES);
|
|
}
|
|
|
|
#[test]
|
|
fn trim_history_preserves_role_ordering() {
|
|
// Recovery: After trimming, role ordering should remain consistent
|
|
let mut history = vec![ChatMessage::system("system")];
|
|
for i in 0..MAX_HISTORY_MESSAGES + 10 {
|
|
history.push(ChatMessage::user(format!("user {i}")));
|
|
history.push(ChatMessage::assistant(format!("assistant {i}")));
|
|
}
|
|
trim_history(&mut history);
|
|
assert_eq!(history[0].role, "system");
|
|
assert_eq!(history[history.len() - 1].role, "assistant");
|
|
}
|
|
|
|
#[test]
|
|
fn trim_history_with_only_system_prompt() {
|
|
// Recovery: Only system prompt should not be trimmed
|
|
let mut history = vec![ChatMessage::system("system prompt")];
|
|
trim_history(&mut history);
|
|
assert_eq!(history.len(), 1);
|
|
}
|
|
|
|
// ═══════════════════════════════════════════════════════════════════════
|
|
// Recovery Tests - Arguments Parsing
|
|
// ═══════════════════════════════════════════════════════════════════════
|
|
|
|
#[test]
|
|
fn parse_arguments_value_handles_invalid_json_string() {
|
|
// Recovery: Invalid JSON string should return empty object
|
|
let value = serde_json::Value::String("not valid json".to_string());
|
|
let result = parse_arguments_value(Some(&value));
|
|
assert!(result.is_object());
|
|
assert!(result.as_object().unwrap().is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn parse_arguments_value_handles_none() {
|
|
// Recovery: None arguments should return empty object
|
|
let result = parse_arguments_value(None);
|
|
assert!(result.is_object());
|
|
assert!(result.as_object().unwrap().is_empty());
|
|
}
|
|
|
|
// ═══════════════════════════════════════════════════════════════════════
|
|
// Recovery Tests - JSON Extraction
|
|
// ═══════════════════════════════════════════════════════════════════════
|
|
|
|
#[test]
|
|
fn extract_json_values_handles_empty_string() {
|
|
// Recovery: Empty input should return empty vec
|
|
let result = extract_json_values("");
|
|
assert!(result.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn extract_json_values_handles_whitespace_only() {
|
|
// Recovery: Whitespace only should return empty vec
|
|
let result = extract_json_values(" \n\t ");
|
|
assert!(result.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn extract_json_values_handles_multiple_objects() {
|
|
// Recovery: Multiple JSON objects should all be extracted
|
|
let input = r#"{"a": 1}{"b": 2}{"c": 3}"#;
|
|
let result = extract_json_values(input);
|
|
assert_eq!(result.len(), 3);
|
|
}
|
|
|
|
#[test]
|
|
fn extract_json_values_handles_arrays() {
|
|
// Recovery: JSON arrays should be extracted
|
|
let input = r#"[1, 2, 3]{"key": "value"}"#;
|
|
let result = extract_json_values(input);
|
|
assert_eq!(result.len(), 2);
|
|
}
|
|
|
|
// ═══════════════════════════════════════════════════════════════════════
|
|
// Recovery Tests - Constants Validation
|
|
// ═══════════════════════════════════════════════════════════════════════
|
|
|
|
const _: () = {
|
|
assert!(MAX_TOOL_ITERATIONS > 0);
|
|
assert!(MAX_TOOL_ITERATIONS <= 100);
|
|
assert!(MAX_HISTORY_MESSAGES > 0);
|
|
assert!(MAX_HISTORY_MESSAGES <= 1000);
|
|
};
|
|
|
|
#[test]
|
|
fn constants_bounds_are_compile_time_checked() {
|
|
// Bounds are enforced by the const assertions above.
|
|
}
|
|
|
|
// ═══════════════════════════════════════════════════════════════════════
|
|
// Recovery Tests - Tool Call Value Parsing
|
|
// ═══════════════════════════════════════════════════════════════════════
|
|
|
|
#[test]
|
|
fn parse_tool_call_value_handles_missing_name_field() {
|
|
// Recovery: Missing name field should return None
|
|
let value = serde_json::json!({"function": {"arguments": {}}});
|
|
let result = parse_tool_call_value(&value);
|
|
assert!(result.is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn parse_tool_call_value_handles_top_level_name() {
|
|
// Recovery: Tool call with name at top level (non-OpenAI format)
|
|
let value = serde_json::json!({"name": "test_tool", "arguments": {}});
|
|
let result = parse_tool_call_value(&value);
|
|
assert!(result.is_some());
|
|
assert_eq!(result.unwrap().name, "test_tool");
|
|
}
|
|
|
|
#[test]
|
|
fn parse_tool_calls_from_json_value_handles_empty_array() {
|
|
// Recovery: Empty tool_calls array should return empty vec
|
|
let value = serde_json::json!({"tool_calls": []});
|
|
let result = parse_tool_calls_from_json_value(&value);
|
|
assert!(result.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn parse_tool_calls_from_json_value_handles_missing_tool_calls() {
|
|
// Recovery: Missing tool_calls field should fall through
|
|
let value = serde_json::json!({"name": "test", "arguments": {}});
|
|
let result = parse_tool_calls_from_json_value(&value);
|
|
assert_eq!(result.len(), 1);
|
|
}
|
|
|
|
#[test]
|
|
fn parse_tool_calls_from_json_value_handles_top_level_array() {
|
|
// Recovery: Top-level array of tool calls
|
|
let value = serde_json::json!([
|
|
{"name": "tool_a", "arguments": {}},
|
|
{"name": "tool_b", "arguments": {}}
|
|
]);
|
|
let result = parse_tool_calls_from_json_value(&value);
|
|
assert_eq!(result.len(), 2);
|
|
}
|
|
}
|