use super::traits::{Tool, ToolResult}; use crate::security::SecurityPolicy; use async_trait::async_trait; use serde_json::json; use std::sync::Arc; use std::time::Duration; /// Maximum shell command execution time before kill. const SHELL_TIMEOUT_SECS: u64 = 60; /// Maximum output size in bytes (1MB). const MAX_OUTPUT_BYTES: usize = 1_048_576; /// Shell command execution tool with sandboxing pub struct ShellTool { security: Arc, } impl ShellTool { pub fn new(security: Arc) -> Self { Self { security } } } #[async_trait] impl Tool for ShellTool { fn name(&self) -> &str { "shell" } fn description(&self) -> &str { "Execute a shell command in the workspace directory" } fn parameters_schema(&self) -> serde_json::Value { json!({ "type": "object", "properties": { "command": { "type": "string", "description": "The shell command to execute" } }, "required": ["command"] }) } async fn execute(&self, args: serde_json::Value) -> anyhow::Result { let command = args .get("command") .and_then(|v| v.as_str()) .ok_or_else(|| anyhow::anyhow!("Missing 'command' parameter"))?; // Security check: validate command against allowlist if !self.security.is_command_allowed(command) { return Ok(ToolResult { success: false, output: String::new(), error: Some(format!("Command not allowed by security policy: {command}")), }); } // Execute with timeout to prevent hanging commands let result = tokio::time::timeout( Duration::from_secs(SHELL_TIMEOUT_SECS), tokio::process::Command::new("sh") .arg("-c") .arg(command) .current_dir(&self.security.workspace_dir) .output(), ) .await; match result { Ok(Ok(output)) => { let mut stdout = String::from_utf8_lossy(&output.stdout).to_string(); let mut stderr = String::from_utf8_lossy(&output.stderr).to_string(); // Truncate output to prevent OOM if stdout.len() > MAX_OUTPUT_BYTES { stdout.truncate(MAX_OUTPUT_BYTES); stdout.push_str("\n... [output truncated at 1MB]"); } if stderr.len() > MAX_OUTPUT_BYTES { stderr.truncate(MAX_OUTPUT_BYTES); stderr.push_str("\n... [stderr truncated at 1MB]"); } Ok(ToolResult { success: output.status.success(), output: stdout, error: if stderr.is_empty() { None } else { Some(stderr) }, }) } Ok(Err(e)) => Ok(ToolResult { success: false, output: String::new(), error: Some(format!("Failed to execute command: {e}")), }), Err(_) => Ok(ToolResult { success: false, output: String::new(), error: Some(format!( "Command timed out after {SHELL_TIMEOUT_SECS}s and was killed" )), }), } } } #[cfg(test)] mod tests { use super::*; use crate::security::{AutonomyLevel, SecurityPolicy}; fn test_security(autonomy: AutonomyLevel) -> Arc { Arc::new(SecurityPolicy { autonomy, workspace_dir: std::env::temp_dir(), ..SecurityPolicy::default() }) } #[test] fn shell_tool_name() { let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); assert_eq!(tool.name(), "shell"); } #[test] fn shell_tool_description() { let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); assert!(!tool.description().is_empty()); } #[test] fn shell_tool_schema_has_command() { let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); let schema = tool.parameters_schema(); assert!(schema["properties"]["command"].is_object()); assert!(schema["required"] .as_array() .unwrap() .contains(&json!("command"))); } #[tokio::test] async fn shell_executes_allowed_command() { let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); let result = tool .execute(json!({"command": "echo hello"})) .await .unwrap(); assert!(result.success); assert!(result.output.trim().contains("hello")); assert!(result.error.is_none()); } #[tokio::test] async fn shell_blocks_disallowed_command() { let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); let result = tool.execute(json!({"command": "rm -rf /"})).await.unwrap(); assert!(!result.success); assert!(result.error.as_ref().unwrap().contains("not allowed")); } #[tokio::test] async fn shell_blocks_readonly() { let tool = ShellTool::new(test_security(AutonomyLevel::ReadOnly)); let result = tool.execute(json!({"command": "ls"})).await.unwrap(); assert!(!result.success); assert!(result.error.as_ref().unwrap().contains("not allowed")); } #[tokio::test] async fn shell_missing_command_param() { let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); let result = tool.execute(json!({})).await; assert!(result.is_err()); assert!(result.unwrap_err().to_string().contains("command")); } #[tokio::test] async fn shell_wrong_type_param() { let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); let result = tool.execute(json!({"command": 123})).await; assert!(result.is_err()); } #[tokio::test] async fn shell_captures_exit_code() { let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); let result = tool .execute(json!({"command": "ls /nonexistent_dir_xyz"})) .await .unwrap(); assert!(!result.success); } }