diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index dfce36a..d284088 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -339,7 +339,7 @@ struct ParsedToolCall { /// Execute a single turn of the agent loop: send messages, parse tool calls, /// execute tools, and loop until the LLM produces a final text response. -async fn agent_turn( +pub(crate) async fn agent_turn( provider: &dyn Provider, history: &mut Vec, tools_registry: &[Box], @@ -414,7 +414,7 @@ async fn agent_turn( /// Build the tool instruction block for the system prompt so the LLM knows /// how to invoke tools. -fn build_tool_instructions(tools_registry: &[Box]) -> String { +pub(crate) fn build_tool_instructions(tools_registry: &[Box]) -> String { let mut instructions = String::new(); instructions.push_str("\n## Tool Use Protocol\n\n"); instructions.push_str("To use a tool, wrap a JSON object in tags:\n\n"); diff --git a/src/channels/discord.rs b/src/channels/discord.rs index 27d2582..c685e96 100644 --- a/src/channels/discord.rs +++ b/src/channels/discord.rs @@ -16,7 +16,12 @@ pub struct DiscordChannel { } impl DiscordChannel { - pub fn new(bot_token: String, guild_id: Option, allowed_users: Vec, listen_to_bots: bool) -> Self { + pub fn new( + bot_token: String, + guild_id: Option, + allowed_users: Vec, + listen_to_bots: bool, + ) -> Self { Self { bot_token, guild_id, diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 936a26b..e7e3671 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -20,10 +20,15 @@ pub use telegram::TelegramChannel; pub use traits::Channel; pub use whatsapp::WhatsAppChannel; +use crate::agent::loop_::{agent_turn, build_tool_instructions}; use crate::config::Config; use crate::identity; use crate::memory::{self, Memory}; -use crate::providers::{self, Provider}; +use crate::observability::{self, Observer}; +use crate::providers::{self, ChatMessage, Provider}; +use crate::runtime; +use crate::security::SecurityPolicy; +use crate::tools::{self, Tool}; use crate::util::truncate_with_ellipsis; use anyhow::Result; use std::collections::HashMap; @@ -46,6 +51,8 @@ struct ChannelRuntimeContext { channels_by_name: Arc>>, provider: Arc, memory: Arc, + tools_registry: Arc>>, + observer: Arc, system_prompt: Arc, model: Arc, temperature: f64, @@ -166,11 +173,18 @@ async fn process_channel_message(ctx: Arc, msg: traits::C println!(" ⏳ Processing message..."); let started_at = Instant::now(); + let mut history = vec![ + ChatMessage::system(ctx.system_prompt.as_str()), + ChatMessage::user(&enriched_message), + ]; + let llm_result = tokio::time::timeout( Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS), - ctx.provider.chat_with_system( - Some(ctx.system_prompt.as_str()), - &enriched_message, + agent_turn( + ctx.provider.as_ref(), + &mut history, + ctx.tools_registry.as_ref(), + ctx.observer.as_ref(), ctx.model.as_str(), ctx.temperature, ), @@ -323,7 +337,8 @@ pub fn build_system_prompt( prompt.push_str("```\n\n{\"name\": \"tool_name\", \"arguments\": {\"param\": \"value\"}}\n\n```\n\n"); prompt.push_str("You may use multiple tool calls in a single response. "); prompt.push_str("After tool execution, results appear in tags. "); - prompt.push_str("Continue reasoning with the results until you can give a final answer.\n\n"); + prompt + .push_str("Continue reasoning with the results until you can give a final answer.\n\n"); } // ── 2. Safety ─────────────────────────────────────────────── @@ -674,6 +689,15 @@ pub async fn start_channels(config: Config) -> Result<()> { tracing::warn!("Provider warmup failed (non-fatal): {e}"); } + let observer: Arc = + Arc::from(observability::create_observer(&config.observability)); + let runtime: Arc = + Arc::from(runtime::create_runtime(&config.runtime)?); + let security = Arc::new(SecurityPolicy::from_config( + &config.autonomy, + &config.workspace_dir, + )); + let model = config .default_model .clone() @@ -685,6 +709,22 @@ pub async fn start_channels(config: Config) -> Result<()> { config.api_key.as_deref(), )?); + let composio_key = if config.composio.enabled { + config.composio.api_key.as_deref() + } else { + None + }; + let tools_registry = Arc::new(tools::all_tools_with_runtime( + &security, + runtime, + Arc::clone(&mem), + composio_key, + &config.browser, + &config.http_request, + &config.agents, + config.api_key.as_deref(), + )); + // Build system prompt from workspace identity files + skills let workspace = config.workspace_dir.clone(); let skills = crate::skills::load_skills(&workspace); @@ -723,14 +763,27 @@ pub async fn start_channels(config: Config) -> Result<()> { "Open approved HTTPS URLs in Brave Browser (allowlist-only, no scraping)", )); } + if config.composio.enabled { + tool_descs.push(( + "composio", + "Execute actions on 1000+ apps via Composio (Gmail, Notion, GitHub, Slack, etc.). Use action='list' to discover, 'execute' to run, 'connect' to OAuth.", + )); + } + if !config.agents.is_empty() { + tool_descs.push(( + "delegate", + "Delegate a subtask to a specialized agent. Use when: a task benefits from a different model (e.g. fast summarization, deep reasoning, code generation). The sub-agent runs a single prompt and returns its response.", + )); + } - let system_prompt = build_system_prompt( + let mut system_prompt = build_system_prompt( &workspace, &model, &tool_descs, &skills, Some(&config.identity), ); + system_prompt.push_str(&build_tool_instructions(tools_registry.as_ref())); if !skills.is_empty() { println!( @@ -875,6 +928,8 @@ pub async fn start_channels(config: Config) -> Result<()> { channels_by_name, provider: Arc::clone(&provider), memory: Arc::clone(&mem), + tools_registry: Arc::clone(&tools_registry), + observer, system_prompt: Arc::new(system_prompt), model: Arc::new(model.clone()), temperature, @@ -895,7 +950,9 @@ pub async fn start_channels(config: Config) -> Result<()> { mod tests { use super::*; use crate::memory::{Memory, MemoryCategory, SqliteMemory}; - use crate::providers::Provider; + use crate::observability::NoopObserver; + use crate::providers::{ChatMessage, Provider}; + use crate::tools::{Tool, ToolResult}; use std::collections::HashMap; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; @@ -967,6 +1024,131 @@ mod tests { } } + struct ToolCallingProvider; + + fn tool_call_payload() -> String { + serde_json::json!({ + "content": "", + "tool_calls": [{ + "id": "call_1", + "type": "function", + "function": { + "name": "mock_price", + "arguments": "{\"symbol\":\"BTC\"}" + } + }] + }) + .to_string() + } + + #[async_trait::async_trait] + impl Provider for ToolCallingProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + Ok(tool_call_payload()) + } + + async fn chat_with_history( + &self, + messages: &[ChatMessage], + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + let has_tool_results = messages + .iter() + .any(|msg| msg.role == "user" && msg.content.contains("[Tool results]")); + if has_tool_results { + Ok("BTC is currently around $65,000 based on latest tool output.".to_string()) + } else { + Ok(tool_call_payload()) + } + } + } + + struct MockPriceTool; + + #[async_trait::async_trait] + impl Tool for MockPriceTool { + fn name(&self) -> &str { + "mock_price" + } + + fn description(&self) -> &str { + "Return a mocked BTC price" + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": { + "symbol": { "type": "string" } + }, + "required": ["symbol"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let symbol = args.get("symbol").and_then(serde_json::Value::as_str); + if symbol != Some("BTC") { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("unexpected symbol".to_string()), + }); + } + + Ok(ToolResult { + success: true, + output: r#"{"symbol":"BTC","price_usd":65000}"#.to_string(), + error: None, + }) + } + } + + #[tokio::test] + async fn process_channel_message_executes_tool_calls_instead_of_sending_raw_json() { + let channel_impl = Arc::new(RecordingChannel::default()); + let channel: Arc = channel_impl.clone(); + + let mut channels_by_name = HashMap::new(); + channels_by_name.insert(channel.name().to_string(), channel); + + let runtime_ctx = Arc::new(ChannelRuntimeContext { + channels_by_name: Arc::new(channels_by_name), + provider: Arc::new(ToolCallingProvider), + memory: Arc::new(NoopMemory), + tools_registry: Arc::new(vec![Box::new(MockPriceTool)]), + observer: Arc::new(NoopObserver), + system_prompt: Arc::new("test-system-prompt".to_string()), + model: Arc::new("test-model".to_string()), + temperature: 0.0, + auto_save_memory: false, + }); + + process_channel_message( + runtime_ctx, + traits::ChannelMessage { + id: "msg-1".to_string(), + sender: "alice".to_string(), + content: "What is the BTC price now?".to_string(), + channel: "test-channel".to_string(), + timestamp: 1, + }, + ) + .await; + + let sent_messages = channel_impl.sent_messages.lock().await; + assert_eq!(sent_messages.len(), 1); + assert!(sent_messages[0].contains("BTC is currently around")); + assert!(!sent_messages[0].contains("\"tool_calls\"")); + assert!(!sent_messages[0].contains("mock_price")); + } + struct NoopMemory; #[async_trait::async_trait] @@ -1030,6 +1212,8 @@ mod tests { delay: Duration::from_millis(250), }), memory: Arc::new(NoopMemory), + tools_registry: Arc::new(vec![]), + observer: Arc::new(NoopObserver), system_prompt: Arc::new("test-system-prompt".to_string()), model: Arc::new("test-model".to_string()), temperature: 0.0, @@ -1269,7 +1453,10 @@ mod tests { // Reproduces the production crash path where channel logs truncate at 80 chars. let result = std::panic::catch_unwind(|| crate::util::truncate_with_ellipsis(msg, 80)); - assert!(result.is_ok(), "truncate_with_ellipsis should never panic on UTF-8"); + assert!( + result.is_ok(), + "truncate_with_ellipsis should never panic on UTF-8" + ); let truncated = result.unwrap(); assert!(!truncated.is_empty()); diff --git a/src/config/mod.rs b/src/config/mod.rs index 376d83d..437befc 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,7 +1,7 @@ pub mod schema; pub use schema::{ - AutonomyConfig, AuditConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, + AuditConfig, AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, GatewayConfig, HeartbeatConfig, HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, MemoryConfig, ModelRouteConfig, ObservabilityConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, diff --git a/src/config/schema.rs b/src/config/schema.rs index d25a816..e8d96a2 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -964,7 +964,7 @@ pub struct SandboxConfig { impl Default for SandboxConfig { fn default() -> Self { Self { - enabled: None, // Auto-detect + enabled: None, // Auto-detect backend: SandboxBackend::Auto, firejail_args: Vec::new(), } diff --git a/src/hardware/mod.rs b/src/hardware/mod.rs index cd54854..30b551b 100644 --- a/src/hardware/mod.rs +++ b/src/hardware/mod.rs @@ -168,7 +168,10 @@ impl HardwareConfig { bail!("hardware.baud_rate must be greater than 0."); } if self.baud_rate > 4_000_000 { - bail!("hardware.baud_rate of {} exceeds the 4 MHz safety limit.", self.baud_rate); + bail!( + "hardware.baud_rate of {} exceeds the 4 MHz safety limit.", + self.baud_rate + ); } // PWM frequency sanity @@ -228,20 +231,16 @@ fn discover_native_gpio(devices: &mut Vec) { if gpiomem.exists() || gpiochip.exists() { // Try to read model from device tree let model = read_board_model(); - let name = model - .as_deref() - .unwrap_or("Linux SBC with GPIO"); + let name = model.as_deref().unwrap_or("Linux SBC with GPIO"); devices.push(DiscoveredDevice { name: format!("{name} (Native GPIO)"), transport: HardwareTransport::Native, - device_path: Some( - if gpiomem.exists() { - "/dev/gpiomem".into() - } else { - "/dev/gpiochip0".into() - }, - ), + device_path: Some(if gpiomem.exists() { + "/dev/gpiomem".into() + } else { + "/dev/gpiochip0".into() + }), detail: model, }); } @@ -287,10 +286,7 @@ fn serial_device_paths() -> Vec { "/dev/tty.wchusbserial*".into(), // CH340 clones ] } else if cfg!(target_os = "linux") { - vec![ - "/dev/ttyUSB*".into(), - "/dev/ttyACM*".into(), - ] + vec!["/dev/ttyUSB*".into(), "/dev/ttyACM*".into()] } else { // Windows / other — not yet supported for auto-discovery vec![] @@ -452,10 +448,7 @@ pub fn create_hal(config: &HardwareConfig) -> Result> { ); } HardwareTransport::Probe => { - let target = config - .probe_target - .as_deref() - .unwrap_or("unknown"); + let target = config.probe_target.as_deref().unwrap_or("unknown"); bail!( "Probe transport targeting '{}' is configured but the probe-rs HAL \ backend is not yet compiled in. This will be available in a future release.", @@ -471,15 +464,24 @@ pub fn create_hal(config: &HardwareConfig) -> Result> { /// based on discovery results. pub fn recommended_wizard_default(devices: &[DiscoveredDevice]) -> usize { // If we found native GPIO → recommend Native (index 0) - if devices.iter().any(|d| d.transport == HardwareTransport::Native) { + if devices + .iter() + .any(|d| d.transport == HardwareTransport::Native) + { return 0; } // If we found serial devices → recommend Tethered (index 1) - if devices.iter().any(|d| d.transport == HardwareTransport::Serial) { + if devices + .iter() + .any(|d| d.transport == HardwareTransport::Serial) + { return 1; } // If we found debug probes → recommend Probe (index 2) - if devices.iter().any(|d| d.transport == HardwareTransport::Probe) { + if devices + .iter() + .any(|d| d.transport == HardwareTransport::Probe) + { return 2; } // Default: Software Only (index 3) @@ -487,10 +489,7 @@ pub fn recommended_wizard_default(devices: &[DiscoveredDevice]) -> usize { } /// Build a `HardwareConfig` from a wizard selection and discovered devices. -pub fn config_from_wizard_choice( - choice: usize, - devices: &[DiscoveredDevice], -) -> HardwareConfig { +pub fn config_from_wizard_choice(choice: usize, devices: &[DiscoveredDevice]) -> HardwareConfig { match choice { // Native 0 => { @@ -548,39 +547,102 @@ mod tests { #[test] fn transport_parse_native_variants() { - assert_eq!(HardwareTransport::from_str_loose("native"), HardwareTransport::Native); - assert_eq!(HardwareTransport::from_str_loose("gpio"), HardwareTransport::Native); - assert_eq!(HardwareTransport::from_str_loose("rppal"), HardwareTransport::Native); - assert_eq!(HardwareTransport::from_str_loose("sysfs"), HardwareTransport::Native); - assert_eq!(HardwareTransport::from_str_loose("NATIVE"), HardwareTransport::Native); - assert_eq!(HardwareTransport::from_str_loose(" Native "), HardwareTransport::Native); + assert_eq!( + HardwareTransport::from_str_loose("native"), + HardwareTransport::Native + ); + assert_eq!( + HardwareTransport::from_str_loose("gpio"), + HardwareTransport::Native + ); + assert_eq!( + HardwareTransport::from_str_loose("rppal"), + HardwareTransport::Native + ); + assert_eq!( + HardwareTransport::from_str_loose("sysfs"), + HardwareTransport::Native + ); + assert_eq!( + HardwareTransport::from_str_loose("NATIVE"), + HardwareTransport::Native + ); + assert_eq!( + HardwareTransport::from_str_loose(" Native "), + HardwareTransport::Native + ); } #[test] fn transport_parse_serial_variants() { - assert_eq!(HardwareTransport::from_str_loose("serial"), HardwareTransport::Serial); - assert_eq!(HardwareTransport::from_str_loose("uart"), HardwareTransport::Serial); - assert_eq!(HardwareTransport::from_str_loose("usb"), HardwareTransport::Serial); - assert_eq!(HardwareTransport::from_str_loose("tethered"), HardwareTransport::Serial); - assert_eq!(HardwareTransport::from_str_loose("SERIAL"), HardwareTransport::Serial); + assert_eq!( + HardwareTransport::from_str_loose("serial"), + HardwareTransport::Serial + ); + assert_eq!( + HardwareTransport::from_str_loose("uart"), + HardwareTransport::Serial + ); + assert_eq!( + HardwareTransport::from_str_loose("usb"), + HardwareTransport::Serial + ); + assert_eq!( + HardwareTransport::from_str_loose("tethered"), + HardwareTransport::Serial + ); + assert_eq!( + HardwareTransport::from_str_loose("SERIAL"), + HardwareTransport::Serial + ); } #[test] fn transport_parse_probe_variants() { - assert_eq!(HardwareTransport::from_str_loose("probe"), HardwareTransport::Probe); - assert_eq!(HardwareTransport::from_str_loose("probe-rs"), HardwareTransport::Probe); - assert_eq!(HardwareTransport::from_str_loose("swd"), HardwareTransport::Probe); - assert_eq!(HardwareTransport::from_str_loose("jtag"), HardwareTransport::Probe); - assert_eq!(HardwareTransport::from_str_loose("jlink"), HardwareTransport::Probe); - assert_eq!(HardwareTransport::from_str_loose("j-link"), HardwareTransport::Probe); + assert_eq!( + HardwareTransport::from_str_loose("probe"), + HardwareTransport::Probe + ); + assert_eq!( + HardwareTransport::from_str_loose("probe-rs"), + HardwareTransport::Probe + ); + assert_eq!( + HardwareTransport::from_str_loose("swd"), + HardwareTransport::Probe + ); + assert_eq!( + HardwareTransport::from_str_loose("jtag"), + HardwareTransport::Probe + ); + assert_eq!( + HardwareTransport::from_str_loose("jlink"), + HardwareTransport::Probe + ); + assert_eq!( + HardwareTransport::from_str_loose("j-link"), + HardwareTransport::Probe + ); } #[test] fn transport_parse_none_and_unknown() { - assert_eq!(HardwareTransport::from_str_loose("none"), HardwareTransport::None); - assert_eq!(HardwareTransport::from_str_loose(""), HardwareTransport::None); - assert_eq!(HardwareTransport::from_str_loose("foobar"), HardwareTransport::None); - assert_eq!(HardwareTransport::from_str_loose("bluetooth"), HardwareTransport::None); + assert_eq!( + HardwareTransport::from_str_loose("none"), + HardwareTransport::None + ); + assert_eq!( + HardwareTransport::from_str_loose(""), + HardwareTransport::None + ); + assert_eq!( + HardwareTransport::from_str_loose("foobar"), + HardwareTransport::None + ); + assert_eq!( + HardwareTransport::from_str_loose("bluetooth"), + HardwareTransport::None + ); } #[test] @@ -918,7 +980,9 @@ mod tests { #[test] fn noop_hal_firmware_upload_fails() { let hal = NoopHal; - let err = hal.firmware_upload(Path::new("/tmp/firmware.bin")).unwrap_err(); + let err = hal + .firmware_upload(Path::new("/tmp/firmware.bin")) + .unwrap_err(); assert!(err.to_string().contains("not enabled")); assert!(err.to_string().contains("firmware.bin")); } diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index 69e0f83..c749d07 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -1093,7 +1093,9 @@ fn setup_hardware() -> Result { } // ── Probe: ask for target chip ── - if hw_config.transport_mode() == hardware::HardwareTransport::Probe && hw_config.probe_target.is_none() { + if hw_config.transport_mode() == hardware::HardwareTransport::Probe + && hw_config.probe_target.is_none() + { let target: String = Input::new() .with_prompt(" Target MCU chip (e.g. STM32F411CEUx, nRF52840_xxAA)") .default("STM32F411CEUx".into()) @@ -2698,21 +2700,25 @@ fn print_summary(config: &Config) { if config.hardware.enabled { let mode = config.hardware.transport_mode(); match mode { - hardware::HardwareTransport::Native => style("Native GPIO (direct)").green().to_string(), + hardware::HardwareTransport::Native => { + style("Native GPIO (direct)").green().to_string() + } hardware::HardwareTransport::Serial => format!( "{}", style(format!( "Serial → {} @ {} baud", config.hardware.serial_port.as_deref().unwrap_or("?"), config.hardware.baud_rate - )).green() + )) + .green() ), hardware::HardwareTransport::Probe => format!( "{}", style(format!( "Probe → {}", config.hardware.probe_target.as_deref().unwrap_or("?") - )).green() + )) + .green() ), hardware::HardwareTransport::None => "disabled (software only)".to_string(), } diff --git a/src/security/audit.rs b/src/security/audit.rs index 971134e..b7dabae 100644 --- a/src/security/audit.rs +++ b/src/security/audit.rs @@ -88,7 +88,12 @@ impl AuditEvent { } /// Set the actor - pub fn with_actor(mut self, channel: String, user_id: Option, username: Option) -> Self { + pub fn with_actor( + mut self, + channel: String, + user_id: Option, + username: Option, + ) -> Self { self.actor = Some(Actor { channel, user_id, @@ -98,7 +103,13 @@ impl AuditEvent { } /// Set the action - pub fn with_action(mut self, command: String, risk_level: String, approved: bool, allowed: bool) -> Self { + pub fn with_action( + mut self, + command: String, + risk_level: String, + approved: bool, + allowed: bool, + ) -> Self { self.action = Some(Action { command: Some(command), risk_level: Some(risk_level), @@ -109,7 +120,13 @@ impl AuditEvent { } /// Set the result - pub fn with_result(mut self, success: bool, exit_code: Option, duration_ms: u64, error: Option) -> Self { + pub fn with_result( + mut self, + success: bool, + exit_code: Option, + duration_ms: u64, + error: Option, + ) -> Self { self.result = Some(ExecutionResult { success, exit_code, @@ -179,7 +196,12 @@ impl AuditLogger { ) -> Result<()> { let event = AuditEvent::new(AuditEventType::CommandExecution) .with_actor(channel.to_string(), None, None) - .with_action(command.to_string(), risk_level.to_string(), approved, allowed) + .with_action( + command.to_string(), + risk_level.to_string(), + approved, + allowed, + ) .with_result(success, None, duration_ms, None); self.log(&event) @@ -224,8 +246,11 @@ mod tests { #[test] fn audit_event_with_actor() { - let event = AuditEvent::new(AuditEventType::CommandExecution) - .with_actor("telegram".to_string(), Some("123".to_string()), Some("@alice".to_string())); + let event = AuditEvent::new(AuditEventType::CommandExecution).with_actor( + "telegram".to_string(), + Some("123".to_string()), + Some("@alice".to_string()), + ); assert!(event.actor.is_some()); let actor = event.actor.as_ref().unwrap(); @@ -236,8 +261,12 @@ mod tests { #[test] fn audit_event_with_action() { - let event = AuditEvent::new(AuditEventType::CommandExecution) - .with_action("ls -la".to_string(), "low".to_string(), false, true); + let event = AuditEvent::new(AuditEventType::CommandExecution).with_action( + "ls -la".to_string(), + "low".to_string(), + false, + true, + ); assert!(event.action.is_some()); let action = event.action.as_ref().unwrap(); diff --git a/src/security/bubblewrap.rs b/src/security/bubblewrap.rs index 1c83c8f..5c7106e 100644 --- a/src/security/bubblewrap.rs +++ b/src/security/bubblewrap.rs @@ -35,14 +35,23 @@ impl BubblewrapSandbox { impl Sandbox for BubblewrapSandbox { fn wrap_command(&self, cmd: &mut Command) -> std::io::Result<()> { let program = cmd.get_program().to_string_lossy().to_string(); - let args: Vec = cmd.get_args().map(|s| s.to_string_lossy().to_string()).collect(); + let args: Vec = cmd + .get_args() + .map(|s| s.to_string_lossy().to_string()) + .collect(); let mut bwrap_cmd = Command::new("bwrap"); bwrap_cmd.args([ - "--ro-bind", "/usr", "/usr", - "--dev", "/dev", - "--proc", "/proc", - "--bind", "/tmp", "/tmp", + "--ro-bind", + "/usr", + "/usr", + "--dev", + "/dev", + "--proc", + "/proc", + "--bind", + "/tmp", + "/tmp", "--unshare-all", "--die-with-parent", ]); diff --git a/src/security/detect.rs b/src/security/detect.rs index 11c7ea0..751d8d0 100644 --- a/src/security/detect.rs +++ b/src/security/detect.rs @@ -25,7 +25,9 @@ pub fn create_sandbox(config: &SecurityConfig) -> Arc { } } } - tracing::warn!("Landlock requested but not available, falling back to application-layer"); + tracing::warn!( + "Landlock requested but not available, falling back to application-layer" + ); Arc::new(super::traits::NoopSandbox) } SandboxBackend::Firejail => { @@ -35,7 +37,9 @@ pub fn create_sandbox(config: &SecurityConfig) -> Arc { return Arc::new(sandbox); } } - tracing::warn!("Firejail requested but not available, falling back to application-layer"); + tracing::warn!( + "Firejail requested but not available, falling back to application-layer" + ); Arc::new(super::traits::NoopSandbox) } SandboxBackend::Bubblewrap => { @@ -48,7 +52,9 @@ pub fn create_sandbox(config: &SecurityConfig) -> Arc { } } } - tracing::warn!("Bubblewrap requested but not available, falling back to application-layer"); + tracing::warn!( + "Bubblewrap requested but not available, falling back to application-layer" + ); Arc::new(super::traits::NoopSandbox) } SandboxBackend::Docker => { @@ -138,7 +144,7 @@ mod tests { fn auto_mode_detects_something() { let config = SecurityConfig { sandbox: SandboxConfig { - enabled: None, // Auto-detect + enabled: None, // Auto-detect backend: SandboxBackend::Auto, firejail_args: Vec::new(), }, diff --git a/src/security/docker.rs b/src/security/docker.rs index 84aac10..2c32e20 100644 --- a/src/security/docker.rs +++ b/src/security/docker.rs @@ -56,14 +56,21 @@ impl DockerSandbox { impl Sandbox for DockerSandbox { fn wrap_command(&self, cmd: &mut Command) -> std::io::Result<()> { let program = cmd.get_program().to_string_lossy().to_string(); - let args: Vec = cmd.get_args().map(|s| s.to_string_lossy().to_string()).collect(); + let args: Vec = cmd + .get_args() + .map(|s| s.to_string_lossy().to_string()) + .collect(); let mut docker_cmd = Command::new("docker"); docker_cmd.args([ - "run", "--rm", - "--memory", "512m", - "--cpus", "1.0", - "--network", "none", + "run", + "--rm", + "--memory", + "512m", + "--cpus", + "1.0", + "--network", + "none", ]); docker_cmd.arg(&self.image); docker_cmd.arg(&program); diff --git a/src/security/firejail.rs b/src/security/firejail.rs index 08bbf3c..9eeb6c7 100644 --- a/src/security/firejail.rs +++ b/src/security/firejail.rs @@ -41,20 +41,23 @@ impl Sandbox for FirejailSandbox { fn wrap_command(&self, cmd: &mut Command) -> std::io::Result<()> { // Prepend firejail to the command let program = cmd.get_program().to_string_lossy().to_string(); - let args: Vec = cmd.get_args().map(|s| s.to_string_lossy().to_string()).collect(); + let args: Vec = cmd + .get_args() + .map(|s| s.to_string_lossy().to_string()) + .collect(); // Build firejail wrapper with security flags let mut firejail_cmd = Command::new("firejail"); firejail_cmd.args([ - "--private=home", // New home directory - "--private-dev", // Minimal /dev - "--nosound", // No audio - "--no3d", // No 3D acceleration - "--novideo", // No video devices - "--nowheel", // No input devices - "--notv", // No TV devices - "--noprofile", // Skip profile loading - "--quiet", // Suppress warnings + "--private=home", // New home directory + "--private-dev", // Minimal /dev + "--nosound", // No audio + "--no3d", // No 3D acceleration + "--novideo", // No video devices + "--nowheel", // No input devices + "--notv", // No TV devices + "--noprofile", // Skip profile loading + "--quiet", // Suppress warnings ]); // Add the original command @@ -100,7 +103,10 @@ mod tests { let result = FirejailSandbox::new(); match result { Ok(_) => println!("Firejail is installed"), - Err(e) => assert!(e.kind() == std::io::ErrorKind::NotFound || e.kind() == std::io::ErrorKind::Unsupported), + Err(e) => assert!( + e.kind() == std::io::ErrorKind::NotFound + || e.kind() == std::io::ErrorKind::Unsupported + ), } } diff --git a/src/security/landlock.rs b/src/security/landlock.rs index 90942e2..afb990f 100644 --- a/src/security/landlock.rs +++ b/src/security/landlock.rs @@ -26,8 +26,7 @@ impl LandlockSandbox { /// Create a Landlock sandbox with a specific workspace directory pub fn with_workspace(workspace_dir: Option) -> std::io::Result { // Test if Landlock is available by trying to create a minimal ruleset - let test_ruleset = Ruleset::new() - .set_access_fs(AccessFS::read_file | AccessFS::write_file); + let test_ruleset = Ruleset::new().set_access_fs(AccessFS::read_file | AccessFS::write_file); match test_ruleset.create() { Ok(_) => Ok(Self { workspace_dir }), @@ -48,30 +47,35 @@ impl LandlockSandbox { /// Apply Landlock restrictions to the current process fn apply_restrictions(&self) -> std::io::Result<()> { - let mut ruleset = Ruleset::new() - .set_access_fs( - AccessFS::read_file - | AccessFS::write_file - | AccessFS::read_dir - | AccessFS::remove_dir - | AccessFS::remove_file - | AccessFS::make_char - | AccessFS::make_sock - | AccessFS::make_fifo - | AccessFS::make_block - | AccessFS::make_reg - | AccessFS::make_sym - ); + let mut ruleset = Ruleset::new().set_access_fs( + AccessFS::read_file + | AccessFS::write_file + | AccessFS::read_dir + | AccessFS::remove_dir + | AccessFS::remove_file + | AccessFS::make_char + | AccessFS::make_sock + | AccessFS::make_fifo + | AccessFS::make_block + | AccessFS::make_reg + | AccessFS::make_sym, + ); // Allow workspace directory (read/write) if let Some(ref workspace) = self.workspace_dir { if workspace.exists() { - ruleset = ruleset.add_path(workspace, AccessFS::read_file | AccessFS::write_file | AccessFS::read_dir)?; + ruleset = ruleset.add_path( + workspace, + AccessFS::read_file | AccessFS::write_file | AccessFS::read_dir, + )?; } } // Allow /tmp for general operations - ruleset = ruleset.add_path(Path::new("/tmp"), AccessFS::read_file | AccessFS::write_file)?; + ruleset = ruleset.add_path( + Path::new("/tmp"), + AccessFS::read_file | AccessFS::write_file, + )?; // Allow /usr and /bin for executing commands ruleset = ruleset.add_path(Path::new("/usr"), AccessFS::read_file | AccessFS::read_dir)?; @@ -193,7 +197,10 @@ mod tests { // Result depends on platform and feature flag match result { Ok(sandbox) => assert!(sandbox.is_available()), - Err(_) => assert!(!cfg!(all(feature = "sandbox-landlock", target_os = "linux"))), + Err(_) => assert!(!cfg!(all( + feature = "sandbox-landlock", + target_os = "linux" + ))), } } } diff --git a/src/security/mod.rs b/src/security/mod.rs index 60885bd..498fd18 100644 --- a/src/security/mod.rs +++ b/src/security/mod.rs @@ -1,7 +1,7 @@ pub mod audit; -pub mod detect; #[cfg(feature = "sandbox-bubblewrap")] pub mod bubblewrap; +pub mod detect; pub mod docker; #[cfg(target_os = "linux")] pub mod firejail; diff --git a/src/security/traits.rs b/src/security/traits.rs index 452480d..06fc4ef 100644 --- a/src/security/traits.rs +++ b/src/security/traits.rs @@ -61,7 +61,10 @@ mod tests { let mut cmd = Command::new("echo"); cmd.arg("test"); let original_program = cmd.get_program().to_string_lossy().to_string(); - let original_args: Vec = cmd.get_args().map(|s| s.to_string_lossy().to_string()).collect(); + let original_args: Vec = cmd + .get_args() + .map(|s| s.to_string_lossy().to_string()) + .collect(); let sandbox = NoopSandbox; assert!(sandbox.wrap_command(&mut cmd).is_ok()); @@ -69,7 +72,9 @@ mod tests { // Command should be unchanged assert_eq!(cmd.get_program().to_string_lossy(), original_program); assert_eq!( - cmd.get_args().map(|s| s.to_string_lossy().to_string()).collect::>(), + cmd.get_args() + .map(|s| s.to_string_lossy().to_string()) + .collect::>(), original_args ); } diff --git a/src/tools/http_request.rs b/src/tools/http_request.rs index 4ec9b01..36ebbd6 100644 --- a/src/tools/http_request.rs +++ b/src/tools/http_request.rs @@ -124,7 +124,10 @@ impl HttpRequestTool { fn truncate_response(&self, text: &str) -> String { if text.len() > self.max_response_size { - let mut truncated = text.chars().take(self.max_response_size).collect::(); + let mut truncated = text + .chars() + .take(self.max_response_size) + .collect::(); truncated.push_str("\n\n... [Response truncated due to size limit] ..."); truncated } else { @@ -221,7 +224,10 @@ impl Tool for HttpRequestTool { let sanitized_headers = self.sanitize_headers(&headers_val); - match self.execute_request(&url, method, sanitized_headers, body).await { + match self + .execute_request(&url, method, sanitized_headers, body) + .await + { Ok(response) => { let status = response.status(); let status_code = status.as_u16(); @@ -407,7 +413,12 @@ mod tests { autonomy: AutonomyLevel::Supervised, ..SecurityPolicy::default() }); - HttpRequestTool::new(security, allowed_domains.into_iter().map(String::from).collect(), 1_000_000, 30) + HttpRequestTool::new( + security, + allowed_domains.into_iter().map(String::from).collect(), + 1_000_000, + 30, + ) } #[test] @@ -598,8 +609,14 @@ mod tests { }); let sanitized = tool.sanitize_headers(&headers); assert_eq!(sanitized.len(), 3); - assert!(sanitized.iter().any(|(k, v)| k == "Authorization" && v == "***REDACTED***")); - assert!(sanitized.iter().any(|(k, v)| k == "X-API-Key" && v == "***REDACTED***")); - assert!(sanitized.iter().any(|(k, v)| k == "Content-Type" && v == "application/json")); + assert!(sanitized + .iter() + .any(|(k, v)| k == "Authorization" && v == "***REDACTED***")); + assert!(sanitized + .iter() + .any(|(k, v)| k == "X-API-Key" && v == "***REDACTED***")); + assert!(sanitized + .iter() + .any(|(k, v)| k == "Content-Type" && v == "application/json")); } } diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 0f139d1..a20a916 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -320,7 +320,15 @@ mod tests { }, ); - let tools = all_tools(&security, mem, None, &browser, &http, &agents, Some("sk-test")); + let tools = all_tools( + &security, + mem, + None, + &browser, + &http, + &agents, + Some("sk-test"), + ); let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); assert!(names.contains(&"delegate")); }