feat: add agent structure and improve tooling for provider
This commit is contained in:
parent
e2c966d31e
commit
b341fdb368
21 changed files with 2567 additions and 443 deletions
701
src/agent/agent.rs
Normal file
701
src/agent/agent.rs
Normal file
|
|
@ -0,0 +1,701 @@
|
||||||
|
use crate::agent::dispatcher::{
|
||||||
|
NativeToolDispatcher, ParsedToolCall, ToolDispatcher, ToolExecutionResult, XmlToolDispatcher,
|
||||||
|
};
|
||||||
|
use crate::agent::memory_loader::{DefaultMemoryLoader, MemoryLoader};
|
||||||
|
use crate::agent::prompt::{PromptContext, SystemPromptBuilder};
|
||||||
|
use crate::config::Config;
|
||||||
|
use crate::memory::{self, Memory, MemoryCategory};
|
||||||
|
use crate::observability::{self, Observer, ObserverEvent};
|
||||||
|
use crate::providers::{self, ChatMessage, ChatRequest, ConversationMessage, Provider};
|
||||||
|
use crate::runtime;
|
||||||
|
use crate::security::SecurityPolicy;
|
||||||
|
use crate::tools::{self, Tool, ToolSpec};
|
||||||
|
use crate::util::truncate_with_ellipsis;
|
||||||
|
use anyhow::Result;
|
||||||
|
use std::io::Write as IoWrite;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
pub struct Agent {
|
||||||
|
provider: Box<dyn Provider>,
|
||||||
|
tools: Vec<Box<dyn Tool>>,
|
||||||
|
tool_specs: Vec<ToolSpec>,
|
||||||
|
memory: Arc<dyn Memory>,
|
||||||
|
observer: Arc<dyn Observer>,
|
||||||
|
prompt_builder: SystemPromptBuilder,
|
||||||
|
tool_dispatcher: Box<dyn ToolDispatcher>,
|
||||||
|
memory_loader: Box<dyn MemoryLoader>,
|
||||||
|
config: crate::config::AgentConfig,
|
||||||
|
model_name: String,
|
||||||
|
temperature: f64,
|
||||||
|
workspace_dir: std::path::PathBuf,
|
||||||
|
identity_config: crate::config::IdentityConfig,
|
||||||
|
skills: Vec<crate::skills::Skill>,
|
||||||
|
auto_save: bool,
|
||||||
|
history: Vec<ConversationMessage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct AgentBuilder {
|
||||||
|
provider: Option<Box<dyn Provider>>,
|
||||||
|
tools: Option<Vec<Box<dyn Tool>>>,
|
||||||
|
memory: Option<Arc<dyn Memory>>,
|
||||||
|
observer: Option<Arc<dyn Observer>>,
|
||||||
|
prompt_builder: Option<SystemPromptBuilder>,
|
||||||
|
tool_dispatcher: Option<Box<dyn ToolDispatcher>>,
|
||||||
|
memory_loader: Option<Box<dyn MemoryLoader>>,
|
||||||
|
config: Option<crate::config::AgentConfig>,
|
||||||
|
model_name: Option<String>,
|
||||||
|
temperature: Option<f64>,
|
||||||
|
workspace_dir: Option<std::path::PathBuf>,
|
||||||
|
identity_config: Option<crate::config::IdentityConfig>,
|
||||||
|
skills: Option<Vec<crate::skills::Skill>>,
|
||||||
|
auto_save: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AgentBuilder {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
provider: None,
|
||||||
|
tools: None,
|
||||||
|
memory: None,
|
||||||
|
observer: None,
|
||||||
|
prompt_builder: None,
|
||||||
|
tool_dispatcher: None,
|
||||||
|
memory_loader: None,
|
||||||
|
config: None,
|
||||||
|
model_name: None,
|
||||||
|
temperature: None,
|
||||||
|
workspace_dir: None,
|
||||||
|
identity_config: None,
|
||||||
|
skills: None,
|
||||||
|
auto_save: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn provider(mut self, provider: Box<dyn Provider>) -> Self {
|
||||||
|
self.provider = Some(provider);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tools(mut self, tools: Vec<Box<dyn Tool>>) -> Self {
|
||||||
|
self.tools = Some(tools);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn memory(mut self, memory: Arc<dyn Memory>) -> Self {
|
||||||
|
self.memory = Some(memory);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn observer(mut self, observer: Arc<dyn Observer>) -> Self {
|
||||||
|
self.observer = Some(observer);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn prompt_builder(mut self, prompt_builder: SystemPromptBuilder) -> Self {
|
||||||
|
self.prompt_builder = Some(prompt_builder);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tool_dispatcher(mut self, tool_dispatcher: Box<dyn ToolDispatcher>) -> Self {
|
||||||
|
self.tool_dispatcher = Some(tool_dispatcher);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn memory_loader(mut self, memory_loader: Box<dyn MemoryLoader>) -> Self {
|
||||||
|
self.memory_loader = Some(memory_loader);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn config(mut self, config: crate::config::AgentConfig) -> Self {
|
||||||
|
self.config = Some(config);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn model_name(mut self, model_name: String) -> Self {
|
||||||
|
self.model_name = Some(model_name);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn temperature(mut self, temperature: f64) -> Self {
|
||||||
|
self.temperature = Some(temperature);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn workspace_dir(mut self, workspace_dir: std::path::PathBuf) -> Self {
|
||||||
|
self.workspace_dir = Some(workspace_dir);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn identity_config(mut self, identity_config: crate::config::IdentityConfig) -> Self {
|
||||||
|
self.identity_config = Some(identity_config);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn skills(mut self, skills: Vec<crate::skills::Skill>) -> Self {
|
||||||
|
self.skills = Some(skills);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn auto_save(mut self, auto_save: bool) -> Self {
|
||||||
|
self.auto_save = Some(auto_save);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build(self) -> Result<Agent> {
|
||||||
|
let tools = self
|
||||||
|
.tools
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("tools are required"))?;
|
||||||
|
let tool_specs = tools.iter().map(|tool| tool.spec()).collect();
|
||||||
|
|
||||||
|
Ok(Agent {
|
||||||
|
provider: self
|
||||||
|
.provider
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("provider is required"))?,
|
||||||
|
tools,
|
||||||
|
tool_specs,
|
||||||
|
memory: self
|
||||||
|
.memory
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("memory is required"))?,
|
||||||
|
observer: self
|
||||||
|
.observer
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("observer is required"))?,
|
||||||
|
prompt_builder: self
|
||||||
|
.prompt_builder
|
||||||
|
.unwrap_or_else(SystemPromptBuilder::with_defaults),
|
||||||
|
tool_dispatcher: self
|
||||||
|
.tool_dispatcher
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("tool_dispatcher is required"))?,
|
||||||
|
memory_loader: self
|
||||||
|
.memory_loader
|
||||||
|
.unwrap_or_else(|| Box::new(DefaultMemoryLoader::default())),
|
||||||
|
config: self.config.unwrap_or_default(),
|
||||||
|
model_name: self
|
||||||
|
.model_name
|
||||||
|
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into()),
|
||||||
|
temperature: self.temperature.unwrap_or(0.7),
|
||||||
|
workspace_dir: self
|
||||||
|
.workspace_dir
|
||||||
|
.unwrap_or_else(|| std::path::PathBuf::from(".")),
|
||||||
|
identity_config: self.identity_config.unwrap_or_default(),
|
||||||
|
skills: self.skills.unwrap_or_default(),
|
||||||
|
auto_save: self.auto_save.unwrap_or(false),
|
||||||
|
history: Vec::new(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Agent {
|
||||||
|
pub fn builder() -> AgentBuilder {
|
||||||
|
AgentBuilder::new()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn history(&self) -> &[ConversationMessage] {
|
||||||
|
&self.history
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_history(&mut self) {
|
||||||
|
self.history.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_config(config: &Config) -> Result<Self> {
|
||||||
|
let observer: Arc<dyn Observer> =
|
||||||
|
Arc::from(observability::create_observer(&config.observability));
|
||||||
|
let runtime: Arc<dyn runtime::RuntimeAdapter> =
|
||||||
|
Arc::from(runtime::create_runtime(&config.runtime)?);
|
||||||
|
let security = Arc::new(SecurityPolicy::from_config(
|
||||||
|
&config.autonomy,
|
||||||
|
&config.workspace_dir,
|
||||||
|
));
|
||||||
|
|
||||||
|
let memory: Arc<dyn Memory> = Arc::from(memory::create_memory(
|
||||||
|
&config.memory,
|
||||||
|
&config.workspace_dir,
|
||||||
|
config.api_key.as_deref(),
|
||||||
|
)?);
|
||||||
|
|
||||||
|
let composio_key = if config.composio.enabled {
|
||||||
|
config.composio.api_key.as_deref()
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let tools = tools::all_tools_with_runtime(
|
||||||
|
&security,
|
||||||
|
runtime,
|
||||||
|
memory.clone(),
|
||||||
|
composio_key,
|
||||||
|
&config.browser,
|
||||||
|
&config.http_request,
|
||||||
|
&config.workspace_dir,
|
||||||
|
&config.agents,
|
||||||
|
config.api_key.as_deref(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let provider_name = config.default_provider.as_deref().unwrap_or("openrouter");
|
||||||
|
|
||||||
|
let model_name = config
|
||||||
|
.default_model
|
||||||
|
.as_deref()
|
||||||
|
.unwrap_or("anthropic/claude-sonnet-4-20250514")
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
let provider: Box<dyn Provider> = providers::create_routed_provider(
|
||||||
|
provider_name,
|
||||||
|
config.api_key.as_deref(),
|
||||||
|
&config.reliability,
|
||||||
|
&config.model_routes,
|
||||||
|
&model_name,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let dispatcher_choice = config.agent.tool_dispatcher.as_str();
|
||||||
|
let tool_dispatcher: Box<dyn ToolDispatcher> = match dispatcher_choice {
|
||||||
|
"native" => Box::new(NativeToolDispatcher),
|
||||||
|
"xml" => Box::new(XmlToolDispatcher),
|
||||||
|
_ if provider.supports_native_tools() => Box::new(NativeToolDispatcher),
|
||||||
|
_ => Box::new(XmlToolDispatcher),
|
||||||
|
};
|
||||||
|
|
||||||
|
Agent::builder()
|
||||||
|
.provider(provider)
|
||||||
|
.tools(tools)
|
||||||
|
.memory(memory)
|
||||||
|
.observer(observer)
|
||||||
|
.tool_dispatcher(tool_dispatcher)
|
||||||
|
.memory_loader(Box::new(DefaultMemoryLoader::default()))
|
||||||
|
.prompt_builder(SystemPromptBuilder::with_defaults())
|
||||||
|
.config(config.agent.clone())
|
||||||
|
.model_name(model_name)
|
||||||
|
.temperature(config.default_temperature)
|
||||||
|
.workspace_dir(config.workspace_dir.clone())
|
||||||
|
.identity_config(config.identity.clone())
|
||||||
|
.skills(crate::skills::load_skills(&config.workspace_dir))
|
||||||
|
.auto_save(config.memory.auto_save)
|
||||||
|
.build()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn trim_history(&mut self) {
|
||||||
|
let max = self.config.max_history_messages;
|
||||||
|
if self.history.len() <= max {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut system_messages = Vec::new();
|
||||||
|
let mut other_messages = Vec::new();
|
||||||
|
|
||||||
|
for msg in self.history.drain(..) {
|
||||||
|
match &msg {
|
||||||
|
ConversationMessage::Chat(chat) if chat.role == "system" => {
|
||||||
|
system_messages.push(msg)
|
||||||
|
}
|
||||||
|
_ => other_messages.push(msg),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if other_messages.len() > max {
|
||||||
|
let drop_count = other_messages.len() - max;
|
||||||
|
other_messages.drain(0..drop_count);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.history = system_messages;
|
||||||
|
self.history.extend(other_messages);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_system_prompt(&self) -> Result<String> {
|
||||||
|
let instructions = self.tool_dispatcher.prompt_instructions(&self.tools);
|
||||||
|
let ctx = PromptContext {
|
||||||
|
workspace_dir: &self.workspace_dir,
|
||||||
|
model_name: &self.model_name,
|
||||||
|
tools: &self.tools,
|
||||||
|
skills: &self.skills,
|
||||||
|
identity_config: Some(&self.identity_config),
|
||||||
|
dispatcher_instructions: &instructions,
|
||||||
|
};
|
||||||
|
self.prompt_builder.build(&ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute_tool_call(&self, call: &ParsedToolCall) -> ToolExecutionResult {
|
||||||
|
let start = Instant::now();
|
||||||
|
|
||||||
|
let result = if let Some(tool) = self.tools.iter().find(|t| t.name() == call.name) {
|
||||||
|
match tool.execute(call.arguments.clone()).await {
|
||||||
|
Ok(r) => {
|
||||||
|
self.observer.record_event(&ObserverEvent::ToolCall {
|
||||||
|
tool: call.name.clone(),
|
||||||
|
duration: start.elapsed(),
|
||||||
|
success: r.success,
|
||||||
|
});
|
||||||
|
if r.success {
|
||||||
|
r.output
|
||||||
|
} else {
|
||||||
|
format!("Error: {}", r.error.unwrap_or(r.output))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
self.observer.record_event(&ObserverEvent::ToolCall {
|
||||||
|
tool: call.name.clone(),
|
||||||
|
duration: start.elapsed(),
|
||||||
|
success: false,
|
||||||
|
});
|
||||||
|
format!("Error executing {}: {e}", call.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
format!("Unknown tool: {}", call.name)
|
||||||
|
};
|
||||||
|
|
||||||
|
ToolExecutionResult {
|
||||||
|
name: call.name.clone(),
|
||||||
|
output: result,
|
||||||
|
success: true,
|
||||||
|
tool_call_id: call.tool_call_id.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute_tools(&self, calls: &[ParsedToolCall]) -> Vec<ToolExecutionResult> {
|
||||||
|
if !self.config.parallel_tools {
|
||||||
|
let mut results = Vec::with_capacity(calls.len());
|
||||||
|
for call in calls {
|
||||||
|
results.push(self.execute_tool_call(call).await);
|
||||||
|
}
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut results = Vec::with_capacity(calls.len());
|
||||||
|
for call in calls {
|
||||||
|
results.push(self.execute_tool_call(call).await);
|
||||||
|
}
|
||||||
|
results
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn turn(&mut self, user_message: &str) -> Result<String> {
|
||||||
|
if self.history.is_empty() {
|
||||||
|
let system_prompt = self.build_system_prompt()?;
|
||||||
|
self.history
|
||||||
|
.push(ConversationMessage::Chat(ChatMessage::system(
|
||||||
|
system_prompt,
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.auto_save {
|
||||||
|
let _ = self
|
||||||
|
.memory
|
||||||
|
.store("user_msg", user_message, MemoryCategory::Conversation)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let context = self
|
||||||
|
.memory_loader
|
||||||
|
.load_context(self.memory.as_ref(), user_message)
|
||||||
|
.await
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
let enriched = if context.is_empty() {
|
||||||
|
user_message.to_string()
|
||||||
|
} else {
|
||||||
|
format!("{context}{user_message}")
|
||||||
|
};
|
||||||
|
|
||||||
|
self.history
|
||||||
|
.push(ConversationMessage::Chat(ChatMessage::user(enriched)));
|
||||||
|
|
||||||
|
for _ in 0..self.config.max_tool_iterations {
|
||||||
|
let messages = self.tool_dispatcher.to_provider_messages(&self.history);
|
||||||
|
let response = match self
|
||||||
|
.provider
|
||||||
|
.chat(
|
||||||
|
ChatRequest {
|
||||||
|
messages: &messages,
|
||||||
|
tools: if self.tool_dispatcher.should_send_tool_specs() {
|
||||||
|
Some(&self.tool_specs)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
},
|
||||||
|
},
|
||||||
|
&self.model_name,
|
||||||
|
self.temperature,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(resp) => resp,
|
||||||
|
Err(err) => return Err(err),
|
||||||
|
};
|
||||||
|
|
||||||
|
let (text, calls) = self.tool_dispatcher.parse_response(&response);
|
||||||
|
if calls.is_empty() {
|
||||||
|
let final_text = if text.is_empty() {
|
||||||
|
response.text.unwrap_or_default()
|
||||||
|
} else {
|
||||||
|
text
|
||||||
|
};
|
||||||
|
|
||||||
|
self.history
|
||||||
|
.push(ConversationMessage::Chat(ChatMessage::assistant(
|
||||||
|
final_text.clone(),
|
||||||
|
)));
|
||||||
|
self.trim_history();
|
||||||
|
|
||||||
|
if self.auto_save {
|
||||||
|
let summary = truncate_with_ellipsis(&final_text, 100);
|
||||||
|
let _ = self
|
||||||
|
.memory
|
||||||
|
.store("assistant_resp", &summary, MemoryCategory::Daily)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Ok(final_text);
|
||||||
|
}
|
||||||
|
|
||||||
|
if !text.is_empty() {
|
||||||
|
self.history
|
||||||
|
.push(ConversationMessage::Chat(ChatMessage::assistant(
|
||||||
|
text.clone(),
|
||||||
|
)));
|
||||||
|
print!("{text}");
|
||||||
|
let _ = std::io::stdout().flush();
|
||||||
|
}
|
||||||
|
|
||||||
|
self.history.push(ConversationMessage::AssistantToolCalls {
|
||||||
|
text: response.text.clone(),
|
||||||
|
tool_calls: response.tool_calls.clone(),
|
||||||
|
});
|
||||||
|
|
||||||
|
let results = self.execute_tools(&calls).await;
|
||||||
|
let formatted = self.tool_dispatcher.format_results(&results);
|
||||||
|
self.history.push(formatted);
|
||||||
|
self.trim_history();
|
||||||
|
}
|
||||||
|
|
||||||
|
anyhow::bail!(
|
||||||
|
"Agent exceeded maximum tool iterations ({})",
|
||||||
|
self.config.max_tool_iterations
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_single(&mut self, message: &str) -> Result<String> {
|
||||||
|
self.turn(message).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_interactive(&mut self) -> Result<()> {
|
||||||
|
println!("🦀 ZeroClaw Interactive Mode");
|
||||||
|
println!("Type /quit to exit.\n");
|
||||||
|
|
||||||
|
let (tx, mut rx) = tokio::sync::mpsc::channel(32);
|
||||||
|
let cli = crate::channels::CliChannel::new();
|
||||||
|
|
||||||
|
let listen_handle = tokio::spawn(async move {
|
||||||
|
let _ = crate::channels::Channel::listen(&cli, tx).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
while let Some(msg) = rx.recv().await {
|
||||||
|
let response = match self.turn(&msg.content).await {
|
||||||
|
Ok(resp) => resp,
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("\nError: {e}\n");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
println!("\n{response}\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
listen_handle.abort();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run(
|
||||||
|
config: Config,
|
||||||
|
message: Option<String>,
|
||||||
|
provider_override: Option<String>,
|
||||||
|
model_override: Option<String>,
|
||||||
|
temperature: f64,
|
||||||
|
) -> Result<()> {
|
||||||
|
let start = Instant::now();
|
||||||
|
|
||||||
|
let mut effective_config = config;
|
||||||
|
if let Some(p) = provider_override {
|
||||||
|
effective_config.default_provider = Some(p);
|
||||||
|
}
|
||||||
|
if let Some(m) = model_override {
|
||||||
|
effective_config.default_model = Some(m);
|
||||||
|
}
|
||||||
|
effective_config.default_temperature = temperature;
|
||||||
|
|
||||||
|
let mut agent = Agent::from_config(&effective_config)?;
|
||||||
|
|
||||||
|
let provider_name = effective_config
|
||||||
|
.default_provider
|
||||||
|
.as_deref()
|
||||||
|
.unwrap_or("openrouter")
|
||||||
|
.to_string();
|
||||||
|
let model_name = effective_config
|
||||||
|
.default_model
|
||||||
|
.as_deref()
|
||||||
|
.unwrap_or("anthropic/claude-sonnet-4-20250514")
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
agent.observer.record_event(&ObserverEvent::AgentStart {
|
||||||
|
provider: provider_name,
|
||||||
|
model: model_name,
|
||||||
|
});
|
||||||
|
|
||||||
|
if let Some(msg) = message {
|
||||||
|
let response = agent.run_single(&msg).await?;
|
||||||
|
println!("{response}");
|
||||||
|
} else {
|
||||||
|
agent.run_interactive().await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
agent.observer.record_event(&ObserverEvent::AgentEnd {
|
||||||
|
duration: start.elapsed(),
|
||||||
|
tokens_used: None,
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use std::sync::Mutex;
|
||||||
|
|
||||||
|
struct MockProvider {
|
||||||
|
responses: Mutex<Vec<crate::providers::ChatResponse>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Provider for MockProvider {
|
||||||
|
async fn chat_with_system(
|
||||||
|
&self,
|
||||||
|
_system_prompt: Option<&str>,
|
||||||
|
_message: &str,
|
||||||
|
_model: &str,
|
||||||
|
_temperature: f64,
|
||||||
|
) -> Result<String> {
|
||||||
|
Ok("ok".into())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn chat(
|
||||||
|
&self,
|
||||||
|
_request: ChatRequest<'_>,
|
||||||
|
_model: &str,
|
||||||
|
_temperature: f64,
|
||||||
|
) -> Result<crate::providers::ChatResponse> {
|
||||||
|
let mut guard = self.responses.lock().unwrap();
|
||||||
|
if guard.is_empty() {
|
||||||
|
return Ok(crate::providers::ChatResponse {
|
||||||
|
text: Some("done".into()),
|
||||||
|
tool_calls: vec![],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Ok(guard.remove(0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct MockTool;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for MockTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"echo"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"echo"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
|
serde_json::json!({"type": "object"})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, _args: serde_json::Value) -> Result<crate::tools::ToolResult> {
|
||||||
|
Ok(crate::tools::ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: "tool-out".into(),
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn turn_without_tools_returns_text() {
|
||||||
|
let provider = Box::new(MockProvider {
|
||||||
|
responses: Mutex::new(vec![crate::providers::ChatResponse {
|
||||||
|
text: Some("hello".into()),
|
||||||
|
tool_calls: vec![],
|
||||||
|
}]),
|
||||||
|
});
|
||||||
|
|
||||||
|
let memory_cfg = crate::config::MemoryConfig {
|
||||||
|
backend: "none".into(),
|
||||||
|
..crate::config::MemoryConfig::default()
|
||||||
|
};
|
||||||
|
let mem: Arc<dyn Memory> = Arc::from(
|
||||||
|
crate::memory::create_memory(&memory_cfg, std::path::Path::new("/tmp"), None).unwrap(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let observer: Arc<dyn Observer> = Arc::from(crate::observability::NoopObserver {});
|
||||||
|
let mut agent = Agent::builder()
|
||||||
|
.provider(provider)
|
||||||
|
.tools(vec![Box::new(MockTool)])
|
||||||
|
.memory(mem)
|
||||||
|
.observer(observer)
|
||||||
|
.tool_dispatcher(Box::new(XmlToolDispatcher))
|
||||||
|
.workspace_dir(std::path::PathBuf::from("/tmp"))
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let response = agent.turn("hi").await.unwrap();
|
||||||
|
assert_eq!(response, "hello");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn turn_with_native_dispatcher_handles_tool_results_variant() {
|
||||||
|
let provider = Box::new(MockProvider {
|
||||||
|
responses: Mutex::new(vec![
|
||||||
|
crate::providers::ChatResponse {
|
||||||
|
text: Some("".into()),
|
||||||
|
tool_calls: vec![crate::providers::ToolCall {
|
||||||
|
id: "tc1".into(),
|
||||||
|
name: "echo".into(),
|
||||||
|
arguments: "{}".into(),
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
crate::providers::ChatResponse {
|
||||||
|
text: Some("done".into()),
|
||||||
|
tool_calls: vec![],
|
||||||
|
},
|
||||||
|
]),
|
||||||
|
});
|
||||||
|
|
||||||
|
let memory_cfg = crate::config::MemoryConfig {
|
||||||
|
backend: "none".into(),
|
||||||
|
..crate::config::MemoryConfig::default()
|
||||||
|
};
|
||||||
|
let mem: Arc<dyn Memory> = Arc::from(
|
||||||
|
crate::memory::create_memory(&memory_cfg, std::path::Path::new("/tmp"), None).unwrap(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let observer: Arc<dyn Observer> = Arc::from(crate::observability::NoopObserver {});
|
||||||
|
let mut agent = Agent::builder()
|
||||||
|
.provider(provider)
|
||||||
|
.tools(vec![Box::new(MockTool)])
|
||||||
|
.memory(mem)
|
||||||
|
.observer(observer)
|
||||||
|
.tool_dispatcher(Box::new(NativeToolDispatcher))
|
||||||
|
.workspace_dir(std::path::PathBuf::from("/tmp"))
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let response = agent.turn("hi").await.unwrap();
|
||||||
|
assert_eq!(response, "done");
|
||||||
|
assert!(matches!(
|
||||||
|
agent
|
||||||
|
.history()
|
||||||
|
.iter()
|
||||||
|
.find(|msg| matches!(msg, ConversationMessage::ToolResults(_))),
|
||||||
|
Some(_)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
312
src/agent/dispatcher.rs
Normal file
312
src/agent/dispatcher.rs
Normal file
|
|
@ -0,0 +1,312 @@
|
||||||
|
use crate::providers::{ChatMessage, ChatResponse, ConversationMessage, ToolResultMessage};
|
||||||
|
use crate::tools::{Tool, ToolSpec};
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::fmt::Write;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ParsedToolCall {
|
||||||
|
pub name: String,
|
||||||
|
pub arguments: Value,
|
||||||
|
pub tool_call_id: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ToolExecutionResult {
|
||||||
|
pub name: String,
|
||||||
|
pub output: String,
|
||||||
|
pub success: bool,
|
||||||
|
pub tool_call_id: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait ToolDispatcher: Send + Sync {
|
||||||
|
fn parse_response(&self, response: &ChatResponse) -> (String, Vec<ParsedToolCall>);
|
||||||
|
fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage;
|
||||||
|
fn prompt_instructions(&self, tools: &[Box<dyn Tool>]) -> String;
|
||||||
|
fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec<ChatMessage>;
|
||||||
|
fn should_send_tool_specs(&self) -> bool;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct XmlToolDispatcher;
|
||||||
|
|
||||||
|
impl XmlToolDispatcher {
|
||||||
|
fn parse_xml_tool_calls(response: &str) -> (String, Vec<ParsedToolCall>) {
|
||||||
|
let mut text_parts = Vec::new();
|
||||||
|
let mut calls = Vec::new();
|
||||||
|
let mut remaining = response;
|
||||||
|
|
||||||
|
while let Some(start) = remaining.find("<tool_call>") {
|
||||||
|
let before = &remaining[..start];
|
||||||
|
if !before.trim().is_empty() {
|
||||||
|
text_parts.push(before.trim().to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(end) = remaining[start..].find("</tool_call>") {
|
||||||
|
let inner = &remaining[start + 11..start + end];
|
||||||
|
match serde_json::from_str::<Value>(inner.trim()) {
|
||||||
|
Ok(parsed) => {
|
||||||
|
let name = parsed
|
||||||
|
.get("name")
|
||||||
|
.and_then(Value::as_str)
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string();
|
||||||
|
if name.is_empty() {
|
||||||
|
remaining = &remaining[start + end + 12..];
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let arguments = parsed
|
||||||
|
.get("arguments")
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or_else(|| Value::Object(serde_json::Map::new()));
|
||||||
|
calls.push(ParsedToolCall {
|
||||||
|
name,
|
||||||
|
arguments,
|
||||||
|
tool_call_id: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!("Malformed <tool_call> JSON: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
remaining = &remaining[start + end + 12..];
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !remaining.trim().is_empty() {
|
||||||
|
text_parts.push(remaining.trim().to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
(text_parts.join("\n"), calls)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tool_specs(tools: &[Box<dyn Tool>]) -> Vec<ToolSpec> {
|
||||||
|
tools.iter().map(|tool| tool.spec()).collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolDispatcher for XmlToolDispatcher {
|
||||||
|
fn parse_response(&self, response: &ChatResponse) -> (String, Vec<ParsedToolCall>) {
|
||||||
|
let text = response.text_or_empty();
|
||||||
|
Self::parse_xml_tool_calls(text)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage {
|
||||||
|
let mut content = String::new();
|
||||||
|
for result in results {
|
||||||
|
let status = if result.success { "ok" } else { "error" };
|
||||||
|
let _ = writeln!(
|
||||||
|
content,
|
||||||
|
"<tool_result name=\"{}\" status=\"{}\">\n{}\n</tool_result>",
|
||||||
|
result.name, status, result.output
|
||||||
|
);
|
||||||
|
}
|
||||||
|
ConversationMessage::Chat(ChatMessage::user(format!("[Tool results]\n{content}")))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn prompt_instructions(&self, tools: &[Box<dyn Tool>]) -> String {
|
||||||
|
let mut instructions = String::new();
|
||||||
|
instructions.push_str("## Tool Use Protocol\n\n");
|
||||||
|
instructions
|
||||||
|
.push_str("To use a tool, wrap a JSON object in <tool_call></tool_call> tags:\n\n");
|
||||||
|
instructions.push_str(
|
||||||
|
"```\n<tool_call>\n{\"name\": \"tool_name\", \"arguments\": {\"param\": \"value\"}}\n</tool_call>\n```\n\n",
|
||||||
|
);
|
||||||
|
instructions.push_str("### Available Tools\n\n");
|
||||||
|
|
||||||
|
for tool in tools {
|
||||||
|
let _ = writeln!(
|
||||||
|
instructions,
|
||||||
|
"- **{}**: {}\n Parameters: `{}`",
|
||||||
|
tool.name(),
|
||||||
|
tool.description(),
|
||||||
|
tool.parameters_schema()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
instructions
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec<ChatMessage> {
|
||||||
|
history
|
||||||
|
.iter()
|
||||||
|
.flat_map(|msg| match msg {
|
||||||
|
ConversationMessage::Chat(chat) => vec![chat.clone()],
|
||||||
|
ConversationMessage::AssistantToolCalls { text, .. } => {
|
||||||
|
vec![ChatMessage::assistant(text.clone().unwrap_or_default())]
|
||||||
|
}
|
||||||
|
ConversationMessage::ToolResults(results) => {
|
||||||
|
let mut content = String::new();
|
||||||
|
for result in results {
|
||||||
|
let _ = writeln!(
|
||||||
|
content,
|
||||||
|
"<tool_result id=\"{}\">\n{}\n</tool_result>",
|
||||||
|
result.tool_call_id, result.content
|
||||||
|
);
|
||||||
|
}
|
||||||
|
vec![ChatMessage::user(format!("[Tool results]\n{content}"))]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn should_send_tool_specs(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct NativeToolDispatcher;
|
||||||
|
|
||||||
|
impl ToolDispatcher for NativeToolDispatcher {
|
||||||
|
fn parse_response(&self, response: &ChatResponse) -> (String, Vec<ParsedToolCall>) {
|
||||||
|
let text = response.text.clone().unwrap_or_default();
|
||||||
|
let calls = response
|
||||||
|
.tool_calls
|
||||||
|
.iter()
|
||||||
|
.map(|tc| ParsedToolCall {
|
||||||
|
name: tc.name.clone(),
|
||||||
|
arguments: serde_json::from_str(&tc.arguments)
|
||||||
|
.unwrap_or_else(|_| Value::Object(serde_json::Map::new())),
|
||||||
|
tool_call_id: Some(tc.id.clone()),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
(text, calls)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage {
|
||||||
|
let messages = results
|
||||||
|
.iter()
|
||||||
|
.map(|result| ToolResultMessage {
|
||||||
|
tool_call_id: result
|
||||||
|
.tool_call_id
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| "unknown".to_string()),
|
||||||
|
content: result.output.clone(),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
ConversationMessage::ToolResults(messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn prompt_instructions(&self, _tools: &[Box<dyn Tool>]) -> String {
|
||||||
|
String::new()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec<ChatMessage> {
|
||||||
|
history
|
||||||
|
.iter()
|
||||||
|
.flat_map(|msg| match msg {
|
||||||
|
ConversationMessage::Chat(chat) => vec![chat.clone()],
|
||||||
|
ConversationMessage::AssistantToolCalls { text, tool_calls } => {
|
||||||
|
let payload = serde_json::json!({
|
||||||
|
"content": text,
|
||||||
|
"tool_calls": tool_calls,
|
||||||
|
});
|
||||||
|
vec![ChatMessage::assistant(payload.to_string())]
|
||||||
|
}
|
||||||
|
ConversationMessage::ToolResults(results) => results
|
||||||
|
.iter()
|
||||||
|
.map(|result| {
|
||||||
|
ChatMessage::tool(
|
||||||
|
serde_json::json!({
|
||||||
|
"tool_call_id": result.tool_call_id,
|
||||||
|
"content": result.content,
|
||||||
|
})
|
||||||
|
.to_string(),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn should_send_tool_specs(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn xml_dispatcher_parses_tool_calls() {
|
||||||
|
let response = ChatResponse {
|
||||||
|
text: Some(
|
||||||
|
"Checking\n<tool_call>{\"name\":\"shell\",\"arguments\":{\"command\":\"ls\"}}</tool_call>"
|
||||||
|
.into(),
|
||||||
|
),
|
||||||
|
tool_calls: vec![],
|
||||||
|
};
|
||||||
|
let dispatcher = XmlToolDispatcher;
|
||||||
|
let (_, calls) = dispatcher.parse_response(&response);
|
||||||
|
assert_eq!(calls.len(), 1);
|
||||||
|
assert_eq!(calls[0].name, "shell");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn native_dispatcher_roundtrip() {
|
||||||
|
let response = ChatResponse {
|
||||||
|
text: Some("ok".into()),
|
||||||
|
tool_calls: vec![crate::providers::ToolCall {
|
||||||
|
id: "tc1".into(),
|
||||||
|
name: "file_read".into(),
|
||||||
|
arguments: "{\"path\":\"a.txt\"}".into(),
|
||||||
|
}],
|
||||||
|
};
|
||||||
|
let dispatcher = NativeToolDispatcher;
|
||||||
|
let (_, calls) = dispatcher.parse_response(&response);
|
||||||
|
assert_eq!(calls.len(), 1);
|
||||||
|
assert_eq!(calls[0].tool_call_id.as_deref(), Some("tc1"));
|
||||||
|
|
||||||
|
let msg = dispatcher.format_results(&[ToolExecutionResult {
|
||||||
|
name: "file_read".into(),
|
||||||
|
output: "hello".into(),
|
||||||
|
success: true,
|
||||||
|
tool_call_id: Some("tc1".into()),
|
||||||
|
}]);
|
||||||
|
match msg {
|
||||||
|
ConversationMessage::ToolResults(results) => {
|
||||||
|
assert_eq!(results.len(), 1);
|
||||||
|
assert_eq!(results[0].tool_call_id, "tc1");
|
||||||
|
}
|
||||||
|
_ => panic!("expected tool results"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn xml_format_results_contains_tool_result_tags() {
|
||||||
|
let dispatcher = XmlToolDispatcher;
|
||||||
|
let msg = dispatcher.format_results(&[ToolExecutionResult {
|
||||||
|
name: "shell".into(),
|
||||||
|
output: "ok".into(),
|
||||||
|
success: true,
|
||||||
|
tool_call_id: None,
|
||||||
|
}]);
|
||||||
|
let rendered = match msg {
|
||||||
|
ConversationMessage::Chat(chat) => chat.content,
|
||||||
|
_ => String::new(),
|
||||||
|
};
|
||||||
|
assert!(rendered.contains("<tool_result"));
|
||||||
|
assert!(rendered.contains("shell"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn native_format_results_keeps_tool_call_id() {
|
||||||
|
let dispatcher = NativeToolDispatcher;
|
||||||
|
let msg = dispatcher.format_results(&[ToolExecutionResult {
|
||||||
|
name: "shell".into(),
|
||||||
|
output: "ok".into(),
|
||||||
|
success: true,
|
||||||
|
tool_call_id: Some("tc-1".into()),
|
||||||
|
}]);
|
||||||
|
|
||||||
|
match msg {
|
||||||
|
ConversationMessage::ToolResults(results) => {
|
||||||
|
assert_eq!(results.len(), 1);
|
||||||
|
assert_eq!(results[0].tool_call_id, "tc-1");
|
||||||
|
}
|
||||||
|
_ => panic!("expected ToolResults variant"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -8,11 +8,10 @@ use crate::tools::{self, Tool};
|
||||||
use crate::util::truncate_with_ellipsis;
|
use crate::util::truncate_with_ellipsis;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use std::fmt::Write;
|
use std::fmt::Write;
|
||||||
use std::io::Write as IoWrite;
|
use std::io::Write as _;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
/// Maximum agentic tool-use iterations per user message to prevent runaway loops.
|
/// Maximum agentic tool-use iterations per user message to prevent runaway loops.
|
||||||
const MAX_TOOL_ITERATIONS: usize = 10;
|
const MAX_TOOL_ITERATIONS: usize = 10;
|
||||||
|
|
||||||
|
|
@ -113,7 +112,6 @@ async fn auto_compact_history(
|
||||||
let summary_raw = provider
|
let summary_raw = provider
|
||||||
.chat_with_system(Some(summarizer_system), &summarizer_user, model, 0.2)
|
.chat_with_system(Some(summarizer_system), &summarizer_user, model, 0.2)
|
||||||
.await
|
.await
|
||||||
.map(|resp| resp.text_or_empty().to_string())
|
|
||||||
.unwrap_or_else(|_| {
|
.unwrap_or_else(|_| {
|
||||||
// Fallback to deterministic local truncation when summarization fails.
|
// Fallback to deterministic local truncation when summarization fails.
|
||||||
truncate_with_ellipsis(&transcript, COMPACTION_MAX_SUMMARY_CHARS)
|
truncate_with_ellipsis(&transcript, COMPACTION_MAX_SUMMARY_CHARS)
|
||||||
|
|
@ -482,21 +480,11 @@ pub(crate) async fn run_tool_call_loop(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let response_text = response.text.unwrap_or_default();
|
let response_text = response;
|
||||||
let mut assistant_history_content = response_text.clone();
|
let mut assistant_history_content = response_text.clone();
|
||||||
let mut parsed_text = response_text.clone();
|
let (parsed_text, tool_calls) = parse_tool_calls(&response_text);
|
||||||
let mut tool_calls = parse_structured_tool_calls(&response.tool_calls);
|
let mut parsed_text = parsed_text;
|
||||||
|
let mut tool_calls = tool_calls;
|
||||||
if !response.tool_calls.is_empty() {
|
|
||||||
assistant_history_content =
|
|
||||||
build_assistant_history_with_tool_calls(&response_text, &response.tool_calls);
|
|
||||||
}
|
|
||||||
|
|
||||||
if tool_calls.is_empty() {
|
|
||||||
let (fallback_text, fallback_calls) = parse_tool_calls(&response_text);
|
|
||||||
parsed_text = fallback_text;
|
|
||||||
tool_calls = fallback_calls;
|
|
||||||
}
|
|
||||||
|
|
||||||
if tool_calls.is_empty() {
|
if tool_calls.is_empty() {
|
||||||
// No tool calls — this is the final response
|
// No tool calls — this is the final response
|
||||||
|
|
|
||||||
118
src/agent/memory_loader.rs
Normal file
118
src/agent/memory_loader.rs
Normal file
|
|
@ -0,0 +1,118 @@
|
||||||
|
use crate::memory::Memory;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use std::fmt::Write;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait MemoryLoader: Send + Sync {
|
||||||
|
async fn load_context(&self, memory: &dyn Memory, user_message: &str)
|
||||||
|
-> anyhow::Result<String>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct DefaultMemoryLoader {
|
||||||
|
limit: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for DefaultMemoryLoader {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self { limit: 5 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DefaultMemoryLoader {
|
||||||
|
pub fn new(limit: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
limit: limit.max(1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl MemoryLoader for DefaultMemoryLoader {
|
||||||
|
async fn load_context(
|
||||||
|
&self,
|
||||||
|
memory: &dyn Memory,
|
||||||
|
user_message: &str,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
let entries = memory.recall(user_message, self.limit).await?;
|
||||||
|
if entries.is_empty() {
|
||||||
|
return Ok(String::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut context = String::from("[Memory context]\n");
|
||||||
|
for entry in entries {
|
||||||
|
let _ = writeln!(context, "- {}: {}", entry.key, entry.content);
|
||||||
|
}
|
||||||
|
context.push('\n');
|
||||||
|
Ok(context)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::memory::{Memory, MemoryCategory, MemoryEntry};
|
||||||
|
|
||||||
|
struct MockMemory;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Memory for MockMemory {
|
||||||
|
async fn store(
|
||||||
|
&self,
|
||||||
|
_key: &str,
|
||||||
|
_content: &str,
|
||||||
|
_category: MemoryCategory,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn recall(&self, _query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
|
if limit == 0 {
|
||||||
|
return Ok(vec![]);
|
||||||
|
}
|
||||||
|
Ok(vec![MemoryEntry {
|
||||||
|
id: "1".into(),
|
||||||
|
key: "k".into(),
|
||||||
|
content: "v".into(),
|
||||||
|
category: MemoryCategory::Conversation,
|
||||||
|
timestamp: "now".into(),
|
||||||
|
session_id: None,
|
||||||
|
score: None,
|
||||||
|
}])
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get(&self, _key: &str) -> anyhow::Result<Option<MemoryEntry>> {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list(
|
||||||
|
&self,
|
||||||
|
_category: Option<&MemoryCategory>,
|
||||||
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
|
Ok(vec![])
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn forget(&self, _key: &str) -> anyhow::Result<bool> {
|
||||||
|
Ok(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn count(&self) -> anyhow::Result<usize> {
|
||||||
|
Ok(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health_check(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"mock"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn default_loader_formats_context() {
|
||||||
|
let loader = DefaultMemoryLoader::default();
|
||||||
|
let context = loader.load_context(&MockMemory, "hello").await.unwrap();
|
||||||
|
assert!(context.contains("[Memory context]"));
|
||||||
|
assert!(context.contains("- k: v"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,3 +1,24 @@
|
||||||
|
pub mod agent;
|
||||||
|
pub mod dispatcher;
|
||||||
pub mod loop_;
|
pub mod loop_;
|
||||||
|
pub mod memory_loader;
|
||||||
|
pub mod prompt;
|
||||||
|
|
||||||
|
#[allow(unused_imports)]
|
||||||
|
pub use agent::{Agent, AgentBuilder};
|
||||||
pub use loop_::{process_message, run};
|
pub use loop_::{process_message, run};
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn assert_reexport_exists<F>(_value: F) {}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn run_function_is_reexported() {
|
||||||
|
assert_reexport_exists(run);
|
||||||
|
assert_reexport_exists(process_message);
|
||||||
|
assert_reexport_exists(loop_::run);
|
||||||
|
assert_reexport_exists(loop_::process_message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
304
src/agent/prompt.rs
Normal file
304
src/agent/prompt.rs
Normal file
|
|
@ -0,0 +1,304 @@
|
||||||
|
use crate::config::IdentityConfig;
|
||||||
|
use crate::identity;
|
||||||
|
use crate::skills::Skill;
|
||||||
|
use crate::tools::Tool;
|
||||||
|
use anyhow::Result;
|
||||||
|
use chrono::Local;
|
||||||
|
use std::fmt::Write;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
const BOOTSTRAP_MAX_CHARS: usize = 20_000;
|
||||||
|
|
||||||
|
pub struct PromptContext<'a> {
|
||||||
|
pub workspace_dir: &'a Path,
|
||||||
|
pub model_name: &'a str,
|
||||||
|
pub tools: &'a [Box<dyn Tool>],
|
||||||
|
pub skills: &'a [Skill],
|
||||||
|
pub identity_config: Option<&'a IdentityConfig>,
|
||||||
|
pub dispatcher_instructions: &'a str,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait PromptSection: Send + Sync {
|
||||||
|
fn name(&self) -> &str;
|
||||||
|
fn build(&self, ctx: &PromptContext<'_>) -> Result<String>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct SystemPromptBuilder {
|
||||||
|
sections: Vec<Box<dyn PromptSection>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SystemPromptBuilder {
|
||||||
|
pub fn with_defaults() -> Self {
|
||||||
|
Self {
|
||||||
|
sections: vec![
|
||||||
|
Box::new(IdentitySection),
|
||||||
|
Box::new(ToolsSection),
|
||||||
|
Box::new(SafetySection),
|
||||||
|
Box::new(SkillsSection),
|
||||||
|
Box::new(WorkspaceSection),
|
||||||
|
Box::new(DateTimeSection),
|
||||||
|
Box::new(RuntimeSection),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn add_section(mut self, section: Box<dyn PromptSection>) -> Self {
|
||||||
|
self.sections.push(section);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build(&self, ctx: &PromptContext<'_>) -> Result<String> {
|
||||||
|
let mut output = String::new();
|
||||||
|
for section in &self.sections {
|
||||||
|
let part = section.build(ctx)?;
|
||||||
|
if part.trim().is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
output.push_str(part.trim_end());
|
||||||
|
output.push_str("\n\n");
|
||||||
|
}
|
||||||
|
Ok(output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct IdentitySection;
|
||||||
|
pub struct ToolsSection;
|
||||||
|
pub struct SafetySection;
|
||||||
|
pub struct SkillsSection;
|
||||||
|
pub struct WorkspaceSection;
|
||||||
|
pub struct RuntimeSection;
|
||||||
|
pub struct DateTimeSection;
|
||||||
|
|
||||||
|
impl PromptSection for IdentitySection {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"identity"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build(&self, ctx: &PromptContext<'_>) -> Result<String> {
|
||||||
|
let mut prompt = String::from("## Project Context\n\n");
|
||||||
|
if let Some(config) = ctx.identity_config {
|
||||||
|
if identity::is_aieos_configured(config) {
|
||||||
|
if let Ok(Some(aieos)) = identity::load_aieos_identity(config, ctx.workspace_dir) {
|
||||||
|
let rendered = identity::aieos_to_system_prompt(&aieos);
|
||||||
|
if !rendered.is_empty() {
|
||||||
|
prompt.push_str(&rendered);
|
||||||
|
return Ok(prompt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt.push_str(
|
||||||
|
"The following workspace files define your identity, behavior, and context.\n\n",
|
||||||
|
);
|
||||||
|
for file in [
|
||||||
|
"AGENTS.md",
|
||||||
|
"SOUL.md",
|
||||||
|
"TOOLS.md",
|
||||||
|
"IDENTITY.md",
|
||||||
|
"USER.md",
|
||||||
|
"HEARTBEAT.md",
|
||||||
|
"BOOTSTRAP.md",
|
||||||
|
"MEMORY.md",
|
||||||
|
] {
|
||||||
|
inject_workspace_file(&mut prompt, ctx.workspace_dir, file);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(prompt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PromptSection for ToolsSection {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"tools"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build(&self, ctx: &PromptContext<'_>) -> Result<String> {
|
||||||
|
let mut out = String::from("## Tools\n\n");
|
||||||
|
for tool in ctx.tools {
|
||||||
|
let _ = writeln!(
|
||||||
|
out,
|
||||||
|
"- **{}**: {}\n Parameters: `{}`",
|
||||||
|
tool.name(),
|
||||||
|
tool.description(),
|
||||||
|
tool.parameters_schema()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if !ctx.dispatcher_instructions.is_empty() {
|
||||||
|
out.push('\n');
|
||||||
|
out.push_str(ctx.dispatcher_instructions);
|
||||||
|
}
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PromptSection for SafetySection {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"safety"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build(&self, _ctx: &PromptContext<'_>) -> Result<String> {
|
||||||
|
Ok("## Safety\n\n- Do not exfiltrate private data.\n- Do not run destructive commands without asking.\n- Do not bypass oversight or approval mechanisms.\n- Prefer `trash` over `rm`.\n- When in doubt, ask before acting externally.".into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PromptSection for SkillsSection {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"skills"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build(&self, ctx: &PromptContext<'_>) -> Result<String> {
|
||||||
|
if ctx.skills.is_empty() {
|
||||||
|
return Ok(String::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut prompt = String::from("## Available Skills\n\n<available_skills>\n");
|
||||||
|
for skill in ctx.skills {
|
||||||
|
let location = skill.location.clone().unwrap_or_else(|| {
|
||||||
|
ctx.workspace_dir
|
||||||
|
.join("skills")
|
||||||
|
.join(&skill.name)
|
||||||
|
.join("SKILL.md")
|
||||||
|
});
|
||||||
|
let _ = writeln!(
|
||||||
|
prompt,
|
||||||
|
" <skill>\n <name>{}</name>\n <description>{}</description>\n <location>{}</location>\n </skill>",
|
||||||
|
skill.name,
|
||||||
|
skill.description,
|
||||||
|
location.display()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
prompt.push_str("</available_skills>");
|
||||||
|
Ok(prompt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PromptSection for WorkspaceSection {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"workspace"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build(&self, ctx: &PromptContext<'_>) -> Result<String> {
|
||||||
|
Ok(format!(
|
||||||
|
"## Workspace\n\nWorking directory: `{}`",
|
||||||
|
ctx.workspace_dir.display()
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PromptSection for RuntimeSection {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"runtime"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build(&self, ctx: &PromptContext<'_>) -> Result<String> {
|
||||||
|
let host =
|
||||||
|
hostname::get().map_or_else(|_| "unknown".into(), |h| h.to_string_lossy().to_string());
|
||||||
|
Ok(format!(
|
||||||
|
"## Runtime\n\nHost: {host} | OS: {} | Model: {}",
|
||||||
|
std::env::consts::OS,
|
||||||
|
ctx.model_name
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PromptSection for DateTimeSection {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"datetime"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build(&self, _ctx: &PromptContext<'_>) -> Result<String> {
|
||||||
|
let now = Local::now();
|
||||||
|
Ok(format!(
|
||||||
|
"## Current Date & Time\n\nTimezone: {}",
|
||||||
|
now.format("%Z")
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn inject_workspace_file(prompt: &mut String, workspace_dir: &Path, filename: &str) {
|
||||||
|
let path = workspace_dir.join(filename);
|
||||||
|
match std::fs::read_to_string(&path) {
|
||||||
|
Ok(content) => {
|
||||||
|
let trimmed = content.trim();
|
||||||
|
if trimmed.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let _ = writeln!(prompt, "### {filename}\n");
|
||||||
|
let truncated = if trimmed.chars().count() > BOOTSTRAP_MAX_CHARS {
|
||||||
|
trimmed
|
||||||
|
.char_indices()
|
||||||
|
.nth(BOOTSTRAP_MAX_CHARS)
|
||||||
|
.map(|(idx, _)| &trimmed[..idx])
|
||||||
|
.unwrap_or(trimmed)
|
||||||
|
} else {
|
||||||
|
trimmed
|
||||||
|
};
|
||||||
|
prompt.push_str(truncated);
|
||||||
|
if truncated.len() < trimmed.len() {
|
||||||
|
let _ = writeln!(
|
||||||
|
prompt,
|
||||||
|
"\n\n[... truncated at {BOOTSTRAP_MAX_CHARS} chars — use `read` for full file]\n"
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
prompt.push_str("\n\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
let _ = writeln!(prompt, "### {filename}\n\n[File not found: {filename}]\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::tools::traits::Tool;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
struct TestTool;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for TestTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"test_tool"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"tool desc"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
|
serde_json::json!({"type": "object"})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(
|
||||||
|
&self,
|
||||||
|
_args: serde_json::Value,
|
||||||
|
) -> anyhow::Result<crate::tools::ToolResult> {
|
||||||
|
Ok(crate::tools::ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: "ok".into(),
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn prompt_builder_assembles_sections() {
|
||||||
|
let tools: Vec<Box<dyn Tool>> = vec![Box::new(TestTool)];
|
||||||
|
let ctx = PromptContext {
|
||||||
|
workspace_dir: Path::new("/tmp"),
|
||||||
|
model_name: "test-model",
|
||||||
|
tools: &tools,
|
||||||
|
skills: &[],
|
||||||
|
identity_config: None,
|
||||||
|
dispatcher_instructions: "instr",
|
||||||
|
};
|
||||||
|
let prompt = SystemPromptBuilder::with_defaults().build(&ctx).unwrap();
|
||||||
|
assert!(prompt.contains("## Tools"));
|
||||||
|
assert!(prompt.contains("test_tool"));
|
||||||
|
assert!(prompt.contains("instr"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -765,18 +765,16 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||||
&config.autonomy,
|
&config.autonomy,
|
||||||
&config.workspace_dir,
|
&config.workspace_dir,
|
||||||
));
|
));
|
||||||
|
|
||||||
let model = config
|
let model = config
|
||||||
.default_model
|
.default_model
|
||||||
.clone()
|
.clone()
|
||||||
.unwrap_or_else(|| "anthropic/claude-sonnet-4".into());
|
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into());
|
||||||
let temperature = config.default_temperature;
|
let temperature = config.default_temperature;
|
||||||
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory(
|
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory(
|
||||||
&config.memory,
|
&config.memory,
|
||||||
&config.workspace_dir,
|
&config.workspace_dir,
|
||||||
config.api_key.as_deref(),
|
config.api_key.as_deref(),
|
||||||
)?);
|
)?);
|
||||||
|
|
||||||
let (composio_key, composio_entity_id) = if config.composio.enabled {
|
let (composio_key, composio_entity_id) = if config.composio.enabled {
|
||||||
(
|
(
|
||||||
config.composio.api_key.as_deref(),
|
config.composio.api_key.as_deref(),
|
||||||
|
|
@ -785,6 +783,8 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||||
} else {
|
} else {
|
||||||
(None, None)
|
(None, None)
|
||||||
};
|
};
|
||||||
|
// Build system prompt from workspace identity files + skills
|
||||||
|
let workspace = config.workspace_dir.clone();
|
||||||
let tools_registry = Arc::new(tools::all_tools_with_runtime(
|
let tools_registry = Arc::new(tools::all_tools_with_runtime(
|
||||||
&security,
|
&security,
|
||||||
runtime,
|
runtime,
|
||||||
|
|
@ -793,14 +793,12 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||||
composio_entity_id,
|
composio_entity_id,
|
||||||
&config.browser,
|
&config.browser,
|
||||||
&config.http_request,
|
&config.http_request,
|
||||||
&config.workspace_dir,
|
&workspace,
|
||||||
&config.agents,
|
&config.agents,
|
||||||
config.api_key.as_deref(),
|
config.api_key.as_deref(),
|
||||||
&config,
|
&config,
|
||||||
));
|
));
|
||||||
|
|
||||||
// Build system prompt from workspace identity files + skills
|
|
||||||
let workspace = config.workspace_dir.clone();
|
|
||||||
let skills = crate::skills::load_skills(&workspace);
|
let skills = crate::skills::load_skills(&workspace);
|
||||||
|
|
||||||
// Collect tool descriptions for the prompt
|
// Collect tool descriptions for the prompt
|
||||||
|
|
@ -1112,23 +1110,19 @@ mod tests {
|
||||||
message: &str,
|
message: &str,
|
||||||
_model: &str,
|
_model: &str,
|
||||||
_temperature: f64,
|
_temperature: f64,
|
||||||
) -> anyhow::Result<ChatResponse> {
|
) -> anyhow::Result<String> {
|
||||||
tokio::time::sleep(self.delay).await;
|
tokio::time::sleep(self.delay).await;
|
||||||
Ok(ChatResponse::with_text(format!("echo: {message}")))
|
Ok(format!("echo: {message}"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ToolCallingProvider;
|
struct ToolCallingProvider;
|
||||||
|
|
||||||
fn tool_call_payload() -> ChatResponse {
|
fn tool_call_payload() -> String {
|
||||||
ChatResponse {
|
r#"<tool_call>
|
||||||
text: Some(String::new()),
|
{"name":"mock_price","arguments":{"symbol":"BTC"}}
|
||||||
tool_calls: vec![ToolCall {
|
</tool_call>"#
|
||||||
id: "call_1".into(),
|
.to_string()
|
||||||
name: "mock_price".into(),
|
|
||||||
arguments: r#"{"symbol":"BTC"}"#.into(),
|
|
||||||
}],
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
|
|
@ -1139,7 +1133,7 @@ mod tests {
|
||||||
_message: &str,
|
_message: &str,
|
||||||
_model: &str,
|
_model: &str,
|
||||||
_temperature: f64,
|
_temperature: f64,
|
||||||
) -> anyhow::Result<ChatResponse> {
|
) -> anyhow::Result<String> {
|
||||||
Ok(tool_call_payload())
|
Ok(tool_call_payload())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1148,14 +1142,12 @@ mod tests {
|
||||||
messages: &[ChatMessage],
|
messages: &[ChatMessage],
|
||||||
_model: &str,
|
_model: &str,
|
||||||
_temperature: f64,
|
_temperature: f64,
|
||||||
) -> anyhow::Result<ChatResponse> {
|
) -> anyhow::Result<String> {
|
||||||
let has_tool_results = messages
|
let has_tool_results = messages
|
||||||
.iter()
|
.iter()
|
||||||
.any(|msg| msg.role == "user" && msg.content.contains("[Tool results]"));
|
.any(|msg| msg.role == "user" && msg.content.contains("[Tool results]"));
|
||||||
if has_tool_results {
|
if has_tool_results {
|
||||||
Ok(ChatResponse::with_text(
|
Ok("BTC is currently around $65,000 based on latest tool output.".to_string())
|
||||||
"BTC is currently around $65,000 based on latest tool output.",
|
|
||||||
))
|
|
||||||
} else {
|
} else {
|
||||||
Ok(tool_call_payload())
|
Ok(tool_call_payload())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -37,6 +37,9 @@ pub struct Config {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub scheduler: SchedulerConfig,
|
pub scheduler: SchedulerConfig,
|
||||||
|
|
||||||
|
#[serde(default)]
|
||||||
|
pub agent: AgentConfig,
|
||||||
|
|
||||||
/// Model routing rules — route `hint:<name>` to specific provider+model combos.
|
/// Model routing rules — route `hint:<name>` to specific provider+model combos.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub model_routes: Vec<ModelRouteConfig>,
|
pub model_routes: Vec<ModelRouteConfig>,
|
||||||
|
|
@ -209,6 +212,41 @@ impl Default for HardwareConfig {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct AgentConfig {
|
||||||
|
#[serde(default = "default_agent_max_tool_iterations")]
|
||||||
|
pub max_tool_iterations: usize,
|
||||||
|
#[serde(default = "default_agent_max_history_messages")]
|
||||||
|
pub max_history_messages: usize,
|
||||||
|
#[serde(default)]
|
||||||
|
pub parallel_tools: bool,
|
||||||
|
#[serde(default = "default_agent_tool_dispatcher")]
|
||||||
|
pub tool_dispatcher: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_agent_max_tool_iterations() -> usize {
|
||||||
|
10
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_agent_max_history_messages() -> usize {
|
||||||
|
50
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_agent_tool_dispatcher() -> String {
|
||||||
|
"auto".into()
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for AgentConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
max_tool_iterations: default_agent_max_tool_iterations(),
|
||||||
|
max_history_messages: default_agent_max_history_messages(),
|
||||||
|
parallel_tools: false,
|
||||||
|
tool_dispatcher: default_agent_tool_dispatcher(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ── Identity (AIEOS / OpenClaw format) ──────────────────────────
|
// ── Identity (AIEOS / OpenClaw format) ──────────────────────────
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
|
@ -1507,6 +1545,7 @@ impl Default for Config {
|
||||||
runtime: RuntimeConfig::default(),
|
runtime: RuntimeConfig::default(),
|
||||||
reliability: ReliabilityConfig::default(),
|
reliability: ReliabilityConfig::default(),
|
||||||
scheduler: SchedulerConfig::default(),
|
scheduler: SchedulerConfig::default(),
|
||||||
|
agent: AgentConfig::default(),
|
||||||
model_routes: Vec::new(),
|
model_routes: Vec::new(),
|
||||||
heartbeat: HeartbeatConfig::default(),
|
heartbeat: HeartbeatConfig::default(),
|
||||||
channels_config: ChannelsConfig::default(),
|
channels_config: ChannelsConfig::default(),
|
||||||
|
|
@ -1873,6 +1912,7 @@ mod tests {
|
||||||
secrets: SecretsConfig::default(),
|
secrets: SecretsConfig::default(),
|
||||||
browser: BrowserConfig::default(),
|
browser: BrowserConfig::default(),
|
||||||
http_request: HttpRequestConfig::default(),
|
http_request: HttpRequestConfig::default(),
|
||||||
|
agent: AgentConfig::default(),
|
||||||
identity: IdentityConfig::default(),
|
identity: IdentityConfig::default(),
|
||||||
cost: CostConfig::default(),
|
cost: CostConfig::default(),
|
||||||
peripherals: PeripheralsConfig::default(),
|
peripherals: PeripheralsConfig::default(),
|
||||||
|
|
@ -1922,6 +1962,32 @@ default_temperature = 0.7
|
||||||
assert_eq!(parsed.memory.conversation_retention_days, 30);
|
assert_eq!(parsed.memory.conversation_retention_days, 30);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn agent_config_defaults() {
|
||||||
|
let cfg = AgentConfig::default();
|
||||||
|
assert_eq!(cfg.max_tool_iterations, 10);
|
||||||
|
assert_eq!(cfg.max_history_messages, 50);
|
||||||
|
assert!(!cfg.parallel_tools);
|
||||||
|
assert_eq!(cfg.tool_dispatcher, "auto");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn agent_config_deserializes() {
|
||||||
|
let raw = r#"
|
||||||
|
default_temperature = 0.7
|
||||||
|
[agent]
|
||||||
|
max_tool_iterations = 20
|
||||||
|
max_history_messages = 80
|
||||||
|
parallel_tools = true
|
||||||
|
tool_dispatcher = "xml"
|
||||||
|
"#;
|
||||||
|
let parsed: Config = toml::from_str(raw).unwrap();
|
||||||
|
assert_eq!(parsed.agent.max_tool_iterations, 20);
|
||||||
|
assert_eq!(parsed.agent.max_history_messages, 80);
|
||||||
|
assert!(parsed.agent.parallel_tools);
|
||||||
|
assert_eq!(parsed.agent.tool_dispatcher, "xml");
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn config_save_and_load_tmpdir() {
|
fn config_save_and_load_tmpdir() {
|
||||||
let dir = std::env::temp_dir().join("zeroclaw_test_config");
|
let dir = std::env::temp_dir().join("zeroclaw_test_config");
|
||||||
|
|
@ -1951,6 +2017,7 @@ default_temperature = 0.7
|
||||||
secrets: SecretsConfig::default(),
|
secrets: SecretsConfig::default(),
|
||||||
browser: BrowserConfig::default(),
|
browser: BrowserConfig::default(),
|
||||||
http_request: HttpRequestConfig::default(),
|
http_request: HttpRequestConfig::default(),
|
||||||
|
agent: AgentConfig::default(),
|
||||||
identity: IdentityConfig::default(),
|
identity: IdentityConfig::default(),
|
||||||
cost: CostConfig::default(),
|
cost: CostConfig::default(),
|
||||||
peripherals: PeripheralsConfig::default(),
|
peripherals: PeripheralsConfig::default(),
|
||||||
|
|
|
||||||
|
|
@ -10,14 +10,8 @@
|
||||||
use crate::channels::{Channel, WhatsAppChannel};
|
use crate::channels::{Channel, WhatsAppChannel};
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use crate::memory::{self, Memory, MemoryCategory};
|
use crate::memory::{self, Memory, MemoryCategory};
|
||||||
use crate::observability::{self, Observer};
|
use crate::providers::{self, Provider};
|
||||||
use crate::providers::{self, ChatMessage, Provider};
|
use crate::security::pairing::{constant_time_eq, is_public_bind, PairingGuard};
|
||||||
use crate::runtime;
|
|
||||||
use crate::security::{
|
|
||||||
pairing::{constant_time_eq, is_public_bind, PairingGuard},
|
|
||||||
SecurityPolicy,
|
|
||||||
};
|
|
||||||
use crate::tools::{self, Tool};
|
|
||||||
use crate::util::truncate_with_ellipsis;
|
use crate::util::truncate_with_ellipsis;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use axum::{
|
use axum::{
|
||||||
|
|
@ -51,35 +45,6 @@ fn whatsapp_memory_key(msg: &crate::channels::traits::ChannelMessage) -> String
|
||||||
format!("whatsapp_{}_{}", msg.sender, msg.id)
|
format!("whatsapp_{}_{}", msg.sender, msg.id)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn normalize_gateway_reply(reply: String) -> String {
|
|
||||||
if reply.trim().is_empty() {
|
|
||||||
return "Model returned an empty response.".to_string();
|
|
||||||
}
|
|
||||||
|
|
||||||
reply
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn gateway_agent_reply(state: &AppState, message: &str) -> Result<String> {
|
|
||||||
let mut history = vec![
|
|
||||||
ChatMessage::system(state.system_prompt.as_str()),
|
|
||||||
ChatMessage::user(message),
|
|
||||||
];
|
|
||||||
|
|
||||||
let reply = crate::agent::loop_::run_tool_call_loop(
|
|
||||||
state.provider.as_ref(),
|
|
||||||
&mut history,
|
|
||||||
state.tools_registry.as_ref(),
|
|
||||||
state.observer.as_ref(),
|
|
||||||
"gateway",
|
|
||||||
&state.model,
|
|
||||||
state.temperature,
|
|
||||||
true, // silent — gateway responses go over HTTP
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(normalize_gateway_reply(reply))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// How often the rate limiter sweeps stale IP entries from its map.
|
/// How often the rate limiter sweeps stale IP entries from its map.
|
||||||
const RATE_LIMITER_SWEEP_INTERVAL_SECS: u64 = 300; // 5 minutes
|
const RATE_LIMITER_SWEEP_INTERVAL_SECS: u64 = 300; // 5 minutes
|
||||||
|
|
||||||
|
|
@ -207,9 +172,6 @@ fn client_key_from_headers(headers: &HeaderMap) -> String {
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct AppState {
|
pub struct AppState {
|
||||||
pub provider: Arc<dyn Provider>,
|
pub provider: Arc<dyn Provider>,
|
||||||
pub observer: Arc<dyn Observer>,
|
|
||||||
pub tools_registry: Arc<Vec<Box<dyn Tool>>>,
|
|
||||||
pub system_prompt: Arc<String>,
|
|
||||||
pub model: String,
|
pub model: String,
|
||||||
pub temperature: f64,
|
pub temperature: f64,
|
||||||
pub mem: Arc<dyn Memory>,
|
pub mem: Arc<dyn Memory>,
|
||||||
|
|
@ -256,55 +218,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||||
&config.workspace_dir,
|
&config.workspace_dir,
|
||||||
config.api_key.as_deref(),
|
config.api_key.as_deref(),
|
||||||
)?);
|
)?);
|
||||||
let observer: Arc<dyn Observer> =
|
|
||||||
Arc::from(observability::create_observer(&config.observability));
|
|
||||||
let runtime: Arc<dyn runtime::RuntimeAdapter> =
|
|
||||||
Arc::from(runtime::create_runtime(&config.runtime)?);
|
|
||||||
let security = Arc::new(SecurityPolicy::from_config(
|
|
||||||
&config.autonomy,
|
|
||||||
&config.workspace_dir,
|
|
||||||
));
|
|
||||||
|
|
||||||
let (composio_key, composio_entity_id) = if config.composio.enabled {
|
|
||||||
(
|
|
||||||
config.composio.api_key.as_deref(),
|
|
||||||
Some(config.composio.entity_id.as_str()),
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
(None, None)
|
|
||||||
};
|
|
||||||
|
|
||||||
let tools_registry = Arc::new(tools::all_tools_with_runtime(
|
|
||||||
&security,
|
|
||||||
runtime,
|
|
||||||
Arc::clone(&mem),
|
|
||||||
composio_key,
|
|
||||||
composio_entity_id,
|
|
||||||
&config.browser,
|
|
||||||
&config.http_request,
|
|
||||||
&config.workspace_dir,
|
|
||||||
&config.agents,
|
|
||||||
config.api_key.as_deref(),
|
|
||||||
&config,
|
|
||||||
));
|
|
||||||
let skills = crate::skills::load_skills(&config.workspace_dir);
|
|
||||||
let tool_descs: Vec<(&str, &str)> = tools_registry
|
|
||||||
.iter()
|
|
||||||
.map(|tool| (tool.name(), tool.description()))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let mut system_prompt = crate::channels::build_system_prompt(
|
|
||||||
&config.workspace_dir,
|
|
||||||
&model,
|
|
||||||
&tool_descs,
|
|
||||||
&skills,
|
|
||||||
Some(&config.identity),
|
|
||||||
None, // bootstrap_max_chars — no compact context for gateway
|
|
||||||
);
|
|
||||||
system_prompt.push_str(&crate::agent::loop_::build_tool_instructions(
|
|
||||||
tools_registry.as_ref(),
|
|
||||||
));
|
|
||||||
let system_prompt = Arc::new(system_prompt);
|
|
||||||
|
|
||||||
// Extract webhook secret for authentication
|
// Extract webhook secret for authentication
|
||||||
let webhook_secret: Option<Arc<str>> = config
|
let webhook_secret: Option<Arc<str>> = config
|
||||||
|
|
@ -408,9 +322,6 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||||
// Build shared state
|
// Build shared state
|
||||||
let state = AppState {
|
let state = AppState {
|
||||||
provider,
|
provider,
|
||||||
observer,
|
|
||||||
tools_registry,
|
|
||||||
system_prompt,
|
|
||||||
model,
|
model,
|
||||||
temperature,
|
temperature,
|
||||||
mem,
|
mem,
|
||||||
|
|
@ -594,9 +505,13 @@ async fn handle_webhook(
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
match gateway_agent_reply(&state, message).await {
|
match state
|
||||||
Ok(reply) => {
|
.provider
|
||||||
let body = serde_json::json!({"response": reply, "model": state.model});
|
.simple_chat(message, &state.model, state.temperature)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(response) => {
|
||||||
|
let body = serde_json::json!({"response": response, "model": state.model});
|
||||||
(StatusCode::OK, Json(body))
|
(StatusCode::OK, Json(body))
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
|
@ -744,10 +659,14 @@ async fn handle_whatsapp_message(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call the LLM
|
// Call the LLM
|
||||||
match gateway_agent_reply(&state, &msg.content).await {
|
match state
|
||||||
Ok(reply) => {
|
.provider
|
||||||
|
.simple_chat(&msg.content, &state.model, state.temperature)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(response) => {
|
||||||
// Send reply via WhatsApp
|
// Send reply via WhatsApp
|
||||||
if let Err(e) = wa.send(&reply, &msg.sender).await {
|
if let Err(e) = wa.send(&response, &msg.sender).await {
|
||||||
tracing::error!("Failed to send WhatsApp reply: {e}");
|
tracing::error!("Failed to send WhatsApp reply: {e}");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -966,9 +885,9 @@ mod tests {
|
||||||
_message: &str,
|
_message: &str,
|
||||||
_model: &str,
|
_model: &str,
|
||||||
_temperature: f64,
|
_temperature: f64,
|
||||||
) -> anyhow::Result<crate::providers::ChatResponse> {
|
) -> anyhow::Result<String> {
|
||||||
self.calls.fetch_add(1, Ordering::SeqCst);
|
self.calls.fetch_add(1, Ordering::SeqCst);
|
||||||
Ok(crate::providers::ChatResponse::with_text("ok"))
|
Ok("ok".into())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1029,36 +948,25 @@ mod tests {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn test_app_state(
|
|
||||||
provider: Arc<dyn Provider>,
|
|
||||||
memory: Arc<dyn Memory>,
|
|
||||||
auto_save: bool,
|
|
||||||
) -> AppState {
|
|
||||||
AppState {
|
|
||||||
provider,
|
|
||||||
observer: Arc::new(crate::observability::NoopObserver),
|
|
||||||
tools_registry: Arc::new(Vec::new()),
|
|
||||||
system_prompt: Arc::new("test-system-prompt".into()),
|
|
||||||
model: "test-model".into(),
|
|
||||||
temperature: 0.0,
|
|
||||||
mem: memory,
|
|
||||||
auto_save,
|
|
||||||
webhook_secret: None,
|
|
||||||
pairing: Arc::new(PairingGuard::new(false, &[])),
|
|
||||||
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
|
||||||
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
|
||||||
whatsapp: None,
|
|
||||||
whatsapp_app_secret: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn webhook_idempotency_skips_duplicate_provider_calls() {
|
async fn webhook_idempotency_skips_duplicate_provider_calls() {
|
||||||
let provider_impl = Arc::new(MockProvider::default());
|
let provider_impl = Arc::new(MockProvider::default());
|
||||||
let provider: Arc<dyn Provider> = provider_impl.clone();
|
let provider: Arc<dyn Provider> = provider_impl.clone();
|
||||||
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
|
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
|
||||||
|
|
||||||
let state = test_app_state(provider, memory, false);
|
let state = AppState {
|
||||||
|
provider,
|
||||||
|
model: "test-model".into(),
|
||||||
|
temperature: 0.0,
|
||||||
|
mem: memory,
|
||||||
|
auto_save: false,
|
||||||
|
webhook_secret: None,
|
||||||
|
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||||
|
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
||||||
|
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
||||||
|
whatsapp: None,
|
||||||
|
whatsapp_app_secret: None,
|
||||||
|
};
|
||||||
|
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
headers.insert("X-Idempotency-Key", HeaderValue::from_static("abc-123"));
|
headers.insert("X-Idempotency-Key", HeaderValue::from_static("abc-123"));
|
||||||
|
|
@ -1094,7 +1002,19 @@ mod tests {
|
||||||
let tracking_impl = Arc::new(TrackingMemory::default());
|
let tracking_impl = Arc::new(TrackingMemory::default());
|
||||||
let memory: Arc<dyn Memory> = tracking_impl.clone();
|
let memory: Arc<dyn Memory> = tracking_impl.clone();
|
||||||
|
|
||||||
let state = test_app_state(provider, memory, true);
|
let state = AppState {
|
||||||
|
provider,
|
||||||
|
model: "test-model".into(),
|
||||||
|
temperature: 0.0,
|
||||||
|
mem: memory,
|
||||||
|
auto_save: true,
|
||||||
|
webhook_secret: None,
|
||||||
|
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||||
|
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
||||||
|
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
||||||
|
whatsapp: None,
|
||||||
|
whatsapp_app_secret: None,
|
||||||
|
};
|
||||||
|
|
||||||
let headers = HeaderMap::new();
|
let headers = HeaderMap::new();
|
||||||
|
|
||||||
|
|
@ -1126,110 +1046,6 @@ mod tests {
|
||||||
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2);
|
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default)]
|
|
||||||
struct StructuredToolCallProvider {
|
|
||||||
calls: AtomicUsize,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl Provider for StructuredToolCallProvider {
|
|
||||||
async fn chat_with_system(
|
|
||||||
&self,
|
|
||||||
_system_prompt: Option<&str>,
|
|
||||||
_message: &str,
|
|
||||||
_model: &str,
|
|
||||||
_temperature: f64,
|
|
||||||
) -> anyhow::Result<crate::providers::ChatResponse> {
|
|
||||||
let turn = self.calls.fetch_add(1, Ordering::SeqCst);
|
|
||||||
|
|
||||||
if turn == 0 {
|
|
||||||
return Ok(crate::providers::ChatResponse {
|
|
||||||
text: Some("Running tool...".into()),
|
|
||||||
tool_calls: vec![crate::providers::ToolCall {
|
|
||||||
id: "call_1".into(),
|
|
||||||
name: "mock_tool".into(),
|
|
||||||
arguments: r#"{"query":"gateway"}"#.into(),
|
|
||||||
}],
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(crate::providers::ChatResponse::with_text(
|
|
||||||
"Gateway tool result ready.",
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct MockTool {
|
|
||||||
calls: Arc<AtomicUsize>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl Tool for MockTool {
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
"mock_tool"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn description(&self) -> &str {
|
|
||||||
"Mock tool for gateway tests"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parameters_schema(&self) -> serde_json::Value {
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"query": {"type": "string"}
|
|
||||||
},
|
|
||||||
"required": ["query"]
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn execute(
|
|
||||||
&self,
|
|
||||||
args: serde_json::Value,
|
|
||||||
) -> anyhow::Result<crate::tools::ToolResult> {
|
|
||||||
self.calls.fetch_add(1, Ordering::SeqCst);
|
|
||||||
assert_eq!(args["query"], "gateway");
|
|
||||||
|
|
||||||
Ok(crate::tools::ToolResult {
|
|
||||||
success: true,
|
|
||||||
output: "ok".into(),
|
|
||||||
error: None,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn webhook_executes_structured_tool_calls() {
|
|
||||||
let provider_impl = Arc::new(StructuredToolCallProvider::default());
|
|
||||||
let provider: Arc<dyn Provider> = provider_impl.clone();
|
|
||||||
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
|
|
||||||
|
|
||||||
let tool_calls = Arc::new(AtomicUsize::new(0));
|
|
||||||
let tools: Vec<Box<dyn Tool>> = vec![Box::new(MockTool {
|
|
||||||
calls: Arc::clone(&tool_calls),
|
|
||||||
})];
|
|
||||||
|
|
||||||
let mut state = test_app_state(provider, memory, false);
|
|
||||||
state.tools_registry = Arc::new(tools);
|
|
||||||
|
|
||||||
let response = handle_webhook(
|
|
||||||
State(state),
|
|
||||||
HeaderMap::new(),
|
|
||||||
Ok(Json(WebhookBody {
|
|
||||||
message: "please use tool".into(),
|
|
||||||
})),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.into_response();
|
|
||||||
|
|
||||||
assert_eq!(response.status(), StatusCode::OK);
|
|
||||||
let payload = response.into_body().collect().await.unwrap().to_bytes();
|
|
||||||
let parsed: serde_json::Value = serde_json::from_slice(&payload).unwrap();
|
|
||||||
assert_eq!(parsed["response"], "Gateway tool result ready.");
|
|
||||||
assert_eq!(tool_calls.load(Ordering::SeqCst), 1);
|
|
||||||
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2);
|
|
||||||
}
|
|
||||||
|
|
||||||
// ══════════════════════════════════════════════════════════
|
// ══════════════════════════════════════════════════════════
|
||||||
// WhatsApp Signature Verification Tests (CWE-345 Prevention)
|
// WhatsApp Signature Verification Tests (CWE-345 Prevention)
|
||||||
// ══════════════════════════════════════════════════════════
|
// ══════════════════════════════════════════════════════════
|
||||||
|
|
|
||||||
|
|
@ -114,6 +114,7 @@ pub fn run_wizard() -> Result<Config> {
|
||||||
runtime: RuntimeConfig::default(),
|
runtime: RuntimeConfig::default(),
|
||||||
reliability: crate::config::ReliabilityConfig::default(),
|
reliability: crate::config::ReliabilityConfig::default(),
|
||||||
scheduler: crate::config::schema::SchedulerConfig::default(),
|
scheduler: crate::config::schema::SchedulerConfig::default(),
|
||||||
|
agent: crate::config::schema::AgentConfig::default(),
|
||||||
model_routes: Vec::new(),
|
model_routes: Vec::new(),
|
||||||
heartbeat: HeartbeatConfig::default(),
|
heartbeat: HeartbeatConfig::default(),
|
||||||
channels_config,
|
channels_config,
|
||||||
|
|
@ -318,6 +319,7 @@ pub fn run_quick_setup(
|
||||||
runtime: RuntimeConfig::default(),
|
runtime: RuntimeConfig::default(),
|
||||||
reliability: crate::config::ReliabilityConfig::default(),
|
reliability: crate::config::ReliabilityConfig::default(),
|
||||||
scheduler: crate::config::schema::SchedulerConfig::default(),
|
scheduler: crate::config::schema::SchedulerConfig::default(),
|
||||||
|
agent: crate::config::schema::AgentConfig::default(),
|
||||||
model_routes: Vec::new(),
|
model_routes: Vec::new(),
|
||||||
heartbeat: HeartbeatConfig::default(),
|
heartbeat: HeartbeatConfig::default(),
|
||||||
channels_config: ChannelsConfig::default(),
|
channels_config: ChannelsConfig::default(),
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,8 @@
|
||||||
use crate::providers::traits::{ChatResponse as ProviderChatResponse, Provider};
|
use crate::providers::traits::{
|
||||||
|
ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
|
||||||
|
Provider, ToolCall as ProviderToolCall,
|
||||||
|
};
|
||||||
|
use crate::tools::ToolSpec;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
@ -26,13 +30,76 @@ struct Message {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct ApiChatResponse {
|
struct ChatResponse {
|
||||||
content: Vec<ContentBlock>,
|
content: Vec<ContentBlock>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct ContentBlock {
|
struct ContentBlock {
|
||||||
text: String,
|
#[serde(rename = "type")]
|
||||||
|
kind: String,
|
||||||
|
#[serde(default)]
|
||||||
|
text: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct NativeChatRequest {
|
||||||
|
model: String,
|
||||||
|
max_tokens: u32,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
system: Option<String>,
|
||||||
|
messages: Vec<NativeMessage>,
|
||||||
|
temperature: f64,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tools: Option<Vec<NativeToolSpec>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct NativeMessage {
|
||||||
|
role: String,
|
||||||
|
content: Vec<NativeContentOut>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
#[serde(tag = "type")]
|
||||||
|
enum NativeContentOut {
|
||||||
|
#[serde(rename = "text")]
|
||||||
|
Text { text: String },
|
||||||
|
#[serde(rename = "tool_use")]
|
||||||
|
ToolUse {
|
||||||
|
id: String,
|
||||||
|
name: String,
|
||||||
|
input: serde_json::Value,
|
||||||
|
},
|
||||||
|
#[serde(rename = "tool_result")]
|
||||||
|
ToolResult { tool_use_id: String, content: String },
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct NativeToolSpec {
|
||||||
|
name: String,
|
||||||
|
description: String,
|
||||||
|
input_schema: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct NativeChatResponse {
|
||||||
|
#[serde(default)]
|
||||||
|
content: Vec<NativeContentIn>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct NativeContentIn {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
kind: String,
|
||||||
|
#[serde(default)]
|
||||||
|
text: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
id: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
name: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
input: Option<serde_json::Value>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AnthropicProvider {
|
impl AnthropicProvider {
|
||||||
|
|
@ -62,6 +129,186 @@ impl AnthropicProvider {
|
||||||
fn is_setup_token(token: &str) -> bool {
|
fn is_setup_token(token: &str) -> bool {
|
||||||
token.starts_with("sk-ant-oat01-")
|
token.starts_with("sk-ant-oat01-")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn apply_auth(
|
||||||
|
&self,
|
||||||
|
request: reqwest::RequestBuilder,
|
||||||
|
credential: &str,
|
||||||
|
) -> reqwest::RequestBuilder {
|
||||||
|
if Self::is_setup_token(credential) {
|
||||||
|
request.header("Authorization", format!("Bearer {credential}"))
|
||||||
|
} else {
|
||||||
|
request.header("x-api-key", credential)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec>> {
|
||||||
|
let items = tools?;
|
||||||
|
if items.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
Some(
|
||||||
|
items
|
||||||
|
.iter()
|
||||||
|
.map(|tool| NativeToolSpec {
|
||||||
|
name: tool.name.clone(),
|
||||||
|
description: tool.description.clone(),
|
||||||
|
input_schema: tool.parameters.clone(),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_assistant_tool_call_message(content: &str) -> Option<Vec<NativeContentOut>> {
|
||||||
|
let value = serde_json::from_str::<serde_json::Value>(content).ok()?;
|
||||||
|
let tool_calls = value
|
||||||
|
.get("tool_calls")
|
||||||
|
.and_then(|v| serde_json::from_value::<Vec<ProviderToolCall>>(v.clone()).ok())?;
|
||||||
|
|
||||||
|
let mut blocks = Vec::new();
|
||||||
|
if let Some(text) = value
|
||||||
|
.get("content")
|
||||||
|
.and_then(serde_json::Value::as_str)
|
||||||
|
.map(str::trim)
|
||||||
|
.filter(|t| !t.is_empty())
|
||||||
|
{
|
||||||
|
blocks.push(NativeContentOut::Text {
|
||||||
|
text: text.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
for call in tool_calls {
|
||||||
|
let input = serde_json::from_str::<serde_json::Value>(&call.arguments)
|
||||||
|
.unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new()));
|
||||||
|
blocks.push(NativeContentOut::ToolUse {
|
||||||
|
id: call.id,
|
||||||
|
name: call.name,
|
||||||
|
input,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Some(blocks)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_tool_result_message(content: &str) -> Option<NativeMessage> {
|
||||||
|
let value = serde_json::from_str::<serde_json::Value>(content).ok()?;
|
||||||
|
let tool_use_id = value
|
||||||
|
.get("tool_call_id")
|
||||||
|
.and_then(serde_json::Value::as_str)?
|
||||||
|
.to_string();
|
||||||
|
let result = value
|
||||||
|
.get("content")
|
||||||
|
.and_then(serde_json::Value::as_str)
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string();
|
||||||
|
Some(NativeMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: vec![NativeContentOut::ToolResult {
|
||||||
|
tool_use_id,
|
||||||
|
content: result,
|
||||||
|
}],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_messages(messages: &[ChatMessage]) -> (Option<String>, Vec<NativeMessage>) {
|
||||||
|
let mut system_prompt = None;
|
||||||
|
let mut native_messages = Vec::new();
|
||||||
|
|
||||||
|
for msg in messages {
|
||||||
|
match msg.role.as_str() {
|
||||||
|
"system" => {
|
||||||
|
if system_prompt.is_none() {
|
||||||
|
system_prompt = Some(msg.content.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"assistant" => {
|
||||||
|
if let Some(blocks) = Self::parse_assistant_tool_call_message(&msg.content) {
|
||||||
|
native_messages.push(NativeMessage {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: blocks,
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
native_messages.push(NativeMessage {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: vec![NativeContentOut::Text {
|
||||||
|
text: msg.content.clone(),
|
||||||
|
}],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"tool" => {
|
||||||
|
if let Some(tool_result) = Self::parse_tool_result_message(&msg.content) {
|
||||||
|
native_messages.push(tool_result);
|
||||||
|
} else {
|
||||||
|
native_messages.push(NativeMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: vec![NativeContentOut::Text {
|
||||||
|
text: msg.content.clone(),
|
||||||
|
}],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
native_messages.push(NativeMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: vec![NativeContentOut::Text {
|
||||||
|
text: msg.content.clone(),
|
||||||
|
}],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(system_prompt, native_messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_text_response(response: ChatResponse) -> anyhow::Result<String> {
|
||||||
|
response
|
||||||
|
.content
|
||||||
|
.into_iter()
|
||||||
|
.find(|c| c.kind == "text")
|
||||||
|
.and_then(|c| c.text)
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("No response from Anthropic"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_native_response(response: NativeChatResponse) -> ProviderChatResponse {
|
||||||
|
let mut text_parts = Vec::new();
|
||||||
|
let mut tool_calls = Vec::new();
|
||||||
|
|
||||||
|
for block in response.content {
|
||||||
|
match block.kind.as_str() {
|
||||||
|
"text" => {
|
||||||
|
if let Some(text) = block.text.map(|t| t.trim().to_string()) {
|
||||||
|
if !text.is_empty() {
|
||||||
|
text_parts.push(text);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"tool_use" => {
|
||||||
|
let name = block.name.unwrap_or_default();
|
||||||
|
if name.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let arguments = block
|
||||||
|
.input
|
||||||
|
.unwrap_or_else(|| serde_json::Value::Object(serde_json::Map::new()));
|
||||||
|
tool_calls.push(ProviderToolCall {
|
||||||
|
id: block.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
||||||
|
name,
|
||||||
|
arguments: arguments.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ProviderChatResponse {
|
||||||
|
text: if text_parts.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(text_parts.join("\n"))
|
||||||
|
},
|
||||||
|
tool_calls,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
|
|
@ -72,7 +319,7 @@ impl Provider for AnthropicProvider {
|
||||||
message: &str,
|
message: &str,
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<ProviderChatResponse> {
|
) -> anyhow::Result<String> {
|
||||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||||
anyhow::anyhow!(
|
anyhow::anyhow!(
|
||||||
"Anthropic credentials not set. Set ANTHROPIC_API_KEY or ANTHROPIC_OAUTH_TOKEN (setup-token)."
|
"Anthropic credentials not set. Set ANTHROPIC_API_KEY or ANTHROPIC_OAUTH_TOKEN (setup-token)."
|
||||||
|
|
@ -97,11 +344,7 @@ impl Provider for AnthropicProvider {
|
||||||
.header("content-type", "application/json")
|
.header("content-type", "application/json")
|
||||||
.json(&request);
|
.json(&request);
|
||||||
|
|
||||||
if Self::is_setup_token(credential) {
|
request = self.apply_auth(request, credential);
|
||||||
request = request.header("Authorization", format!("Bearer {credential}"));
|
|
||||||
} else {
|
|
||||||
request = request.header("x-api-key", credential);
|
|
||||||
}
|
|
||||||
|
|
||||||
let response = request.send().await?;
|
let response = request.send().await?;
|
||||||
|
|
||||||
|
|
@ -109,14 +352,50 @@ impl Provider for AnthropicProvider {
|
||||||
return Err(super::api_error("Anthropic", response).await);
|
return Err(super::api_error("Anthropic", response).await);
|
||||||
}
|
}
|
||||||
|
|
||||||
let chat_response: ApiChatResponse = response.json().await?;
|
let chat_response: ChatResponse = response.json().await?;
|
||||||
|
Self::parse_text_response(chat_response)
|
||||||
|
}
|
||||||
|
|
||||||
chat_response
|
async fn chat(
|
||||||
.content
|
&self,
|
||||||
.into_iter()
|
request: ProviderChatRequest<'_>,
|
||||||
.next()
|
model: &str,
|
||||||
.map(|c| ProviderChatResponse::with_text(c.text))
|
temperature: f64,
|
||||||
.ok_or_else(|| anyhow::anyhow!("No response from Anthropic"))
|
) -> anyhow::Result<ProviderChatResponse> {
|
||||||
|
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||||
|
anyhow::anyhow!(
|
||||||
|
"Anthropic credentials not set. Set ANTHROPIC_API_KEY or ANTHROPIC_OAUTH_TOKEN (setup-token)."
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let (system_prompt, messages) = Self::convert_messages(request.messages);
|
||||||
|
let native_request = NativeChatRequest {
|
||||||
|
model: model.to_string(),
|
||||||
|
max_tokens: 4096,
|
||||||
|
system: system_prompt,
|
||||||
|
messages,
|
||||||
|
temperature,
|
||||||
|
tools: Self::convert_tools(request.tools),
|
||||||
|
};
|
||||||
|
|
||||||
|
let req = self
|
||||||
|
.client
|
||||||
|
.post(format!("{}/v1/messages", self.base_url))
|
||||||
|
.header("anthropic-version", "2023-06-01")
|
||||||
|
.header("content-type", "application/json")
|
||||||
|
.json(&native_request);
|
||||||
|
|
||||||
|
let response = self.apply_auth(req, credential).send().await?;
|
||||||
|
if !response.status().is_success() {
|
||||||
|
return Err(super::api_error("Anthropic", response).await);
|
||||||
|
}
|
||||||
|
|
||||||
|
let native_response: NativeChatResponse = response.json().await?;
|
||||||
|
Ok(Self::parse_native_response(native_response))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn supports_native_tools(&self) -> bool {
|
||||||
|
true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -241,15 +520,16 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn chat_response_deserializes() {
|
fn chat_response_deserializes() {
|
||||||
let json = r#"{"content":[{"type":"text","text":"Hello there!"}]}"#;
|
let json = r#"{"content":[{"type":"text","text":"Hello there!"}]}"#;
|
||||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||||
assert_eq!(resp.content.len(), 1);
|
assert_eq!(resp.content.len(), 1);
|
||||||
assert_eq!(resp.content[0].text, "Hello there!");
|
assert_eq!(resp.content[0].kind, "text");
|
||||||
|
assert_eq!(resp.content[0].text.as_deref(), Some("Hello there!"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn chat_response_empty_content() {
|
fn chat_response_empty_content() {
|
||||||
let json = r#"{"content":[]}"#;
|
let json = r#"{"content":[]}"#;
|
||||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||||
assert!(resp.content.is_empty());
|
assert!(resp.content.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -257,10 +537,10 @@ mod tests {
|
||||||
fn chat_response_multiple_blocks() {
|
fn chat_response_multiple_blocks() {
|
||||||
let json =
|
let json =
|
||||||
r#"{"content":[{"type":"text","text":"First"},{"type":"text","text":"Second"}]}"#;
|
r#"{"content":[{"type":"text","text":"First"},{"type":"text","text":"Second"}]}"#;
|
||||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||||
assert_eq!(resp.content.len(), 2);
|
assert_eq!(resp.content.len(), 2);
|
||||||
assert_eq!(resp.content[0].text, "First");
|
assert_eq!(resp.content[0].text.as_deref(), Some("First"));
|
||||||
assert_eq!(resp.content[1].text, "Second");
|
assert_eq!(resp.content[1].text.as_deref(), Some("Second"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,10 @@
|
||||||
//! Most LLM APIs follow the same `/v1/chat/completions` format.
|
//! Most LLM APIs follow the same `/v1/chat/completions` format.
|
||||||
//! This module provides a single implementation that works for all of them.
|
//! This module provides a single implementation that works for all of them.
|
||||||
|
|
||||||
use crate::providers::traits::{ChatMessage, ChatResponse, Provider, ToolCall};
|
use crate::providers::traits::{
|
||||||
|
ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
|
||||||
|
Provider, ToolCall as ProviderToolCall,
|
||||||
|
};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
@ -163,12 +166,11 @@ struct ResponseMessage {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
content: Option<String>,
|
content: Option<String>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
tool_calls: Option<Vec<ApiToolCall>>,
|
tool_calls: Option<Vec<ToolCall>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize)]
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
struct ApiToolCall {
|
struct ToolCall {
|
||||||
id: Option<String>,
|
|
||||||
#[serde(rename = "type")]
|
#[serde(rename = "type")]
|
||||||
kind: Option<String>,
|
kind: Option<String>,
|
||||||
function: Option<Function>,
|
function: Option<Function>,
|
||||||
|
|
@ -254,44 +256,6 @@ fn extract_responses_text(response: ResponsesResponse) -> Option<String> {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
fn map_response_message(message: ResponseMessage) -> ChatResponse {
|
|
||||||
let text = first_nonempty(message.content.as_deref());
|
|
||||||
let tool_calls = message
|
|
||||||
.tool_calls
|
|
||||||
.unwrap_or_default()
|
|
||||||
.into_iter()
|
|
||||||
.enumerate()
|
|
||||||
.filter_map(|(index, call)| map_api_tool_call(call, index))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
ChatResponse { text, tool_calls }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn map_api_tool_call(call: ApiToolCall, index: usize) -> Option<ToolCall> {
|
|
||||||
if call.kind.as_deref().is_some_and(|kind| kind != "function") {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
let function = call.function?;
|
|
||||||
let name = function
|
|
||||||
.name
|
|
||||||
.and_then(|value| first_nonempty(Some(value.as_str())))?;
|
|
||||||
let arguments = function
|
|
||||||
.arguments
|
|
||||||
.and_then(|value| first_nonempty(Some(value.as_str())))
|
|
||||||
.unwrap_or_else(|| "{}".to_string());
|
|
||||||
let id = call
|
|
||||||
.id
|
|
||||||
.and_then(|value| first_nonempty(Some(value.as_str())))
|
|
||||||
.unwrap_or_else(|| format!("call_{}", index + 1));
|
|
||||||
|
|
||||||
Some(ToolCall {
|
|
||||||
id,
|
|
||||||
name,
|
|
||||||
arguments,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
impl OpenAiCompatibleProvider {
|
impl OpenAiCompatibleProvider {
|
||||||
fn apply_auth_header(
|
fn apply_auth_header(
|
||||||
&self,
|
&self,
|
||||||
|
|
@ -311,7 +275,7 @@ impl OpenAiCompatibleProvider {
|
||||||
system_prompt: Option<&str>,
|
system_prompt: Option<&str>,
|
||||||
message: &str,
|
message: &str,
|
||||||
model: &str,
|
model: &str,
|
||||||
) -> anyhow::Result<ChatResponse> {
|
) -> anyhow::Result<String> {
|
||||||
let request = ResponsesRequest {
|
let request = ResponsesRequest {
|
||||||
model: model.to_string(),
|
model: model.to_string(),
|
||||||
input: vec![ResponsesInput {
|
input: vec![ResponsesInput {
|
||||||
|
|
@ -337,7 +301,6 @@ impl OpenAiCompatibleProvider {
|
||||||
let responses: ResponsesResponse = response.json().await?;
|
let responses: ResponsesResponse = response.json().await?;
|
||||||
|
|
||||||
extract_responses_text(responses)
|
extract_responses_text(responses)
|
||||||
.map(ChatResponse::with_text)
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("No response from {} Responses API", self.name))
|
.ok_or_else(|| anyhow::anyhow!("No response from {} Responses API", self.name))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -350,7 +313,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
message: &str,
|
message: &str,
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<ChatResponse> {
|
) -> anyhow::Result<String> {
|
||||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||||
anyhow::anyhow!(
|
anyhow::anyhow!(
|
||||||
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
|
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
|
||||||
|
|
@ -408,13 +371,27 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
|
|
||||||
let chat_response: ApiChatResponse = response.json().await?;
|
let chat_response: ApiChatResponse = response.json().await?;
|
||||||
|
|
||||||
let choice = chat_response
|
chat_response
|
||||||
.choices
|
.choices
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.next()
|
.next()
|
||||||
.ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))?;
|
.map(|c| {
|
||||||
|
// If tool_calls are present, serialize the full message as JSON
|
||||||
Ok(map_response_message(choice.message))
|
// so parse_tool_calls can handle the OpenAI-style format
|
||||||
|
if c.message.tool_calls.is_some()
|
||||||
|
&& c.message
|
||||||
|
.tool_calls
|
||||||
|
.as_ref()
|
||||||
|
.map_or(false, |t| !t.is_empty())
|
||||||
|
{
|
||||||
|
serde_json::to_string(&c.message)
|
||||||
|
.unwrap_or_else(|_| c.message.content.unwrap_or_default())
|
||||||
|
} else {
|
||||||
|
// No tool calls, return content as-is
|
||||||
|
c.message.content.unwrap_or_default()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn chat_with_history(
|
async fn chat_with_history(
|
||||||
|
|
@ -422,7 +399,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
messages: &[ChatMessage],
|
messages: &[ChatMessage],
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<ChatResponse> {
|
) -> anyhow::Result<String> {
|
||||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||||
anyhow::anyhow!(
|
anyhow::anyhow!(
|
||||||
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
|
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
|
||||||
|
|
@ -482,13 +459,71 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
|
|
||||||
let chat_response: ApiChatResponse = response.json().await?;
|
let chat_response: ApiChatResponse = response.json().await?;
|
||||||
|
|
||||||
let choice = chat_response
|
chat_response
|
||||||
.choices
|
.choices
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.next()
|
.next()
|
||||||
.ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))?;
|
.map(|c| {
|
||||||
|
// If tool_calls are present, serialize the full message as JSON
|
||||||
|
// so parse_tool_calls can handle the OpenAI-style format
|
||||||
|
if c.message.tool_calls.is_some()
|
||||||
|
&& c.message
|
||||||
|
.tool_calls
|
||||||
|
.as_ref()
|
||||||
|
.map_or(false, |t| !t.is_empty())
|
||||||
|
{
|
||||||
|
serde_json::to_string(&c.message)
|
||||||
|
.unwrap_or_else(|_| c.message.content.unwrap_or_default())
|
||||||
|
} else {
|
||||||
|
// No tool calls, return content as-is
|
||||||
|
c.message.content.unwrap_or_default()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))
|
||||||
|
}
|
||||||
|
|
||||||
Ok(map_response_message(choice.message))
|
async fn chat(
|
||||||
|
&self,
|
||||||
|
request: ProviderChatRequest<'_>,
|
||||||
|
model: &str,
|
||||||
|
temperature: f64,
|
||||||
|
) -> anyhow::Result<ProviderChatResponse> {
|
||||||
|
let text = self
|
||||||
|
.chat_with_history(request.messages, model, temperature)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Backward compatible path: chat_with_history may serialize tool_calls JSON into content.
|
||||||
|
if let Ok(message) = serde_json::from_str::<ResponseMessage>(&text) {
|
||||||
|
let tool_calls = message
|
||||||
|
.tool_calls
|
||||||
|
.unwrap_or_default()
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|tc| {
|
||||||
|
let function = tc.function?;
|
||||||
|
let name = function.name?;
|
||||||
|
let arguments = function.arguments.unwrap_or_else(|| "{}".to_string());
|
||||||
|
Some(ProviderToolCall {
|
||||||
|
id: uuid::Uuid::new_v4().to_string(),
|
||||||
|
name,
|
||||||
|
arguments,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
return Ok(ProviderChatResponse {
|
||||||
|
text: message.content,
|
||||||
|
tool_calls,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(ProviderChatResponse {
|
||||||
|
text: Some(text),
|
||||||
|
tool_calls: vec![],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn supports_native_tools(&self) -> bool {
|
||||||
|
true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -573,20 +608,6 @@ mod tests {
|
||||||
assert!(resp.choices.is_empty());
|
assert!(resp.choices.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn response_with_tool_calls_maps_structured_data() {
|
|
||||||
let json = r#"{"choices":[{"message":{"content":"Running checks","tool_calls":[{"id":"call_1","type":"function","function":{"name":"shell","arguments":"{\"command\":\"pwd\"}"}}]}}]}"#;
|
|
||||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
|
||||||
let choice = resp.choices.into_iter().next().unwrap();
|
|
||||||
|
|
||||||
let mapped = map_response_message(choice.message);
|
|
||||||
assert_eq!(mapped.text.as_deref(), Some("Running checks"));
|
|
||||||
assert_eq!(mapped.tool_calls.len(), 1);
|
|
||||||
assert_eq!(mapped.tool_calls[0].id, "call_1");
|
|
||||||
assert_eq!(mapped.tool_calls[0].name, "shell");
|
|
||||||
assert_eq!(mapped.tool_calls[0].arguments, r#"{"command":"pwd"}"#);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn x_api_key_auth_style() {
|
fn x_api_key_auth_style() {
|
||||||
let p = OpenAiCompatibleProvider::new(
|
let p = OpenAiCompatibleProvider::new(
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
//! - Gemini CLI OAuth tokens (reuse existing ~/.gemini/ authentication)
|
//! - Gemini CLI OAuth tokens (reuse existing ~/.gemini/ authentication)
|
||||||
//! - Google Cloud ADC (`GOOGLE_APPLICATION_CREDENTIALS`)
|
//! - Google Cloud ADC (`GOOGLE_APPLICATION_CREDENTIALS`)
|
||||||
|
|
||||||
use crate::providers::traits::{ChatResponse, Provider};
|
use crate::providers::traits::Provider;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use directories::UserDirs;
|
use directories::UserDirs;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
|
|
@ -260,7 +260,7 @@ impl Provider for GeminiProvider {
|
||||||
message: &str,
|
message: &str,
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<ChatResponse> {
|
) -> anyhow::Result<String> {
|
||||||
let auth = self.auth.as_ref().ok_or_else(|| {
|
let auth = self.auth.as_ref().ok_or_else(|| {
|
||||||
anyhow::anyhow!(
|
anyhow::anyhow!(
|
||||||
"Gemini API key not found. Options:\n\
|
"Gemini API key not found. Options:\n\
|
||||||
|
|
@ -319,7 +319,6 @@ impl Provider for GeminiProvider {
|
||||||
.and_then(|c| c.into_iter().next())
|
.and_then(|c| c.into_iter().next())
|
||||||
.and_then(|c| c.content.parts.into_iter().next())
|
.and_then(|c| c.content.parts.into_iter().next())
|
||||||
.and_then(|p| p.text)
|
.and_then(|p| p.text)
|
||||||
.map(ChatResponse::with_text)
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("No response from Gemini"))
|
.ok_or_else(|| anyhow::anyhow!("No response from Gemini"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,10 @@ pub mod router;
|
||||||
pub mod traits;
|
pub mod traits;
|
||||||
|
|
||||||
#[allow(unused_imports)]
|
#[allow(unused_imports)]
|
||||||
pub use traits::{ChatMessage, ChatResponse, Provider, ToolCall};
|
pub use traits::{
|
||||||
|
ChatMessage, ChatRequest, ChatResponse, ConversationMessage, Provider, ToolCall,
|
||||||
|
ToolResultMessage,
|
||||||
|
};
|
||||||
|
|
||||||
use compatible::{AuthStyle, OpenAiCompatibleProvider};
|
use compatible::{AuthStyle, OpenAiCompatibleProvider};
|
||||||
use reliable::ReliableProvider;
|
use reliable::ReliableProvider;
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
use crate::providers::traits::{ChatResponse as ProviderChatResponse, Provider};
|
use crate::providers::traits::Provider;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
@ -61,7 +61,7 @@ impl Provider for OllamaProvider {
|
||||||
message: &str,
|
message: &str,
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<ProviderChatResponse> {
|
) -> anyhow::Result<String> {
|
||||||
let mut messages = Vec::new();
|
let mut messages = Vec::new();
|
||||||
|
|
||||||
if let Some(sys) = system_prompt {
|
if let Some(sys) = system_prompt {
|
||||||
|
|
@ -93,9 +93,7 @@ impl Provider for OllamaProvider {
|
||||||
}
|
}
|
||||||
|
|
||||||
let chat_response: ApiChatResponse = response.json().await?;
|
let chat_response: ApiChatResponse = response.json().await?;
|
||||||
Ok(ProviderChatResponse::with_text(
|
Ok(chat_response.message.content)
|
||||||
chat_response.message.content,
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,8 @@
|
||||||
use crate::providers::traits::{ChatResponse, Provider};
|
use crate::providers::traits::{
|
||||||
|
ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
|
||||||
|
Provider, ToolCall as ProviderToolCall,
|
||||||
|
};
|
||||||
|
use crate::tools::ToolSpec;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
@ -22,7 +26,7 @@ struct Message {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct ApiChatResponse {
|
struct ChatResponse {
|
||||||
choices: Vec<Choice>,
|
choices: Vec<Choice>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -36,6 +40,75 @@ struct ResponseMessage {
|
||||||
content: String,
|
content: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct NativeChatRequest {
|
||||||
|
model: String,
|
||||||
|
messages: Vec<NativeMessage>,
|
||||||
|
temperature: f64,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tools: Option<Vec<NativeToolSpec>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tool_choice: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct NativeMessage {
|
||||||
|
role: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
content: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tool_call_id: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tool_calls: Option<Vec<NativeToolCall>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct NativeToolSpec {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
kind: String,
|
||||||
|
function: NativeToolFunctionSpec,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct NativeToolFunctionSpec {
|
||||||
|
name: String,
|
||||||
|
description: String,
|
||||||
|
parameters: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct NativeToolCall {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
id: Option<String>,
|
||||||
|
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
|
||||||
|
kind: Option<String>,
|
||||||
|
function: NativeFunctionCall,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct NativeFunctionCall {
|
||||||
|
name: String,
|
||||||
|
arguments: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct NativeChatResponse {
|
||||||
|
choices: Vec<NativeChoice>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct NativeChoice {
|
||||||
|
message: NativeResponseMessage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct NativeResponseMessage {
|
||||||
|
#[serde(default)]
|
||||||
|
content: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
tool_calls: Option<Vec<NativeToolCall>>,
|
||||||
|
}
|
||||||
|
|
||||||
impl OpenAiProvider {
|
impl OpenAiProvider {
|
||||||
pub fn new(api_key: Option<&str>) -> Self {
|
pub fn new(api_key: Option<&str>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
|
@ -47,6 +120,107 @@ impl OpenAiProvider {
|
||||||
.unwrap_or_else(|_| Client::new()),
|
.unwrap_or_else(|_| Client::new()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec>> {
|
||||||
|
tools.map(|items| {
|
||||||
|
items
|
||||||
|
.iter()
|
||||||
|
.map(|tool| NativeToolSpec {
|
||||||
|
kind: "function".to_string(),
|
||||||
|
function: NativeToolFunctionSpec {
|
||||||
|
name: tool.name.clone(),
|
||||||
|
description: tool.description.clone(),
|
||||||
|
parameters: tool.parameters.clone(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_messages(messages: &[ChatMessage]) -> Vec<NativeMessage> {
|
||||||
|
messages
|
||||||
|
.iter()
|
||||||
|
.map(|m| {
|
||||||
|
if m.role == "assistant" {
|
||||||
|
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content) {
|
||||||
|
if let Some(tool_calls_value) = value.get("tool_calls") {
|
||||||
|
if let Ok(parsed_calls) =
|
||||||
|
serde_json::from_value::<Vec<ProviderToolCall>>(
|
||||||
|
tool_calls_value.clone(),
|
||||||
|
)
|
||||||
|
{
|
||||||
|
let tool_calls = parsed_calls
|
||||||
|
.into_iter()
|
||||||
|
.map(|tc| NativeToolCall {
|
||||||
|
id: Some(tc.id),
|
||||||
|
kind: Some("function".to_string()),
|
||||||
|
function: NativeFunctionCall {
|
||||||
|
name: tc.name,
|
||||||
|
arguments: tc.arguments,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let content = value
|
||||||
|
.get("content")
|
||||||
|
.and_then(serde_json::Value::as_str)
|
||||||
|
.map(ToString::to_string);
|
||||||
|
return NativeMessage {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content,
|
||||||
|
tool_call_id: None,
|
||||||
|
tool_calls: Some(tool_calls),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.role == "tool" {
|
||||||
|
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content) {
|
||||||
|
let tool_call_id = value
|
||||||
|
.get("tool_call_id")
|
||||||
|
.and_then(serde_json::Value::as_str)
|
||||||
|
.map(ToString::to_string);
|
||||||
|
let content = value
|
||||||
|
.get("content")
|
||||||
|
.and_then(serde_json::Value::as_str)
|
||||||
|
.map(ToString::to_string);
|
||||||
|
return NativeMessage {
|
||||||
|
role: "tool".to_string(),
|
||||||
|
content,
|
||||||
|
tool_call_id,
|
||||||
|
tool_calls: None,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
NativeMessage {
|
||||||
|
role: m.role.clone(),
|
||||||
|
content: Some(m.content.clone()),
|
||||||
|
tool_call_id: None,
|
||||||
|
tool_calls: None,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_native_response(message: NativeResponseMessage) -> ProviderChatResponse {
|
||||||
|
let tool_calls = message
|
||||||
|
.tool_calls
|
||||||
|
.unwrap_or_default()
|
||||||
|
.into_iter()
|
||||||
|
.map(|tc| ProviderToolCall {
|
||||||
|
id: tc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
||||||
|
name: tc.function.name,
|
||||||
|
arguments: tc.function.arguments,
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
ProviderChatResponse {
|
||||||
|
text: message.content,
|
||||||
|
tool_calls,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
|
|
@ -57,7 +231,7 @@ impl Provider for OpenAiProvider {
|
||||||
message: &str,
|
message: &str,
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<ChatResponse> {
|
) -> anyhow::Result<String> {
|
||||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||||
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
|
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
|
||||||
})?;
|
})?;
|
||||||
|
|
@ -94,15 +268,60 @@ impl Provider for OpenAiProvider {
|
||||||
return Err(super::api_error("OpenAI", response).await);
|
return Err(super::api_error("OpenAI", response).await);
|
||||||
}
|
}
|
||||||
|
|
||||||
let chat_response: ApiChatResponse = response.json().await?;
|
let chat_response: ChatResponse = response.json().await?;
|
||||||
|
|
||||||
chat_response
|
chat_response
|
||||||
.choices
|
.choices
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.next()
|
.next()
|
||||||
.map(|c| ChatResponse::with_text(c.message.content))
|
.map(|c| c.message.content)
|
||||||
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))
|
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn chat(
|
||||||
|
&self,
|
||||||
|
request: ProviderChatRequest<'_>,
|
||||||
|
model: &str,
|
||||||
|
temperature: f64,
|
||||||
|
) -> anyhow::Result<ProviderChatResponse> {
|
||||||
|
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||||
|
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let tools = Self::convert_tools(request.tools);
|
||||||
|
let native_request = NativeChatRequest {
|
||||||
|
model: model.to_string(),
|
||||||
|
messages: Self::convert_messages(request.messages),
|
||||||
|
temperature,
|
||||||
|
tool_choice: tools.as_ref().map(|_| "auto".to_string()),
|
||||||
|
tools,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.post("https://api.openai.com/v1/chat/completions")
|
||||||
|
.header("Authorization", format!("Bearer {api_key}"))
|
||||||
|
.json(&native_request)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
return Err(super::api_error("OpenAI", response).await);
|
||||||
|
}
|
||||||
|
|
||||||
|
let native_response: NativeChatResponse = response.json().await?;
|
||||||
|
let message = native_response
|
||||||
|
.choices
|
||||||
|
.into_iter()
|
||||||
|
.next()
|
||||||
|
.map(|c| c.message)
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?;
|
||||||
|
Ok(Self::parse_native_response(message))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn supports_native_tools(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
@ -184,7 +403,7 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn response_deserializes_single_choice() {
|
fn response_deserializes_single_choice() {
|
||||||
let json = r#"{"choices":[{"message":{"content":"Hi!"}}]}"#;
|
let json = r#"{"choices":[{"message":{"content":"Hi!"}}]}"#;
|
||||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||||
assert_eq!(resp.choices.len(), 1);
|
assert_eq!(resp.choices.len(), 1);
|
||||||
assert_eq!(resp.choices[0].message.content, "Hi!");
|
assert_eq!(resp.choices[0].message.content, "Hi!");
|
||||||
}
|
}
|
||||||
|
|
@ -192,14 +411,14 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn response_deserializes_empty_choices() {
|
fn response_deserializes_empty_choices() {
|
||||||
let json = r#"{"choices":[]}"#;
|
let json = r#"{"choices":[]}"#;
|
||||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||||
assert!(resp.choices.is_empty());
|
assert!(resp.choices.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn response_deserializes_multiple_choices() {
|
fn response_deserializes_multiple_choices() {
|
||||||
let json = r#"{"choices":[{"message":{"content":"A"}},{"message":{"content":"B"}}]}"#;
|
let json = r#"{"choices":[{"message":{"content":"A"}},{"message":{"content":"B"}}]}"#;
|
||||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||||
assert_eq!(resp.choices.len(), 2);
|
assert_eq!(resp.choices.len(), 2);
|
||||||
assert_eq!(resp.choices[0].message.content, "A");
|
assert_eq!(resp.choices[0].message.content, "A");
|
||||||
}
|
}
|
||||||
|
|
@ -207,7 +426,7 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn response_with_unicode() {
|
fn response_with_unicode() {
|
||||||
let json = r#"{"choices":[{"message":{"content":"こんにちは 🦀"}}]}"#;
|
let json = r#"{"choices":[{"message":{"content":"こんにちは 🦀"}}]}"#;
|
||||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||||
assert_eq!(resp.choices[0].message.content, "こんにちは 🦀");
|
assert_eq!(resp.choices[0].message.content, "こんにちは 🦀");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -215,7 +434,7 @@ mod tests {
|
||||||
fn response_with_long_content() {
|
fn response_with_long_content() {
|
||||||
let long = "x".repeat(100_000);
|
let long = "x".repeat(100_000);
|
||||||
let json = format!(r#"{{"choices":[{{"message":{{"content":"{long}"}}}}]}}"#);
|
let json = format!(r#"{{"choices":[{{"message":{{"content":"{long}"}}}}]}}"#);
|
||||||
let resp: ApiChatResponse = serde_json::from_str(&json).unwrap();
|
let resp: ChatResponse = serde_json::from_str(&json).unwrap();
|
||||||
assert_eq!(resp.choices[0].message.content.len(), 100_000);
|
assert_eq!(resp.choices[0].message.content.len(), 100_000);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,8 @@
|
||||||
use crate::providers::traits::{ChatMessage, ChatResponse, Provider};
|
use crate::providers::traits::{
|
||||||
|
ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
|
||||||
|
Provider, ToolCall as ProviderToolCall,
|
||||||
|
};
|
||||||
|
use crate::tools::ToolSpec;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
@ -36,6 +40,75 @@ struct ResponseMessage {
|
||||||
content: String,
|
content: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct NativeChatRequest {
|
||||||
|
model: String,
|
||||||
|
messages: Vec<NativeMessage>,
|
||||||
|
temperature: f64,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tools: Option<Vec<NativeToolSpec>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tool_choice: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct NativeMessage {
|
||||||
|
role: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
content: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tool_call_id: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tool_calls: Option<Vec<NativeToolCall>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct NativeToolSpec {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
kind: String,
|
||||||
|
function: NativeToolFunctionSpec,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct NativeToolFunctionSpec {
|
||||||
|
name: String,
|
||||||
|
description: String,
|
||||||
|
parameters: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct NativeToolCall {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
id: Option<String>,
|
||||||
|
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
|
||||||
|
kind: Option<String>,
|
||||||
|
function: NativeFunctionCall,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct NativeFunctionCall {
|
||||||
|
name: String,
|
||||||
|
arguments: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct NativeChatResponse {
|
||||||
|
choices: Vec<NativeChoice>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct NativeChoice {
|
||||||
|
message: NativeResponseMessage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct NativeResponseMessage {
|
||||||
|
#[serde(default)]
|
||||||
|
content: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
tool_calls: Option<Vec<NativeToolCall>>,
|
||||||
|
}
|
||||||
|
|
||||||
impl OpenRouterProvider {
|
impl OpenRouterProvider {
|
||||||
pub fn new(api_key: Option<&str>) -> Self {
|
pub fn new(api_key: Option<&str>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
|
@ -47,6 +120,111 @@ impl OpenRouterProvider {
|
||||||
.unwrap_or_else(|_| Client::new()),
|
.unwrap_or_else(|_| Client::new()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec>> {
|
||||||
|
let items = tools?;
|
||||||
|
if items.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
Some(
|
||||||
|
items
|
||||||
|
.iter()
|
||||||
|
.map(|tool| NativeToolSpec {
|
||||||
|
kind: "function".to_string(),
|
||||||
|
function: NativeToolFunctionSpec {
|
||||||
|
name: tool.name.clone(),
|
||||||
|
description: tool.description.clone(),
|
||||||
|
parameters: tool.parameters.clone(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_messages(messages: &[ChatMessage]) -> Vec<NativeMessage> {
|
||||||
|
messages
|
||||||
|
.iter()
|
||||||
|
.map(|m| {
|
||||||
|
if m.role == "assistant" {
|
||||||
|
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content) {
|
||||||
|
if let Some(tool_calls_value) = value.get("tool_calls") {
|
||||||
|
if let Ok(parsed_calls) =
|
||||||
|
serde_json::from_value::<Vec<ProviderToolCall>>(
|
||||||
|
tool_calls_value.clone(),
|
||||||
|
)
|
||||||
|
{
|
||||||
|
let tool_calls = parsed_calls
|
||||||
|
.into_iter()
|
||||||
|
.map(|tc| NativeToolCall {
|
||||||
|
id: Some(tc.id),
|
||||||
|
kind: Some("function".to_string()),
|
||||||
|
function: NativeFunctionCall {
|
||||||
|
name: tc.name,
|
||||||
|
arguments: tc.arguments,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let content = value
|
||||||
|
.get("content")
|
||||||
|
.and_then(serde_json::Value::as_str)
|
||||||
|
.map(ToString::to_string);
|
||||||
|
return NativeMessage {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content,
|
||||||
|
tool_call_id: None,
|
||||||
|
tool_calls: Some(tool_calls),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.role == "tool" {
|
||||||
|
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content) {
|
||||||
|
let tool_call_id = value
|
||||||
|
.get("tool_call_id")
|
||||||
|
.and_then(serde_json::Value::as_str)
|
||||||
|
.map(ToString::to_string);
|
||||||
|
let content = value
|
||||||
|
.get("content")
|
||||||
|
.and_then(serde_json::Value::as_str)
|
||||||
|
.map(ToString::to_string);
|
||||||
|
return NativeMessage {
|
||||||
|
role: "tool".to_string(),
|
||||||
|
content,
|
||||||
|
tool_call_id,
|
||||||
|
tool_calls: None,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
NativeMessage {
|
||||||
|
role: m.role.clone(),
|
||||||
|
content: Some(m.content.clone()),
|
||||||
|
tool_call_id: None,
|
||||||
|
tool_calls: None,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_native_response(message: NativeResponseMessage) -> ProviderChatResponse {
|
||||||
|
let tool_calls = message
|
||||||
|
.tool_calls
|
||||||
|
.unwrap_or_default()
|
||||||
|
.into_iter()
|
||||||
|
.map(|tc| ProviderToolCall {
|
||||||
|
id: tc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
||||||
|
name: tc.function.name,
|
||||||
|
arguments: tc.function.arguments,
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
ProviderChatResponse {
|
||||||
|
text: message.content,
|
||||||
|
tool_calls,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
|
|
@ -71,7 +249,7 @@ impl Provider for OpenRouterProvider {
|
||||||
message: &str,
|
message: &str,
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<ChatResponse> {
|
) -> anyhow::Result<String> {
|
||||||
let api_key = self.api_key.as_ref()
|
let api_key = self.api_key.as_ref()
|
||||||
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
|
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
|
||||||
|
|
||||||
|
|
@ -118,7 +296,7 @@ impl Provider for OpenRouterProvider {
|
||||||
.choices
|
.choices
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.next()
|
.next()
|
||||||
.map(|c| ChatResponse::with_text(c.message.content))
|
.map(|c| c.message.content)
|
||||||
.ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))
|
.ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -127,7 +305,7 @@ impl Provider for OpenRouterProvider {
|
||||||
messages: &[ChatMessage],
|
messages: &[ChatMessage],
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<ChatResponse> {
|
) -> anyhow::Result<String> {
|
||||||
let api_key = self.api_key.as_ref()
|
let api_key = self.api_key.as_ref()
|
||||||
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
|
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
|
||||||
|
|
||||||
|
|
@ -168,9 +346,59 @@ impl Provider for OpenRouterProvider {
|
||||||
.choices
|
.choices
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.next()
|
.next()
|
||||||
.map(|c| ChatResponse::with_text(c.message.content))
|
.map(|c| c.message.content)
|
||||||
.ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))
|
.ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn chat(
|
||||||
|
&self,
|
||||||
|
request: ProviderChatRequest<'_>,
|
||||||
|
model: &str,
|
||||||
|
temperature: f64,
|
||||||
|
) -> anyhow::Result<ProviderChatResponse> {
|
||||||
|
let api_key = self.api_key.as_ref().ok_or_else(|| anyhow::anyhow!(
|
||||||
|
"OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."
|
||||||
|
))?;
|
||||||
|
|
||||||
|
let tools = Self::convert_tools(request.tools);
|
||||||
|
let native_request = NativeChatRequest {
|
||||||
|
model: model.to_string(),
|
||||||
|
messages: Self::convert_messages(request.messages),
|
||||||
|
temperature,
|
||||||
|
tool_choice: tools.as_ref().map(|_| "auto".to_string()),
|
||||||
|
tools,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.post("https://openrouter.ai/api/v1/chat/completions")
|
||||||
|
.header("Authorization", format!("Bearer {api_key}"))
|
||||||
|
.header(
|
||||||
|
"HTTP-Referer",
|
||||||
|
"https://github.com/theonlyhennygod/zeroclaw",
|
||||||
|
)
|
||||||
|
.header("X-Title", "ZeroClaw")
|
||||||
|
.json(&native_request)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
return Err(super::api_error("OpenRouter", response).await);
|
||||||
|
}
|
||||||
|
|
||||||
|
let native_response: NativeChatResponse = response.json().await?;
|
||||||
|
let message = native_response
|
||||||
|
.choices
|
||||||
|
.into_iter()
|
||||||
|
.next()
|
||||||
|
.map(|c| c.message)
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))?;
|
||||||
|
Ok(Self::parse_native_response(message))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn supports_native_tools(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
use super::traits::{ChatMessage, ChatResponse};
|
use super::traits::ChatMessage;
|
||||||
use super::Provider;
|
use super::Provider;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
@ -156,7 +156,7 @@ impl Provider for ReliableProvider {
|
||||||
message: &str,
|
message: &str,
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<ChatResponse> {
|
) -> anyhow::Result<String> {
|
||||||
let models = self.model_chain(model);
|
let models = self.model_chain(model);
|
||||||
let mut failures = Vec::new();
|
let mut failures = Vec::new();
|
||||||
|
|
||||||
|
|
@ -254,7 +254,7 @@ impl Provider for ReliableProvider {
|
||||||
messages: &[ChatMessage],
|
messages: &[ChatMessage],
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<ChatResponse> {
|
) -> anyhow::Result<String> {
|
||||||
let models = self.model_chain(model);
|
let models = self.model_chain(model);
|
||||||
let mut failures = Vec::new();
|
let mut failures = Vec::new();
|
||||||
|
|
||||||
|
|
@ -359,12 +359,12 @@ mod tests {
|
||||||
_message: &str,
|
_message: &str,
|
||||||
_model: &str,
|
_model: &str,
|
||||||
_temperature: f64,
|
_temperature: f64,
|
||||||
) -> anyhow::Result<ChatResponse> {
|
) -> anyhow::Result<String> {
|
||||||
let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
|
let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
|
||||||
if attempt <= self.fail_until_attempt {
|
if attempt <= self.fail_until_attempt {
|
||||||
anyhow::bail!(self.error);
|
anyhow::bail!(self.error);
|
||||||
}
|
}
|
||||||
Ok(ChatResponse::with_text(self.response))
|
Ok(self.response.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn chat_with_history(
|
async fn chat_with_history(
|
||||||
|
|
@ -372,12 +372,12 @@ mod tests {
|
||||||
_messages: &[ChatMessage],
|
_messages: &[ChatMessage],
|
||||||
_model: &str,
|
_model: &str,
|
||||||
_temperature: f64,
|
_temperature: f64,
|
||||||
) -> anyhow::Result<ChatResponse> {
|
) -> anyhow::Result<String> {
|
||||||
let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
|
let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
|
||||||
if attempt <= self.fail_until_attempt {
|
if attempt <= self.fail_until_attempt {
|
||||||
anyhow::bail!(self.error);
|
anyhow::bail!(self.error);
|
||||||
}
|
}
|
||||||
Ok(ChatResponse::with_text(self.response))
|
Ok(self.response.to_string())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -397,13 +397,13 @@ mod tests {
|
||||||
_message: &str,
|
_message: &str,
|
||||||
model: &str,
|
model: &str,
|
||||||
_temperature: f64,
|
_temperature: f64,
|
||||||
) -> anyhow::Result<ChatResponse> {
|
) -> anyhow::Result<String> {
|
||||||
self.calls.fetch_add(1, Ordering::SeqCst);
|
self.calls.fetch_add(1, Ordering::SeqCst);
|
||||||
self.models_seen.lock().unwrap().push(model.to_string());
|
self.models_seen.lock().unwrap().push(model.to_string());
|
||||||
if self.fail_models.contains(&model) {
|
if self.fail_models.contains(&model) {
|
||||||
anyhow::bail!("500 model {} unavailable", model);
|
anyhow::bail!("500 model {} unavailable", model);
|
||||||
}
|
}
|
||||||
Ok(ChatResponse::with_text(self.response))
|
Ok(self.response.to_string())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -426,8 +426,8 @@ mod tests {
|
||||||
1,
|
1,
|
||||||
);
|
);
|
||||||
|
|
||||||
let result = provider.chat("hello", "test", 0.0).await.unwrap();
|
let result = provider.simple_chat("hello", "test", 0.0).await.unwrap();
|
||||||
assert_eq!(result.text_or_empty(), "ok");
|
assert_eq!(result, "ok");
|
||||||
assert_eq!(calls.load(Ordering::SeqCst), 1);
|
assert_eq!(calls.load(Ordering::SeqCst), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -448,8 +448,8 @@ mod tests {
|
||||||
1,
|
1,
|
||||||
);
|
);
|
||||||
|
|
||||||
let result = provider.chat("hello", "test", 0.0).await.unwrap();
|
let result = provider.simple_chat("hello", "test", 0.0).await.unwrap();
|
||||||
assert_eq!(result.text_or_empty(), "recovered");
|
assert_eq!(result, "recovered");
|
||||||
assert_eq!(calls.load(Ordering::SeqCst), 2);
|
assert_eq!(calls.load(Ordering::SeqCst), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -483,8 +483,8 @@ mod tests {
|
||||||
1,
|
1,
|
||||||
);
|
);
|
||||||
|
|
||||||
let result = provider.chat("hello", "test", 0.0).await.unwrap();
|
let result = provider.simple_chat("hello", "test", 0.0).await.unwrap();
|
||||||
assert_eq!(result.text_or_empty(), "from fallback");
|
assert_eq!(result, "from fallback");
|
||||||
assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
|
assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
|
||||||
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
|
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
|
||||||
}
|
}
|
||||||
|
|
@ -517,7 +517,7 @@ mod tests {
|
||||||
);
|
);
|
||||||
|
|
||||||
let err = provider
|
let err = provider
|
||||||
.chat("hello", "test", 0.0)
|
.simple_chat("hello", "test", 0.0)
|
||||||
.await
|
.await
|
||||||
.expect_err("all providers should fail");
|
.expect_err("all providers should fail");
|
||||||
let msg = err.to_string();
|
let msg = err.to_string();
|
||||||
|
|
@ -572,8 +572,8 @@ mod tests {
|
||||||
1,
|
1,
|
||||||
);
|
);
|
||||||
|
|
||||||
let result = provider.chat("hello", "test", 0.0).await.unwrap();
|
let result = provider.simple_chat("hello", "test", 0.0).await.unwrap();
|
||||||
assert_eq!(result.text_or_empty(), "from fallback");
|
assert_eq!(result, "from fallback");
|
||||||
// Primary should have been called only once (no retries)
|
// Primary should have been called only once (no retries)
|
||||||
assert_eq!(primary_calls.load(Ordering::SeqCst), 1);
|
assert_eq!(primary_calls.load(Ordering::SeqCst), 1);
|
||||||
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
|
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
|
||||||
|
|
@ -601,7 +601,7 @@ mod tests {
|
||||||
.chat_with_history(&messages, "test", 0.0)
|
.chat_with_history(&messages, "test", 0.0)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(result.text_or_empty(), "history ok");
|
assert_eq!(result, "history ok");
|
||||||
assert_eq!(calls.load(Ordering::SeqCst), 2);
|
assert_eq!(calls.load(Ordering::SeqCst), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -640,7 +640,7 @@ mod tests {
|
||||||
.chat_with_history(&messages, "test", 0.0)
|
.chat_with_history(&messages, "test", 0.0)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(result.text_or_empty(), "fallback ok");
|
assert_eq!(result, "fallback ok");
|
||||||
assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
|
assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
|
||||||
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
|
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
|
||||||
}
|
}
|
||||||
|
|
@ -827,7 +827,7 @@ mod tests {
|
||||||
message: &str,
|
message: &str,
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<ChatResponse> {
|
) -> anyhow::Result<String> {
|
||||||
self.as_ref()
|
self.as_ref()
|
||||||
.chat_with_system(system_prompt, message, model, temperature)
|
.chat_with_system(system_prompt, message, model, temperature)
|
||||||
.await
|
.await
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
use super::traits::{ChatMessage, ChatResponse};
|
use super::traits::{ChatMessage, ChatRequest, ChatResponse};
|
||||||
use super::Provider;
|
use super::Provider;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
@ -98,7 +98,7 @@ impl Provider for RouterProvider {
|
||||||
message: &str,
|
message: &str,
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<ChatResponse> {
|
) -> anyhow::Result<String> {
|
||||||
let (provider_idx, resolved_model) = self.resolve(model);
|
let (provider_idx, resolved_model) = self.resolve(model);
|
||||||
|
|
||||||
let (provider_name, provider) = &self.providers[provider_idx];
|
let (provider_name, provider) = &self.providers[provider_idx];
|
||||||
|
|
@ -118,7 +118,7 @@ impl Provider for RouterProvider {
|
||||||
messages: &[ChatMessage],
|
messages: &[ChatMessage],
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<ChatResponse> {
|
) -> anyhow::Result<String> {
|
||||||
let (provider_idx, resolved_model) = self.resolve(model);
|
let (provider_idx, resolved_model) = self.resolve(model);
|
||||||
let (_, provider) = &self.providers[provider_idx];
|
let (_, provider) = &self.providers[provider_idx];
|
||||||
provider
|
provider
|
||||||
|
|
@ -126,6 +126,24 @@ impl Provider for RouterProvider {
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn chat(
|
||||||
|
&self,
|
||||||
|
request: ChatRequest<'_>,
|
||||||
|
model: &str,
|
||||||
|
temperature: f64,
|
||||||
|
) -> anyhow::Result<ChatResponse> {
|
||||||
|
let (provider_idx, resolved_model) = self.resolve(model);
|
||||||
|
let (_, provider) = &self.providers[provider_idx];
|
||||||
|
provider.chat(request, &resolved_model, temperature).await
|
||||||
|
}
|
||||||
|
|
||||||
|
fn supports_native_tools(&self) -> bool {
|
||||||
|
self.providers
|
||||||
|
.get(self.default_index)
|
||||||
|
.map(|(_, p)| p.supports_native_tools())
|
||||||
|
.unwrap_or(false)
|
||||||
|
}
|
||||||
|
|
||||||
async fn warmup(&self) -> anyhow::Result<()> {
|
async fn warmup(&self) -> anyhow::Result<()> {
|
||||||
for (name, provider) in &self.providers {
|
for (name, provider) in &self.providers {
|
||||||
tracing::info!(provider = name, "Warming up routed provider");
|
tracing::info!(provider = name, "Warming up routed provider");
|
||||||
|
|
@ -175,10 +193,10 @@ mod tests {
|
||||||
_message: &str,
|
_message: &str,
|
||||||
model: &str,
|
model: &str,
|
||||||
_temperature: f64,
|
_temperature: f64,
|
||||||
) -> anyhow::Result<ChatResponse> {
|
) -> anyhow::Result<String> {
|
||||||
self.calls.fetch_add(1, Ordering::SeqCst);
|
self.calls.fetch_add(1, Ordering::SeqCst);
|
||||||
*self.last_model.lock().unwrap() = model.to_string();
|
*self.last_model.lock().unwrap() = model.to_string();
|
||||||
Ok(ChatResponse::with_text(self.response))
|
Ok(self.response.to_string())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -229,7 +247,7 @@ mod tests {
|
||||||
message: &str,
|
message: &str,
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<ChatResponse> {
|
) -> anyhow::Result<String> {
|
||||||
self.as_ref()
|
self.as_ref()
|
||||||
.chat_with_system(system_prompt, message, model, temperature)
|
.chat_with_system(system_prompt, message, model, temperature)
|
||||||
.await
|
.await
|
||||||
|
|
@ -246,8 +264,11 @@ mod tests {
|
||||||
],
|
],
|
||||||
);
|
);
|
||||||
|
|
||||||
let result = router.chat("hello", "hint:reasoning", 0.5).await.unwrap();
|
let result = router
|
||||||
assert_eq!(result.text_or_empty(), "smart-response");
|
.simple_chat("hello", "hint:reasoning", 0.5)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(result, "smart-response");
|
||||||
assert_eq!(mocks[1].call_count(), 1);
|
assert_eq!(mocks[1].call_count(), 1);
|
||||||
assert_eq!(mocks[1].last_model(), "claude-opus");
|
assert_eq!(mocks[1].last_model(), "claude-opus");
|
||||||
assert_eq!(mocks[0].call_count(), 0);
|
assert_eq!(mocks[0].call_count(), 0);
|
||||||
|
|
@ -260,8 +281,8 @@ mod tests {
|
||||||
vec![("fast", "fast", "llama-3-70b")],
|
vec![("fast", "fast", "llama-3-70b")],
|
||||||
);
|
);
|
||||||
|
|
||||||
let result = router.chat("hello", "hint:fast", 0.5).await.unwrap();
|
let result = router.simple_chat("hello", "hint:fast", 0.5).await.unwrap();
|
||||||
assert_eq!(result.text_or_empty(), "fast-response");
|
assert_eq!(result, "fast-response");
|
||||||
assert_eq!(mocks[0].call_count(), 1);
|
assert_eq!(mocks[0].call_count(), 1);
|
||||||
assert_eq!(mocks[0].last_model(), "llama-3-70b");
|
assert_eq!(mocks[0].last_model(), "llama-3-70b");
|
||||||
}
|
}
|
||||||
|
|
@ -273,8 +294,11 @@ mod tests {
|
||||||
vec![],
|
vec![],
|
||||||
);
|
);
|
||||||
|
|
||||||
let result = router.chat("hello", "hint:nonexistent", 0.5).await.unwrap();
|
let result = router
|
||||||
assert_eq!(result.text_or_empty(), "default-response");
|
.simple_chat("hello", "hint:nonexistent", 0.5)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(result, "default-response");
|
||||||
assert_eq!(mocks[0].call_count(), 1);
|
assert_eq!(mocks[0].call_count(), 1);
|
||||||
// Falls back to default with the hint as model name
|
// Falls back to default with the hint as model name
|
||||||
assert_eq!(mocks[0].last_model(), "hint:nonexistent");
|
assert_eq!(mocks[0].last_model(), "hint:nonexistent");
|
||||||
|
|
@ -291,10 +315,10 @@ mod tests {
|
||||||
);
|
);
|
||||||
|
|
||||||
let result = router
|
let result = router
|
||||||
.chat("hello", "anthropic/claude-sonnet-4-20250514", 0.5)
|
.simple_chat("hello", "anthropic/claude-sonnet-4-20250514", 0.5)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(result.text_or_empty(), "primary-response");
|
assert_eq!(result, "primary-response");
|
||||||
assert_eq!(mocks[0].call_count(), 1);
|
assert_eq!(mocks[0].call_count(), 1);
|
||||||
assert_eq!(mocks[0].last_model(), "anthropic/claude-sonnet-4-20250514");
|
assert_eq!(mocks[0].last_model(), "anthropic/claude-sonnet-4-20250514");
|
||||||
}
|
}
|
||||||
|
|
@ -355,7 +379,7 @@ mod tests {
|
||||||
.chat_with_system(Some("system"), "hello", "model", 0.5)
|
.chat_with_system(Some("system"), "hello", "model", 0.5)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(result.text_or_empty(), "response");
|
assert_eq!(result, "response");
|
||||||
assert_eq!(mock.call_count(), 1);
|
assert_eq!(mock.call_count(), 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
use crate::tools::ToolSpec;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
|
@ -29,6 +30,13 @@ impl ChatMessage {
|
||||||
content: content.into(),
|
content: content.into(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn tool(content: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
role: "tool".into(),
|
||||||
|
content: content.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A tool call requested by the LLM.
|
/// A tool call requested by the LLM.
|
||||||
|
|
@ -49,14 +57,6 @@ pub struct ChatResponse {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ChatResponse {
|
impl ChatResponse {
|
||||||
/// Convenience: construct a plain text response with no tool calls.
|
|
||||||
pub fn with_text(text: impl Into<String>) -> Self {
|
|
||||||
Self {
|
|
||||||
text: Some(text.into()),
|
|
||||||
tool_calls: vec![],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// True when the LLM wants to invoke at least one tool.
|
/// True when the LLM wants to invoke at least one tool.
|
||||||
pub fn has_tool_calls(&self) -> bool {
|
pub fn has_tool_calls(&self) -> bool {
|
||||||
!self.tool_calls.is_empty()
|
!self.tool_calls.is_empty()
|
||||||
|
|
@ -68,6 +68,13 @@ impl ChatResponse {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Request payload for provider chat calls.
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct ChatRequest<'a> {
|
||||||
|
pub messages: &'a [ChatMessage],
|
||||||
|
pub tools: Option<&'a [ToolSpec]>,
|
||||||
|
}
|
||||||
|
|
||||||
/// A tool result to feed back to the LLM.
|
/// A tool result to feed back to the LLM.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct ToolResultMessage {
|
pub struct ToolResultMessage {
|
||||||
|
|
@ -77,7 +84,7 @@ pub struct ToolResultMessage {
|
||||||
|
|
||||||
/// A message in a multi-turn conversation, including tool interactions.
|
/// A message in a multi-turn conversation, including tool interactions.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type", content = "data")]
|
||||||
pub enum ConversationMessage {
|
pub enum ConversationMessage {
|
||||||
/// Regular chat message (system, user, assistant).
|
/// Regular chat message (system, user, assistant).
|
||||||
Chat(ChatMessage),
|
Chat(ChatMessage),
|
||||||
|
|
@ -86,29 +93,34 @@ pub enum ConversationMessage {
|
||||||
text: Option<String>,
|
text: Option<String>,
|
||||||
tool_calls: Vec<ToolCall>,
|
tool_calls: Vec<ToolCall>,
|
||||||
},
|
},
|
||||||
/// Result of a tool execution, fed back to the LLM.
|
/// Results of tool executions, fed back to the LLM.
|
||||||
ToolResult(ToolResultMessage),
|
ToolResults(Vec<ToolResultMessage>),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait Provider: Send + Sync {
|
pub trait Provider: Send + Sync {
|
||||||
async fn chat(
|
/// Simple one-shot chat (single user message, no explicit system prompt).
|
||||||
|
///
|
||||||
|
/// This is the preferred API for non-agentic direct interactions.
|
||||||
|
async fn simple_chat(
|
||||||
&self,
|
&self,
|
||||||
message: &str,
|
message: &str,
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<ChatResponse> {
|
) -> anyhow::Result<String> {
|
||||||
self.chat_with_system(None, message, model, temperature)
|
self.chat_with_system(None, message, model, temperature).await
|
||||||
.await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// One-shot chat with optional system prompt.
|
||||||
|
///
|
||||||
|
/// Kept for compatibility and advanced one-shot prompting.
|
||||||
async fn chat_with_system(
|
async fn chat_with_system(
|
||||||
&self,
|
&self,
|
||||||
system_prompt: Option<&str>,
|
system_prompt: Option<&str>,
|
||||||
message: &str,
|
message: &str,
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<ChatResponse>;
|
) -> anyhow::Result<String>;
|
||||||
|
|
||||||
/// Multi-turn conversation. Default implementation extracts the last user
|
/// Multi-turn conversation. Default implementation extracts the last user
|
||||||
/// message and delegates to `chat_with_system`.
|
/// message and delegates to `chat_with_system`.
|
||||||
|
|
@ -117,7 +129,7 @@ pub trait Provider: Send + Sync {
|
||||||
messages: &[ChatMessage],
|
messages: &[ChatMessage],
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<ChatResponse> {
|
) -> anyhow::Result<String> {
|
||||||
let system = messages
|
let system = messages
|
||||||
.iter()
|
.iter()
|
||||||
.find(|m| m.role == "system")
|
.find(|m| m.role == "system")
|
||||||
|
|
@ -131,6 +143,27 @@ pub trait Provider: Send + Sync {
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Structured chat API for agent loop callers.
|
||||||
|
async fn chat(
|
||||||
|
&self,
|
||||||
|
request: ChatRequest<'_>,
|
||||||
|
model: &str,
|
||||||
|
temperature: f64,
|
||||||
|
) -> anyhow::Result<ChatResponse> {
|
||||||
|
let text = self
|
||||||
|
.chat_with_history(request.messages, model, temperature)
|
||||||
|
.await?;
|
||||||
|
Ok(ChatResponse {
|
||||||
|
text: Some(text),
|
||||||
|
tool_calls: Vec::new(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Whether provider supports native tool calls over API.
|
||||||
|
fn supports_native_tools(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
/// Warm up the HTTP connection pool (TLS handshake, DNS, HTTP/2 setup).
|
/// Warm up the HTTP connection pool (TLS handshake, DNS, HTTP/2 setup).
|
||||||
/// Default implementation is a no-op; providers with HTTP clients should override.
|
/// Default implementation is a no-op; providers with HTTP clients should override.
|
||||||
async fn warmup(&self) -> anyhow::Result<()> {
|
async fn warmup(&self) -> anyhow::Result<()> {
|
||||||
|
|
@ -153,6 +186,9 @@ mod tests {
|
||||||
|
|
||||||
let asst = ChatMessage::assistant("Hi there");
|
let asst = ChatMessage::assistant("Hi there");
|
||||||
assert_eq!(asst.role, "assistant");
|
assert_eq!(asst.role, "assistant");
|
||||||
|
|
||||||
|
let tool = ChatMessage::tool("{}");
|
||||||
|
assert_eq!(tool.role, "tool");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -194,11 +230,11 @@ mod tests {
|
||||||
let json = serde_json::to_string(&chat).unwrap();
|
let json = serde_json::to_string(&chat).unwrap();
|
||||||
assert!(json.contains("\"type\":\"Chat\""));
|
assert!(json.contains("\"type\":\"Chat\""));
|
||||||
|
|
||||||
let tool_result = ConversationMessage::ToolResult(ToolResultMessage {
|
let tool_result = ConversationMessage::ToolResults(vec![ToolResultMessage {
|
||||||
tool_call_id: "1".into(),
|
tool_call_id: "1".into(),
|
||||||
content: "done".into(),
|
content: "done".into(),
|
||||||
});
|
}]);
|
||||||
let json = serde_json::to_string(&tool_result).unwrap();
|
let json = serde_json::to_string(&tool_result).unwrap();
|
||||||
assert!(json.contains("\"type\":\"ToolResult\""));
|
assert!(json.contains("\"type\":\"ToolResults\""));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -221,15 +221,10 @@ impl Tool for DelegateTool {
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok(response) => {
|
Ok(response) => {
|
||||||
let has_tool_calls = response.has_tool_calls();
|
let mut rendered = response;
|
||||||
let mut rendered = response.text.unwrap_or_default();
|
|
||||||
if rendered.trim().is_empty() {
|
if rendered.trim().is_empty() {
|
||||||
if has_tool_calls {
|
|
||||||
rendered = "[Tool-only response; no text content]".to_string();
|
|
||||||
} else {
|
|
||||||
rendered = "[Empty response]".to_string();
|
rendered = "[Empty response]".to_string();
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
Ok(ToolResult {
|
Ok(ToolResult {
|
||||||
success: true,
|
success: true,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue