zeroclaw/src/agent/agent.rs

710 lines
22 KiB
Rust

use crate::agent::dispatcher::{
NativeToolDispatcher, ParsedToolCall, ToolDispatcher, ToolExecutionResult, XmlToolDispatcher,
};
use crate::agent::memory_loader::{DefaultMemoryLoader, MemoryLoader};
use crate::agent::prompt::{PromptContext, SystemPromptBuilder};
use crate::config::Config;
use crate::memory::{self, Memory, MemoryCategory};
use crate::observability::{self, Observer, ObserverEvent};
use crate::providers::{self, ChatMessage, ChatRequest, ConversationMessage, Provider};
use crate::runtime;
use crate::security::SecurityPolicy;
use crate::tools::{self, Tool, ToolSpec};
use crate::util::truncate_with_ellipsis;
use anyhow::Result;
use std::io::Write as IoWrite;
use std::sync::Arc;
use std::time::Instant;
pub struct Agent {
provider: Box<dyn Provider>,
tools: Vec<Box<dyn Tool>>,
tool_specs: Vec<ToolSpec>,
memory: Arc<dyn Memory>,
observer: Arc<dyn Observer>,
prompt_builder: SystemPromptBuilder,
tool_dispatcher: Box<dyn ToolDispatcher>,
memory_loader: Box<dyn MemoryLoader>,
config: crate::config::AgentConfig,
model_name: String,
temperature: f64,
workspace_dir: std::path::PathBuf,
identity_config: crate::config::IdentityConfig,
skills: Vec<crate::skills::Skill>,
auto_save: bool,
history: Vec<ConversationMessage>,
}
pub struct AgentBuilder {
provider: Option<Box<dyn Provider>>,
tools: Option<Vec<Box<dyn Tool>>>,
memory: Option<Arc<dyn Memory>>,
observer: Option<Arc<dyn Observer>>,
prompt_builder: Option<SystemPromptBuilder>,
tool_dispatcher: Option<Box<dyn ToolDispatcher>>,
memory_loader: Option<Box<dyn MemoryLoader>>,
config: Option<crate::config::AgentConfig>,
model_name: Option<String>,
temperature: Option<f64>,
workspace_dir: Option<std::path::PathBuf>,
identity_config: Option<crate::config::IdentityConfig>,
skills: Option<Vec<crate::skills::Skill>>,
auto_save: Option<bool>,
}
impl AgentBuilder {
pub fn new() -> Self {
Self {
provider: None,
tools: None,
memory: None,
observer: None,
prompt_builder: None,
tool_dispatcher: None,
memory_loader: None,
config: None,
model_name: None,
temperature: None,
workspace_dir: None,
identity_config: None,
skills: None,
auto_save: None,
}
}
pub fn provider(mut self, provider: Box<dyn Provider>) -> Self {
self.provider = Some(provider);
self
}
pub fn tools(mut self, tools: Vec<Box<dyn Tool>>) -> Self {
self.tools = Some(tools);
self
}
pub fn memory(mut self, memory: Arc<dyn Memory>) -> Self {
self.memory = Some(memory);
self
}
pub fn observer(mut self, observer: Arc<dyn Observer>) -> Self {
self.observer = Some(observer);
self
}
pub fn prompt_builder(mut self, prompt_builder: SystemPromptBuilder) -> Self {
self.prompt_builder = Some(prompt_builder);
self
}
pub fn tool_dispatcher(mut self, tool_dispatcher: Box<dyn ToolDispatcher>) -> Self {
self.tool_dispatcher = Some(tool_dispatcher);
self
}
pub fn memory_loader(mut self, memory_loader: Box<dyn MemoryLoader>) -> Self {
self.memory_loader = Some(memory_loader);
self
}
pub fn config(mut self, config: crate::config::AgentConfig) -> Self {
self.config = Some(config);
self
}
pub fn model_name(mut self, model_name: String) -> Self {
self.model_name = Some(model_name);
self
}
pub fn temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}
pub fn workspace_dir(mut self, workspace_dir: std::path::PathBuf) -> Self {
self.workspace_dir = Some(workspace_dir);
self
}
pub fn identity_config(mut self, identity_config: crate::config::IdentityConfig) -> Self {
self.identity_config = Some(identity_config);
self
}
pub fn skills(mut self, skills: Vec<crate::skills::Skill>) -> Self {
self.skills = Some(skills);
self
}
pub fn auto_save(mut self, auto_save: bool) -> Self {
self.auto_save = Some(auto_save);
self
}
pub fn build(self) -> Result<Agent> {
let tools = self
.tools
.ok_or_else(|| anyhow::anyhow!("tools are required"))?;
let tool_specs = tools.iter().map(|tool| tool.spec()).collect();
Ok(Agent {
provider: self
.provider
.ok_or_else(|| anyhow::anyhow!("provider is required"))?,
tools,
tool_specs,
memory: self
.memory
.ok_or_else(|| anyhow::anyhow!("memory is required"))?,
observer: self
.observer
.ok_or_else(|| anyhow::anyhow!("observer is required"))?,
prompt_builder: self
.prompt_builder
.unwrap_or_else(SystemPromptBuilder::with_defaults),
tool_dispatcher: self
.tool_dispatcher
.ok_or_else(|| anyhow::anyhow!("tool_dispatcher is required"))?,
memory_loader: self
.memory_loader
.unwrap_or_else(|| Box::new(DefaultMemoryLoader::default())),
config: self.config.unwrap_or_default(),
model_name: self
.model_name
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into()),
temperature: self.temperature.unwrap_or(0.7),
workspace_dir: self
.workspace_dir
.unwrap_or_else(|| std::path::PathBuf::from(".")),
identity_config: self.identity_config.unwrap_or_default(),
skills: self.skills.unwrap_or_default(),
auto_save: self.auto_save.unwrap_or(false),
history: Vec::new(),
})
}
}
impl Agent {
pub fn builder() -> AgentBuilder {
AgentBuilder::new()
}
pub fn history(&self) -> &[ConversationMessage] {
&self.history
}
pub fn clear_history(&mut self) {
self.history.clear();
}
pub fn from_config(config: &Config) -> Result<Self> {
let observer: Arc<dyn Observer> =
Arc::from(observability::create_observer(&config.observability));
let runtime: Arc<dyn runtime::RuntimeAdapter> =
Arc::from(runtime::create_runtime(&config.runtime)?);
let security = Arc::new(SecurityPolicy::from_config(
&config.autonomy,
&config.workspace_dir,
));
let memory: Arc<dyn Memory> = Arc::from(memory::create_memory(
&config.memory,
&config.workspace_dir,
config.api_key.as_deref(),
)?);
let composio_key = if config.composio.enabled {
config.composio.api_key.as_deref()
} else {
None
};
let composio_entity_id = if config.composio.enabled {
Some(config.composio.entity_id.as_str())
} else {
None
};
let tools = tools::all_tools_with_runtime(
Arc::new(config.clone()),
&security,
runtime,
memory.clone(),
composio_key,
composio_entity_id,
&config.browser,
&config.http_request,
&config.workspace_dir,
&config.agents,
config.api_key.as_deref(),
config,
);
let provider_name = config.default_provider.as_deref().unwrap_or("openrouter");
let model_name = config
.default_model
.as_deref()
.unwrap_or("anthropic/claude-sonnet-4-20250514")
.to_string();
let provider: Box<dyn Provider> = providers::create_routed_provider(
provider_name,
config.api_key.as_deref(),
config.api_url.as_deref(),
&config.reliability,
&config.model_routes,
&model_name,
)?;
let dispatcher_choice = config.agent.tool_dispatcher.as_str();
let tool_dispatcher: Box<dyn ToolDispatcher> = match dispatcher_choice {
"native" => Box::new(NativeToolDispatcher),
"xml" => Box::new(XmlToolDispatcher),
_ if provider.supports_native_tools() => Box::new(NativeToolDispatcher),
_ => Box::new(XmlToolDispatcher),
};
Agent::builder()
.provider(provider)
.tools(tools)
.memory(memory)
.observer(observer)
.tool_dispatcher(tool_dispatcher)
.memory_loader(Box::new(DefaultMemoryLoader::default()))
.prompt_builder(SystemPromptBuilder::with_defaults())
.config(config.agent.clone())
.model_name(model_name)
.temperature(config.default_temperature)
.workspace_dir(config.workspace_dir.clone())
.identity_config(config.identity.clone())
.skills(crate::skills::load_skills(&config.workspace_dir))
.auto_save(config.memory.auto_save)
.build()
}
fn trim_history(&mut self) {
let max = self.config.max_history_messages;
if self.history.len() <= max {
return;
}
let mut system_messages = Vec::new();
let mut other_messages = Vec::new();
for msg in self.history.drain(..) {
match &msg {
ConversationMessage::Chat(chat) if chat.role == "system" => {
system_messages.push(msg);
}
_ => other_messages.push(msg),
}
}
if other_messages.len() > max {
let drop_count = other_messages.len() - max;
other_messages.drain(0..drop_count);
}
self.history = system_messages;
self.history.extend(other_messages);
}
fn build_system_prompt(&self) -> Result<String> {
let instructions = self.tool_dispatcher.prompt_instructions(&self.tools);
let ctx = PromptContext {
workspace_dir: &self.workspace_dir,
model_name: &self.model_name,
tools: &self.tools,
skills: &self.skills,
identity_config: Some(&self.identity_config),
dispatcher_instructions: &instructions,
};
self.prompt_builder.build(&ctx)
}
async fn execute_tool_call(&self, call: &ParsedToolCall) -> ToolExecutionResult {
let start = Instant::now();
let result = if let Some(tool) = self.tools.iter().find(|t| t.name() == call.name) {
match tool.execute(call.arguments.clone()).await {
Ok(r) => {
self.observer.record_event(&ObserverEvent::ToolCall {
tool: call.name.clone(),
duration: start.elapsed(),
success: r.success,
});
if r.success {
r.output
} else {
format!("Error: {}", r.error.unwrap_or(r.output))
}
}
Err(e) => {
self.observer.record_event(&ObserverEvent::ToolCall {
tool: call.name.clone(),
duration: start.elapsed(),
success: false,
});
format!("Error executing {}: {e}", call.name)
}
}
} else {
format!("Unknown tool: {}", call.name)
};
ToolExecutionResult {
name: call.name.clone(),
output: result,
success: true,
tool_call_id: call.tool_call_id.clone(),
}
}
async fn execute_tools(&self, calls: &[ParsedToolCall]) -> Vec<ToolExecutionResult> {
if !self.config.parallel_tools {
let mut results = Vec::with_capacity(calls.len());
for call in calls {
results.push(self.execute_tool_call(call).await);
}
return results;
}
let mut results = Vec::with_capacity(calls.len());
for call in calls {
results.push(self.execute_tool_call(call).await);
}
results
}
pub async fn turn(&mut self, user_message: &str) -> Result<String> {
if self.history.is_empty() {
let system_prompt = self.build_system_prompt()?;
self.history
.push(ConversationMessage::Chat(ChatMessage::system(
system_prompt,
)));
}
if self.auto_save {
let _ = self
.memory
.store("user_msg", user_message, MemoryCategory::Conversation, None)
.await;
}
let context = self
.memory_loader
.load_context(self.memory.as_ref(), user_message)
.await
.unwrap_or_default();
let enriched = if context.is_empty() {
user_message.to_string()
} else {
format!("{context}{user_message}")
};
self.history
.push(ConversationMessage::Chat(ChatMessage::user(enriched)));
for _ in 0..self.config.max_tool_iterations {
let messages = self.tool_dispatcher.to_provider_messages(&self.history);
let response = match self
.provider
.chat(
ChatRequest {
messages: &messages,
tools: if self.tool_dispatcher.should_send_tool_specs() {
Some(&self.tool_specs)
} else {
None
},
},
&self.model_name,
self.temperature,
)
.await
{
Ok(resp) => resp,
Err(err) => return Err(err),
};
let (text, calls) = self.tool_dispatcher.parse_response(&response);
if calls.is_empty() {
let final_text = if text.is_empty() {
response.text.unwrap_or_default()
} else {
text
};
self.history
.push(ConversationMessage::Chat(ChatMessage::assistant(
final_text.clone(),
)));
self.trim_history();
if self.auto_save {
let summary = truncate_with_ellipsis(&final_text, 100);
let _ = self
.memory
.store("assistant_resp", &summary, MemoryCategory::Daily, None)
.await;
}
return Ok(final_text);
}
if !text.is_empty() {
self.history
.push(ConversationMessage::Chat(ChatMessage::assistant(
text.clone(),
)));
print!("{text}");
let _ = std::io::stdout().flush();
}
self.history.push(ConversationMessage::AssistantToolCalls {
text: response.text.clone(),
tool_calls: response.tool_calls.clone(),
});
let results = self.execute_tools(&calls).await;
let formatted = self.tool_dispatcher.format_results(&results);
self.history.push(formatted);
self.trim_history();
}
anyhow::bail!(
"Agent exceeded maximum tool iterations ({})",
self.config.max_tool_iterations
)
}
pub async fn run_single(&mut self, message: &str) -> Result<String> {
self.turn(message).await
}
pub async fn run_interactive(&mut self) -> Result<()> {
println!("🦀 ZeroClaw Interactive Mode");
println!("Type /quit to exit.\n");
let (tx, mut rx) = tokio::sync::mpsc::channel(32);
let cli = crate::channels::CliChannel::new();
let listen_handle = tokio::spawn(async move {
let _ = crate::channels::Channel::listen(&cli, tx).await;
});
while let Some(msg) = rx.recv().await {
let response = match self.turn(&msg.content).await {
Ok(resp) => resp,
Err(e) => {
eprintln!("\nError: {e}\n");
continue;
}
};
println!("\n{response}\n");
}
listen_handle.abort();
Ok(())
}
}
pub async fn run(
config: Config,
message: Option<String>,
provider_override: Option<String>,
model_override: Option<String>,
temperature: f64,
) -> Result<()> {
let start = Instant::now();
let mut effective_config = config;
if let Some(p) = provider_override {
effective_config.default_provider = Some(p);
}
if let Some(m) = model_override {
effective_config.default_model = Some(m);
}
effective_config.default_temperature = temperature;
let mut agent = Agent::from_config(&effective_config)?;
let provider_name = effective_config
.default_provider
.as_deref()
.unwrap_or("openrouter")
.to_string();
let model_name = effective_config
.default_model
.as_deref()
.unwrap_or("anthropic/claude-sonnet-4-20250514")
.to_string();
agent.observer.record_event(&ObserverEvent::AgentStart {
provider: provider_name.clone(),
model: model_name.clone(),
});
if let Some(msg) = message {
let response = agent.run_single(&msg).await?;
println!("{response}");
} else {
agent.run_interactive().await?;
}
agent.observer.record_event(&ObserverEvent::AgentEnd {
provider: provider_name,
model: model_name,
duration: start.elapsed(),
tokens_used: None,
cost_usd: None,
});
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use parking_lot::Mutex;
struct MockProvider {
responses: Mutex<Vec<crate::providers::ChatResponse>>,
}
#[async_trait]
impl Provider for MockProvider {
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> Result<String> {
Ok("ok".into())
}
async fn chat(
&self,
_request: ChatRequest<'_>,
_model: &str,
_temperature: f64,
) -> Result<crate::providers::ChatResponse> {
let mut guard = self.responses.lock();
if guard.is_empty() {
return Ok(crate::providers::ChatResponse {
text: Some("done".into()),
tool_calls: vec![],
});
}
Ok(guard.remove(0))
}
}
struct MockTool;
#[async_trait]
impl Tool for MockTool {
fn name(&self) -> &str {
"echo"
}
fn description(&self) -> &str {
"echo"
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({"type": "object"})
}
async fn execute(&self, _args: serde_json::Value) -> Result<crate::tools::ToolResult> {
Ok(crate::tools::ToolResult {
success: true,
output: "tool-out".into(),
error: None,
})
}
}
#[tokio::test]
async fn turn_without_tools_returns_text() {
let provider = Box::new(MockProvider {
responses: Mutex::new(vec![crate::providers::ChatResponse {
text: Some("hello".into()),
tool_calls: vec![],
}]),
});
let memory_cfg = crate::config::MemoryConfig {
backend: "none".into(),
..crate::config::MemoryConfig::default()
};
let mem: Arc<dyn Memory> = Arc::from(
crate::memory::create_memory(&memory_cfg, std::path::Path::new("/tmp"), None).unwrap(),
);
let observer: Arc<dyn Observer> = Arc::from(crate::observability::NoopObserver {});
let mut agent = Agent::builder()
.provider(provider)
.tools(vec![Box::new(MockTool)])
.memory(mem)
.observer(observer)
.tool_dispatcher(Box::new(XmlToolDispatcher))
.workspace_dir(std::path::PathBuf::from("/tmp"))
.build()
.unwrap();
let response = agent.turn("hi").await.unwrap();
assert_eq!(response, "hello");
}
#[tokio::test]
async fn turn_with_native_dispatcher_handles_tool_results_variant() {
let provider = Box::new(MockProvider {
responses: Mutex::new(vec![
crate::providers::ChatResponse {
text: Some(String::new()),
tool_calls: vec![crate::providers::ToolCall {
id: "tc1".into(),
name: "echo".into(),
arguments: "{}".into(),
}],
},
crate::providers::ChatResponse {
text: Some("done".into()),
tool_calls: vec![],
},
]),
});
let memory_cfg = crate::config::MemoryConfig {
backend: "none".into(),
..crate::config::MemoryConfig::default()
};
let mem: Arc<dyn Memory> = Arc::from(
crate::memory::create_memory(&memory_cfg, std::path::Path::new("/tmp"), None).unwrap(),
);
let observer: Arc<dyn Observer> = Arc::from(crate::observability::NoopObserver {});
let mut agent = Agent::builder()
.provider(provider)
.tools(vec![Box::new(MockTool)])
.memory(mem)
.observer(observer)
.tool_dispatcher(Box::new(NativeToolDispatcher))
.workspace_dir(std::path::PathBuf::from("/tmp"))
.build()
.unwrap();
let response = agent.turn("hi").await.unwrap();
assert_eq!(response, "done");
assert!(agent
.history()
.iter()
.any(|msg| matches!(msg, ConversationMessage::ToolResults(_))));
}
}