fix(channels): execute tool calls in channel runtime (#302)
* fix(channels): execute tool calls in channel runtime (#302) * chore(fmt): align repo formatting with rustfmt 1.92
This commit is contained in:
parent
efabe9703f
commit
9d29f30a31
17 changed files with 483 additions and 127 deletions
|
|
@ -339,7 +339,7 @@ struct ParsedToolCall {
|
|||
|
||||
/// Execute a single turn of the agent loop: send messages, parse tool calls,
|
||||
/// execute tools, and loop until the LLM produces a final text response.
|
||||
async fn agent_turn(
|
||||
pub(crate) async fn agent_turn(
|
||||
provider: &dyn Provider,
|
||||
history: &mut Vec<ChatMessage>,
|
||||
tools_registry: &[Box<dyn Tool>],
|
||||
|
|
@ -414,7 +414,7 @@ async fn agent_turn(
|
|||
|
||||
/// Build the tool instruction block for the system prompt so the LLM knows
|
||||
/// how to invoke tools.
|
||||
fn build_tool_instructions(tools_registry: &[Box<dyn Tool>]) -> String {
|
||||
pub(crate) fn build_tool_instructions(tools_registry: &[Box<dyn Tool>]) -> String {
|
||||
let mut instructions = String::new();
|
||||
instructions.push_str("\n## Tool Use Protocol\n\n");
|
||||
instructions.push_str("To use a tool, wrap a JSON object in <tool_call></tool_call> tags:\n\n");
|
||||
|
|
|
|||
|
|
@ -16,7 +16,12 @@ pub struct DiscordChannel {
|
|||
}
|
||||
|
||||
impl DiscordChannel {
|
||||
pub fn new(bot_token: String, guild_id: Option<String>, allowed_users: Vec<String>, listen_to_bots: bool) -> Self {
|
||||
pub fn new(
|
||||
bot_token: String,
|
||||
guild_id: Option<String>,
|
||||
allowed_users: Vec<String>,
|
||||
listen_to_bots: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
bot_token,
|
||||
guild_id,
|
||||
|
|
|
|||
|
|
@ -20,10 +20,15 @@ pub use telegram::TelegramChannel;
|
|||
pub use traits::Channel;
|
||||
pub use whatsapp::WhatsAppChannel;
|
||||
|
||||
use crate::agent::loop_::{agent_turn, build_tool_instructions};
|
||||
use crate::config::Config;
|
||||
use crate::identity;
|
||||
use crate::memory::{self, Memory};
|
||||
use crate::providers::{self, Provider};
|
||||
use crate::observability::{self, Observer};
|
||||
use crate::providers::{self, ChatMessage, Provider};
|
||||
use crate::runtime;
|
||||
use crate::security::SecurityPolicy;
|
||||
use crate::tools::{self, Tool};
|
||||
use crate::util::truncate_with_ellipsis;
|
||||
use anyhow::Result;
|
||||
use std::collections::HashMap;
|
||||
|
|
@ -46,6 +51,8 @@ struct ChannelRuntimeContext {
|
|||
channels_by_name: Arc<HashMap<String, Arc<dyn Channel>>>,
|
||||
provider: Arc<dyn Provider>,
|
||||
memory: Arc<dyn Memory>,
|
||||
tools_registry: Arc<Vec<Box<dyn Tool>>>,
|
||||
observer: Arc<dyn Observer>,
|
||||
system_prompt: Arc<String>,
|
||||
model: Arc<String>,
|
||||
temperature: f64,
|
||||
|
|
@ -166,11 +173,18 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
|||
println!(" ⏳ Processing message...");
|
||||
let started_at = Instant::now();
|
||||
|
||||
let mut history = vec![
|
||||
ChatMessage::system(ctx.system_prompt.as_str()),
|
||||
ChatMessage::user(&enriched_message),
|
||||
];
|
||||
|
||||
let llm_result = tokio::time::timeout(
|
||||
Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS),
|
||||
ctx.provider.chat_with_system(
|
||||
Some(ctx.system_prompt.as_str()),
|
||||
&enriched_message,
|
||||
agent_turn(
|
||||
ctx.provider.as_ref(),
|
||||
&mut history,
|
||||
ctx.tools_registry.as_ref(),
|
||||
ctx.observer.as_ref(),
|
||||
ctx.model.as_str(),
|
||||
ctx.temperature,
|
||||
),
|
||||
|
|
@ -323,7 +337,8 @@ pub fn build_system_prompt(
|
|||
prompt.push_str("```\n<invoke>\n{\"name\": \"tool_name\", \"arguments\": {\"param\": \"value\"}}\n</invoke>\n```\n\n");
|
||||
prompt.push_str("You may use multiple tool calls in a single response. ");
|
||||
prompt.push_str("After tool execution, results appear in <tool_result> tags. ");
|
||||
prompt.push_str("Continue reasoning with the results until you can give a final answer.\n\n");
|
||||
prompt
|
||||
.push_str("Continue reasoning with the results until you can give a final answer.\n\n");
|
||||
}
|
||||
|
||||
// ── 2. Safety ───────────────────────────────────────────────
|
||||
|
|
@ -674,6 +689,15 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
tracing::warn!("Provider warmup failed (non-fatal): {e}");
|
||||
}
|
||||
|
||||
let observer: Arc<dyn Observer> =
|
||||
Arc::from(observability::create_observer(&config.observability));
|
||||
let runtime: Arc<dyn runtime::RuntimeAdapter> =
|
||||
Arc::from(runtime::create_runtime(&config.runtime)?);
|
||||
let security = Arc::new(SecurityPolicy::from_config(
|
||||
&config.autonomy,
|
||||
&config.workspace_dir,
|
||||
));
|
||||
|
||||
let model = config
|
||||
.default_model
|
||||
.clone()
|
||||
|
|
@ -685,6 +709,22 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
config.api_key.as_deref(),
|
||||
)?);
|
||||
|
||||
let composio_key = if config.composio.enabled {
|
||||
config.composio.api_key.as_deref()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let tools_registry = Arc::new(tools::all_tools_with_runtime(
|
||||
&security,
|
||||
runtime,
|
||||
Arc::clone(&mem),
|
||||
composio_key,
|
||||
&config.browser,
|
||||
&config.http_request,
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
));
|
||||
|
||||
// Build system prompt from workspace identity files + skills
|
||||
let workspace = config.workspace_dir.clone();
|
||||
let skills = crate::skills::load_skills(&workspace);
|
||||
|
|
@ -723,14 +763,27 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
"Open approved HTTPS URLs in Brave Browser (allowlist-only, no scraping)",
|
||||
));
|
||||
}
|
||||
if config.composio.enabled {
|
||||
tool_descs.push((
|
||||
"composio",
|
||||
"Execute actions on 1000+ apps via Composio (Gmail, Notion, GitHub, Slack, etc.). Use action='list' to discover, 'execute' to run, 'connect' to OAuth.",
|
||||
));
|
||||
}
|
||||
if !config.agents.is_empty() {
|
||||
tool_descs.push((
|
||||
"delegate",
|
||||
"Delegate a subtask to a specialized agent. Use when: a task benefits from a different model (e.g. fast summarization, deep reasoning, code generation). The sub-agent runs a single prompt and returns its response.",
|
||||
));
|
||||
}
|
||||
|
||||
let system_prompt = build_system_prompt(
|
||||
let mut system_prompt = build_system_prompt(
|
||||
&workspace,
|
||||
&model,
|
||||
&tool_descs,
|
||||
&skills,
|
||||
Some(&config.identity),
|
||||
);
|
||||
system_prompt.push_str(&build_tool_instructions(tools_registry.as_ref()));
|
||||
|
||||
if !skills.is_empty() {
|
||||
println!(
|
||||
|
|
@ -875,6 +928,8 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
channels_by_name,
|
||||
provider: Arc::clone(&provider),
|
||||
memory: Arc::clone(&mem),
|
||||
tools_registry: Arc::clone(&tools_registry),
|
||||
observer,
|
||||
system_prompt: Arc::new(system_prompt),
|
||||
model: Arc::new(model.clone()),
|
||||
temperature,
|
||||
|
|
@ -895,7 +950,9 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::{Memory, MemoryCategory, SqliteMemory};
|
||||
use crate::providers::Provider;
|
||||
use crate::observability::NoopObserver;
|
||||
use crate::providers::{ChatMessage, Provider};
|
||||
use crate::tools::{Tool, ToolResult};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
|
@ -967,6 +1024,131 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
struct ToolCallingProvider;
|
||||
|
||||
fn tool_call_payload() -> String {
|
||||
serde_json::json!({
|
||||
"content": "",
|
||||
"tool_calls": [{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "mock_price",
|
||||
"arguments": "{\"symbol\":\"BTC\"}"
|
||||
}
|
||||
}]
|
||||
})
|
||||
.to_string()
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Provider for ToolCallingProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
_system_prompt: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Ok(tool_call_payload())
|
||||
}
|
||||
|
||||
async fn chat_with_history(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let has_tool_results = messages
|
||||
.iter()
|
||||
.any(|msg| msg.role == "user" && msg.content.contains("[Tool results]"));
|
||||
if has_tool_results {
|
||||
Ok("BTC is currently around $65,000 based on latest tool output.".to_string())
|
||||
} else {
|
||||
Ok(tool_call_payload())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct MockPriceTool;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Tool for MockPriceTool {
|
||||
fn name(&self) -> &str {
|
||||
"mock_price"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Return a mocked BTC price"
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"symbol": { "type": "string" }
|
||||
},
|
||||
"required": ["symbol"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let symbol = args.get("symbol").and_then(serde_json::Value::as_str);
|
||||
if symbol != Some("BTC") {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("unexpected symbol".to_string()),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: r#"{"symbol":"BTC","price_usd":65000}"#.to_string(),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_executes_tool_calls_instead_of_sending_raw_json() {
|
||||
let channel_impl = Arc::new(RecordingChannel::default());
|
||||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: Arc::new(ToolCallingProvider),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![Box::new(MockPriceTool)]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||
model: Arc::new("test-model".to_string()),
|
||||
temperature: 0.0,
|
||||
auto_save_memory: false,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
runtime_ctx,
|
||||
traits::ChannelMessage {
|
||||
id: "msg-1".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
content: "What is the BTC price now?".to_string(),
|
||||
channel: "test-channel".to_string(),
|
||||
timestamp: 1,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
let sent_messages = channel_impl.sent_messages.lock().await;
|
||||
assert_eq!(sent_messages.len(), 1);
|
||||
assert!(sent_messages[0].contains("BTC is currently around"));
|
||||
assert!(!sent_messages[0].contains("\"tool_calls\""));
|
||||
assert!(!sent_messages[0].contains("mock_price"));
|
||||
}
|
||||
|
||||
struct NoopMemory;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
|
|
@ -1030,6 +1212,8 @@ mod tests {
|
|||
delay: Duration::from_millis(250),
|
||||
}),
|
||||
memory: Arc::new(NoopMemory),
|
||||
tools_registry: Arc::new(vec![]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||
model: Arc::new("test-model".to_string()),
|
||||
temperature: 0.0,
|
||||
|
|
@ -1269,7 +1453,10 @@ mod tests {
|
|||
|
||||
// Reproduces the production crash path where channel logs truncate at 80 chars.
|
||||
let result = std::panic::catch_unwind(|| crate::util::truncate_with_ellipsis(msg, 80));
|
||||
assert!(result.is_ok(), "truncate_with_ellipsis should never panic on UTF-8");
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"truncate_with_ellipsis should never panic on UTF-8"
|
||||
);
|
||||
|
||||
let truncated = result.unwrap();
|
||||
assert!(!truncated.is_empty());
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
pub mod schema;
|
||||
|
||||
pub use schema::{
|
||||
AutonomyConfig, AuditConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config,
|
||||
AuditConfig, AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config,
|
||||
DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, GatewayConfig, HeartbeatConfig,
|
||||
HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, MemoryConfig,
|
||||
ModelRouteConfig, ObservabilityConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig,
|
||||
|
|
|
|||
|
|
@ -964,7 +964,7 @@ pub struct SandboxConfig {
|
|||
impl Default for SandboxConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: None, // Auto-detect
|
||||
enabled: None, // Auto-detect
|
||||
backend: SandboxBackend::Auto,
|
||||
firejail_args: Vec::new(),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -168,7 +168,10 @@ impl HardwareConfig {
|
|||
bail!("hardware.baud_rate must be greater than 0.");
|
||||
}
|
||||
if self.baud_rate > 4_000_000 {
|
||||
bail!("hardware.baud_rate of {} exceeds the 4 MHz safety limit.", self.baud_rate);
|
||||
bail!(
|
||||
"hardware.baud_rate of {} exceeds the 4 MHz safety limit.",
|
||||
self.baud_rate
|
||||
);
|
||||
}
|
||||
|
||||
// PWM frequency sanity
|
||||
|
|
@ -228,20 +231,16 @@ fn discover_native_gpio(devices: &mut Vec<DiscoveredDevice>) {
|
|||
if gpiomem.exists() || gpiochip.exists() {
|
||||
// Try to read model from device tree
|
||||
let model = read_board_model();
|
||||
let name = model
|
||||
.as_deref()
|
||||
.unwrap_or("Linux SBC with GPIO");
|
||||
let name = model.as_deref().unwrap_or("Linux SBC with GPIO");
|
||||
|
||||
devices.push(DiscoveredDevice {
|
||||
name: format!("{name} (Native GPIO)"),
|
||||
transport: HardwareTransport::Native,
|
||||
device_path: Some(
|
||||
if gpiomem.exists() {
|
||||
"/dev/gpiomem".into()
|
||||
} else {
|
||||
"/dev/gpiochip0".into()
|
||||
},
|
||||
),
|
||||
device_path: Some(if gpiomem.exists() {
|
||||
"/dev/gpiomem".into()
|
||||
} else {
|
||||
"/dev/gpiochip0".into()
|
||||
}),
|
||||
detail: model,
|
||||
});
|
||||
}
|
||||
|
|
@ -287,10 +286,7 @@ fn serial_device_paths() -> Vec<String> {
|
|||
"/dev/tty.wchusbserial*".into(), // CH340 clones
|
||||
]
|
||||
} else if cfg!(target_os = "linux") {
|
||||
vec![
|
||||
"/dev/ttyUSB*".into(),
|
||||
"/dev/ttyACM*".into(),
|
||||
]
|
||||
vec!["/dev/ttyUSB*".into(), "/dev/ttyACM*".into()]
|
||||
} else {
|
||||
// Windows / other — not yet supported for auto-discovery
|
||||
vec![]
|
||||
|
|
@ -452,10 +448,7 @@ pub fn create_hal(config: &HardwareConfig) -> Result<Box<dyn HardwareHal>> {
|
|||
);
|
||||
}
|
||||
HardwareTransport::Probe => {
|
||||
let target = config
|
||||
.probe_target
|
||||
.as_deref()
|
||||
.unwrap_or("unknown");
|
||||
let target = config.probe_target.as_deref().unwrap_or("unknown");
|
||||
bail!(
|
||||
"Probe transport targeting '{}' is configured but the probe-rs HAL \
|
||||
backend is not yet compiled in. This will be available in a future release.",
|
||||
|
|
@ -471,15 +464,24 @@ pub fn create_hal(config: &HardwareConfig) -> Result<Box<dyn HardwareHal>> {
|
|||
/// based on discovery results.
|
||||
pub fn recommended_wizard_default(devices: &[DiscoveredDevice]) -> usize {
|
||||
// If we found native GPIO → recommend Native (index 0)
|
||||
if devices.iter().any(|d| d.transport == HardwareTransport::Native) {
|
||||
if devices
|
||||
.iter()
|
||||
.any(|d| d.transport == HardwareTransport::Native)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
// If we found serial devices → recommend Tethered (index 1)
|
||||
if devices.iter().any(|d| d.transport == HardwareTransport::Serial) {
|
||||
if devices
|
||||
.iter()
|
||||
.any(|d| d.transport == HardwareTransport::Serial)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
// If we found debug probes → recommend Probe (index 2)
|
||||
if devices.iter().any(|d| d.transport == HardwareTransport::Probe) {
|
||||
if devices
|
||||
.iter()
|
||||
.any(|d| d.transport == HardwareTransport::Probe)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
// Default: Software Only (index 3)
|
||||
|
|
@ -487,10 +489,7 @@ pub fn recommended_wizard_default(devices: &[DiscoveredDevice]) -> usize {
|
|||
}
|
||||
|
||||
/// Build a `HardwareConfig` from a wizard selection and discovered devices.
|
||||
pub fn config_from_wizard_choice(
|
||||
choice: usize,
|
||||
devices: &[DiscoveredDevice],
|
||||
) -> HardwareConfig {
|
||||
pub fn config_from_wizard_choice(choice: usize, devices: &[DiscoveredDevice]) -> HardwareConfig {
|
||||
match choice {
|
||||
// Native
|
||||
0 => {
|
||||
|
|
@ -548,39 +547,102 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn transport_parse_native_variants() {
|
||||
assert_eq!(HardwareTransport::from_str_loose("native"), HardwareTransport::Native);
|
||||
assert_eq!(HardwareTransport::from_str_loose("gpio"), HardwareTransport::Native);
|
||||
assert_eq!(HardwareTransport::from_str_loose("rppal"), HardwareTransport::Native);
|
||||
assert_eq!(HardwareTransport::from_str_loose("sysfs"), HardwareTransport::Native);
|
||||
assert_eq!(HardwareTransport::from_str_loose("NATIVE"), HardwareTransport::Native);
|
||||
assert_eq!(HardwareTransport::from_str_loose(" Native "), HardwareTransport::Native);
|
||||
assert_eq!(
|
||||
HardwareTransport::from_str_loose("native"),
|
||||
HardwareTransport::Native
|
||||
);
|
||||
assert_eq!(
|
||||
HardwareTransport::from_str_loose("gpio"),
|
||||
HardwareTransport::Native
|
||||
);
|
||||
assert_eq!(
|
||||
HardwareTransport::from_str_loose("rppal"),
|
||||
HardwareTransport::Native
|
||||
);
|
||||
assert_eq!(
|
||||
HardwareTransport::from_str_loose("sysfs"),
|
||||
HardwareTransport::Native
|
||||
);
|
||||
assert_eq!(
|
||||
HardwareTransport::from_str_loose("NATIVE"),
|
||||
HardwareTransport::Native
|
||||
);
|
||||
assert_eq!(
|
||||
HardwareTransport::from_str_loose(" Native "),
|
||||
HardwareTransport::Native
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transport_parse_serial_variants() {
|
||||
assert_eq!(HardwareTransport::from_str_loose("serial"), HardwareTransport::Serial);
|
||||
assert_eq!(HardwareTransport::from_str_loose("uart"), HardwareTransport::Serial);
|
||||
assert_eq!(HardwareTransport::from_str_loose("usb"), HardwareTransport::Serial);
|
||||
assert_eq!(HardwareTransport::from_str_loose("tethered"), HardwareTransport::Serial);
|
||||
assert_eq!(HardwareTransport::from_str_loose("SERIAL"), HardwareTransport::Serial);
|
||||
assert_eq!(
|
||||
HardwareTransport::from_str_loose("serial"),
|
||||
HardwareTransport::Serial
|
||||
);
|
||||
assert_eq!(
|
||||
HardwareTransport::from_str_loose("uart"),
|
||||
HardwareTransport::Serial
|
||||
);
|
||||
assert_eq!(
|
||||
HardwareTransport::from_str_loose("usb"),
|
||||
HardwareTransport::Serial
|
||||
);
|
||||
assert_eq!(
|
||||
HardwareTransport::from_str_loose("tethered"),
|
||||
HardwareTransport::Serial
|
||||
);
|
||||
assert_eq!(
|
||||
HardwareTransport::from_str_loose("SERIAL"),
|
||||
HardwareTransport::Serial
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transport_parse_probe_variants() {
|
||||
assert_eq!(HardwareTransport::from_str_loose("probe"), HardwareTransport::Probe);
|
||||
assert_eq!(HardwareTransport::from_str_loose("probe-rs"), HardwareTransport::Probe);
|
||||
assert_eq!(HardwareTransport::from_str_loose("swd"), HardwareTransport::Probe);
|
||||
assert_eq!(HardwareTransport::from_str_loose("jtag"), HardwareTransport::Probe);
|
||||
assert_eq!(HardwareTransport::from_str_loose("jlink"), HardwareTransport::Probe);
|
||||
assert_eq!(HardwareTransport::from_str_loose("j-link"), HardwareTransport::Probe);
|
||||
assert_eq!(
|
||||
HardwareTransport::from_str_loose("probe"),
|
||||
HardwareTransport::Probe
|
||||
);
|
||||
assert_eq!(
|
||||
HardwareTransport::from_str_loose("probe-rs"),
|
||||
HardwareTransport::Probe
|
||||
);
|
||||
assert_eq!(
|
||||
HardwareTransport::from_str_loose("swd"),
|
||||
HardwareTransport::Probe
|
||||
);
|
||||
assert_eq!(
|
||||
HardwareTransport::from_str_loose("jtag"),
|
||||
HardwareTransport::Probe
|
||||
);
|
||||
assert_eq!(
|
||||
HardwareTransport::from_str_loose("jlink"),
|
||||
HardwareTransport::Probe
|
||||
);
|
||||
assert_eq!(
|
||||
HardwareTransport::from_str_loose("j-link"),
|
||||
HardwareTransport::Probe
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transport_parse_none_and_unknown() {
|
||||
assert_eq!(HardwareTransport::from_str_loose("none"), HardwareTransport::None);
|
||||
assert_eq!(HardwareTransport::from_str_loose(""), HardwareTransport::None);
|
||||
assert_eq!(HardwareTransport::from_str_loose("foobar"), HardwareTransport::None);
|
||||
assert_eq!(HardwareTransport::from_str_loose("bluetooth"), HardwareTransport::None);
|
||||
assert_eq!(
|
||||
HardwareTransport::from_str_loose("none"),
|
||||
HardwareTransport::None
|
||||
);
|
||||
assert_eq!(
|
||||
HardwareTransport::from_str_loose(""),
|
||||
HardwareTransport::None
|
||||
);
|
||||
assert_eq!(
|
||||
HardwareTransport::from_str_loose("foobar"),
|
||||
HardwareTransport::None
|
||||
);
|
||||
assert_eq!(
|
||||
HardwareTransport::from_str_loose("bluetooth"),
|
||||
HardwareTransport::None
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -918,7 +980,9 @@ mod tests {
|
|||
#[test]
|
||||
fn noop_hal_firmware_upload_fails() {
|
||||
let hal = NoopHal;
|
||||
let err = hal.firmware_upload(Path::new("/tmp/firmware.bin")).unwrap_err();
|
||||
let err = hal
|
||||
.firmware_upload(Path::new("/tmp/firmware.bin"))
|
||||
.unwrap_err();
|
||||
assert!(err.to_string().contains("not enabled"));
|
||||
assert!(err.to_string().contains("firmware.bin"));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1093,7 +1093,9 @@ fn setup_hardware() -> Result<HardwareConfig> {
|
|||
}
|
||||
|
||||
// ── Probe: ask for target chip ──
|
||||
if hw_config.transport_mode() == hardware::HardwareTransport::Probe && hw_config.probe_target.is_none() {
|
||||
if hw_config.transport_mode() == hardware::HardwareTransport::Probe
|
||||
&& hw_config.probe_target.is_none()
|
||||
{
|
||||
let target: String = Input::new()
|
||||
.with_prompt(" Target MCU chip (e.g. STM32F411CEUx, nRF52840_xxAA)")
|
||||
.default("STM32F411CEUx".into())
|
||||
|
|
@ -2698,21 +2700,25 @@ fn print_summary(config: &Config) {
|
|||
if config.hardware.enabled {
|
||||
let mode = config.hardware.transport_mode();
|
||||
match mode {
|
||||
hardware::HardwareTransport::Native => style("Native GPIO (direct)").green().to_string(),
|
||||
hardware::HardwareTransport::Native => {
|
||||
style("Native GPIO (direct)").green().to_string()
|
||||
}
|
||||
hardware::HardwareTransport::Serial => format!(
|
||||
"{}",
|
||||
style(format!(
|
||||
"Serial → {} @ {} baud",
|
||||
config.hardware.serial_port.as_deref().unwrap_or("?"),
|
||||
config.hardware.baud_rate
|
||||
)).green()
|
||||
))
|
||||
.green()
|
||||
),
|
||||
hardware::HardwareTransport::Probe => format!(
|
||||
"{}",
|
||||
style(format!(
|
||||
"Probe → {}",
|
||||
config.hardware.probe_target.as_deref().unwrap_or("?")
|
||||
)).green()
|
||||
))
|
||||
.green()
|
||||
),
|
||||
hardware::HardwareTransport::None => "disabled (software only)".to_string(),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -88,7 +88,12 @@ impl AuditEvent {
|
|||
}
|
||||
|
||||
/// Set the actor
|
||||
pub fn with_actor(mut self, channel: String, user_id: Option<String>, username: Option<String>) -> Self {
|
||||
pub fn with_actor(
|
||||
mut self,
|
||||
channel: String,
|
||||
user_id: Option<String>,
|
||||
username: Option<String>,
|
||||
) -> Self {
|
||||
self.actor = Some(Actor {
|
||||
channel,
|
||||
user_id,
|
||||
|
|
@ -98,7 +103,13 @@ impl AuditEvent {
|
|||
}
|
||||
|
||||
/// Set the action
|
||||
pub fn with_action(mut self, command: String, risk_level: String, approved: bool, allowed: bool) -> Self {
|
||||
pub fn with_action(
|
||||
mut self,
|
||||
command: String,
|
||||
risk_level: String,
|
||||
approved: bool,
|
||||
allowed: bool,
|
||||
) -> Self {
|
||||
self.action = Some(Action {
|
||||
command: Some(command),
|
||||
risk_level: Some(risk_level),
|
||||
|
|
@ -109,7 +120,13 @@ impl AuditEvent {
|
|||
}
|
||||
|
||||
/// Set the result
|
||||
pub fn with_result(mut self, success: bool, exit_code: Option<i32>, duration_ms: u64, error: Option<String>) -> Self {
|
||||
pub fn with_result(
|
||||
mut self,
|
||||
success: bool,
|
||||
exit_code: Option<i32>,
|
||||
duration_ms: u64,
|
||||
error: Option<String>,
|
||||
) -> Self {
|
||||
self.result = Some(ExecutionResult {
|
||||
success,
|
||||
exit_code,
|
||||
|
|
@ -179,7 +196,12 @@ impl AuditLogger {
|
|||
) -> Result<()> {
|
||||
let event = AuditEvent::new(AuditEventType::CommandExecution)
|
||||
.with_actor(channel.to_string(), None, None)
|
||||
.with_action(command.to_string(), risk_level.to_string(), approved, allowed)
|
||||
.with_action(
|
||||
command.to_string(),
|
||||
risk_level.to_string(),
|
||||
approved,
|
||||
allowed,
|
||||
)
|
||||
.with_result(success, None, duration_ms, None);
|
||||
|
||||
self.log(&event)
|
||||
|
|
@ -224,8 +246,11 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn audit_event_with_actor() {
|
||||
let event = AuditEvent::new(AuditEventType::CommandExecution)
|
||||
.with_actor("telegram".to_string(), Some("123".to_string()), Some("@alice".to_string()));
|
||||
let event = AuditEvent::new(AuditEventType::CommandExecution).with_actor(
|
||||
"telegram".to_string(),
|
||||
Some("123".to_string()),
|
||||
Some("@alice".to_string()),
|
||||
);
|
||||
|
||||
assert!(event.actor.is_some());
|
||||
let actor = event.actor.as_ref().unwrap();
|
||||
|
|
@ -236,8 +261,12 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn audit_event_with_action() {
|
||||
let event = AuditEvent::new(AuditEventType::CommandExecution)
|
||||
.with_action("ls -la".to_string(), "low".to_string(), false, true);
|
||||
let event = AuditEvent::new(AuditEventType::CommandExecution).with_action(
|
||||
"ls -la".to_string(),
|
||||
"low".to_string(),
|
||||
false,
|
||||
true,
|
||||
);
|
||||
|
||||
assert!(event.action.is_some());
|
||||
let action = event.action.as_ref().unwrap();
|
||||
|
|
|
|||
|
|
@ -35,14 +35,23 @@ impl BubblewrapSandbox {
|
|||
impl Sandbox for BubblewrapSandbox {
|
||||
fn wrap_command(&self, cmd: &mut Command) -> std::io::Result<()> {
|
||||
let program = cmd.get_program().to_string_lossy().to_string();
|
||||
let args: Vec<String> = cmd.get_args().map(|s| s.to_string_lossy().to_string()).collect();
|
||||
let args: Vec<String> = cmd
|
||||
.get_args()
|
||||
.map(|s| s.to_string_lossy().to_string())
|
||||
.collect();
|
||||
|
||||
let mut bwrap_cmd = Command::new("bwrap");
|
||||
bwrap_cmd.args([
|
||||
"--ro-bind", "/usr", "/usr",
|
||||
"--dev", "/dev",
|
||||
"--proc", "/proc",
|
||||
"--bind", "/tmp", "/tmp",
|
||||
"--ro-bind",
|
||||
"/usr",
|
||||
"/usr",
|
||||
"--dev",
|
||||
"/dev",
|
||||
"--proc",
|
||||
"/proc",
|
||||
"--bind",
|
||||
"/tmp",
|
||||
"/tmp",
|
||||
"--unshare-all",
|
||||
"--die-with-parent",
|
||||
]);
|
||||
|
|
|
|||
|
|
@ -25,7 +25,9 @@ pub fn create_sandbox(config: &SecurityConfig) -> Arc<dyn Sandbox> {
|
|||
}
|
||||
}
|
||||
}
|
||||
tracing::warn!("Landlock requested but not available, falling back to application-layer");
|
||||
tracing::warn!(
|
||||
"Landlock requested but not available, falling back to application-layer"
|
||||
);
|
||||
Arc::new(super::traits::NoopSandbox)
|
||||
}
|
||||
SandboxBackend::Firejail => {
|
||||
|
|
@ -35,7 +37,9 @@ pub fn create_sandbox(config: &SecurityConfig) -> Arc<dyn Sandbox> {
|
|||
return Arc::new(sandbox);
|
||||
}
|
||||
}
|
||||
tracing::warn!("Firejail requested but not available, falling back to application-layer");
|
||||
tracing::warn!(
|
||||
"Firejail requested but not available, falling back to application-layer"
|
||||
);
|
||||
Arc::new(super::traits::NoopSandbox)
|
||||
}
|
||||
SandboxBackend::Bubblewrap => {
|
||||
|
|
@ -48,7 +52,9 @@ pub fn create_sandbox(config: &SecurityConfig) -> Arc<dyn Sandbox> {
|
|||
}
|
||||
}
|
||||
}
|
||||
tracing::warn!("Bubblewrap requested but not available, falling back to application-layer");
|
||||
tracing::warn!(
|
||||
"Bubblewrap requested but not available, falling back to application-layer"
|
||||
);
|
||||
Arc::new(super::traits::NoopSandbox)
|
||||
}
|
||||
SandboxBackend::Docker => {
|
||||
|
|
@ -138,7 +144,7 @@ mod tests {
|
|||
fn auto_mode_detects_something() {
|
||||
let config = SecurityConfig {
|
||||
sandbox: SandboxConfig {
|
||||
enabled: None, // Auto-detect
|
||||
enabled: None, // Auto-detect
|
||||
backend: SandboxBackend::Auto,
|
||||
firejail_args: Vec::new(),
|
||||
},
|
||||
|
|
|
|||
|
|
@ -56,14 +56,21 @@ impl DockerSandbox {
|
|||
impl Sandbox for DockerSandbox {
|
||||
fn wrap_command(&self, cmd: &mut Command) -> std::io::Result<()> {
|
||||
let program = cmd.get_program().to_string_lossy().to_string();
|
||||
let args: Vec<String> = cmd.get_args().map(|s| s.to_string_lossy().to_string()).collect();
|
||||
let args: Vec<String> = cmd
|
||||
.get_args()
|
||||
.map(|s| s.to_string_lossy().to_string())
|
||||
.collect();
|
||||
|
||||
let mut docker_cmd = Command::new("docker");
|
||||
docker_cmd.args([
|
||||
"run", "--rm",
|
||||
"--memory", "512m",
|
||||
"--cpus", "1.0",
|
||||
"--network", "none",
|
||||
"run",
|
||||
"--rm",
|
||||
"--memory",
|
||||
"512m",
|
||||
"--cpus",
|
||||
"1.0",
|
||||
"--network",
|
||||
"none",
|
||||
]);
|
||||
docker_cmd.arg(&self.image);
|
||||
docker_cmd.arg(&program);
|
||||
|
|
|
|||
|
|
@ -41,20 +41,23 @@ impl Sandbox for FirejailSandbox {
|
|||
fn wrap_command(&self, cmd: &mut Command) -> std::io::Result<()> {
|
||||
// Prepend firejail to the command
|
||||
let program = cmd.get_program().to_string_lossy().to_string();
|
||||
let args: Vec<String> = cmd.get_args().map(|s| s.to_string_lossy().to_string()).collect();
|
||||
let args: Vec<String> = cmd
|
||||
.get_args()
|
||||
.map(|s| s.to_string_lossy().to_string())
|
||||
.collect();
|
||||
|
||||
// Build firejail wrapper with security flags
|
||||
let mut firejail_cmd = Command::new("firejail");
|
||||
firejail_cmd.args([
|
||||
"--private=home", // New home directory
|
||||
"--private-dev", // Minimal /dev
|
||||
"--nosound", // No audio
|
||||
"--no3d", // No 3D acceleration
|
||||
"--novideo", // No video devices
|
||||
"--nowheel", // No input devices
|
||||
"--notv", // No TV devices
|
||||
"--noprofile", // Skip profile loading
|
||||
"--quiet", // Suppress warnings
|
||||
"--private=home", // New home directory
|
||||
"--private-dev", // Minimal /dev
|
||||
"--nosound", // No audio
|
||||
"--no3d", // No 3D acceleration
|
||||
"--novideo", // No video devices
|
||||
"--nowheel", // No input devices
|
||||
"--notv", // No TV devices
|
||||
"--noprofile", // Skip profile loading
|
||||
"--quiet", // Suppress warnings
|
||||
]);
|
||||
|
||||
// Add the original command
|
||||
|
|
@ -100,7 +103,10 @@ mod tests {
|
|||
let result = FirejailSandbox::new();
|
||||
match result {
|
||||
Ok(_) => println!("Firejail is installed"),
|
||||
Err(e) => assert!(e.kind() == std::io::ErrorKind::NotFound || e.kind() == std::io::ErrorKind::Unsupported),
|
||||
Err(e) => assert!(
|
||||
e.kind() == std::io::ErrorKind::NotFound
|
||||
|| e.kind() == std::io::ErrorKind::Unsupported
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -26,8 +26,7 @@ impl LandlockSandbox {
|
|||
/// Create a Landlock sandbox with a specific workspace directory
|
||||
pub fn with_workspace(workspace_dir: Option<std::path::PathBuf>) -> std::io::Result<Self> {
|
||||
// Test if Landlock is available by trying to create a minimal ruleset
|
||||
let test_ruleset = Ruleset::new()
|
||||
.set_access_fs(AccessFS::read_file | AccessFS::write_file);
|
||||
let test_ruleset = Ruleset::new().set_access_fs(AccessFS::read_file | AccessFS::write_file);
|
||||
|
||||
match test_ruleset.create() {
|
||||
Ok(_) => Ok(Self { workspace_dir }),
|
||||
|
|
@ -48,30 +47,35 @@ impl LandlockSandbox {
|
|||
|
||||
/// Apply Landlock restrictions to the current process
|
||||
fn apply_restrictions(&self) -> std::io::Result<()> {
|
||||
let mut ruleset = Ruleset::new()
|
||||
.set_access_fs(
|
||||
AccessFS::read_file
|
||||
| AccessFS::write_file
|
||||
| AccessFS::read_dir
|
||||
| AccessFS::remove_dir
|
||||
| AccessFS::remove_file
|
||||
| AccessFS::make_char
|
||||
| AccessFS::make_sock
|
||||
| AccessFS::make_fifo
|
||||
| AccessFS::make_block
|
||||
| AccessFS::make_reg
|
||||
| AccessFS::make_sym
|
||||
);
|
||||
let mut ruleset = Ruleset::new().set_access_fs(
|
||||
AccessFS::read_file
|
||||
| AccessFS::write_file
|
||||
| AccessFS::read_dir
|
||||
| AccessFS::remove_dir
|
||||
| AccessFS::remove_file
|
||||
| AccessFS::make_char
|
||||
| AccessFS::make_sock
|
||||
| AccessFS::make_fifo
|
||||
| AccessFS::make_block
|
||||
| AccessFS::make_reg
|
||||
| AccessFS::make_sym,
|
||||
);
|
||||
|
||||
// Allow workspace directory (read/write)
|
||||
if let Some(ref workspace) = self.workspace_dir {
|
||||
if workspace.exists() {
|
||||
ruleset = ruleset.add_path(workspace, AccessFS::read_file | AccessFS::write_file | AccessFS::read_dir)?;
|
||||
ruleset = ruleset.add_path(
|
||||
workspace,
|
||||
AccessFS::read_file | AccessFS::write_file | AccessFS::read_dir,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Allow /tmp for general operations
|
||||
ruleset = ruleset.add_path(Path::new("/tmp"), AccessFS::read_file | AccessFS::write_file)?;
|
||||
ruleset = ruleset.add_path(
|
||||
Path::new("/tmp"),
|
||||
AccessFS::read_file | AccessFS::write_file,
|
||||
)?;
|
||||
|
||||
// Allow /usr and /bin for executing commands
|
||||
ruleset = ruleset.add_path(Path::new("/usr"), AccessFS::read_file | AccessFS::read_dir)?;
|
||||
|
|
@ -193,7 +197,10 @@ mod tests {
|
|||
// Result depends on platform and feature flag
|
||||
match result {
|
||||
Ok(sandbox) => assert!(sandbox.is_available()),
|
||||
Err(_) => assert!(!cfg!(all(feature = "sandbox-landlock", target_os = "linux"))),
|
||||
Err(_) => assert!(!cfg!(all(
|
||||
feature = "sandbox-landlock",
|
||||
target_os = "linux"
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
pub mod audit;
|
||||
pub mod detect;
|
||||
#[cfg(feature = "sandbox-bubblewrap")]
|
||||
pub mod bubblewrap;
|
||||
pub mod detect;
|
||||
pub mod docker;
|
||||
#[cfg(target_os = "linux")]
|
||||
pub mod firejail;
|
||||
|
|
|
|||
|
|
@ -61,7 +61,10 @@ mod tests {
|
|||
let mut cmd = Command::new("echo");
|
||||
cmd.arg("test");
|
||||
let original_program = cmd.get_program().to_string_lossy().to_string();
|
||||
let original_args: Vec<String> = cmd.get_args().map(|s| s.to_string_lossy().to_string()).collect();
|
||||
let original_args: Vec<String> = cmd
|
||||
.get_args()
|
||||
.map(|s| s.to_string_lossy().to_string())
|
||||
.collect();
|
||||
|
||||
let sandbox = NoopSandbox;
|
||||
assert!(sandbox.wrap_command(&mut cmd).is_ok());
|
||||
|
|
@ -69,7 +72,9 @@ mod tests {
|
|||
// Command should be unchanged
|
||||
assert_eq!(cmd.get_program().to_string_lossy(), original_program);
|
||||
assert_eq!(
|
||||
cmd.get_args().map(|s| s.to_string_lossy().to_string()).collect::<Vec<_>>(),
|
||||
cmd.get_args()
|
||||
.map(|s| s.to_string_lossy().to_string())
|
||||
.collect::<Vec<_>>(),
|
||||
original_args
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -124,7 +124,10 @@ impl HttpRequestTool {
|
|||
|
||||
fn truncate_response(&self, text: &str) -> String {
|
||||
if text.len() > self.max_response_size {
|
||||
let mut truncated = text.chars().take(self.max_response_size).collect::<String>();
|
||||
let mut truncated = text
|
||||
.chars()
|
||||
.take(self.max_response_size)
|
||||
.collect::<String>();
|
||||
truncated.push_str("\n\n... [Response truncated due to size limit] ...");
|
||||
truncated
|
||||
} else {
|
||||
|
|
@ -221,7 +224,10 @@ impl Tool for HttpRequestTool {
|
|||
|
||||
let sanitized_headers = self.sanitize_headers(&headers_val);
|
||||
|
||||
match self.execute_request(&url, method, sanitized_headers, body).await {
|
||||
match self
|
||||
.execute_request(&url, method, sanitized_headers, body)
|
||||
.await
|
||||
{
|
||||
Ok(response) => {
|
||||
let status = response.status();
|
||||
let status_code = status.as_u16();
|
||||
|
|
@ -407,7 +413,12 @@ mod tests {
|
|||
autonomy: AutonomyLevel::Supervised,
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
HttpRequestTool::new(security, allowed_domains.into_iter().map(String::from).collect(), 1_000_000, 30)
|
||||
HttpRequestTool::new(
|
||||
security,
|
||||
allowed_domains.into_iter().map(String::from).collect(),
|
||||
1_000_000,
|
||||
30,
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -598,8 +609,14 @@ mod tests {
|
|||
});
|
||||
let sanitized = tool.sanitize_headers(&headers);
|
||||
assert_eq!(sanitized.len(), 3);
|
||||
assert!(sanitized.iter().any(|(k, v)| k == "Authorization" && v == "***REDACTED***"));
|
||||
assert!(sanitized.iter().any(|(k, v)| k == "X-API-Key" && v == "***REDACTED***"));
|
||||
assert!(sanitized.iter().any(|(k, v)| k == "Content-Type" && v == "application/json"));
|
||||
assert!(sanitized
|
||||
.iter()
|
||||
.any(|(k, v)| k == "Authorization" && v == "***REDACTED***"));
|
||||
assert!(sanitized
|
||||
.iter()
|
||||
.any(|(k, v)| k == "X-API-Key" && v == "***REDACTED***"));
|
||||
assert!(sanitized
|
||||
.iter()
|
||||
.any(|(k, v)| k == "Content-Type" && v == "application/json"));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -320,7 +320,15 @@ mod tests {
|
|||
},
|
||||
);
|
||||
|
||||
let tools = all_tools(&security, mem, None, &browser, &http, &agents, Some("sk-test"));
|
||||
let tools = all_tools(
|
||||
&security,
|
||||
mem,
|
||||
None,
|
||||
&browser,
|
||||
&http,
|
||||
&agents,
|
||||
Some("sk-test"),
|
||||
);
|
||||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||
assert!(names.contains(&"delegate"));
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue