fix(channels): execute tool calls in channel runtime (#302)

* fix(channels): execute tool calls in channel runtime (#302)

* chore(fmt): align repo formatting with rustfmt 1.92
This commit is contained in:
Chummy 2026-02-16 18:07:01 +08:00 committed by GitHub
parent efabe9703f
commit 9d29f30a31
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 483 additions and 127 deletions

View file

@ -339,7 +339,7 @@ struct ParsedToolCall {
/// Execute a single turn of the agent loop: send messages, parse tool calls,
/// execute tools, and loop until the LLM produces a final text response.
async fn agent_turn(
pub(crate) async fn agent_turn(
provider: &dyn Provider,
history: &mut Vec<ChatMessage>,
tools_registry: &[Box<dyn Tool>],
@ -414,7 +414,7 @@ async fn agent_turn(
/// Build the tool instruction block for the system prompt so the LLM knows
/// how to invoke tools.
fn build_tool_instructions(tools_registry: &[Box<dyn Tool>]) -> String {
pub(crate) fn build_tool_instructions(tools_registry: &[Box<dyn Tool>]) -> String {
let mut instructions = String::new();
instructions.push_str("\n## Tool Use Protocol\n\n");
instructions.push_str("To use a tool, wrap a JSON object in <tool_call></tool_call> tags:\n\n");

View file

@ -16,7 +16,12 @@ pub struct DiscordChannel {
}
impl DiscordChannel {
pub fn new(bot_token: String, guild_id: Option<String>, allowed_users: Vec<String>, listen_to_bots: bool) -> Self {
pub fn new(
bot_token: String,
guild_id: Option<String>,
allowed_users: Vec<String>,
listen_to_bots: bool,
) -> Self {
Self {
bot_token,
guild_id,

View file

@ -20,10 +20,15 @@ pub use telegram::TelegramChannel;
pub use traits::Channel;
pub use whatsapp::WhatsAppChannel;
use crate::agent::loop_::{agent_turn, build_tool_instructions};
use crate::config::Config;
use crate::identity;
use crate::memory::{self, Memory};
use crate::providers::{self, Provider};
use crate::observability::{self, Observer};
use crate::providers::{self, ChatMessage, Provider};
use crate::runtime;
use crate::security::SecurityPolicy;
use crate::tools::{self, Tool};
use crate::util::truncate_with_ellipsis;
use anyhow::Result;
use std::collections::HashMap;
@ -46,6 +51,8 @@ struct ChannelRuntimeContext {
channels_by_name: Arc<HashMap<String, Arc<dyn Channel>>>,
provider: Arc<dyn Provider>,
memory: Arc<dyn Memory>,
tools_registry: Arc<Vec<Box<dyn Tool>>>,
observer: Arc<dyn Observer>,
system_prompt: Arc<String>,
model: Arc<String>,
temperature: f64,
@ -166,11 +173,18 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
println!(" ⏳ Processing message...");
let started_at = Instant::now();
let mut history = vec![
ChatMessage::system(ctx.system_prompt.as_str()),
ChatMessage::user(&enriched_message),
];
let llm_result = tokio::time::timeout(
Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS),
ctx.provider.chat_with_system(
Some(ctx.system_prompt.as_str()),
&enriched_message,
agent_turn(
ctx.provider.as_ref(),
&mut history,
ctx.tools_registry.as_ref(),
ctx.observer.as_ref(),
ctx.model.as_str(),
ctx.temperature,
),
@ -323,7 +337,8 @@ pub fn build_system_prompt(
prompt.push_str("```\n<invoke>\n{\"name\": \"tool_name\", \"arguments\": {\"param\": \"value\"}}\n</invoke>\n```\n\n");
prompt.push_str("You may use multiple tool calls in a single response. ");
prompt.push_str("After tool execution, results appear in <tool_result> tags. ");
prompt.push_str("Continue reasoning with the results until you can give a final answer.\n\n");
prompt
.push_str("Continue reasoning with the results until you can give a final answer.\n\n");
}
// ── 2. Safety ───────────────────────────────────────────────
@ -674,6 +689,15 @@ pub async fn start_channels(config: Config) -> Result<()> {
tracing::warn!("Provider warmup failed (non-fatal): {e}");
}
let observer: Arc<dyn Observer> =
Arc::from(observability::create_observer(&config.observability));
let runtime: Arc<dyn runtime::RuntimeAdapter> =
Arc::from(runtime::create_runtime(&config.runtime)?);
let security = Arc::new(SecurityPolicy::from_config(
&config.autonomy,
&config.workspace_dir,
));
let model = config
.default_model
.clone()
@ -685,6 +709,22 @@ pub async fn start_channels(config: Config) -> Result<()> {
config.api_key.as_deref(),
)?);
let composio_key = if config.composio.enabled {
config.composio.api_key.as_deref()
} else {
None
};
let tools_registry = Arc::new(tools::all_tools_with_runtime(
&security,
runtime,
Arc::clone(&mem),
composio_key,
&config.browser,
&config.http_request,
&config.agents,
config.api_key.as_deref(),
));
// Build system prompt from workspace identity files + skills
let workspace = config.workspace_dir.clone();
let skills = crate::skills::load_skills(&workspace);
@ -723,14 +763,27 @@ pub async fn start_channels(config: Config) -> Result<()> {
"Open approved HTTPS URLs in Brave Browser (allowlist-only, no scraping)",
));
}
if config.composio.enabled {
tool_descs.push((
"composio",
"Execute actions on 1000+ apps via Composio (Gmail, Notion, GitHub, Slack, etc.). Use action='list' to discover, 'execute' to run, 'connect' to OAuth.",
));
}
if !config.agents.is_empty() {
tool_descs.push((
"delegate",
"Delegate a subtask to a specialized agent. Use when: a task benefits from a different model (e.g. fast summarization, deep reasoning, code generation). The sub-agent runs a single prompt and returns its response.",
));
}
let system_prompt = build_system_prompt(
let mut system_prompt = build_system_prompt(
&workspace,
&model,
&tool_descs,
&skills,
Some(&config.identity),
);
system_prompt.push_str(&build_tool_instructions(tools_registry.as_ref()));
if !skills.is_empty() {
println!(
@ -875,6 +928,8 @@ pub async fn start_channels(config: Config) -> Result<()> {
channels_by_name,
provider: Arc::clone(&provider),
memory: Arc::clone(&mem),
tools_registry: Arc::clone(&tools_registry),
observer,
system_prompt: Arc::new(system_prompt),
model: Arc::new(model.clone()),
temperature,
@ -895,7 +950,9 @@ pub async fn start_channels(config: Config) -> Result<()> {
mod tests {
use super::*;
use crate::memory::{Memory, MemoryCategory, SqliteMemory};
use crate::providers::Provider;
use crate::observability::NoopObserver;
use crate::providers::{ChatMessage, Provider};
use crate::tools::{Tool, ToolResult};
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
@ -967,6 +1024,131 @@ mod tests {
}
}
struct ToolCallingProvider;
fn tool_call_payload() -> String {
serde_json::json!({
"content": "",
"tool_calls": [{
"id": "call_1",
"type": "function",
"function": {
"name": "mock_price",
"arguments": "{\"symbol\":\"BTC\"}"
}
}]
})
.to_string()
}
#[async_trait::async_trait]
impl Provider for ToolCallingProvider {
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
Ok(tool_call_payload())
}
async fn chat_with_history(
&self,
messages: &[ChatMessage],
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
let has_tool_results = messages
.iter()
.any(|msg| msg.role == "user" && msg.content.contains("[Tool results]"));
if has_tool_results {
Ok("BTC is currently around $65,000 based on latest tool output.".to_string())
} else {
Ok(tool_call_payload())
}
}
}
struct MockPriceTool;
#[async_trait::async_trait]
impl Tool for MockPriceTool {
fn name(&self) -> &str {
"mock_price"
}
fn description(&self) -> &str {
"Return a mocked BTC price"
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"symbol": { "type": "string" }
},
"required": ["symbol"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let symbol = args.get("symbol").and_then(serde_json::Value::as_str);
if symbol != Some("BTC") {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("unexpected symbol".to_string()),
});
}
Ok(ToolResult {
success: true,
output: r#"{"symbol":"BTC","price_usd":65000}"#.to_string(),
error: None,
})
}
}
#[tokio::test]
async fn process_channel_message_executes_tool_calls_instead_of_sending_raw_json() {
let channel_impl = Arc::new(RecordingChannel::default());
let channel: Arc<dyn Channel> = channel_impl.clone();
let mut channels_by_name = HashMap::new();
channels_by_name.insert(channel.name().to_string(), channel);
let runtime_ctx = Arc::new(ChannelRuntimeContext {
channels_by_name: Arc::new(channels_by_name),
provider: Arc::new(ToolCallingProvider),
memory: Arc::new(NoopMemory),
tools_registry: Arc::new(vec![Box::new(MockPriceTool)]),
observer: Arc::new(NoopObserver),
system_prompt: Arc::new("test-system-prompt".to_string()),
model: Arc::new("test-model".to_string()),
temperature: 0.0,
auto_save_memory: false,
});
process_channel_message(
runtime_ctx,
traits::ChannelMessage {
id: "msg-1".to_string(),
sender: "alice".to_string(),
content: "What is the BTC price now?".to_string(),
channel: "test-channel".to_string(),
timestamp: 1,
},
)
.await;
let sent_messages = channel_impl.sent_messages.lock().await;
assert_eq!(sent_messages.len(), 1);
assert!(sent_messages[0].contains("BTC is currently around"));
assert!(!sent_messages[0].contains("\"tool_calls\""));
assert!(!sent_messages[0].contains("mock_price"));
}
struct NoopMemory;
#[async_trait::async_trait]
@ -1030,6 +1212,8 @@ mod tests {
delay: Duration::from_millis(250),
}),
memory: Arc::new(NoopMemory),
tools_registry: Arc::new(vec![]),
observer: Arc::new(NoopObserver),
system_prompt: Arc::new("test-system-prompt".to_string()),
model: Arc::new("test-model".to_string()),
temperature: 0.0,
@ -1269,7 +1453,10 @@ mod tests {
// Reproduces the production crash path where channel logs truncate at 80 chars.
let result = std::panic::catch_unwind(|| crate::util::truncate_with_ellipsis(msg, 80));
assert!(result.is_ok(), "truncate_with_ellipsis should never panic on UTF-8");
assert!(
result.is_ok(),
"truncate_with_ellipsis should never panic on UTF-8"
);
let truncated = result.unwrap();
assert!(!truncated.is_empty());

View file

@ -1,7 +1,7 @@
pub mod schema;
pub use schema::{
AutonomyConfig, AuditConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config,
AuditConfig, AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config,
DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, GatewayConfig, HeartbeatConfig,
HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, MemoryConfig,
ModelRouteConfig, ObservabilityConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig,

View file

@ -168,7 +168,10 @@ impl HardwareConfig {
bail!("hardware.baud_rate must be greater than 0.");
}
if self.baud_rate > 4_000_000 {
bail!("hardware.baud_rate of {} exceeds the 4 MHz safety limit.", self.baud_rate);
bail!(
"hardware.baud_rate of {} exceeds the 4 MHz safety limit.",
self.baud_rate
);
}
// PWM frequency sanity
@ -228,20 +231,16 @@ fn discover_native_gpio(devices: &mut Vec<DiscoveredDevice>) {
if gpiomem.exists() || gpiochip.exists() {
// Try to read model from device tree
let model = read_board_model();
let name = model
.as_deref()
.unwrap_or("Linux SBC with GPIO");
let name = model.as_deref().unwrap_or("Linux SBC with GPIO");
devices.push(DiscoveredDevice {
name: format!("{name} (Native GPIO)"),
transport: HardwareTransport::Native,
device_path: Some(
if gpiomem.exists() {
device_path: Some(if gpiomem.exists() {
"/dev/gpiomem".into()
} else {
"/dev/gpiochip0".into()
},
),
}),
detail: model,
});
}
@ -287,10 +286,7 @@ fn serial_device_paths() -> Vec<String> {
"/dev/tty.wchusbserial*".into(), // CH340 clones
]
} else if cfg!(target_os = "linux") {
vec![
"/dev/ttyUSB*".into(),
"/dev/ttyACM*".into(),
]
vec!["/dev/ttyUSB*".into(), "/dev/ttyACM*".into()]
} else {
// Windows / other — not yet supported for auto-discovery
vec![]
@ -452,10 +448,7 @@ pub fn create_hal(config: &HardwareConfig) -> Result<Box<dyn HardwareHal>> {
);
}
HardwareTransport::Probe => {
let target = config
.probe_target
.as_deref()
.unwrap_or("unknown");
let target = config.probe_target.as_deref().unwrap_or("unknown");
bail!(
"Probe transport targeting '{}' is configured but the probe-rs HAL \
backend is not yet compiled in. This will be available in a future release.",
@ -471,15 +464,24 @@ pub fn create_hal(config: &HardwareConfig) -> Result<Box<dyn HardwareHal>> {
/// based on discovery results.
pub fn recommended_wizard_default(devices: &[DiscoveredDevice]) -> usize {
// If we found native GPIO → recommend Native (index 0)
if devices.iter().any(|d| d.transport == HardwareTransport::Native) {
if devices
.iter()
.any(|d| d.transport == HardwareTransport::Native)
{
return 0;
}
// If we found serial devices → recommend Tethered (index 1)
if devices.iter().any(|d| d.transport == HardwareTransport::Serial) {
if devices
.iter()
.any(|d| d.transport == HardwareTransport::Serial)
{
return 1;
}
// If we found debug probes → recommend Probe (index 2)
if devices.iter().any(|d| d.transport == HardwareTransport::Probe) {
if devices
.iter()
.any(|d| d.transport == HardwareTransport::Probe)
{
return 2;
}
// Default: Software Only (index 3)
@ -487,10 +489,7 @@ pub fn recommended_wizard_default(devices: &[DiscoveredDevice]) -> usize {
}
/// Build a `HardwareConfig` from a wizard selection and discovered devices.
pub fn config_from_wizard_choice(
choice: usize,
devices: &[DiscoveredDevice],
) -> HardwareConfig {
pub fn config_from_wizard_choice(choice: usize, devices: &[DiscoveredDevice]) -> HardwareConfig {
match choice {
// Native
0 => {
@ -548,39 +547,102 @@ mod tests {
#[test]
fn transport_parse_native_variants() {
assert_eq!(HardwareTransport::from_str_loose("native"), HardwareTransport::Native);
assert_eq!(HardwareTransport::from_str_loose("gpio"), HardwareTransport::Native);
assert_eq!(HardwareTransport::from_str_loose("rppal"), HardwareTransport::Native);
assert_eq!(HardwareTransport::from_str_loose("sysfs"), HardwareTransport::Native);
assert_eq!(HardwareTransport::from_str_loose("NATIVE"), HardwareTransport::Native);
assert_eq!(HardwareTransport::from_str_loose(" Native "), HardwareTransport::Native);
assert_eq!(
HardwareTransport::from_str_loose("native"),
HardwareTransport::Native
);
assert_eq!(
HardwareTransport::from_str_loose("gpio"),
HardwareTransport::Native
);
assert_eq!(
HardwareTransport::from_str_loose("rppal"),
HardwareTransport::Native
);
assert_eq!(
HardwareTransport::from_str_loose("sysfs"),
HardwareTransport::Native
);
assert_eq!(
HardwareTransport::from_str_loose("NATIVE"),
HardwareTransport::Native
);
assert_eq!(
HardwareTransport::from_str_loose(" Native "),
HardwareTransport::Native
);
}
#[test]
fn transport_parse_serial_variants() {
assert_eq!(HardwareTransport::from_str_loose("serial"), HardwareTransport::Serial);
assert_eq!(HardwareTransport::from_str_loose("uart"), HardwareTransport::Serial);
assert_eq!(HardwareTransport::from_str_loose("usb"), HardwareTransport::Serial);
assert_eq!(HardwareTransport::from_str_loose("tethered"), HardwareTransport::Serial);
assert_eq!(HardwareTransport::from_str_loose("SERIAL"), HardwareTransport::Serial);
assert_eq!(
HardwareTransport::from_str_loose("serial"),
HardwareTransport::Serial
);
assert_eq!(
HardwareTransport::from_str_loose("uart"),
HardwareTransport::Serial
);
assert_eq!(
HardwareTransport::from_str_loose("usb"),
HardwareTransport::Serial
);
assert_eq!(
HardwareTransport::from_str_loose("tethered"),
HardwareTransport::Serial
);
assert_eq!(
HardwareTransport::from_str_loose("SERIAL"),
HardwareTransport::Serial
);
}
#[test]
fn transport_parse_probe_variants() {
assert_eq!(HardwareTransport::from_str_loose("probe"), HardwareTransport::Probe);
assert_eq!(HardwareTransport::from_str_loose("probe-rs"), HardwareTransport::Probe);
assert_eq!(HardwareTransport::from_str_loose("swd"), HardwareTransport::Probe);
assert_eq!(HardwareTransport::from_str_loose("jtag"), HardwareTransport::Probe);
assert_eq!(HardwareTransport::from_str_loose("jlink"), HardwareTransport::Probe);
assert_eq!(HardwareTransport::from_str_loose("j-link"), HardwareTransport::Probe);
assert_eq!(
HardwareTransport::from_str_loose("probe"),
HardwareTransport::Probe
);
assert_eq!(
HardwareTransport::from_str_loose("probe-rs"),
HardwareTransport::Probe
);
assert_eq!(
HardwareTransport::from_str_loose("swd"),
HardwareTransport::Probe
);
assert_eq!(
HardwareTransport::from_str_loose("jtag"),
HardwareTransport::Probe
);
assert_eq!(
HardwareTransport::from_str_loose("jlink"),
HardwareTransport::Probe
);
assert_eq!(
HardwareTransport::from_str_loose("j-link"),
HardwareTransport::Probe
);
}
#[test]
fn transport_parse_none_and_unknown() {
assert_eq!(HardwareTransport::from_str_loose("none"), HardwareTransport::None);
assert_eq!(HardwareTransport::from_str_loose(""), HardwareTransport::None);
assert_eq!(HardwareTransport::from_str_loose("foobar"), HardwareTransport::None);
assert_eq!(HardwareTransport::from_str_loose("bluetooth"), HardwareTransport::None);
assert_eq!(
HardwareTransport::from_str_loose("none"),
HardwareTransport::None
);
assert_eq!(
HardwareTransport::from_str_loose(""),
HardwareTransport::None
);
assert_eq!(
HardwareTransport::from_str_loose("foobar"),
HardwareTransport::None
);
assert_eq!(
HardwareTransport::from_str_loose("bluetooth"),
HardwareTransport::None
);
}
#[test]
@ -918,7 +980,9 @@ mod tests {
#[test]
fn noop_hal_firmware_upload_fails() {
let hal = NoopHal;
let err = hal.firmware_upload(Path::new("/tmp/firmware.bin")).unwrap_err();
let err = hal
.firmware_upload(Path::new("/tmp/firmware.bin"))
.unwrap_err();
assert!(err.to_string().contains("not enabled"));
assert!(err.to_string().contains("firmware.bin"));
}

View file

@ -1093,7 +1093,9 @@ fn setup_hardware() -> Result<HardwareConfig> {
}
// ── Probe: ask for target chip ──
if hw_config.transport_mode() == hardware::HardwareTransport::Probe && hw_config.probe_target.is_none() {
if hw_config.transport_mode() == hardware::HardwareTransport::Probe
&& hw_config.probe_target.is_none()
{
let target: String = Input::new()
.with_prompt(" Target MCU chip (e.g. STM32F411CEUx, nRF52840_xxAA)")
.default("STM32F411CEUx".into())
@ -2698,21 +2700,25 @@ fn print_summary(config: &Config) {
if config.hardware.enabled {
let mode = config.hardware.transport_mode();
match mode {
hardware::HardwareTransport::Native => style("Native GPIO (direct)").green().to_string(),
hardware::HardwareTransport::Native => {
style("Native GPIO (direct)").green().to_string()
}
hardware::HardwareTransport::Serial => format!(
"{}",
style(format!(
"Serial → {} @ {} baud",
config.hardware.serial_port.as_deref().unwrap_or("?"),
config.hardware.baud_rate
)).green()
))
.green()
),
hardware::HardwareTransport::Probe => format!(
"{}",
style(format!(
"Probe → {}",
config.hardware.probe_target.as_deref().unwrap_or("?")
)).green()
))
.green()
),
hardware::HardwareTransport::None => "disabled (software only)".to_string(),
}

View file

@ -88,7 +88,12 @@ impl AuditEvent {
}
/// Set the actor
pub fn with_actor(mut self, channel: String, user_id: Option<String>, username: Option<String>) -> Self {
pub fn with_actor(
mut self,
channel: String,
user_id: Option<String>,
username: Option<String>,
) -> Self {
self.actor = Some(Actor {
channel,
user_id,
@ -98,7 +103,13 @@ impl AuditEvent {
}
/// Set the action
pub fn with_action(mut self, command: String, risk_level: String, approved: bool, allowed: bool) -> Self {
pub fn with_action(
mut self,
command: String,
risk_level: String,
approved: bool,
allowed: bool,
) -> Self {
self.action = Some(Action {
command: Some(command),
risk_level: Some(risk_level),
@ -109,7 +120,13 @@ impl AuditEvent {
}
/// Set the result
pub fn with_result(mut self, success: bool, exit_code: Option<i32>, duration_ms: u64, error: Option<String>) -> Self {
pub fn with_result(
mut self,
success: bool,
exit_code: Option<i32>,
duration_ms: u64,
error: Option<String>,
) -> Self {
self.result = Some(ExecutionResult {
success,
exit_code,
@ -179,7 +196,12 @@ impl AuditLogger {
) -> Result<()> {
let event = AuditEvent::new(AuditEventType::CommandExecution)
.with_actor(channel.to_string(), None, None)
.with_action(command.to_string(), risk_level.to_string(), approved, allowed)
.with_action(
command.to_string(),
risk_level.to_string(),
approved,
allowed,
)
.with_result(success, None, duration_ms, None);
self.log(&event)
@ -224,8 +246,11 @@ mod tests {
#[test]
fn audit_event_with_actor() {
let event = AuditEvent::new(AuditEventType::CommandExecution)
.with_actor("telegram".to_string(), Some("123".to_string()), Some("@alice".to_string()));
let event = AuditEvent::new(AuditEventType::CommandExecution).with_actor(
"telegram".to_string(),
Some("123".to_string()),
Some("@alice".to_string()),
);
assert!(event.actor.is_some());
let actor = event.actor.as_ref().unwrap();
@ -236,8 +261,12 @@ mod tests {
#[test]
fn audit_event_with_action() {
let event = AuditEvent::new(AuditEventType::CommandExecution)
.with_action("ls -la".to_string(), "low".to_string(), false, true);
let event = AuditEvent::new(AuditEventType::CommandExecution).with_action(
"ls -la".to_string(),
"low".to_string(),
false,
true,
);
assert!(event.action.is_some());
let action = event.action.as_ref().unwrap();

View file

@ -35,14 +35,23 @@ impl BubblewrapSandbox {
impl Sandbox for BubblewrapSandbox {
fn wrap_command(&self, cmd: &mut Command) -> std::io::Result<()> {
let program = cmd.get_program().to_string_lossy().to_string();
let args: Vec<String> = cmd.get_args().map(|s| s.to_string_lossy().to_string()).collect();
let args: Vec<String> = cmd
.get_args()
.map(|s| s.to_string_lossy().to_string())
.collect();
let mut bwrap_cmd = Command::new("bwrap");
bwrap_cmd.args([
"--ro-bind", "/usr", "/usr",
"--dev", "/dev",
"--proc", "/proc",
"--bind", "/tmp", "/tmp",
"--ro-bind",
"/usr",
"/usr",
"--dev",
"/dev",
"--proc",
"/proc",
"--bind",
"/tmp",
"/tmp",
"--unshare-all",
"--die-with-parent",
]);

View file

@ -25,7 +25,9 @@ pub fn create_sandbox(config: &SecurityConfig) -> Arc<dyn Sandbox> {
}
}
}
tracing::warn!("Landlock requested but not available, falling back to application-layer");
tracing::warn!(
"Landlock requested but not available, falling back to application-layer"
);
Arc::new(super::traits::NoopSandbox)
}
SandboxBackend::Firejail => {
@ -35,7 +37,9 @@ pub fn create_sandbox(config: &SecurityConfig) -> Arc<dyn Sandbox> {
return Arc::new(sandbox);
}
}
tracing::warn!("Firejail requested but not available, falling back to application-layer");
tracing::warn!(
"Firejail requested but not available, falling back to application-layer"
);
Arc::new(super::traits::NoopSandbox)
}
SandboxBackend::Bubblewrap => {
@ -48,7 +52,9 @@ pub fn create_sandbox(config: &SecurityConfig) -> Arc<dyn Sandbox> {
}
}
}
tracing::warn!("Bubblewrap requested but not available, falling back to application-layer");
tracing::warn!(
"Bubblewrap requested but not available, falling back to application-layer"
);
Arc::new(super::traits::NoopSandbox)
}
SandboxBackend::Docker => {

View file

@ -56,14 +56,21 @@ impl DockerSandbox {
impl Sandbox for DockerSandbox {
fn wrap_command(&self, cmd: &mut Command) -> std::io::Result<()> {
let program = cmd.get_program().to_string_lossy().to_string();
let args: Vec<String> = cmd.get_args().map(|s| s.to_string_lossy().to_string()).collect();
let args: Vec<String> = cmd
.get_args()
.map(|s| s.to_string_lossy().to_string())
.collect();
let mut docker_cmd = Command::new("docker");
docker_cmd.args([
"run", "--rm",
"--memory", "512m",
"--cpus", "1.0",
"--network", "none",
"run",
"--rm",
"--memory",
"512m",
"--cpus",
"1.0",
"--network",
"none",
]);
docker_cmd.arg(&self.image);
docker_cmd.arg(&program);

View file

@ -41,7 +41,10 @@ impl Sandbox for FirejailSandbox {
fn wrap_command(&self, cmd: &mut Command) -> std::io::Result<()> {
// Prepend firejail to the command
let program = cmd.get_program().to_string_lossy().to_string();
let args: Vec<String> = cmd.get_args().map(|s| s.to_string_lossy().to_string()).collect();
let args: Vec<String> = cmd
.get_args()
.map(|s| s.to_string_lossy().to_string())
.collect();
// Build firejail wrapper with security flags
let mut firejail_cmd = Command::new("firejail");
@ -100,7 +103,10 @@ mod tests {
let result = FirejailSandbox::new();
match result {
Ok(_) => println!("Firejail is installed"),
Err(e) => assert!(e.kind() == std::io::ErrorKind::NotFound || e.kind() == std::io::ErrorKind::Unsupported),
Err(e) => assert!(
e.kind() == std::io::ErrorKind::NotFound
|| e.kind() == std::io::ErrorKind::Unsupported
),
}
}

View file

@ -26,8 +26,7 @@ impl LandlockSandbox {
/// Create a Landlock sandbox with a specific workspace directory
pub fn with_workspace(workspace_dir: Option<std::path::PathBuf>) -> std::io::Result<Self> {
// Test if Landlock is available by trying to create a minimal ruleset
let test_ruleset = Ruleset::new()
.set_access_fs(AccessFS::read_file | AccessFS::write_file);
let test_ruleset = Ruleset::new().set_access_fs(AccessFS::read_file | AccessFS::write_file);
match test_ruleset.create() {
Ok(_) => Ok(Self { workspace_dir }),
@ -48,8 +47,7 @@ impl LandlockSandbox {
/// Apply Landlock restrictions to the current process
fn apply_restrictions(&self) -> std::io::Result<()> {
let mut ruleset = Ruleset::new()
.set_access_fs(
let mut ruleset = Ruleset::new().set_access_fs(
AccessFS::read_file
| AccessFS::write_file
| AccessFS::read_dir
@ -60,18 +58,24 @@ impl LandlockSandbox {
| AccessFS::make_fifo
| AccessFS::make_block
| AccessFS::make_reg
| AccessFS::make_sym
| AccessFS::make_sym,
);
// Allow workspace directory (read/write)
if let Some(ref workspace) = self.workspace_dir {
if workspace.exists() {
ruleset = ruleset.add_path(workspace, AccessFS::read_file | AccessFS::write_file | AccessFS::read_dir)?;
ruleset = ruleset.add_path(
workspace,
AccessFS::read_file | AccessFS::write_file | AccessFS::read_dir,
)?;
}
}
// Allow /tmp for general operations
ruleset = ruleset.add_path(Path::new("/tmp"), AccessFS::read_file | AccessFS::write_file)?;
ruleset = ruleset.add_path(
Path::new("/tmp"),
AccessFS::read_file | AccessFS::write_file,
)?;
// Allow /usr and /bin for executing commands
ruleset = ruleset.add_path(Path::new("/usr"), AccessFS::read_file | AccessFS::read_dir)?;
@ -193,7 +197,10 @@ mod tests {
// Result depends on platform and feature flag
match result {
Ok(sandbox) => assert!(sandbox.is_available()),
Err(_) => assert!(!cfg!(all(feature = "sandbox-landlock", target_os = "linux"))),
Err(_) => assert!(!cfg!(all(
feature = "sandbox-landlock",
target_os = "linux"
))),
}
}
}

View file

@ -1,7 +1,7 @@
pub mod audit;
pub mod detect;
#[cfg(feature = "sandbox-bubblewrap")]
pub mod bubblewrap;
pub mod detect;
pub mod docker;
#[cfg(target_os = "linux")]
pub mod firejail;

View file

@ -61,7 +61,10 @@ mod tests {
let mut cmd = Command::new("echo");
cmd.arg("test");
let original_program = cmd.get_program().to_string_lossy().to_string();
let original_args: Vec<String> = cmd.get_args().map(|s| s.to_string_lossy().to_string()).collect();
let original_args: Vec<String> = cmd
.get_args()
.map(|s| s.to_string_lossy().to_string())
.collect();
let sandbox = NoopSandbox;
assert!(sandbox.wrap_command(&mut cmd).is_ok());
@ -69,7 +72,9 @@ mod tests {
// Command should be unchanged
assert_eq!(cmd.get_program().to_string_lossy(), original_program);
assert_eq!(
cmd.get_args().map(|s| s.to_string_lossy().to_string()).collect::<Vec<_>>(),
cmd.get_args()
.map(|s| s.to_string_lossy().to_string())
.collect::<Vec<_>>(),
original_args
);
}

View file

@ -124,7 +124,10 @@ impl HttpRequestTool {
fn truncate_response(&self, text: &str) -> String {
if text.len() > self.max_response_size {
let mut truncated = text.chars().take(self.max_response_size).collect::<String>();
let mut truncated = text
.chars()
.take(self.max_response_size)
.collect::<String>();
truncated.push_str("\n\n... [Response truncated due to size limit] ...");
truncated
} else {
@ -221,7 +224,10 @@ impl Tool for HttpRequestTool {
let sanitized_headers = self.sanitize_headers(&headers_val);
match self.execute_request(&url, method, sanitized_headers, body).await {
match self
.execute_request(&url, method, sanitized_headers, body)
.await
{
Ok(response) => {
let status = response.status();
let status_code = status.as_u16();
@ -407,7 +413,12 @@ mod tests {
autonomy: AutonomyLevel::Supervised,
..SecurityPolicy::default()
});
HttpRequestTool::new(security, allowed_domains.into_iter().map(String::from).collect(), 1_000_000, 30)
HttpRequestTool::new(
security,
allowed_domains.into_iter().map(String::from).collect(),
1_000_000,
30,
)
}
#[test]
@ -598,8 +609,14 @@ mod tests {
});
let sanitized = tool.sanitize_headers(&headers);
assert_eq!(sanitized.len(), 3);
assert!(sanitized.iter().any(|(k, v)| k == "Authorization" && v == "***REDACTED***"));
assert!(sanitized.iter().any(|(k, v)| k == "X-API-Key" && v == "***REDACTED***"));
assert!(sanitized.iter().any(|(k, v)| k == "Content-Type" && v == "application/json"));
assert!(sanitized
.iter()
.any(|(k, v)| k == "Authorization" && v == "***REDACTED***"));
assert!(sanitized
.iter()
.any(|(k, v)| k == "X-API-Key" && v == "***REDACTED***"));
assert!(sanitized
.iter()
.any(|(k, v)| k == "Content-Type" && v == "application/json"));
}
}

View file

@ -320,7 +320,15 @@ mod tests {
},
);
let tools = all_tools(&security, mem, None, &browser, &http, &agents, Some("sk-test"));
let tools = all_tools(
&security,
mem,
None,
&browser,
&http,
&agents,
Some("sk-test"),
);
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
assert!(names.contains(&"delegate"));
}