feat: add agent structure and improve tooling for provider
This commit is contained in:
parent
e2c966d31e
commit
b341fdb368
21 changed files with 2567 additions and 443 deletions
701
src/agent/agent.rs
Normal file
701
src/agent/agent.rs
Normal file
|
|
@ -0,0 +1,701 @@
|
|||
use crate::agent::dispatcher::{
|
||||
NativeToolDispatcher, ParsedToolCall, ToolDispatcher, ToolExecutionResult, XmlToolDispatcher,
|
||||
};
|
||||
use crate::agent::memory_loader::{DefaultMemoryLoader, MemoryLoader};
|
||||
use crate::agent::prompt::{PromptContext, SystemPromptBuilder};
|
||||
use crate::config::Config;
|
||||
use crate::memory::{self, Memory, MemoryCategory};
|
||||
use crate::observability::{self, Observer, ObserverEvent};
|
||||
use crate::providers::{self, ChatMessage, ChatRequest, ConversationMessage, Provider};
|
||||
use crate::runtime;
|
||||
use crate::security::SecurityPolicy;
|
||||
use crate::tools::{self, Tool, ToolSpec};
|
||||
use crate::util::truncate_with_ellipsis;
|
||||
use anyhow::Result;
|
||||
use std::io::Write as IoWrite;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
pub struct Agent {
|
||||
provider: Box<dyn Provider>,
|
||||
tools: Vec<Box<dyn Tool>>,
|
||||
tool_specs: Vec<ToolSpec>,
|
||||
memory: Arc<dyn Memory>,
|
||||
observer: Arc<dyn Observer>,
|
||||
prompt_builder: SystemPromptBuilder,
|
||||
tool_dispatcher: Box<dyn ToolDispatcher>,
|
||||
memory_loader: Box<dyn MemoryLoader>,
|
||||
config: crate::config::AgentConfig,
|
||||
model_name: String,
|
||||
temperature: f64,
|
||||
workspace_dir: std::path::PathBuf,
|
||||
identity_config: crate::config::IdentityConfig,
|
||||
skills: Vec<crate::skills::Skill>,
|
||||
auto_save: bool,
|
||||
history: Vec<ConversationMessage>,
|
||||
}
|
||||
|
||||
pub struct AgentBuilder {
|
||||
provider: Option<Box<dyn Provider>>,
|
||||
tools: Option<Vec<Box<dyn Tool>>>,
|
||||
memory: Option<Arc<dyn Memory>>,
|
||||
observer: Option<Arc<dyn Observer>>,
|
||||
prompt_builder: Option<SystemPromptBuilder>,
|
||||
tool_dispatcher: Option<Box<dyn ToolDispatcher>>,
|
||||
memory_loader: Option<Box<dyn MemoryLoader>>,
|
||||
config: Option<crate::config::AgentConfig>,
|
||||
model_name: Option<String>,
|
||||
temperature: Option<f64>,
|
||||
workspace_dir: Option<std::path::PathBuf>,
|
||||
identity_config: Option<crate::config::IdentityConfig>,
|
||||
skills: Option<Vec<crate::skills::Skill>>,
|
||||
auto_save: Option<bool>,
|
||||
}
|
||||
|
||||
impl AgentBuilder {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
provider: None,
|
||||
tools: None,
|
||||
memory: None,
|
||||
observer: None,
|
||||
prompt_builder: None,
|
||||
tool_dispatcher: None,
|
||||
memory_loader: None,
|
||||
config: None,
|
||||
model_name: None,
|
||||
temperature: None,
|
||||
workspace_dir: None,
|
||||
identity_config: None,
|
||||
skills: None,
|
||||
auto_save: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn provider(mut self, provider: Box<dyn Provider>) -> Self {
|
||||
self.provider = Some(provider);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn tools(mut self, tools: Vec<Box<dyn Tool>>) -> Self {
|
||||
self.tools = Some(tools);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn memory(mut self, memory: Arc<dyn Memory>) -> Self {
|
||||
self.memory = Some(memory);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn observer(mut self, observer: Arc<dyn Observer>) -> Self {
|
||||
self.observer = Some(observer);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn prompt_builder(mut self, prompt_builder: SystemPromptBuilder) -> Self {
|
||||
self.prompt_builder = Some(prompt_builder);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn tool_dispatcher(mut self, tool_dispatcher: Box<dyn ToolDispatcher>) -> Self {
|
||||
self.tool_dispatcher = Some(tool_dispatcher);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn memory_loader(mut self, memory_loader: Box<dyn MemoryLoader>) -> Self {
|
||||
self.memory_loader = Some(memory_loader);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn config(mut self, config: crate::config::AgentConfig) -> Self {
|
||||
self.config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn model_name(mut self, model_name: String) -> Self {
|
||||
self.model_name = Some(model_name);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn temperature(mut self, temperature: f64) -> Self {
|
||||
self.temperature = Some(temperature);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn workspace_dir(mut self, workspace_dir: std::path::PathBuf) -> Self {
|
||||
self.workspace_dir = Some(workspace_dir);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn identity_config(mut self, identity_config: crate::config::IdentityConfig) -> Self {
|
||||
self.identity_config = Some(identity_config);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn skills(mut self, skills: Vec<crate::skills::Skill>) -> Self {
|
||||
self.skills = Some(skills);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn auto_save(mut self, auto_save: bool) -> Self {
|
||||
self.auto_save = Some(auto_save);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Result<Agent> {
|
||||
let tools = self
|
||||
.tools
|
||||
.ok_or_else(|| anyhow::anyhow!("tools are required"))?;
|
||||
let tool_specs = tools.iter().map(|tool| tool.spec()).collect();
|
||||
|
||||
Ok(Agent {
|
||||
provider: self
|
||||
.provider
|
||||
.ok_or_else(|| anyhow::anyhow!("provider is required"))?,
|
||||
tools,
|
||||
tool_specs,
|
||||
memory: self
|
||||
.memory
|
||||
.ok_or_else(|| anyhow::anyhow!("memory is required"))?,
|
||||
observer: self
|
||||
.observer
|
||||
.ok_or_else(|| anyhow::anyhow!("observer is required"))?,
|
||||
prompt_builder: self
|
||||
.prompt_builder
|
||||
.unwrap_or_else(SystemPromptBuilder::with_defaults),
|
||||
tool_dispatcher: self
|
||||
.tool_dispatcher
|
||||
.ok_or_else(|| anyhow::anyhow!("tool_dispatcher is required"))?,
|
||||
memory_loader: self
|
||||
.memory_loader
|
||||
.unwrap_or_else(|| Box::new(DefaultMemoryLoader::default())),
|
||||
config: self.config.unwrap_or_default(),
|
||||
model_name: self
|
||||
.model_name
|
||||
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into()),
|
||||
temperature: self.temperature.unwrap_or(0.7),
|
||||
workspace_dir: self
|
||||
.workspace_dir
|
||||
.unwrap_or_else(|| std::path::PathBuf::from(".")),
|
||||
identity_config: self.identity_config.unwrap_or_default(),
|
||||
skills: self.skills.unwrap_or_default(),
|
||||
auto_save: self.auto_save.unwrap_or(false),
|
||||
history: Vec::new(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Agent {
|
||||
pub fn builder() -> AgentBuilder {
|
||||
AgentBuilder::new()
|
||||
}
|
||||
|
||||
pub fn history(&self) -> &[ConversationMessage] {
|
||||
&self.history
|
||||
}
|
||||
|
||||
pub fn clear_history(&mut self) {
|
||||
self.history.clear();
|
||||
}
|
||||
|
||||
pub fn from_config(config: &Config) -> Result<Self> {
|
||||
let observer: Arc<dyn Observer> =
|
||||
Arc::from(observability::create_observer(&config.observability));
|
||||
let runtime: Arc<dyn runtime::RuntimeAdapter> =
|
||||
Arc::from(runtime::create_runtime(&config.runtime)?);
|
||||
let security = Arc::new(SecurityPolicy::from_config(
|
||||
&config.autonomy,
|
||||
&config.workspace_dir,
|
||||
));
|
||||
|
||||
let memory: Arc<dyn Memory> = Arc::from(memory::create_memory(
|
||||
&config.memory,
|
||||
&config.workspace_dir,
|
||||
config.api_key.as_deref(),
|
||||
)?);
|
||||
|
||||
let composio_key = if config.composio.enabled {
|
||||
config.composio.api_key.as_deref()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let tools = tools::all_tools_with_runtime(
|
||||
&security,
|
||||
runtime,
|
||||
memory.clone(),
|
||||
composio_key,
|
||||
&config.browser,
|
||||
&config.http_request,
|
||||
&config.workspace_dir,
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
);
|
||||
|
||||
let provider_name = config.default_provider.as_deref().unwrap_or("openrouter");
|
||||
|
||||
let model_name = config
|
||||
.default_model
|
||||
.as_deref()
|
||||
.unwrap_or("anthropic/claude-sonnet-4-20250514")
|
||||
.to_string();
|
||||
|
||||
let provider: Box<dyn Provider> = providers::create_routed_provider(
|
||||
provider_name,
|
||||
config.api_key.as_deref(),
|
||||
&config.reliability,
|
||||
&config.model_routes,
|
||||
&model_name,
|
||||
)?;
|
||||
|
||||
let dispatcher_choice = config.agent.tool_dispatcher.as_str();
|
||||
let tool_dispatcher: Box<dyn ToolDispatcher> = match dispatcher_choice {
|
||||
"native" => Box::new(NativeToolDispatcher),
|
||||
"xml" => Box::new(XmlToolDispatcher),
|
||||
_ if provider.supports_native_tools() => Box::new(NativeToolDispatcher),
|
||||
_ => Box::new(XmlToolDispatcher),
|
||||
};
|
||||
|
||||
Agent::builder()
|
||||
.provider(provider)
|
||||
.tools(tools)
|
||||
.memory(memory)
|
||||
.observer(observer)
|
||||
.tool_dispatcher(tool_dispatcher)
|
||||
.memory_loader(Box::new(DefaultMemoryLoader::default()))
|
||||
.prompt_builder(SystemPromptBuilder::with_defaults())
|
||||
.config(config.agent.clone())
|
||||
.model_name(model_name)
|
||||
.temperature(config.default_temperature)
|
||||
.workspace_dir(config.workspace_dir.clone())
|
||||
.identity_config(config.identity.clone())
|
||||
.skills(crate::skills::load_skills(&config.workspace_dir))
|
||||
.auto_save(config.memory.auto_save)
|
||||
.build()
|
||||
}
|
||||
|
||||
fn trim_history(&mut self) {
|
||||
let max = self.config.max_history_messages;
|
||||
if self.history.len() <= max {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut system_messages = Vec::new();
|
||||
let mut other_messages = Vec::new();
|
||||
|
||||
for msg in self.history.drain(..) {
|
||||
match &msg {
|
||||
ConversationMessage::Chat(chat) if chat.role == "system" => {
|
||||
system_messages.push(msg)
|
||||
}
|
||||
_ => other_messages.push(msg),
|
||||
}
|
||||
}
|
||||
|
||||
if other_messages.len() > max {
|
||||
let drop_count = other_messages.len() - max;
|
||||
other_messages.drain(0..drop_count);
|
||||
}
|
||||
|
||||
self.history = system_messages;
|
||||
self.history.extend(other_messages);
|
||||
}
|
||||
|
||||
fn build_system_prompt(&self) -> Result<String> {
|
||||
let instructions = self.tool_dispatcher.prompt_instructions(&self.tools);
|
||||
let ctx = PromptContext {
|
||||
workspace_dir: &self.workspace_dir,
|
||||
model_name: &self.model_name,
|
||||
tools: &self.tools,
|
||||
skills: &self.skills,
|
||||
identity_config: Some(&self.identity_config),
|
||||
dispatcher_instructions: &instructions,
|
||||
};
|
||||
self.prompt_builder.build(&ctx)
|
||||
}
|
||||
|
||||
async fn execute_tool_call(&self, call: &ParsedToolCall) -> ToolExecutionResult {
|
||||
let start = Instant::now();
|
||||
|
||||
let result = if let Some(tool) = self.tools.iter().find(|t| t.name() == call.name) {
|
||||
match tool.execute(call.arguments.clone()).await {
|
||||
Ok(r) => {
|
||||
self.observer.record_event(&ObserverEvent::ToolCall {
|
||||
tool: call.name.clone(),
|
||||
duration: start.elapsed(),
|
||||
success: r.success,
|
||||
});
|
||||
if r.success {
|
||||
r.output
|
||||
} else {
|
||||
format!("Error: {}", r.error.unwrap_or(r.output))
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
self.observer.record_event(&ObserverEvent::ToolCall {
|
||||
tool: call.name.clone(),
|
||||
duration: start.elapsed(),
|
||||
success: false,
|
||||
});
|
||||
format!("Error executing {}: {e}", call.name)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
format!("Unknown tool: {}", call.name)
|
||||
};
|
||||
|
||||
ToolExecutionResult {
|
||||
name: call.name.clone(),
|
||||
output: result,
|
||||
success: true,
|
||||
tool_call_id: call.tool_call_id.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn execute_tools(&self, calls: &[ParsedToolCall]) -> Vec<ToolExecutionResult> {
|
||||
if !self.config.parallel_tools {
|
||||
let mut results = Vec::with_capacity(calls.len());
|
||||
for call in calls {
|
||||
results.push(self.execute_tool_call(call).await);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
let mut results = Vec::with_capacity(calls.len());
|
||||
for call in calls {
|
||||
results.push(self.execute_tool_call(call).await);
|
||||
}
|
||||
results
|
||||
}
|
||||
|
||||
pub async fn turn(&mut self, user_message: &str) -> Result<String> {
|
||||
if self.history.is_empty() {
|
||||
let system_prompt = self.build_system_prompt()?;
|
||||
self.history
|
||||
.push(ConversationMessage::Chat(ChatMessage::system(
|
||||
system_prompt,
|
||||
)));
|
||||
}
|
||||
|
||||
if self.auto_save {
|
||||
let _ = self
|
||||
.memory
|
||||
.store("user_msg", user_message, MemoryCategory::Conversation)
|
||||
.await;
|
||||
}
|
||||
|
||||
let context = self
|
||||
.memory_loader
|
||||
.load_context(self.memory.as_ref(), user_message)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
let enriched = if context.is_empty() {
|
||||
user_message.to_string()
|
||||
} else {
|
||||
format!("{context}{user_message}")
|
||||
};
|
||||
|
||||
self.history
|
||||
.push(ConversationMessage::Chat(ChatMessage::user(enriched)));
|
||||
|
||||
for _ in 0..self.config.max_tool_iterations {
|
||||
let messages = self.tool_dispatcher.to_provider_messages(&self.history);
|
||||
let response = match self
|
||||
.provider
|
||||
.chat(
|
||||
ChatRequest {
|
||||
messages: &messages,
|
||||
tools: if self.tool_dispatcher.should_send_tool_specs() {
|
||||
Some(&self.tool_specs)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
},
|
||||
&self.model_name,
|
||||
self.temperature,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp,
|
||||
Err(err) => return Err(err),
|
||||
};
|
||||
|
||||
let (text, calls) = self.tool_dispatcher.parse_response(&response);
|
||||
if calls.is_empty() {
|
||||
let final_text = if text.is_empty() {
|
||||
response.text.unwrap_or_default()
|
||||
} else {
|
||||
text
|
||||
};
|
||||
|
||||
self.history
|
||||
.push(ConversationMessage::Chat(ChatMessage::assistant(
|
||||
final_text.clone(),
|
||||
)));
|
||||
self.trim_history();
|
||||
|
||||
if self.auto_save {
|
||||
let summary = truncate_with_ellipsis(&final_text, 100);
|
||||
let _ = self
|
||||
.memory
|
||||
.store("assistant_resp", &summary, MemoryCategory::Daily)
|
||||
.await;
|
||||
}
|
||||
|
||||
return Ok(final_text);
|
||||
}
|
||||
|
||||
if !text.is_empty() {
|
||||
self.history
|
||||
.push(ConversationMessage::Chat(ChatMessage::assistant(
|
||||
text.clone(),
|
||||
)));
|
||||
print!("{text}");
|
||||
let _ = std::io::stdout().flush();
|
||||
}
|
||||
|
||||
self.history.push(ConversationMessage::AssistantToolCalls {
|
||||
text: response.text.clone(),
|
||||
tool_calls: response.tool_calls.clone(),
|
||||
});
|
||||
|
||||
let results = self.execute_tools(&calls).await;
|
||||
let formatted = self.tool_dispatcher.format_results(&results);
|
||||
self.history.push(formatted);
|
||||
self.trim_history();
|
||||
}
|
||||
|
||||
anyhow::bail!(
|
||||
"Agent exceeded maximum tool iterations ({})",
|
||||
self.config.max_tool_iterations
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn run_single(&mut self, message: &str) -> Result<String> {
|
||||
self.turn(message).await
|
||||
}
|
||||
|
||||
pub async fn run_interactive(&mut self) -> Result<()> {
|
||||
println!("🦀 ZeroClaw Interactive Mode");
|
||||
println!("Type /quit to exit.\n");
|
||||
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel(32);
|
||||
let cli = crate::channels::CliChannel::new();
|
||||
|
||||
let listen_handle = tokio::spawn(async move {
|
||||
let _ = crate::channels::Channel::listen(&cli, tx).await;
|
||||
});
|
||||
|
||||
while let Some(msg) = rx.recv().await {
|
||||
let response = match self.turn(&msg.content).await {
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
eprintln!("\nError: {e}\n");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
println!("\n{response}\n");
|
||||
}
|
||||
|
||||
listen_handle.abort();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(
|
||||
config: Config,
|
||||
message: Option<String>,
|
||||
provider_override: Option<String>,
|
||||
model_override: Option<String>,
|
||||
temperature: f64,
|
||||
) -> Result<()> {
|
||||
let start = Instant::now();
|
||||
|
||||
let mut effective_config = config;
|
||||
if let Some(p) = provider_override {
|
||||
effective_config.default_provider = Some(p);
|
||||
}
|
||||
if let Some(m) = model_override {
|
||||
effective_config.default_model = Some(m);
|
||||
}
|
||||
effective_config.default_temperature = temperature;
|
||||
|
||||
let mut agent = Agent::from_config(&effective_config)?;
|
||||
|
||||
let provider_name = effective_config
|
||||
.default_provider
|
||||
.as_deref()
|
||||
.unwrap_or("openrouter")
|
||||
.to_string();
|
||||
let model_name = effective_config
|
||||
.default_model
|
||||
.as_deref()
|
||||
.unwrap_or("anthropic/claude-sonnet-4-20250514")
|
||||
.to_string();
|
||||
|
||||
agent.observer.record_event(&ObserverEvent::AgentStart {
|
||||
provider: provider_name,
|
||||
model: model_name,
|
||||
});
|
||||
|
||||
if let Some(msg) = message {
|
||||
let response = agent.run_single(&msg).await?;
|
||||
println!("{response}");
|
||||
} else {
|
||||
agent.run_interactive().await?;
|
||||
}
|
||||
|
||||
agent.observer.record_event(&ObserverEvent::AgentEnd {
|
||||
duration: start.elapsed(),
|
||||
tokens_used: None,
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Mutex;
|
||||
|
||||
struct MockProvider {
|
||||
responses: Mutex<Vec<crate::providers::ChatResponse>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for MockProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
_system_prompt: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> Result<String> {
|
||||
Ok("ok".into())
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
_request: ChatRequest<'_>,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> Result<crate::providers::ChatResponse> {
|
||||
let mut guard = self.responses.lock().unwrap();
|
||||
if guard.is_empty() {
|
||||
return Ok(crate::providers::ChatResponse {
|
||||
text: Some("done".into()),
|
||||
tool_calls: vec![],
|
||||
});
|
||||
}
|
||||
Ok(guard.remove(0))
|
||||
}
|
||||
}
|
||||
|
||||
struct MockTool;
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for MockTool {
|
||||
fn name(&self) -> &str {
|
||||
"echo"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"echo"
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
serde_json::json!({"type": "object"})
|
||||
}
|
||||
|
||||
async fn execute(&self, _args: serde_json::Value) -> Result<crate::tools::ToolResult> {
|
||||
Ok(crate::tools::ToolResult {
|
||||
success: true,
|
||||
output: "tool-out".into(),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn turn_without_tools_returns_text() {
|
||||
let provider = Box::new(MockProvider {
|
||||
responses: Mutex::new(vec![crate::providers::ChatResponse {
|
||||
text: Some("hello".into()),
|
||||
tool_calls: vec![],
|
||||
}]),
|
||||
});
|
||||
|
||||
let memory_cfg = crate::config::MemoryConfig {
|
||||
backend: "none".into(),
|
||||
..crate::config::MemoryConfig::default()
|
||||
};
|
||||
let mem: Arc<dyn Memory> = Arc::from(
|
||||
crate::memory::create_memory(&memory_cfg, std::path::Path::new("/tmp"), None).unwrap(),
|
||||
);
|
||||
|
||||
let observer: Arc<dyn Observer> = Arc::from(crate::observability::NoopObserver {});
|
||||
let mut agent = Agent::builder()
|
||||
.provider(provider)
|
||||
.tools(vec![Box::new(MockTool)])
|
||||
.memory(mem)
|
||||
.observer(observer)
|
||||
.tool_dispatcher(Box::new(XmlToolDispatcher))
|
||||
.workspace_dir(std::path::PathBuf::from("/tmp"))
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let response = agent.turn("hi").await.unwrap();
|
||||
assert_eq!(response, "hello");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn turn_with_native_dispatcher_handles_tool_results_variant() {
|
||||
let provider = Box::new(MockProvider {
|
||||
responses: Mutex::new(vec![
|
||||
crate::providers::ChatResponse {
|
||||
text: Some("".into()),
|
||||
tool_calls: vec![crate::providers::ToolCall {
|
||||
id: "tc1".into(),
|
||||
name: "echo".into(),
|
||||
arguments: "{}".into(),
|
||||
}],
|
||||
},
|
||||
crate::providers::ChatResponse {
|
||||
text: Some("done".into()),
|
||||
tool_calls: vec![],
|
||||
},
|
||||
]),
|
||||
});
|
||||
|
||||
let memory_cfg = crate::config::MemoryConfig {
|
||||
backend: "none".into(),
|
||||
..crate::config::MemoryConfig::default()
|
||||
};
|
||||
let mem: Arc<dyn Memory> = Arc::from(
|
||||
crate::memory::create_memory(&memory_cfg, std::path::Path::new("/tmp"), None).unwrap(),
|
||||
);
|
||||
|
||||
let observer: Arc<dyn Observer> = Arc::from(crate::observability::NoopObserver {});
|
||||
let mut agent = Agent::builder()
|
||||
.provider(provider)
|
||||
.tools(vec![Box::new(MockTool)])
|
||||
.memory(mem)
|
||||
.observer(observer)
|
||||
.tool_dispatcher(Box::new(NativeToolDispatcher))
|
||||
.workspace_dir(std::path::PathBuf::from("/tmp"))
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let response = agent.turn("hi").await.unwrap();
|
||||
assert_eq!(response, "done");
|
||||
assert!(matches!(
|
||||
agent
|
||||
.history()
|
||||
.iter()
|
||||
.find(|msg| matches!(msg, ConversationMessage::ToolResults(_))),
|
||||
Some(_)
|
||||
));
|
||||
}
|
||||
}
|
||||
312
src/agent/dispatcher.rs
Normal file
312
src/agent/dispatcher.rs
Normal file
|
|
@ -0,0 +1,312 @@
|
|||
use crate::providers::{ChatMessage, ChatResponse, ConversationMessage, ToolResultMessage};
|
||||
use crate::tools::{Tool, ToolSpec};
|
||||
use serde_json::Value;
|
||||
use std::fmt::Write;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ParsedToolCall {
|
||||
pub name: String,
|
||||
pub arguments: Value,
|
||||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ToolExecutionResult {
|
||||
pub name: String,
|
||||
pub output: String,
|
||||
pub success: bool,
|
||||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
pub trait ToolDispatcher: Send + Sync {
|
||||
fn parse_response(&self, response: &ChatResponse) -> (String, Vec<ParsedToolCall>);
|
||||
fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage;
|
||||
fn prompt_instructions(&self, tools: &[Box<dyn Tool>]) -> String;
|
||||
fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec<ChatMessage>;
|
||||
fn should_send_tool_specs(&self) -> bool;
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct XmlToolDispatcher;
|
||||
|
||||
impl XmlToolDispatcher {
|
||||
fn parse_xml_tool_calls(response: &str) -> (String, Vec<ParsedToolCall>) {
|
||||
let mut text_parts = Vec::new();
|
||||
let mut calls = Vec::new();
|
||||
let mut remaining = response;
|
||||
|
||||
while let Some(start) = remaining.find("<tool_call>") {
|
||||
let before = &remaining[..start];
|
||||
if !before.trim().is_empty() {
|
||||
text_parts.push(before.trim().to_string());
|
||||
}
|
||||
|
||||
if let Some(end) = remaining[start..].find("</tool_call>") {
|
||||
let inner = &remaining[start + 11..start + end];
|
||||
match serde_json::from_str::<Value>(inner.trim()) {
|
||||
Ok(parsed) => {
|
||||
let name = parsed
|
||||
.get("name")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
if name.is_empty() {
|
||||
remaining = &remaining[start + end + 12..];
|
||||
continue;
|
||||
}
|
||||
let arguments = parsed
|
||||
.get("arguments")
|
||||
.cloned()
|
||||
.unwrap_or_else(|| Value::Object(serde_json::Map::new()));
|
||||
calls.push(ParsedToolCall {
|
||||
name,
|
||||
arguments,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Malformed <tool_call> JSON: {e}");
|
||||
}
|
||||
}
|
||||
remaining = &remaining[start + end + 12..];
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if !remaining.trim().is_empty() {
|
||||
text_parts.push(remaining.trim().to_string());
|
||||
}
|
||||
|
||||
(text_parts.join("\n"), calls)
|
||||
}
|
||||
|
||||
pub fn tool_specs(tools: &[Box<dyn Tool>]) -> Vec<ToolSpec> {
|
||||
tools.iter().map(|tool| tool.spec()).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolDispatcher for XmlToolDispatcher {
|
||||
fn parse_response(&self, response: &ChatResponse) -> (String, Vec<ParsedToolCall>) {
|
||||
let text = response.text_or_empty();
|
||||
Self::parse_xml_tool_calls(text)
|
||||
}
|
||||
|
||||
fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage {
|
||||
let mut content = String::new();
|
||||
for result in results {
|
||||
let status = if result.success { "ok" } else { "error" };
|
||||
let _ = writeln!(
|
||||
content,
|
||||
"<tool_result name=\"{}\" status=\"{}\">\n{}\n</tool_result>",
|
||||
result.name, status, result.output
|
||||
);
|
||||
}
|
||||
ConversationMessage::Chat(ChatMessage::user(format!("[Tool results]\n{content}")))
|
||||
}
|
||||
|
||||
fn prompt_instructions(&self, tools: &[Box<dyn Tool>]) -> String {
|
||||
let mut instructions = String::new();
|
||||
instructions.push_str("## Tool Use Protocol\n\n");
|
||||
instructions
|
||||
.push_str("To use a tool, wrap a JSON object in <tool_call></tool_call> tags:\n\n");
|
||||
instructions.push_str(
|
||||
"```\n<tool_call>\n{\"name\": \"tool_name\", \"arguments\": {\"param\": \"value\"}}\n</tool_call>\n```\n\n",
|
||||
);
|
||||
instructions.push_str("### Available Tools\n\n");
|
||||
|
||||
for tool in tools {
|
||||
let _ = writeln!(
|
||||
instructions,
|
||||
"- **{}**: {}\n Parameters: `{}`",
|
||||
tool.name(),
|
||||
tool.description(),
|
||||
tool.parameters_schema()
|
||||
);
|
||||
}
|
||||
|
||||
instructions
|
||||
}
|
||||
|
||||
fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec<ChatMessage> {
|
||||
history
|
||||
.iter()
|
||||
.flat_map(|msg| match msg {
|
||||
ConversationMessage::Chat(chat) => vec![chat.clone()],
|
||||
ConversationMessage::AssistantToolCalls { text, .. } => {
|
||||
vec![ChatMessage::assistant(text.clone().unwrap_or_default())]
|
||||
}
|
||||
ConversationMessage::ToolResults(results) => {
|
||||
let mut content = String::new();
|
||||
for result in results {
|
||||
let _ = writeln!(
|
||||
content,
|
||||
"<tool_result id=\"{}\">\n{}\n</tool_result>",
|
||||
result.tool_call_id, result.content
|
||||
);
|
||||
}
|
||||
vec![ChatMessage::user(format!("[Tool results]\n{content}"))]
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn should_send_tool_specs(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NativeToolDispatcher;
|
||||
|
||||
impl ToolDispatcher for NativeToolDispatcher {
|
||||
fn parse_response(&self, response: &ChatResponse) -> (String, Vec<ParsedToolCall>) {
|
||||
let text = response.text.clone().unwrap_or_default();
|
||||
let calls = response
|
||||
.tool_calls
|
||||
.iter()
|
||||
.map(|tc| ParsedToolCall {
|
||||
name: tc.name.clone(),
|
||||
arguments: serde_json::from_str(&tc.arguments)
|
||||
.unwrap_or_else(|_| Value::Object(serde_json::Map::new())),
|
||||
tool_call_id: Some(tc.id.clone()),
|
||||
})
|
||||
.collect();
|
||||
(text, calls)
|
||||
}
|
||||
|
||||
fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage {
|
||||
let messages = results
|
||||
.iter()
|
||||
.map(|result| ToolResultMessage {
|
||||
tool_call_id: result
|
||||
.tool_call_id
|
||||
.clone()
|
||||
.unwrap_or_else(|| "unknown".to_string()),
|
||||
content: result.output.clone(),
|
||||
})
|
||||
.collect();
|
||||
ConversationMessage::ToolResults(messages)
|
||||
}
|
||||
|
||||
fn prompt_instructions(&self, _tools: &[Box<dyn Tool>]) -> String {
|
||||
String::new()
|
||||
}
|
||||
|
||||
fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec<ChatMessage> {
|
||||
history
|
||||
.iter()
|
||||
.flat_map(|msg| match msg {
|
||||
ConversationMessage::Chat(chat) => vec![chat.clone()],
|
||||
ConversationMessage::AssistantToolCalls { text, tool_calls } => {
|
||||
let payload = serde_json::json!({
|
||||
"content": text,
|
||||
"tool_calls": tool_calls,
|
||||
});
|
||||
vec![ChatMessage::assistant(payload.to_string())]
|
||||
}
|
||||
ConversationMessage::ToolResults(results) => results
|
||||
.iter()
|
||||
.map(|result| {
|
||||
ChatMessage::tool(
|
||||
serde_json::json!({
|
||||
"tool_call_id": result.tool_call_id,
|
||||
"content": result.content,
|
||||
})
|
||||
.to_string(),
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn should_send_tool_specs(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn xml_dispatcher_parses_tool_calls() {
|
||||
let response = ChatResponse {
|
||||
text: Some(
|
||||
"Checking\n<tool_call>{\"name\":\"shell\",\"arguments\":{\"command\":\"ls\"}}</tool_call>"
|
||||
.into(),
|
||||
),
|
||||
tool_calls: vec![],
|
||||
};
|
||||
let dispatcher = XmlToolDispatcher;
|
||||
let (_, calls) = dispatcher.parse_response(&response);
|
||||
assert_eq!(calls.len(), 1);
|
||||
assert_eq!(calls[0].name, "shell");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn native_dispatcher_roundtrip() {
|
||||
let response = ChatResponse {
|
||||
text: Some("ok".into()),
|
||||
tool_calls: vec![crate::providers::ToolCall {
|
||||
id: "tc1".into(),
|
||||
name: "file_read".into(),
|
||||
arguments: "{\"path\":\"a.txt\"}".into(),
|
||||
}],
|
||||
};
|
||||
let dispatcher = NativeToolDispatcher;
|
||||
let (_, calls) = dispatcher.parse_response(&response);
|
||||
assert_eq!(calls.len(), 1);
|
||||
assert_eq!(calls[0].tool_call_id.as_deref(), Some("tc1"));
|
||||
|
||||
let msg = dispatcher.format_results(&[ToolExecutionResult {
|
||||
name: "file_read".into(),
|
||||
output: "hello".into(),
|
||||
success: true,
|
||||
tool_call_id: Some("tc1".into()),
|
||||
}]);
|
||||
match msg {
|
||||
ConversationMessage::ToolResults(results) => {
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].tool_call_id, "tc1");
|
||||
}
|
||||
_ => panic!("expected tool results"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn xml_format_results_contains_tool_result_tags() {
|
||||
let dispatcher = XmlToolDispatcher;
|
||||
let msg = dispatcher.format_results(&[ToolExecutionResult {
|
||||
name: "shell".into(),
|
||||
output: "ok".into(),
|
||||
success: true,
|
||||
tool_call_id: None,
|
||||
}]);
|
||||
let rendered = match msg {
|
||||
ConversationMessage::Chat(chat) => chat.content,
|
||||
_ => String::new(),
|
||||
};
|
||||
assert!(rendered.contains("<tool_result"));
|
||||
assert!(rendered.contains("shell"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn native_format_results_keeps_tool_call_id() {
|
||||
let dispatcher = NativeToolDispatcher;
|
||||
let msg = dispatcher.format_results(&[ToolExecutionResult {
|
||||
name: "shell".into(),
|
||||
output: "ok".into(),
|
||||
success: true,
|
||||
tool_call_id: Some("tc-1".into()),
|
||||
}]);
|
||||
|
||||
match msg {
|
||||
ConversationMessage::ToolResults(results) => {
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].tool_call_id, "tc-1");
|
||||
}
|
||||
_ => panic!("expected ToolResults variant"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -8,11 +8,10 @@ use crate::tools::{self, Tool};
|
|||
use crate::util::truncate_with_ellipsis;
|
||||
use anyhow::Result;
|
||||
use std::fmt::Write;
|
||||
use std::io::Write as IoWrite;
|
||||
use std::io::Write as _;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Maximum agentic tool-use iterations per user message to prevent runaway loops.
|
||||
const MAX_TOOL_ITERATIONS: usize = 10;
|
||||
|
||||
|
|
@ -113,7 +112,6 @@ async fn auto_compact_history(
|
|||
let summary_raw = provider
|
||||
.chat_with_system(Some(summarizer_system), &summarizer_user, model, 0.2)
|
||||
.await
|
||||
.map(|resp| resp.text_or_empty().to_string())
|
||||
.unwrap_or_else(|_| {
|
||||
// Fallback to deterministic local truncation when summarization fails.
|
||||
truncate_with_ellipsis(&transcript, COMPACTION_MAX_SUMMARY_CHARS)
|
||||
|
|
@ -482,21 +480,11 @@ pub(crate) async fn run_tool_call_loop(
|
|||
}
|
||||
};
|
||||
|
||||
let response_text = response.text.unwrap_or_default();
|
||||
let response_text = response;
|
||||
let mut assistant_history_content = response_text.clone();
|
||||
let mut parsed_text = response_text.clone();
|
||||
let mut tool_calls = parse_structured_tool_calls(&response.tool_calls);
|
||||
|
||||
if !response.tool_calls.is_empty() {
|
||||
assistant_history_content =
|
||||
build_assistant_history_with_tool_calls(&response_text, &response.tool_calls);
|
||||
}
|
||||
|
||||
if tool_calls.is_empty() {
|
||||
let (fallback_text, fallback_calls) = parse_tool_calls(&response_text);
|
||||
parsed_text = fallback_text;
|
||||
tool_calls = fallback_calls;
|
||||
}
|
||||
let (parsed_text, tool_calls) = parse_tool_calls(&response_text);
|
||||
let mut parsed_text = parsed_text;
|
||||
let mut tool_calls = tool_calls;
|
||||
|
||||
if tool_calls.is_empty() {
|
||||
// No tool calls — this is the final response
|
||||
|
|
|
|||
118
src/agent/memory_loader.rs
Normal file
118
src/agent/memory_loader.rs
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
use crate::memory::Memory;
|
||||
use async_trait::async_trait;
|
||||
use std::fmt::Write;
|
||||
|
||||
#[async_trait]
|
||||
pub trait MemoryLoader: Send + Sync {
|
||||
async fn load_context(&self, memory: &dyn Memory, user_message: &str)
|
||||
-> anyhow::Result<String>;
|
||||
}
|
||||
|
||||
pub struct DefaultMemoryLoader {
|
||||
limit: usize,
|
||||
}
|
||||
|
||||
impl Default for DefaultMemoryLoader {
|
||||
fn default() -> Self {
|
||||
Self { limit: 5 }
|
||||
}
|
||||
}
|
||||
|
||||
impl DefaultMemoryLoader {
|
||||
pub fn new(limit: usize) -> Self {
|
||||
Self {
|
||||
limit: limit.max(1),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl MemoryLoader for DefaultMemoryLoader {
|
||||
async fn load_context(
|
||||
&self,
|
||||
memory: &dyn Memory,
|
||||
user_message: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
let entries = memory.recall(user_message, self.limit).await?;
|
||||
if entries.is_empty() {
|
||||
return Ok(String::new());
|
||||
}
|
||||
|
||||
let mut context = String::from("[Memory context]\n");
|
||||
for entry in entries {
|
||||
let _ = writeln!(context, "- {}: {}", entry.key, entry.content);
|
||||
}
|
||||
context.push('\n');
|
||||
Ok(context)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::{Memory, MemoryCategory, MemoryEntry};
|
||||
|
||||
struct MockMemory;
|
||||
|
||||
#[async_trait]
|
||||
impl Memory for MockMemory {
|
||||
async fn store(
|
||||
&self,
|
||||
_key: &str,
|
||||
_content: &str,
|
||||
_category: MemoryCategory,
|
||||
) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recall(&self, _query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
if limit == 0 {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
Ok(vec![MemoryEntry {
|
||||
id: "1".into(),
|
||||
key: "k".into(),
|
||||
content: "v".into(),
|
||||
category: MemoryCategory::Conversation,
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
score: None,
|
||||
}])
|
||||
}
|
||||
|
||||
async fn get(&self, _key: &str) -> anyhow::Result<Option<MemoryEntry>> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn list(
|
||||
&self,
|
||||
_category: Option<&MemoryCategory>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(vec![])
|
||||
}
|
||||
|
||||
async fn forget(&self, _key: &str) -> anyhow::Result<bool> {
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
async fn count(&self) -> anyhow::Result<usize> {
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"mock"
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn default_loader_formats_context() {
|
||||
let loader = DefaultMemoryLoader::default();
|
||||
let context = loader.load_context(&MockMemory, "hello").await.unwrap();
|
||||
assert!(context.contains("[Memory context]"));
|
||||
assert!(context.contains("- k: v"));
|
||||
}
|
||||
}
|
||||
|
|
@ -1,3 +1,24 @@
|
|||
pub mod agent;
|
||||
pub mod dispatcher;
|
||||
pub mod loop_;
|
||||
pub mod memory_loader;
|
||||
pub mod prompt;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
pub use agent::{Agent, AgentBuilder};
|
||||
pub use loop_::{process_message, run};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn assert_reexport_exists<F>(_value: F) {}
|
||||
|
||||
#[test]
|
||||
fn run_function_is_reexported() {
|
||||
assert_reexport_exists(run);
|
||||
assert_reexport_exists(process_message);
|
||||
assert_reexport_exists(loop_::run);
|
||||
assert_reexport_exists(loop_::process_message);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
304
src/agent/prompt.rs
Normal file
304
src/agent/prompt.rs
Normal file
|
|
@ -0,0 +1,304 @@
|
|||
use crate::config::IdentityConfig;
|
||||
use crate::identity;
|
||||
use crate::skills::Skill;
|
||||
use crate::tools::Tool;
|
||||
use anyhow::Result;
|
||||
use chrono::Local;
|
||||
use std::fmt::Write;
|
||||
use std::path::Path;
|
||||
|
||||
const BOOTSTRAP_MAX_CHARS: usize = 20_000;
|
||||
|
||||
pub struct PromptContext<'a> {
|
||||
pub workspace_dir: &'a Path,
|
||||
pub model_name: &'a str,
|
||||
pub tools: &'a [Box<dyn Tool>],
|
||||
pub skills: &'a [Skill],
|
||||
pub identity_config: Option<&'a IdentityConfig>,
|
||||
pub dispatcher_instructions: &'a str,
|
||||
}
|
||||
|
||||
pub trait PromptSection: Send + Sync {
|
||||
fn name(&self) -> &str;
|
||||
fn build(&self, ctx: &PromptContext<'_>) -> Result<String>;
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct SystemPromptBuilder {
|
||||
sections: Vec<Box<dyn PromptSection>>,
|
||||
}
|
||||
|
||||
impl SystemPromptBuilder {
|
||||
pub fn with_defaults() -> Self {
|
||||
Self {
|
||||
sections: vec![
|
||||
Box::new(IdentitySection),
|
||||
Box::new(ToolsSection),
|
||||
Box::new(SafetySection),
|
||||
Box::new(SkillsSection),
|
||||
Box::new(WorkspaceSection),
|
||||
Box::new(DateTimeSection),
|
||||
Box::new(RuntimeSection),
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_section(mut self, section: Box<dyn PromptSection>) -> Self {
|
||||
self.sections.push(section);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(&self, ctx: &PromptContext<'_>) -> Result<String> {
|
||||
let mut output = String::new();
|
||||
for section in &self.sections {
|
||||
let part = section.build(ctx)?;
|
||||
if part.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
output.push_str(part.trim_end());
|
||||
output.push_str("\n\n");
|
||||
}
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct IdentitySection;
|
||||
pub struct ToolsSection;
|
||||
pub struct SafetySection;
|
||||
pub struct SkillsSection;
|
||||
pub struct WorkspaceSection;
|
||||
pub struct RuntimeSection;
|
||||
pub struct DateTimeSection;
|
||||
|
||||
impl PromptSection for IdentitySection {
|
||||
fn name(&self) -> &str {
|
||||
"identity"
|
||||
}
|
||||
|
||||
fn build(&self, ctx: &PromptContext<'_>) -> Result<String> {
|
||||
let mut prompt = String::from("## Project Context\n\n");
|
||||
if let Some(config) = ctx.identity_config {
|
||||
if identity::is_aieos_configured(config) {
|
||||
if let Ok(Some(aieos)) = identity::load_aieos_identity(config, ctx.workspace_dir) {
|
||||
let rendered = identity::aieos_to_system_prompt(&aieos);
|
||||
if !rendered.is_empty() {
|
||||
prompt.push_str(&rendered);
|
||||
return Ok(prompt);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prompt.push_str(
|
||||
"The following workspace files define your identity, behavior, and context.\n\n",
|
||||
);
|
||||
for file in [
|
||||
"AGENTS.md",
|
||||
"SOUL.md",
|
||||
"TOOLS.md",
|
||||
"IDENTITY.md",
|
||||
"USER.md",
|
||||
"HEARTBEAT.md",
|
||||
"BOOTSTRAP.md",
|
||||
"MEMORY.md",
|
||||
] {
|
||||
inject_workspace_file(&mut prompt, ctx.workspace_dir, file);
|
||||
}
|
||||
|
||||
Ok(prompt)
|
||||
}
|
||||
}
|
||||
|
||||
impl PromptSection for ToolsSection {
|
||||
fn name(&self) -> &str {
|
||||
"tools"
|
||||
}
|
||||
|
||||
fn build(&self, ctx: &PromptContext<'_>) -> Result<String> {
|
||||
let mut out = String::from("## Tools\n\n");
|
||||
for tool in ctx.tools {
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"- **{}**: {}\n Parameters: `{}`",
|
||||
tool.name(),
|
||||
tool.description(),
|
||||
tool.parameters_schema()
|
||||
);
|
||||
}
|
||||
if !ctx.dispatcher_instructions.is_empty() {
|
||||
out.push('\n');
|
||||
out.push_str(ctx.dispatcher_instructions);
|
||||
}
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
impl PromptSection for SafetySection {
|
||||
fn name(&self) -> &str {
|
||||
"safety"
|
||||
}
|
||||
|
||||
fn build(&self, _ctx: &PromptContext<'_>) -> Result<String> {
|
||||
Ok("## Safety\n\n- Do not exfiltrate private data.\n- Do not run destructive commands without asking.\n- Do not bypass oversight or approval mechanisms.\n- Prefer `trash` over `rm`.\n- When in doubt, ask before acting externally.".into())
|
||||
}
|
||||
}
|
||||
|
||||
impl PromptSection for SkillsSection {
|
||||
fn name(&self) -> &str {
|
||||
"skills"
|
||||
}
|
||||
|
||||
fn build(&self, ctx: &PromptContext<'_>) -> Result<String> {
|
||||
if ctx.skills.is_empty() {
|
||||
return Ok(String::new());
|
||||
}
|
||||
|
||||
let mut prompt = String::from("## Available Skills\n\n<available_skills>\n");
|
||||
for skill in ctx.skills {
|
||||
let location = skill.location.clone().unwrap_or_else(|| {
|
||||
ctx.workspace_dir
|
||||
.join("skills")
|
||||
.join(&skill.name)
|
||||
.join("SKILL.md")
|
||||
});
|
||||
let _ = writeln!(
|
||||
prompt,
|
||||
" <skill>\n <name>{}</name>\n <description>{}</description>\n <location>{}</location>\n </skill>",
|
||||
skill.name,
|
||||
skill.description,
|
||||
location.display()
|
||||
);
|
||||
}
|
||||
prompt.push_str("</available_skills>");
|
||||
Ok(prompt)
|
||||
}
|
||||
}
|
||||
|
||||
impl PromptSection for WorkspaceSection {
|
||||
fn name(&self) -> &str {
|
||||
"workspace"
|
||||
}
|
||||
|
||||
fn build(&self, ctx: &PromptContext<'_>) -> Result<String> {
|
||||
Ok(format!(
|
||||
"## Workspace\n\nWorking directory: `{}`",
|
||||
ctx.workspace_dir.display()
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl PromptSection for RuntimeSection {
|
||||
fn name(&self) -> &str {
|
||||
"runtime"
|
||||
}
|
||||
|
||||
fn build(&self, ctx: &PromptContext<'_>) -> Result<String> {
|
||||
let host =
|
||||
hostname::get().map_or_else(|_| "unknown".into(), |h| h.to_string_lossy().to_string());
|
||||
Ok(format!(
|
||||
"## Runtime\n\nHost: {host} | OS: {} | Model: {}",
|
||||
std::env::consts::OS,
|
||||
ctx.model_name
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl PromptSection for DateTimeSection {
|
||||
fn name(&self) -> &str {
|
||||
"datetime"
|
||||
}
|
||||
|
||||
fn build(&self, _ctx: &PromptContext<'_>) -> Result<String> {
|
||||
let now = Local::now();
|
||||
Ok(format!(
|
||||
"## Current Date & Time\n\nTimezone: {}",
|
||||
now.format("%Z")
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn inject_workspace_file(prompt: &mut String, workspace_dir: &Path, filename: &str) {
|
||||
let path = workspace_dir.join(filename);
|
||||
match std::fs::read_to_string(&path) {
|
||||
Ok(content) => {
|
||||
let trimmed = content.trim();
|
||||
if trimmed.is_empty() {
|
||||
return;
|
||||
}
|
||||
let _ = writeln!(prompt, "### {filename}\n");
|
||||
let truncated = if trimmed.chars().count() > BOOTSTRAP_MAX_CHARS {
|
||||
trimmed
|
||||
.char_indices()
|
||||
.nth(BOOTSTRAP_MAX_CHARS)
|
||||
.map(|(idx, _)| &trimmed[..idx])
|
||||
.unwrap_or(trimmed)
|
||||
} else {
|
||||
trimmed
|
||||
};
|
||||
prompt.push_str(truncated);
|
||||
if truncated.len() < trimmed.len() {
|
||||
let _ = writeln!(
|
||||
prompt,
|
||||
"\n\n[... truncated at {BOOTSTRAP_MAX_CHARS} chars — use `read` for full file]\n"
|
||||
);
|
||||
} else {
|
||||
prompt.push_str("\n\n");
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
let _ = writeln!(prompt, "### {filename}\n\n[File not found: {filename}]\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tools::traits::Tool;
|
||||
use async_trait::async_trait;
|
||||
|
||||
struct TestTool;
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for TestTool {
|
||||
fn name(&self) -> &str {
|
||||
"test_tool"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"tool desc"
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
serde_json::json!({"type": "object"})
|
||||
}
|
||||
|
||||
async fn execute(
|
||||
&self,
|
||||
_args: serde_json::Value,
|
||||
) -> anyhow::Result<crate::tools::ToolResult> {
|
||||
Ok(crate::tools::ToolResult {
|
||||
success: true,
|
||||
output: "ok".into(),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompt_builder_assembles_sections() {
|
||||
let tools: Vec<Box<dyn Tool>> = vec![Box::new(TestTool)];
|
||||
let ctx = PromptContext {
|
||||
workspace_dir: Path::new("/tmp"),
|
||||
model_name: "test-model",
|
||||
tools: &tools,
|
||||
skills: &[],
|
||||
identity_config: None,
|
||||
dispatcher_instructions: "instr",
|
||||
};
|
||||
let prompt = SystemPromptBuilder::with_defaults().build(&ctx).unwrap();
|
||||
assert!(prompt.contains("## Tools"));
|
||||
assert!(prompt.contains("test_tool"));
|
||||
assert!(prompt.contains("instr"));
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue