fix(channels): execute tool calls in channel runtime (#302)
* fix(channels): execute tool calls in channel runtime (#302) * chore(fmt): align repo formatting with rustfmt 1.92
This commit is contained in:
parent
efabe9703f
commit
9d29f30a31
17 changed files with 483 additions and 127 deletions
|
|
@ -16,7 +16,12 @@ pub struct DiscordChannel {
|
|||
}
|
||||
|
||||
impl DiscordChannel {
|
||||
pub fn new(bot_token: String, guild_id: Option<String>, allowed_users: Vec<String>, listen_to_bots: bool) -> Self {
|
||||
pub fn new(
|
||||
bot_token: String,
|
||||
guild_id: Option<String>,
|
||||
allowed_users: Vec<String>,
|
||||
listen_to_bots: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
bot_token,
|
||||
guild_id,
|
||||
|
|
|
|||
|
|
@ -20,10 +20,15 @@ pub use telegram::TelegramChannel;
|
|||
pub use traits::Channel;
|
||||
pub use whatsapp::WhatsAppChannel;
|
||||
|
||||
use crate::agent::loop_::{agent_turn, build_tool_instructions};
|
||||
use crate::config::Config;
|
||||
use crate::identity;
|
||||
use crate::memory::{self, Memory};
|
||||
use crate::providers::{self, Provider};
|
||||
use crate::observability::{self, Observer};
|
||||
use crate::providers::{self, ChatMessage, Provider};
|
||||
use crate::runtime;
|
||||
use crate::security::SecurityPolicy;
|
||||
use crate::tools::{self, Tool};
|
||||
use crate::util::truncate_with_ellipsis;
|
||||
use anyhow::Result;
|
||||
use std::collections::HashMap;
|
||||
|
|
@ -46,6 +51,8 @@ struct ChannelRuntimeContext {
|
|||
channels_by_name: Arc<HashMap<String, Arc<dyn Channel>>>,
|
||||
provider: Arc<dyn Provider>,
|
||||
memory: Arc<dyn Memory>,
|
||||
tools_registry: Arc<Vec<Box<dyn Tool>>>,
|
||||
observer: Arc<dyn Observer>,
|
||||
system_prompt: Arc<String>,
|
||||
model: Arc<String>,
|
||||
temperature: f64,
|
||||
|
|
@ -166,11 +173,18 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
|||
println!(" ⏳ Processing message...");
|
||||
let started_at = Instant::now();
|
||||
|
||||
let mut history = vec![
|
||||
ChatMessage::system(ctx.system_prompt.as_str()),
|
||||
ChatMessage::user(&enriched_message),
|
||||
];
|
||||
|
||||
let llm_result = tokio::time::timeout(
|
||||
Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS),
|
||||
ctx.provider.chat_with_system(
|
||||
Some(ctx.system_prompt.as_str()),
|
||||
&enriched_message,
|
||||
agent_turn(
|
||||
ctx.provider.as_ref(),
|
||||
&mut history,
|
||||
ctx.tools_registry.as_ref(),
|
||||
ctx.observer.as_ref(),
|
||||
ctx.model.as_str(),
|
||||
ctx.temperature,
|
||||
),
|
||||
|
|
@ -323,7 +337,8 @@ pub fn build_system_prompt(
|
|||
prompt.push_str("```\n<invoke>\n{\"name\": \"tool_name\", \"arguments\": {\"param\": \"value\"}}\n</invoke>\n```\n\n");
|
||||
prompt.push_str("You may use multiple tool calls in a single response. ");
|
||||
prompt.push_str("After tool execution, results appear in <tool_result> tags. ");
|
||||
prompt.push_str("Continue reasoning with the results until you can give a final answer.\n\n");
|
||||
prompt
|
||||
.push_str("Continue reasoning with the results until you can give a final answer.\n\n");
|
||||
}
|
||||
|
||||
// ── 2. Safety ───────────────────────────────────────────────
|
||||
|
|
@ -674,6 +689,15 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
tracing::warn!("Provider warmup failed (non-fatal): {e}");
|
||||
}
|
||||
|
||||
let observer: Arc<dyn Observer> =
|
||||
Arc::from(observability::create_observer(&config.observability));
|
||||
let runtime: Arc<dyn runtime::RuntimeAdapter> =
|
||||
Arc::from(runtime::create_runtime(&config.runtime)?);
|
||||
let security = Arc::new(SecurityPolicy::from_config(
|
||||
&config.autonomy,
|
||||
&config.workspace_dir,
|
||||
));
|
||||
|
||||
let model = config
|
||||
.default_model
|
||||
.clone()
|
||||
|
|
@ -685,6 +709,22 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
config.api_key.as_deref(),
|
||||
)?);
|
||||
|
||||
let composio_key = if config.composio.enabled {
|
||||
config.composio.api_key.as_deref()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let tools_registry = Arc::new(tools::all_tools_with_runtime(
|
||||
&security,
|
||||
runtime,
|
||||
Arc::clone(&mem),
|
||||
composio_key,
|
||||
&config.browser,
|
||||
&config.http_request,
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
));
|
||||
|
||||
// Build system prompt from workspace identity files + skills
|
||||
let workspace = config.workspace_dir.clone();
|
||||
let skills = crate::skills::load_skills(&workspace);
|
||||
|
|
@ -723,14 +763,27 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
"Open approved HTTPS URLs in Brave Browser (allowlist-only, no scraping)",
|
||||
));
|
||||
}
|
||||
if config.composio.enabled {
|
||||
tool_descs.push((
|
||||
"composio",
|
||||
"Execute actions on 1000+ apps via Composio (Gmail, Notion, GitHub, Slack, etc.). Use action='list' to discover, 'execute' to run, 'connect' to OAuth.",
|
||||
));
|
||||
}
|
||||
if !config.agents.is_empty() {
|
||||
tool_descs.push((
|
||||
"delegate",
|
||||
"Delegate a subtask to a specialized agent. Use when: a task benefits from a different model (e.g. fast summarization, deep reasoning, code generation). The sub-agent runs a single prompt and returns its response.",
|
||||
));
|
||||
}
|
||||
|
||||
let system_prompt = build_system_prompt(
|
||||
let mut system_prompt = build_system_prompt(
|
||||
&workspace,
|
||||
&model,
|
||||
&tool_descs,
|
||||
&skills,
|
||||
Some(&config.identity),
|
||||
);
|
||||
system_prompt.push_str(&build_tool_instructions(tools_registry.as_ref()));
|
||||
|
||||
if !skills.is_empty() {
|
||||
println!(
|
||||
|
|
@ -875,6 +928,8 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
channels_by_name,
|
||||
provider: Arc::clone(&provider),
|
||||
memory: Arc::clone(&mem),
|
||||
tools_registry: Arc::clone(&tools_registry),
|
||||
observer,
|
||||
system_prompt: Arc::new(system_prompt),
|
||||
model: Arc::new(model.clone()),
|
||||
temperature,
|
||||
|
|
@ -895,7 +950,9 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::{Memory, MemoryCategory, SqliteMemory};
|
||||
use crate::providers::Provider;
|
||||
use crate::observability::NoopObserver;
|
||||
use crate::providers::{ChatMessage, Provider};
|
||||
use crate::tools::{Tool, ToolResult};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
|
@ -967,6 +1024,131 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
struct ToolCallingProvider;
|
||||
|
||||
fn tool_call_payload() -> String {
|
||||
serde_json::json!({
|
||||
"content": "",
|
||||
"tool_calls": [{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "mock_price",
|
||||
"arguments": "{\"symbol\":\"BTC\"}"
|
||||
}
|
||||
}]
|
||||
})
|
||||
.to_string()
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Provider for ToolCallingProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
_system_prompt: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Ok(tool_call_payload())
|
||||
}
|
||||
|
||||
async fn chat_with_history(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let has_tool_results = messages
|
||||
.iter()
|
||||
.any(|msg| msg.role == "user" && msg.content.contains("[Tool results]"));
|
||||
if has_tool_results {
|
||||
Ok("BTC is currently around $65,000 based on latest tool output.".to_string())
|
||||
} else {
|
||||
Ok(tool_call_payload())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct MockPriceTool;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Tool for MockPriceTool {
|
||||
fn name(&self) -> &str {
|
||||
"mock_price"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Return a mocked BTC price"
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"symbol": { "type": "string" }
|
||||
},
|
||||
"required": ["symbol"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let symbol = args.get("symbol").and_then(serde_json::Value::as_str);
|
||||
if symbol != Some("BTC") {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("unexpected symbol".to_string()),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: r#"{"symbol":"BTC","price_usd":65000}"#.to_string(),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_executes_tool_calls_instead_of_sending_raw_json() {
|
||||
let channel_impl = Arc::new(RecordingChannel::default());
|
||||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: Arc::new(ToolCallingProvider),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![Box::new(MockPriceTool)]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||
model: Arc::new("test-model".to_string()),
|
||||
temperature: 0.0,
|
||||
auto_save_memory: false,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
runtime_ctx,
|
||||
traits::ChannelMessage {
|
||||
id: "msg-1".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
content: "What is the BTC price now?".to_string(),
|
||||
channel: "test-channel".to_string(),
|
||||
timestamp: 1,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
let sent_messages = channel_impl.sent_messages.lock().await;
|
||||
assert_eq!(sent_messages.len(), 1);
|
||||
assert!(sent_messages[0].contains("BTC is currently around"));
|
||||
assert!(!sent_messages[0].contains("\"tool_calls\""));
|
||||
assert!(!sent_messages[0].contains("mock_price"));
|
||||
}
|
||||
|
||||
struct NoopMemory;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
|
|
@ -1030,6 +1212,8 @@ mod tests {
|
|||
delay: Duration::from_millis(250),
|
||||
}),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||
model: Arc::new("test-model".to_string()),
|
||||
temperature: 0.0,
|
||||
|
|
@ -1269,7 +1453,10 @@ mod tests {
|
|||
|
||||
// Reproduces the production crash path where channel logs truncate at 80 chars.
|
||||
let result = std::panic::catch_unwind(|| crate::util::truncate_with_ellipsis(msg, 80));
|
||||
assert!(result.is_ok(), "truncate_with_ellipsis should never panic on UTF-8");
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"truncate_with_ellipsis should never panic on UTF-8"
|
||||
);
|
||||
|
||||
let truncated = result.unwrap();
|
||||
assert!(!truncated.is_empty());
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue