diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index caa7e53..a19d271 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -15,6 +15,45 @@ use std::sync::{Arc, LazyLock}; use std::time::Instant; use uuid::Uuid; +/// Events emitted during tool execution for real-time status display in channels. +#[derive(Debug, Clone)] +pub enum ToolStatusEvent { + /// LLM request started (thinking). + Thinking, + /// A tool is about to execute. + ToolStart { + name: String, + detail: Option, + }, +} + +/// Extract a short display summary from tool arguments for status display. +pub fn extract_tool_detail(tool_name: &str, args: &serde_json::Value) -> Option { + match tool_name { + "shell" => args.get("command").and_then(|v| v.as_str()).map(|s| { + if s.len() > 60 { + format!("{}...", &s[..57]) + } else { + s.to_string() + } + }), + "file_read" | "file_write" => args.get("path").and_then(|v| v.as_str()).map(String::from), + "memory_recall" | "web_search_tool" => args + .get("query") + .and_then(|v| v.as_str()) + .map(|s| format!("\"{s}\"")), + "http_request" | "browser_open" => { + args.get("url").and_then(|v| v.as_str()).map(String::from) + } + "git_operations" => args + .get("operation") + .and_then(|v| v.as_str()) + .map(String::from), + "memory_store" => args.get("key").and_then(|v| v.as_str()).map(String::from), + _ => None, + } +} + /// Minimum characters per chunk when relaying LLM text to a streaming draft. const STREAM_CHUNK_MIN_CHARS: usize = 80; @@ -841,6 +880,7 @@ pub(crate) async fn agent_turn( "channel", max_tool_iterations, None, + None, ) .await } @@ -861,6 +901,7 @@ pub(crate) async fn run_tool_call_loop( channel_name: &str, max_tool_iterations: usize, on_delta: Option>, + on_tool_status: Option>, ) -> Result { let max_iterations = if max_tool_iterations == 0 { DEFAULT_MAX_TOOL_ITERATIONS @@ -873,6 +914,10 @@ pub(crate) async fn run_tool_call_loop( let use_native_tools = provider.supports_native_tools() && !tool_specs.is_empty(); for _iteration in 0..max_iterations { + if let Some(ref tx) = on_tool_status { + let _ = tx.send(ToolStatusEvent::Thinking).await; + } + observer.record_event(&ObserverEvent::LlmRequest { provider: provider_name.to_string(), model: model.to_string(), @@ -1026,6 +1071,15 @@ pub(crate) async fn run_tool_call_loop( observer.record_event(&ObserverEvent::ToolCallStart { tool: call.name.clone(), }); + if let Some(ref tx) = on_tool_status { + let detail = extract_tool_detail(&call.name, &call.arguments); + let _ = tx + .send(ToolStatusEvent::ToolStart { + name: call.name.clone(), + detail, + }) + .await; + } let start = Instant::now(); let result = if let Some(tool) = find_tool(tools_registry, &call.name) { match tool.execute(call.arguments.clone()).await { @@ -1398,6 +1452,7 @@ pub async fn run( "cli", config.agent.max_tool_iterations, None, + None, ) .await?; final_output = response.clone(); @@ -1524,6 +1579,7 @@ pub async fn run( "cli", config.agent.max_tool_iterations, None, + None, ) .await { @@ -2511,4 +2567,98 @@ browser_open/url>https://example.com"#; assert_eq!(calls[0].arguments["command"], "pwd"); assert_eq!(text, "Done"); } + + // ═══════════════════════════════════════════════════════════════════════ + // Tool Status Display - extract_tool_detail + // ═══════════════════════════════════════════════════════════════════════ + + #[test] + fn extract_tool_detail_shell_short() { + let args = serde_json::json!({"command": "ls -la"}); + assert_eq!(extract_tool_detail("shell", &args), Some("ls -la".into())); + } + + #[test] + fn extract_tool_detail_shell_truncates_long_command() { + let long = "a".repeat(80); + let args = serde_json::json!({"command": long}); + let detail = extract_tool_detail("shell", &args).unwrap(); + assert_eq!(detail.len(), 60); // 57 chars + "..." + assert!(detail.ends_with("...")); + } + + #[test] + fn extract_tool_detail_file_read() { + let args = serde_json::json!({"path": "src/main.rs"}); + assert_eq!( + extract_tool_detail("file_read", &args), + Some("src/main.rs".into()) + ); + } + + #[test] + fn extract_tool_detail_file_write() { + let args = serde_json::json!({"path": "/tmp/out.txt", "content": "data"}); + assert_eq!( + extract_tool_detail("file_write", &args), + Some("/tmp/out.txt".into()) + ); + } + + #[test] + fn extract_tool_detail_memory_recall() { + let args = serde_json::json!({"query": "project goals"}); + assert_eq!( + extract_tool_detail("memory_recall", &args), + Some("\"project goals\"".into()) + ); + } + + #[test] + fn extract_tool_detail_web_search() { + let args = serde_json::json!({"query": "rust async"}); + assert_eq!( + extract_tool_detail("web_search_tool", &args), + Some("\"rust async\"".into()) + ); + } + + #[test] + fn extract_tool_detail_http_request() { + let args = serde_json::json!({"url": "https://example.com/api", "method": "GET"}); + assert_eq!( + extract_tool_detail("http_request", &args), + Some("https://example.com/api".into()) + ); + } + + #[test] + fn extract_tool_detail_git_operations() { + let args = serde_json::json!({"operation": "status"}); + assert_eq!( + extract_tool_detail("git_operations", &args), + Some("status".into()) + ); + } + + #[test] + fn extract_tool_detail_memory_store() { + let args = serde_json::json!({"key": "user_pref", "value": "dark mode"}); + assert_eq!( + extract_tool_detail("memory_store", &args), + Some("user_pref".into()) + ); + } + + #[test] + fn extract_tool_detail_unknown_tool_returns_none() { + let args = serde_json::json!({"foo": "bar"}); + assert_eq!(extract_tool_detail("unknown_tool", &args), None); + } + + #[test] + fn extract_tool_detail_missing_key_returns_none() { + let args = serde_json::json!({"other": "value"}); + assert_eq!(extract_tool_detail("shell", &args), None); + } } diff --git a/src/channels/matrix.rs b/src/channels/matrix.rs index 0b063c5..0216a21 100644 --- a/src/channels/matrix.rs +++ b/src/channels/matrix.rs @@ -726,7 +726,7 @@ mod tests { "!r:m".to_string(), vec![], Some(" ".to_string()), - Some("".to_string()), + Some(String::new()), ); assert!(ch.session_user_id_hint.is_none()); diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 0fff1ec..6e8f2f8 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -30,7 +30,7 @@ pub use telegram::TelegramChannel; pub use traits::{Channel, SendMessage}; pub use whatsapp::WhatsAppChannel; -use crate::agent::loop_::{build_tool_instructions, run_tool_call_loop}; +use crate::agent::loop_::{build_tool_instructions, run_tool_call_loop, ToolStatusEvent}; use crate::config::Config; use crate::identity; use crate::memory::{self, Memory}; @@ -60,9 +60,6 @@ const BOOTSTRAP_MAX_CHARS: usize = 20_000; const DEFAULT_CHANNEL_INITIAL_BACKOFF_SECS: u64 = 2; const DEFAULT_CHANNEL_MAX_BACKOFF_SECS: u64 = 60; -/// Timeout for processing a single channel message (LLM + tools). -/// 300s for on-device LLMs (Ollama) which are slower than cloud APIs. -const CHANNEL_MESSAGE_TIMEOUT_SECS: u64 = 300; const CHANNEL_PARALLELISM_PER_CHANNEL: usize = 4; const CHANNEL_MIN_IN_FLIGHT_MESSAGES: usize = 8; const CHANNEL_MAX_IN_FLIGHT_MESSAGES: usize = 64; @@ -85,6 +82,10 @@ enum ChannelRuntimeCommand { SetProvider(String), ShowModel, SetModel(String), + Clear, + ShowSystem, + ShowStatus, + Help, } #[derive(Debug, Clone, Default, Deserialize)] @@ -120,6 +121,7 @@ struct ChannelRuntimeContext { reliability: Arc, provider_runtime_options: providers::ProviderRuntimeOptions, workspace_dir: Arc, + channel_message_timeout_secs: u64, } fn conversation_memory_key(msg: &traits::ChannelMessage) -> String { @@ -179,6 +181,10 @@ fn parse_runtime_command(channel_name: &str, content: &str) -> Option Some(ChannelRuntimeCommand::Clear), + "/system" => Some(ChannelRuntimeCommand::ShowSystem), + "/status" => Some(ChannelRuntimeCommand::ShowStatus), + "/help" => Some(ChannelRuntimeCommand::Help), _ => None, } } @@ -425,6 +431,52 @@ async fn handle_runtime_command_if_needed( ) } } + ChannelRuntimeCommand::Clear => { + clear_sender_history(ctx, &sender_key); + "Conversation history cleared.".to_string() + } + ChannelRuntimeCommand::ShowSystem => { + let prompt = ctx.system_prompt.as_str(); + if prompt.is_empty() { + "No system prompt configured.".to_string() + } else { + let truncated = truncate_with_ellipsis(prompt, 2000); + format!("```\n{truncated}\n```") + } + } + ChannelRuntimeCommand::ShowStatus => { + let tool_names: Vec<&str> = ctx.tools_registry.iter().map(|t| t.name()).collect(); + let mut status = String::new(); + let _ = writeln!(status, "Provider: `{}`", current.provider); + let _ = writeln!(status, "Model: `{}`", current.model); + let _ = writeln!(status, "Temperature: `{}`", ctx.temperature); + let _ = writeln!( + status, + "Tools: {} ({})", + tool_names.len(), + tool_names.join(", ") + ); + let _ = writeln!(status, "Memory: `{}`", ctx.memory.name()); + let _ = writeln!(status, "Max tool iterations: `{}`", ctx.max_tool_iterations); + let _ = writeln!( + status, + "Message timeout: `{}s`", + ctx.channel_message_timeout_secs + ); + status + } + ChannelRuntimeCommand::Help => { + let mut help = String::new(); + help.push_str("/help \u{2014} Show available commands\n"); + help.push_str("/model \u{2014} Show current model\n"); + help.push_str("/model \u{2014} Switch model\n"); + help.push_str("/models \u{2014} List providers\n"); + help.push_str("/models \u{2014} Switch provider\n"); + help.push_str("/clear \u{2014} Clear conversation history\n"); + help.push_str("/system \u{2014} Show system prompt\n"); + help.push_str("/status \u{2014} Show current configuration\n"); + help + } }; if let Err(err) = channel @@ -523,6 +575,15 @@ fn log_worker_join_result(result: Result<(), tokio::task::JoinError>) { } } +/// Compose tool status lines with the current draft content for display. +fn format_tool_display(tool_lines: &str, content: &str) -> String { + if tool_lines.is_empty() { + content.to_string() + } else { + format!("{tool_lines}{content}") + } +} + fn spawn_scoped_typing_task( channel: Arc, recipient: String, @@ -619,32 +680,38 @@ async fn process_channel_message(ctx: Arc, msg: traits::C .cloned() .unwrap_or_default(); - let mut history = vec![ChatMessage::system(ctx.system_prompt.as_str())]; + let system_prompt = match channel_delivery_instructions(&msg.channel) { + Some(instructions) => format!("{}\n\n{instructions}", ctx.system_prompt), + None => ctx.system_prompt.to_string(), + }; + let mut history = vec![ChatMessage::system(&system_prompt)]; history.append(&mut prior_turns); history.push(ChatMessage::user(&enriched_message)); - if let Some(instructions) = channel_delivery_instructions(&msg.channel) { - history.push(ChatMessage::system(instructions)); - } - // Determine if this channel supports streaming draft updates let use_streaming = target_channel .as_ref() .map_or(false, |ch| ch.supports_draft_updates()); - // Set up streaming channel if supported + // Set up streaming channels if supported let (delta_tx, delta_rx) = if use_streaming { let (tx, rx) = tokio::sync::mpsc::channel::(64); (Some(tx), Some(rx)) } else { (None, None) }; + let (tool_status_tx, tool_status_rx) = if use_streaming { + let (tx, rx) = tokio::sync::mpsc::channel::(32); + (Some(tx), Some(rx)) + } else { + (None, None) + }; // Send initial draft message if streaming let draft_message_id = if use_streaming { if let Some(channel) = target_channel.as_ref() { match channel - .send_draft(&SendMessage::new("...", &msg.reply_target)) + .send_draft(&SendMessage::new("Thinking...", &msg.reply_target)) .await { Ok(id) => id, @@ -660,24 +727,88 @@ async fn process_channel_message(ctx: Arc, msg: traits::C None }; - // Spawn a task to forward streaming deltas to draft updates - let draft_updater = if let (Some(mut rx), Some(draft_id_ref), Some(channel_ref)) = ( - delta_rx, - draft_message_id.as_deref(), - target_channel.as_ref(), - ) { + // Spawn a task to merge tool status events and streaming deltas into draft updates + let draft_updater = if let (Some(draft_id_ref), Some(channel_ref)) = + (draft_message_id.as_deref(), target_channel.as_ref()) + { let channel = Arc::clone(channel_ref); let reply_target = msg.reply_target.clone(); let draft_id = draft_id_ref.to_string(); + let mut delta_rx = delta_rx; + let mut tool_status_rx = tool_status_rx; Some(tokio::spawn(async move { + let mut tool_lines = String::new(); let mut accumulated = String::new(); - while let Some(delta) = rx.recv().await { - accumulated.push_str(&delta); - if let Err(e) = channel - .update_draft(&reply_target, &draft_id, &accumulated) - .await - { - tracing::debug!("Draft update failed: {e}"); + + loop { + tokio::select! { + evt = async { + match tool_status_rx.as_mut() { + Some(rx) => rx.recv().await, + None => std::future::pending().await, + } + } => { + match evt { + Some(ToolStatusEvent::Thinking) => { + let display = format_tool_display(&tool_lines, "Thinking..."); + if let Err(e) = channel + .update_draft(&reply_target, &draft_id, &display) + .await + { + tracing::debug!("Draft update failed: {e}"); + } + } + Some(ToolStatusEvent::ToolStart { name, detail }) => { + let label = match detail { + Some(d) => format!("\u{1f527} {name}({d})\n"), + None => format!("\u{1f527} {name}\n"), + }; + tool_lines.push_str(&label); + let display = + format_tool_display(&tool_lines, "Thinking..."); + if let Err(e) = channel + .update_draft(&reply_target, &draft_id, &display) + .await + { + tracing::debug!("Draft update failed: {e}"); + } + } + None => { + // Tool status channel closed; keep consuming deltas + tool_status_rx = None; + if delta_rx.is_none() { + break; + } + } + } + } + delta = async { + match delta_rx.as_mut() { + Some(rx) => rx.recv().await, + None => std::future::pending().await, + } + } => { + match delta { + Some(text) => { + accumulated.push_str(&text); + let display = + format_tool_display(&tool_lines, &accumulated); + if let Err(e) = channel + .update_draft(&reply_target, &draft_id, &display) + .await + { + tracing::debug!("Draft update failed: {e}"); + } + } + None => { + // Delta channel closed; keep consuming tool events + delta_rx = None; + if tool_status_rx.is_none() { + break; + } + } + } + } } } })) @@ -696,7 +827,7 @@ async fn process_channel_message(ctx: Arc, msg: traits::C }; let llm_result = tokio::time::timeout( - Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS), + Duration::from_secs(ctx.channel_message_timeout_secs), run_tool_call_loop( active_provider.as_ref(), &mut history, @@ -710,6 +841,7 @@ async fn process_channel_message(ctx: Arc, msg: traits::C msg.channel.as_str(), ctx.max_tool_iterations, delta_tx, + tool_status_tx, ), ) .await; @@ -789,7 +921,7 @@ async fn process_channel_message(ctx: Arc, msg: traits::C Err(_) => { let timeout_msg = format!( "LLM response timed out after {}s", - CHANNEL_MESSAGE_TIMEOUT_SECS + ctx.channel_message_timeout_secs ); eprintln!( " ❌ {} (elapsed: {}ms)", @@ -1593,10 +1725,13 @@ pub async fn start_channels(config: Config) -> Result<()> { "schedule", "Manage scheduled tasks (create/list/get/cancel/pause/resume). Supports recurring cron and one-shot delays.", )); - tool_descs.push(( - "pushover", - "Send a Pushover notification to your device. Requires PUSHOVER_TOKEN and PUSHOVER_USER_KEY in .env file.", - )); + // notify tool is conditionally registered (Pushover or Telegram fallback) + if tools_registry.iter().any(|t| t.name() == "notify") { + tool_descs.push(( + "notify", + "Send a push notification (via Pushover or Telegram depending on configuration).", + )); + } if !config.agents.is_empty() { tool_descs.push(( "delegate", @@ -1835,6 +1970,7 @@ pub async fn start_channels(config: Config) -> Result<()> { reliability: Arc::new(config.reliability.clone()), provider_runtime_options, workspace_dir: Arc::new(config.workspace_dir.clone()), + channel_message_timeout_secs: config.agent.channel_message_timeout_secs, }); run_message_dispatch_loop(rx, runtime_ctx, max_in_flight_messages).await; @@ -2225,6 +2361,7 @@ mod tests { reliability: Arc::new(crate::config::ReliabilityConfig::default()), provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), + channel_message_timeout_secs: 300, }); process_channel_message( @@ -2277,6 +2414,7 @@ mod tests { reliability: Arc::new(crate::config::ReliabilityConfig::default()), provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), + channel_message_timeout_secs: 300, }); process_channel_message( @@ -2338,6 +2476,7 @@ mod tests { reliability: Arc::new(crate::config::ReliabilityConfig::default()), provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), + channel_message_timeout_secs: 300, }); process_channel_message( @@ -2420,6 +2559,7 @@ mod tests { reliability: Arc::new(crate::config::ReliabilityConfig::default()), provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), + channel_message_timeout_secs: 300, }); process_channel_message( @@ -2478,6 +2618,7 @@ mod tests { reliability: Arc::new(crate::config::ReliabilityConfig::default()), provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), + channel_message_timeout_secs: 300, }); process_channel_message( @@ -2531,6 +2672,7 @@ mod tests { reliability: Arc::new(crate::config::ReliabilityConfig::default()), provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), + channel_message_timeout_secs: 300, }); process_channel_message( @@ -2635,6 +2777,7 @@ mod tests { reliability: Arc::new(crate::config::ReliabilityConfig::default()), provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), + channel_message_timeout_secs: 300, }); let (tx, rx) = tokio::sync::mpsc::channel::(4); @@ -2705,6 +2848,7 @@ mod tests { reliability: Arc::new(crate::config::ReliabilityConfig::default()), provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), + channel_message_timeout_secs: 300, }); process_channel_message( @@ -3095,6 +3239,7 @@ mod tests { reliability: Arc::new(crate::config::ReliabilityConfig::default()), provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), + channel_message_timeout_secs: 300, }); process_channel_message( @@ -3363,4 +3508,228 @@ mod tests { .contains("listen boom")); assert!(calls.load(Ordering::SeqCst) >= 1); } + + // ── Runtime command parsing tests ───────────────────────── + + #[test] + fn parse_runtime_command_clear() { + assert_eq!( + parse_runtime_command("telegram", "/clear"), + Some(ChannelRuntimeCommand::Clear) + ); + } + + #[test] + fn parse_runtime_command_system() { + assert_eq!( + parse_runtime_command("telegram", "/system"), + Some(ChannelRuntimeCommand::ShowSystem) + ); + } + + #[test] + fn parse_runtime_command_status() { + assert_eq!( + parse_runtime_command("telegram", "/status"), + Some(ChannelRuntimeCommand::ShowStatus) + ); + } + + #[test] + fn parse_runtime_command_help() { + assert_eq!( + parse_runtime_command("telegram", "/help"), + Some(ChannelRuntimeCommand::Help) + ); + } + + #[test] + fn parse_runtime_command_ignores_unsupported_channel() { + assert_eq!(parse_runtime_command("cli", "/clear"), None); + assert_eq!(parse_runtime_command("slack", "/help"), None); + } + + #[test] + fn parse_runtime_command_strips_bot_mention_suffix() { + assert_eq!( + parse_runtime_command("telegram", "/clear@MyBot"), + Some(ChannelRuntimeCommand::Clear) + ); + assert_eq!( + parse_runtime_command("telegram", "/help@SomeBot extra"), + Some(ChannelRuntimeCommand::Help) + ); + } + + #[tokio::test] + async fn handle_runtime_command_clear_clears_history() { + let channel_impl = Arc::new(TelegramRecordingChannel::default()); + let channel: Arc = channel_impl.clone(); + + let mut channels_by_name = HashMap::new(); + channels_by_name.insert(channel.name().to_string(), channel.clone()); + + let histories: ConversationHistoryMap = Arc::new(Mutex::new(HashMap::new())); + histories + .lock() + .unwrap() + .insert("telegram_alice".to_string(), vec![]); + + let ctx = ChannelRuntimeContext { + channels_by_name: Arc::new(channels_by_name), + provider: Arc::new(SlowProvider { + delay: Duration::from_millis(1), + }), + default_provider: Arc::new("test".to_string()), + memory: Arc::new(NoopMemory), + tools_registry: Arc::new(vec![]), + observer: Arc::new(NoopObserver), + system_prompt: Arc::new("test".to_string()), + model: Arc::new("test-model".to_string()), + temperature: 0.7, + auto_save_memory: false, + max_tool_iterations: 5, + min_relevance_score: 0.0, + conversation_histories: histories.clone(), + provider_cache: Arc::new(Mutex::new(HashMap::new())), + route_overrides: Arc::new(Mutex::new(HashMap::new())), + api_key: None, + api_url: None, + reliability: Arc::new(crate::config::ReliabilityConfig::default()), + provider_runtime_options: providers::ProviderRuntimeOptions::default(), + workspace_dir: Arc::new(std::env::temp_dir()), + channel_message_timeout_secs: 300, + }; + + let msg = traits::ChannelMessage { + id: "msg-1".into(), + sender: "alice".into(), + reply_target: "chat-1".into(), + content: "/clear".into(), + channel: "telegram".into(), + timestamp: 1, + }; + + let handled = handle_runtime_command_if_needed(&ctx, &msg, Some(&channel)).await; + assert!(handled); + + assert!(!histories.lock().unwrap().contains_key("telegram_alice")); + + let sent = channel_impl.sent_messages.lock().await; + assert_eq!(sent.len(), 1); + assert!(sent[0].contains("Conversation history cleared.")); + } + + #[tokio::test] + async fn handle_runtime_command_help_lists_all_commands() { + let channel_impl = Arc::new(TelegramRecordingChannel::default()); + let channel: Arc = channel_impl.clone(); + + let mut channels_by_name = HashMap::new(); + channels_by_name.insert(channel.name().to_string(), channel.clone()); + + let ctx = ChannelRuntimeContext { + channels_by_name: Arc::new(channels_by_name), + provider: Arc::new(SlowProvider { + delay: Duration::from_millis(1), + }), + default_provider: Arc::new("test".to_string()), + memory: Arc::new(NoopMemory), + tools_registry: Arc::new(vec![]), + observer: Arc::new(NoopObserver), + system_prompt: Arc::new("test".to_string()), + model: Arc::new("test-model".to_string()), + temperature: 0.7, + auto_save_memory: false, + max_tool_iterations: 5, + min_relevance_score: 0.0, + conversation_histories: Arc::new(Mutex::new(HashMap::new())), + provider_cache: Arc::new(Mutex::new(HashMap::new())), + route_overrides: Arc::new(Mutex::new(HashMap::new())), + api_key: None, + api_url: None, + reliability: Arc::new(crate::config::ReliabilityConfig::default()), + provider_runtime_options: providers::ProviderRuntimeOptions::default(), + workspace_dir: Arc::new(std::env::temp_dir()), + channel_message_timeout_secs: 300, + }; + + let msg = traits::ChannelMessage { + id: "msg-1".into(), + sender: "alice".into(), + reply_target: "chat-1".into(), + content: "/help".into(), + channel: "telegram".into(), + timestamp: 1, + }; + + handle_runtime_command_if_needed(&ctx, &msg, Some(&channel)).await; + + let sent = channel_impl.sent_messages.lock().await; + assert_eq!(sent.len(), 1); + let response = &sent[0]; + assert!(response.contains("/help")); + assert!(response.contains("/model")); + assert!(response.contains("/models")); + assert!(response.contains("/clear")); + assert!(response.contains("/system")); + assert!(response.contains("/status")); + } + + #[tokio::test] + async fn handle_runtime_command_status_shows_config_fields() { + let channel_impl = Arc::new(TelegramRecordingChannel::default()); + let channel: Arc = channel_impl.clone(); + + let mut channels_by_name = HashMap::new(); + channels_by_name.insert(channel.name().to_string(), channel.clone()); + + let ctx = ChannelRuntimeContext { + channels_by_name: Arc::new(channels_by_name), + provider: Arc::new(SlowProvider { + delay: Duration::from_millis(1), + }), + default_provider: Arc::new("openai".to_string()), + memory: Arc::new(NoopMemory), + tools_registry: Arc::new(vec![Box::new(MockPriceTool) as Box]), + observer: Arc::new(NoopObserver), + system_prompt: Arc::new("test".to_string()), + model: Arc::new("gpt-4".to_string()), + temperature: 0.5, + auto_save_memory: false, + max_tool_iterations: 10, + min_relevance_score: 0.0, + conversation_histories: Arc::new(Mutex::new(HashMap::new())), + provider_cache: Arc::new(Mutex::new(HashMap::new())), + route_overrides: Arc::new(Mutex::new(HashMap::new())), + api_key: None, + api_url: None, + reliability: Arc::new(crate::config::ReliabilityConfig::default()), + provider_runtime_options: providers::ProviderRuntimeOptions::default(), + workspace_dir: Arc::new(std::env::temp_dir()), + channel_message_timeout_secs: 120, + }; + + let msg = traits::ChannelMessage { + id: "msg-1".into(), + sender: "alice".into(), + reply_target: "chat-1".into(), + content: "/status".into(), + channel: "telegram".into(), + timestamp: 1, + }; + + handle_runtime_command_if_needed(&ctx, &msg, Some(&channel)).await; + + let sent = channel_impl.sent_messages.lock().await; + assert_eq!(sent.len(), 1); + let response = &sent[0]; + assert!(response.contains("openai")); + assert!(response.contains("gpt-4")); + assert!(response.contains("0.5")); + assert!(response.contains("mock_price")); + assert!(response.contains("noop")); + assert!(response.contains("10")); + assert!(response.contains("120s")); + } } diff --git a/src/channels/telegram.rs b/src/channels/telegram.rs index ca0e03b..d9e7725 100644 --- a/src/channels/telegram.rs +++ b/src/channels/telegram.rs @@ -1607,6 +1607,16 @@ impl Channel for TelegramChannel { return Ok(()); } + // Check if edit failed because content is identical (Telegram returns 400 + // with "message is not modified" when the draft already has the final text). + let status = resp.status(); + let resp_body = resp.text().await.unwrap_or_default(); + if status == reqwest::StatusCode::BAD_REQUEST + && resp_body.contains("message is not modified") + { + return Ok(()); + } + // Markdown failed — retry without parse_mode let plain_body = serde_json::json!({ "chat_id": chat_id, @@ -1625,6 +1635,15 @@ impl Channel for TelegramChannel { return Ok(()); } + // Also check plain-text edit for "not modified" + let status = resp.status(); + let resp_body = resp.text().await.unwrap_or_default(); + if status == reqwest::StatusCode::BAD_REQUEST + && resp_body.contains("message is not modified") + { + return Ok(()); + } + // Edit failed entirely — fall back to new message tracing::warn!("Telegram finalize_draft edit failed; falling back to sendMessage"); self.send_text_chunks(text, &chat_id, thread_id.as_deref()) @@ -1672,6 +1691,38 @@ impl Channel for TelegramChannel { let _ = self.get_bot_username().await; } + // Register bot slash-command menu with Telegram + let commands_body = serde_json::json!({ + "commands": [ + {"command": "help", "description": "Show available commands"}, + {"command": "model", "description": "Show or switch model"}, + {"command": "models", "description": "List or switch providers"}, + {"command": "clear", "description": "Clear conversation history"}, + {"command": "system", "description": "Show system prompt"}, + {"command": "status", "description": "Show current configuration"} + ] + }); + match self + .http_client() + .post(self.api_url("setMyCommands")) + .json(&commands_body) + .send() + .await + { + Ok(resp) if resp.status().is_success() => { + tracing::debug!("Telegram setMyCommands registered successfully"); + } + Ok(resp) => { + tracing::debug!( + "Telegram setMyCommands failed with status {}", + resp.status() + ); + } + Err(e) => { + tracing::debug!("Telegram setMyCommands request failed: {e}"); + } + } + tracing::info!("Telegram channel listening for messages..."); loop { diff --git a/src/config/schema.rs b/src/config/schema.rs index 8d9138f..6bd6b0c 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -31,7 +31,7 @@ const SUPPORTED_PROXY_SERVICE_KEYS: &[&str] = &[ "tool.browser", "tool.composio", "tool.http_request", - "tool.pushover", + "tool.notify", "memory.embeddings", "tunnel.custom", ]; @@ -253,6 +253,10 @@ pub struct AgentConfig { pub parallel_tools: bool, #[serde(default = "default_agent_tool_dispatcher")] pub tool_dispatcher: String, + /// Timeout in seconds for processing a single channel message (LLM + tools). + /// Default 300s accommodates on-device LLMs (Ollama) which are slower than cloud APIs. + #[serde(default = "default_channel_message_timeout_secs")] + pub channel_message_timeout_secs: u64, } fn default_agent_max_tool_iterations() -> usize { @@ -267,6 +271,10 @@ fn default_agent_tool_dispatcher() -> String { "auto".into() } +fn default_channel_message_timeout_secs() -> u64 { + 300 +} + impl Default for AgentConfig { fn default() -> Self { Self { @@ -275,6 +283,7 @@ impl Default for AgentConfig { max_history_messages: default_agent_max_history_messages(), parallel_tools: false, tool_dispatcher: default_agent_tool_dispatcher(), + channel_message_timeout_secs: default_channel_message_timeout_secs(), } } } diff --git a/src/cron/scheduler.rs b/src/cron/scheduler.rs index ce9c6c3..018ab13 100644 --- a/src/cron/scheduler.rs +++ b/src/cron/scheduler.rs @@ -218,11 +218,14 @@ fn warn_if_high_frequency_agent_job(job: &CronJob) { Schedule::Every { every_ms } => *every_ms < 5 * 60 * 1000, Schedule::Cron { .. } => { let now = Utc::now(); - match ( - next_run_for_schedule(&job.schedule, now), - next_run_for_schedule(&job.schedule, now + chrono::Duration::seconds(1)), - ) { - (Ok(a), Ok(b)) => (b - a).num_minutes() < 5, + match next_run_for_schedule(&job.schedule, now) { + Ok(first) => { + // Get the occurrence *after* the first one to measure the actual interval. + match next_run_for_schedule(&job.schedule, first) { + Ok(second) => (second - first).num_minutes() < 5, + _ => false, + } + } _ => false, } } diff --git a/src/daemon/mod.rs b/src/daemon/mod.rs index ca0834b..aea2b08 100644 --- a/src/daemon/mod.rs +++ b/src/daemon/mod.rs @@ -66,7 +66,7 @@ pub async fn run(config: Config, host: String, port: u16) -> Result<()> { max_backoff, move || { let cfg = heartbeat_cfg.clone(); - async move { run_heartbeat_worker(cfg).await } + async move { Box::pin(run_heartbeat_worker(cfg)).await } }, )); } diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index d7aed97..dc20177 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -500,7 +500,6 @@ const MINIMAX_ONBOARD_MODELS: [(&str, &str); 5] = [ fn default_model_for_provider(provider: &str) -> String { match canonical_provider_name(provider) { "anthropic" => "claude-sonnet-4-5-20250929".into(), - "openrouter" => "anthropic/claude-sonnet-4.6".into(), "openai" => "gpt-5.2".into(), "openai-codex" => "gpt-5-codex".into(), "venice" => "zai-org-glm-5".into(), @@ -520,7 +519,6 @@ fn default_model_for_provider(provider: &str) -> String { "gemini" => "gemini-2.5-pro".into(), "kimi-code" => "kimi-for-coding".into(), "nvidia" => "meta/llama-3.3-70b-instruct".into(), - "astrai" => "anthropic/claude-sonnet-4.6".into(), _ => "anthropic/claude-sonnet-4.6".into(), } } @@ -5190,7 +5188,7 @@ mod tests { let config = Config { workspace_dir: tmp.path().to_path_buf(), - default_provider: Some("venice".to_string()), + default_provider: Some("unknown-provider".to_string()), ..Config::default() }; diff --git a/src/providers/compatible.rs b/src/providers/compatible.rs index 074ee45..40e0d94 100644 --- a/src/providers/compatible.rs +++ b/src/providers/compatible.rs @@ -1054,7 +1054,10 @@ impl Provider for OpenAiCompatibleProvider { let url = self.chat_completions_url(); let response = self - .apply_auth_header(self.http_client().post(&url).json(&native_request), credential) + .apply_auth_header( + self.http_client().post(&url).json(&native_request), + credential, + ) .send() .await?; diff --git a/src/providers/reliable.rs b/src/providers/reliable.rs index 85f9019..82b7d83 100644 --- a/src/providers/reliable.rs +++ b/src/providers/reliable.rs @@ -45,14 +45,12 @@ fn is_non_retryable(err: &anyhow::Error) -> bool { return true; } - let model_catalog_mismatch = msg_lower.contains("model") + msg_lower.contains("model") && (msg_lower.contains("not found") || msg_lower.contains("unknown") || msg_lower.contains("unsupported") || msg_lower.contains("does not exist") - || msg_lower.contains("invalid")); - - model_catalog_mismatch + || msg_lower.contains("invalid")) } /// Check if an error is a rate-limit (429) error. diff --git a/src/security/policy.rs b/src/security/policy.rs index 806a399..d7a4e6f 100644 --- a/src/security/policy.rs +++ b/src/security/policy.rs @@ -184,6 +184,19 @@ fn contains_single_ampersand(s: &str) -> bool { false } +/// Strip safe stderr redirection patterns before policy checks. +/// +/// Removes `2>/dev/null`, `2> /dev/null`, and `2>&1` so they don't +/// trigger the generic `>` or `&` blockers. +fn strip_safe_stderr(s: &str) -> String { + let mut result = s.to_string(); + // Order matters: longest patterns first + for pat in ["2> /dev/null", "2>/dev/null", "2>&1"] { + result = result.replace(pat, ""); + } + result +} + impl SecurityPolicy { /// Classify command risk. Any high-risk segment marks the whole command high. pub fn command_risk_level(&self, command: &str) -> CommandRiskLevel { @@ -312,16 +325,23 @@ impl SecurityPolicy { approved: bool, ) -> Result { if !self.is_command_allowed(command) { + tracing::debug!(command, "Shell command blocked by allowlist"); return Err(format!("Command not allowed by security policy: {command}")); } let risk = self.command_risk_level(command); + tracing::trace!(command, ?risk, approved, "Shell command risk assessed"); if risk == CommandRiskLevel::High { if self.block_high_risk_commands { + tracing::debug!( + command, + "Shell command blocked: high-risk disallowed by policy" + ); return Err("Command blocked: high-risk command is disallowed by policy".into()); } if self.autonomy == AutonomyLevel::Supervised && !approved { + tracing::debug!(command, "Shell command blocked: high-risk needs approval"); return Err( "Command requires explicit approval (approved=true): high-risk operation" .into(), @@ -334,11 +354,13 @@ impl SecurityPolicy { && self.require_approval_for_medium_risk && !approved { + tracing::debug!(command, "Shell command blocked: medium-risk needs approval"); return Err( "Command requires explicit approval (approved=true): medium-risk operation".into(), ); } + tracing::debug!(command, ?risk, "Shell command allowed by policy"); Ok(risk) } @@ -353,37 +375,46 @@ impl SecurityPolicy { /// - Blocks dangerous arguments (e.g. `find -exec`, `git config`) pub fn is_command_allowed(&self, command: &str) -> bool { if self.autonomy == AutonomyLevel::ReadOnly { + tracing::trace!(command, "Command blocked: read-only mode"); return false; } + // Strip safe stderr redirections (2>/dev/null, 2>&1) before + // operator checks so they don't trigger the generic `>` or `&` blockers. + let sanitized = strip_safe_stderr(command); + // Block subshell/expansion operators — these allow hiding arbitrary // commands inside an allowed command (e.g. `echo $(rm -rf /)`) - if command.contains('`') - || command.contains("$(") - || command.contains("${") - || command.contains("<(") - || command.contains(">(") + if sanitized.contains('`') + || sanitized.contains("$(") + || sanitized.contains("${") + || sanitized.contains("<(") + || sanitized.contains(">(") { + tracing::debug!(command, "Command blocked: subshell/expansion operator"); return false; } - // Block output redirections — they can write to arbitrary paths - if command.contains('>') { + // Block output redirections that write to arbitrary paths. + if sanitized.contains('>') { + tracing::debug!(command, "Command blocked: output redirection"); return false; } // Block `tee` — it can write to arbitrary files, bypassing the // redirect check above (e.g. `echo secret | tee /etc/crontab`) - if command + if sanitized .split_whitespace() .any(|w| w == "tee" || w.ends_with("/tee")) { + tracing::debug!(command, "Command blocked: tee can write arbitrary files"); return false; } // Block background command chaining (`&`), which can hide extra // sub-commands and outlive timeout expectations. Keep `&&` allowed. - if contains_single_ampersand(command) { + if contains_single_ampersand(&sanitized) { + tracing::debug!(command, "Command blocked: background & operator"); return false; } @@ -414,10 +445,12 @@ impl SecurityPolicy { continue; } - if !self - .allowed_commands - .iter() - .any(|allowed| allowed == base_cmd) + let allow_all = self.allowed_commands.iter().any(|c| c == "*"); + if !allow_all + && !self + .allowed_commands + .iter() + .any(|allowed| allowed == base_cmd) { return false; } @@ -702,6 +735,21 @@ mod tests { assert!(!p.is_command_allowed("node malicious.js")); } + #[test] + fn wildcard_allowed_commands_permits_any_binary() { + let p = SecurityPolicy { + allowed_commands: vec!["*".into()], + ..SecurityPolicy::default() + }; + assert!(p.is_command_allowed("curl http://example.com")); + assert!(p.is_command_allowed("wget http://example.com")); + assert!(p.is_command_allowed("python3 script.py")); + assert!(p.is_command_allowed("node app.js")); + // Subshell/redirect blocks still apply + assert!(!p.is_command_allowed("echo $(rm -rf /)")); + assert!(!p.is_command_allowed("echo hello > /etc/passwd")); + } + #[test] fn readonly_blocks_all_commands() { let p = readonly_policy(); @@ -1084,6 +1132,22 @@ mod tests { let p = default_policy(); assert!(!p.is_command_allowed("echo secret > /etc/crontab")); assert!(!p.is_command_allowed("ls >> /tmp/exfil.txt")); + assert!(!p.is_command_allowed("cat file > /tmp/out")); + } + + #[test] + fn stderr_to_devnull_allowed() { + let p = SecurityPolicy { + allowed_commands: vec!["*".into()], + ..SecurityPolicy::default() + }; + assert!(p.is_command_allowed("ls -la /tmp/*.py 2>/dev/null")); + assert!(p.is_command_allowed("ls -la 2> /dev/null")); + assert!(p.is_command_allowed("grep pattern file 2>&1")); + assert!(p.is_command_allowed("cmd 2>/dev/null | grep foo")); + // Stdout redirect still blocked + assert!(!p.is_command_allowed("echo hello > /tmp/file")); + assert!(!p.is_command_allowed("echo hello 1> /tmp/file")); } #[test] diff --git a/src/tools/http_request.rs b/src/tools/http_request.rs index fe1a48e..3e1c590 100644 --- a/src/tools/http_request.rs +++ b/src/tools/http_request.rs @@ -46,17 +46,19 @@ impl HttpRequestTool { if self.allowed_domains.is_empty() { anyhow::bail!( - "HTTP request tool is enabled but no allowed_domains are configured. Add [http_request].allowed_domains in config.toml" + "HTTP request tool is enabled but no allowed_domains are configured. Add [http_request].allowed_domains in config.toml or use [\"*\"] to allow all domains" ); } + let allow_all = self.allowed_domains.iter().any(|d| d == "*"); + let host = extract_host(url)?; if is_private_or_local_host(&host) { anyhow::bail!("Blocked local/private host: {host}"); } - if !host_matches_allowlist(&host, &self.allowed_domains) { + if !allow_all && !host_matches_allowlist(&host, &self.allowed_domains) { anyhow::bail!("Host '{host}' is not in http_request.allowed_domains"); } diff --git a/src/tools/mod.rs b/src/tools/mod.rs index a472afc..21881e5 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -19,8 +19,8 @@ pub mod image_info; pub mod memory_forget; pub mod memory_recall; pub mod memory_store; +pub mod notify; pub mod proxy_config; -pub mod pushover; pub mod schedule; pub mod schema; pub mod screenshot; @@ -49,8 +49,8 @@ pub use image_info::ImageInfoTool; pub use memory_forget::MemoryForgetTool; pub use memory_recall::MemoryRecallTool; pub use memory_store::MemoryStoreTool; +pub use notify::NotifyTool; pub use proxy_config::ProxyConfigTool; -pub use pushover::PushoverTool; pub use schedule::ScheduleTool; #[allow(unused_imports)] pub use schema::{CleaningStrategy, SchemaCleanr}; @@ -151,12 +151,16 @@ pub fn all_tools_with_runtime( security.clone(), workspace_dir.to_path_buf(), )), - Box::new(PushoverTool::new( - security.clone(), - workspace_dir.to_path_buf(), - )), ]; + if let Some(notify_tool) = NotifyTool::detect( + security.clone(), + workspace_dir, + root_config.channels_config.telegram.as_ref(), + ) { + tools.push(Box::new(notify_tool)); + } + if browser_config.enabled { // Add legacy browser_open tool for simple URL opening tools.push(Box::new(BrowserOpenTool::new( @@ -294,7 +298,8 @@ mod tests { let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); assert!(!names.contains(&"browser_open")); assert!(names.contains(&"schedule")); - assert!(names.contains(&"pushover")); + // notify tool is conditionally registered — not present without credentials + assert!(!names.contains(&"notify")); assert!(names.contains(&"proxy_config")); } @@ -333,7 +338,8 @@ mod tests { ); let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); assert!(names.contains(&"browser_open")); - assert!(names.contains(&"pushover")); + // notify tool is conditionally registered — not present without credentials + assert!(!names.contains(&"notify")); assert!(names.contains(&"proxy_config")); } diff --git a/src/tools/notify.rs b/src/tools/notify.rs new file mode 100644 index 0000000..8ff55a9 --- /dev/null +++ b/src/tools/notify.rs @@ -0,0 +1,609 @@ +use super::traits::{Tool, ToolResult}; +use crate::config::TelegramConfig; +use crate::security::SecurityPolicy; +use async_trait::async_trait; +use serde_json::json; +use std::sync::Arc; + +const PUSHOVER_API_URL: &str = "https://api.pushover.net/1/messages.json"; +const NOTIFY_REQUEST_TIMEOUT_SECS: u64 = 15; + +enum NotifyBackend { + Pushover { token: String, user_key: String }, + Telegram { bot_token: String, chat_id: String }, +} + +pub struct NotifyTool { + security: Arc, + backend: NotifyBackend, +} + +impl NotifyTool { + /// Detect the best available notification backend. + /// + /// Checks Pushover credentials first (from `.env`), then Telegram config. + /// Returns `None` if neither backend is available. + pub fn detect( + security: Arc, + workspace_dir: &std::path::Path, + telegram_config: Option<&TelegramConfig>, + ) -> Option { + // Try Pushover first + if let Some((token, user_key)) = Self::read_pushover_credentials(workspace_dir) { + return Some(Self { + security, + backend: NotifyBackend::Pushover { token, user_key }, + }); + } + + // Fall back to Telegram + if let Some(tg) = telegram_config { + if let Some(chat_id) = tg.allowed_users.first() { + if !tg.bot_token.is_empty() && !chat_id.is_empty() { + return Some(Self { + security, + backend: NotifyBackend::Telegram { + bot_token: tg.bot_token.clone(), + chat_id: chat_id.clone(), + }, + }); + } + } + } + + None + } + + fn parse_env_value(raw: &str) -> String { + let raw = raw.trim(); + + let unquoted = if raw.len() >= 2 + && ((raw.starts_with('"') && raw.ends_with('"')) + || (raw.starts_with('\'') && raw.ends_with('\''))) + { + &raw[1..raw.len() - 1] + } else { + raw + }; + + // Keep support for inline comments in unquoted values: + // KEY=value # comment + unquoted.split_once(" #").map_or_else( + || unquoted.trim().to_string(), + |(value, _)| value.trim().to_string(), + ) + } + + fn read_pushover_credentials(workspace_dir: &std::path::Path) -> Option<(String, String)> { + let env_path = workspace_dir.join(".env"); + let content = std::fs::read_to_string(&env_path).ok()?; + + let mut token = None; + let mut user_key = None; + + for line in content.lines() { + let line = line.trim(); + if line.starts_with('#') || line.is_empty() { + continue; + } + let line = line.strip_prefix("export ").map(str::trim).unwrap_or(line); + if let Some((key, value)) = line.split_once('=') { + let key = key.trim(); + let value = Self::parse_env_value(value); + + if key.eq_ignore_ascii_case("PUSHOVER_TOKEN") { + token = Some(value); + } else if key.eq_ignore_ascii_case("PUSHOVER_USER_KEY") { + user_key = Some(value); + } + } + } + + Some((token?, user_key?)) + } + + fn backend_label(&self) -> &str { + match &self.backend { + NotifyBackend::Pushover { .. } => "Pushover", + NotifyBackend::Telegram { .. } => "Telegram", + } + } + + async fn send_pushover( + token: &str, + user_key: &str, + message: &str, + title: Option<&str>, + priority: Option, + sound: Option<&str>, + ) -> anyhow::Result { + let mut form = reqwest::multipart::Form::new() + .text("token", token.to_owned()) + .text("user", user_key.to_owned()) + .text("message", message.to_owned()); + + if let Some(title) = title { + form = form.text("title", title.to_owned()); + } + if let Some(priority) = priority { + form = form.text("priority", priority.to_string()); + } + if let Some(sound) = sound { + form = form.text("sound", sound.to_owned()); + } + + let client = crate::config::build_runtime_proxy_client_with_timeouts( + "tool.notify", + NOTIFY_REQUEST_TIMEOUT_SECS, + 10, + ); + let response = client.post(PUSHOVER_API_URL).multipart(form).send().await?; + + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + + if !status.is_success() { + return Ok(ToolResult { + success: false, + output: body, + error: Some(format!("Pushover API returned status {}", status)), + }); + } + + let api_status = serde_json::from_str::(&body) + .ok() + .and_then(|json| json.get("status").and_then(|value| value.as_i64())); + + if api_status == Some(1) { + Ok(ToolResult { + success: true, + output: format!("Notification sent via Pushover. Response: {}", body), + error: None, + }) + } else { + Ok(ToolResult { + success: false, + output: body, + error: Some("Pushover API returned an application-level error".into()), + }) + } + } + + async fn send_telegram( + bot_token: &str, + chat_id: &str, + message: &str, + title: Option<&str>, + ) -> anyhow::Result { + let text = match title { + Some(t) if !t.is_empty() => format!("*{}*\n{}", t, message), + _ => message.to_owned(), + }; + + let url = format!("https://api.telegram.org/bot{}/sendMessage", bot_token); + + let client = crate::config::build_runtime_proxy_client_with_timeouts( + "tool.notify", + NOTIFY_REQUEST_TIMEOUT_SECS, + 10, + ); + let response = client + .post(&url) + .json(&json!({ + "chat_id": chat_id, + "text": text, + "parse_mode": "Markdown", + })) + .send() + .await?; + + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + + if !status.is_success() { + return Ok(ToolResult { + success: false, + output: body, + error: Some(format!("Telegram API returned status {}", status)), + }); + } + + let ok = serde_json::from_str::(&body) + .ok() + .and_then(|json| json.get("ok").and_then(|v| v.as_bool())); + + if ok == Some(true) { + Ok(ToolResult { + success: true, + output: format!("Notification sent via Telegram. Response: {}", body), + error: None, + }) + } else { + Ok(ToolResult { + success: false, + output: body, + error: Some("Telegram API returned an application-level error".into()), + }) + } + } +} + +#[async_trait] +impl Tool for NotifyTool { + fn name(&self) -> &str { + "notify" + } + + fn description(&self) -> &str { + match &self.backend { + NotifyBackend::Pushover { .. } => { + "Send a push notification to your device via Pushover. Supports title, priority, and sound options." + } + NotifyBackend::Telegram { .. } => { + "Send a notification message via Telegram. Supports optional title (priority/sound are ignored)." + } + } + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "The notification message to send" + }, + "title": { + "type": "string", + "description": "Optional notification title" + }, + "priority": { + "type": "integer", + "enum": [-2, -1, 0, 1, 2], + "description": "Message priority (Pushover only): -2 (lowest/silent), -1 (low/no sound), 0 (normal), 1 (high), 2 (emergency/repeating)" + }, + "sound": { + "type": "string", + "description": "Notification sound override (Pushover only, e.g., 'pushover', 'bike', 'bugle', 'cashregister', etc.)" + } + }, + "required": ["message"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + if !self.security.can_act() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Action blocked: autonomy is read-only".into()), + }); + } + + if !self.security.record_action() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Action blocked: rate limit exceeded".into()), + }); + } + + let message = args + .get("message") + .and_then(|v| v.as_str()) + .map(str::trim) + .filter(|v| !v.is_empty()) + .ok_or_else(|| anyhow::anyhow!("Missing 'message' parameter"))? + .to_string(); + + let title = args.get("title").and_then(|v| v.as_str()).map(String::from); + + let priority = match args.get("priority").and_then(|v| v.as_i64()) { + Some(value) if (-2..=2).contains(&value) => Some(value), + Some(value) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Invalid 'priority': {value}. Expected integer in range -2..=2" + )), + }) + } + None => None, + }; + + let sound = args.get("sound").and_then(|v| v.as_str()).map(String::from); + + match &self.backend { + NotifyBackend::Pushover { token, user_key } => { + Self::send_pushover( + token, + user_key, + &message, + title.as_deref(), + priority, + sound.as_deref(), + ) + .await + } + NotifyBackend::Telegram { bot_token, chat_id } => { + Self::send_telegram(bot_token, chat_id, &message, title.as_deref()).await + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::security::AutonomyLevel; + use std::fs; + use tempfile::TempDir; + + fn test_security(level: AutonomyLevel, max_actions_per_hour: u32) -> Arc { + Arc::new(SecurityPolicy { + autonomy: level, + max_actions_per_hour, + workspace_dir: std::env::temp_dir(), + ..SecurityPolicy::default() + }) + } + + fn make_pushover_tool(security: Arc) -> NotifyTool { + NotifyTool { + security, + backend: NotifyBackend::Pushover { + token: "test_token".into(), + user_key: "test_user".into(), + }, + } + } + + fn make_telegram_tool(security: Arc) -> NotifyTool { + NotifyTool { + security, + backend: NotifyBackend::Telegram { + bot_token: "123:ABC".into(), + chat_id: "456".into(), + }, + } + } + + #[test] + fn notify_tool_name() { + let tool = make_pushover_tool(test_security(AutonomyLevel::Full, 100)); + assert_eq!(tool.name(), "notify"); + } + + #[test] + fn notify_tool_description_pushover() { + let tool = make_pushover_tool(test_security(AutonomyLevel::Full, 100)); + assert!(tool.description().contains("Pushover")); + } + + #[test] + fn notify_tool_description_telegram() { + let tool = make_telegram_tool(test_security(AutonomyLevel::Full, 100)); + assert!(tool.description().contains("Telegram")); + } + + #[test] + fn notify_tool_has_parameters_schema() { + let tool = make_pushover_tool(test_security(AutonomyLevel::Full, 100)); + let schema = tool.parameters_schema(); + assert_eq!(schema["type"], "object"); + assert!(schema["properties"].get("message").is_some()); + } + + #[test] + fn notify_tool_requires_message() { + let tool = make_pushover_tool(test_security(AutonomyLevel::Full, 100)); + let schema = tool.parameters_schema(); + let required = schema["required"].as_array().unwrap(); + assert!(required.contains(&serde_json::Value::String("message".to_string()))); + } + + #[test] + fn credentials_parsed_from_env_file() { + let tmp = TempDir::new().unwrap(); + let env_path = tmp.path().join(".env"); + fs::write( + &env_path, + "PUSHOVER_TOKEN=testtoken123\nPUSHOVER_USER_KEY=userkey456\n", + ) + .unwrap(); + + let result = NotifyTool::read_pushover_credentials(tmp.path()); + assert!(result.is_some()); + let (token, user_key) = result.unwrap(); + assert_eq!(token, "testtoken123"); + assert_eq!(user_key, "userkey456"); + } + + #[test] + fn credentials_none_without_env_file() { + let tmp = TempDir::new().unwrap(); + let result = NotifyTool::read_pushover_credentials(tmp.path()); + assert!(result.is_none()); + } + + #[test] + fn credentials_none_without_token() { + let tmp = TempDir::new().unwrap(); + let env_path = tmp.path().join(".env"); + fs::write(&env_path, "PUSHOVER_USER_KEY=userkey456\n").unwrap(); + + let result = NotifyTool::read_pushover_credentials(tmp.path()); + assert!(result.is_none()); + } + + #[test] + fn credentials_none_without_user_key() { + let tmp = TempDir::new().unwrap(); + let env_path = tmp.path().join(".env"); + fs::write(&env_path, "PUSHOVER_TOKEN=testtoken123\n").unwrap(); + + let result = NotifyTool::read_pushover_credentials(tmp.path()); + assert!(result.is_none()); + } + + #[test] + fn credentials_ignore_comments() { + let tmp = TempDir::new().unwrap(); + let env_path = tmp.path().join(".env"); + fs::write(&env_path, "# This is a comment\nPUSHOVER_TOKEN=realtoken\n# Another comment\nPUSHOVER_USER_KEY=realuser\n").unwrap(); + + let result = NotifyTool::read_pushover_credentials(tmp.path()); + assert!(result.is_some()); + let (token, user_key) = result.unwrap(); + assert_eq!(token, "realtoken"); + assert_eq!(user_key, "realuser"); + } + + #[test] + fn notify_tool_supports_priority() { + let tool = make_pushover_tool(test_security(AutonomyLevel::Full, 100)); + let schema = tool.parameters_schema(); + assert!(schema["properties"].get("priority").is_some()); + } + + #[test] + fn notify_tool_supports_sound() { + let tool = make_pushover_tool(test_security(AutonomyLevel::Full, 100)); + let schema = tool.parameters_schema(); + assert!(schema["properties"].get("sound").is_some()); + } + + #[test] + fn credentials_support_export_and_quoted_values() { + let tmp = TempDir::new().unwrap(); + let env_path = tmp.path().join(".env"); + fs::write( + &env_path, + "export PUSHOVER_TOKEN=\"quotedtoken\"\nPUSHOVER_USER_KEY='quoteduser'\n", + ) + .unwrap(); + + let result = NotifyTool::read_pushover_credentials(tmp.path()); + assert!(result.is_some()); + let (token, user_key) = result.unwrap(); + assert_eq!(token, "quotedtoken"); + assert_eq!(user_key, "quoteduser"); + } + + #[tokio::test] + async fn execute_blocks_readonly_mode() { + let tool = make_pushover_tool(test_security(AutonomyLevel::ReadOnly, 100)); + + let result = tool.execute(json!({"message": "hello"})).await.unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("read-only")); + } + + #[tokio::test] + async fn execute_blocks_rate_limit() { + let tool = make_pushover_tool(test_security(AutonomyLevel::Full, 0)); + + let result = tool.execute(json!({"message": "hello"})).await.unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("rate limit")); + } + + #[tokio::test] + async fn execute_rejects_priority_out_of_range() { + let tool = make_pushover_tool(test_security(AutonomyLevel::Full, 100)); + + let result = tool + .execute(json!({"message": "hello", "priority": 5})) + .await + .unwrap(); + + assert!(!result.success); + assert!(result.error.unwrap().contains("-2..=2")); + } + + #[test] + fn detect_returns_none_when_no_backend_available() { + let tmp = TempDir::new().unwrap(); + let security = test_security(AutonomyLevel::Full, 100); + + let result = NotifyTool::detect(security, tmp.path(), None); + assert!(result.is_none()); + } + + #[test] + fn detect_prefers_pushover_when_both_available() { + let tmp = TempDir::new().unwrap(); + let env_path = tmp.path().join(".env"); + fs::write(&env_path, "PUSHOVER_TOKEN=token\nPUSHOVER_USER_KEY=user\n").unwrap(); + + let tg = TelegramConfig { + bot_token: "123:ABC".into(), + allowed_users: vec!["456".into()], + stream_mode: crate::config::StreamMode::Off, + draft_update_interval_ms: 1000, + mention_only: false, + }; + + let security = test_security(AutonomyLevel::Full, 100); + let tool = NotifyTool::detect(security, tmp.path(), Some(&tg)); + assert!(tool.is_some()); + assert_eq!(tool.unwrap().backend_label(), "Pushover"); + } + + #[test] + fn detect_falls_back_to_telegram_when_no_pushover_credentials() { + let tmp = TempDir::new().unwrap(); + let tg = TelegramConfig { + bot_token: "123:ABC".into(), + allowed_users: vec!["456".into()], + stream_mode: crate::config::StreamMode::Off, + draft_update_interval_ms: 1000, + mention_only: false, + }; + + let security = test_security(AutonomyLevel::Full, 100); + let tool = NotifyTool::detect(security, tmp.path(), Some(&tg)); + assert!(tool.is_some()); + assert_eq!(tool.unwrap().backend_label(), "Telegram"); + } + + #[test] + fn detect_returns_none_for_telegram_with_empty_allowed_users() { + let tmp = TempDir::new().unwrap(); + let tg = TelegramConfig { + bot_token: "123:ABC".into(), + allowed_users: vec![], + stream_mode: crate::config::StreamMode::Off, + draft_update_interval_ms: 1000, + mention_only: false, + }; + + let security = test_security(AutonomyLevel::Full, 100); + let result = NotifyTool::detect(security, tmp.path(), Some(&tg)); + assert!(result.is_none()); + } + + #[test] + fn telegram_backend_formats_message_with_title() { + // Verify the format logic used by send_telegram + let title = Some("Alert"); + let message = "Server is down"; + let text = match title { + Some(t) if !t.is_empty() => format!("*{}*\n{}", t, message), + _ => message.to_owned(), + }; + assert_eq!(text, "*Alert*\nServer is down"); + } + + #[test] + fn telegram_backend_formats_message_without_title() { + let title: Option<&str> = None; + let message = "Server is down"; + let text = match title { + Some(t) if !t.is_empty() => format!("*{}*\n{}", t, message), + _ => message.to_owned(), + }; + assert_eq!(text, "Server is down"); + } +} diff --git a/src/tools/proxy_config.rs b/src/tools/proxy_config.rs index 3ddde9e..5b8b259 100644 --- a/src/tools/proxy_config.rs +++ b/src/tools/proxy_config.rs @@ -93,6 +93,7 @@ impl ProxyConfigTool { anyhow::bail!("'{field}' must be a string or string[]") } + #[allow(clippy::option_option)] // Outer=field present, inner=value-or-null (partial update) fn parse_optional_string_update( args: &Value, field: &str, diff --git a/src/tools/pushover.rs b/src/tools/pushover.rs deleted file mode 100644 index 23d980b..0000000 --- a/src/tools/pushover.rs +++ /dev/null @@ -1,433 +0,0 @@ -use super::traits::{Tool, ToolResult}; -use crate::security::SecurityPolicy; -use async_trait::async_trait; -use serde_json::json; -use std::path::PathBuf; -use std::sync::Arc; - -const PUSHOVER_API_URL: &str = "https://api.pushover.net/1/messages.json"; -const PUSHOVER_REQUEST_TIMEOUT_SECS: u64 = 15; - -pub struct PushoverTool { - security: Arc, - workspace_dir: PathBuf, -} - -impl PushoverTool { - pub fn new(security: Arc, workspace_dir: PathBuf) -> Self { - Self { - security, - workspace_dir, - } - } - - fn parse_env_value(raw: &str) -> String { - let raw = raw.trim(); - - let unquoted = if raw.len() >= 2 - && ((raw.starts_with('"') && raw.ends_with('"')) - || (raw.starts_with('\'') && raw.ends_with('\''))) - { - &raw[1..raw.len() - 1] - } else { - raw - }; - - // Keep support for inline comments in unquoted values: - // KEY=value # comment - unquoted.split_once(" #").map_or_else( - || unquoted.trim().to_string(), - |(value, _)| value.trim().to_string(), - ) - } - - fn get_credentials(&self) -> anyhow::Result<(String, String)> { - let env_path = self.workspace_dir.join(".env"); - let content = std::fs::read_to_string(&env_path) - .map_err(|e| anyhow::anyhow!("Failed to read {}: {}", env_path.display(), e))?; - - let mut token = None; - let mut user_key = None; - - for line in content.lines() { - let line = line.trim(); - if line.starts_with('#') || line.is_empty() { - continue; - } - let line = line.strip_prefix("export ").map(str::trim).unwrap_or(line); - if let Some((key, value)) = line.split_once('=') { - let key = key.trim(); - let value = Self::parse_env_value(value); - - if key.eq_ignore_ascii_case("PUSHOVER_TOKEN") { - token = Some(value); - } else if key.eq_ignore_ascii_case("PUSHOVER_USER_KEY") { - user_key = Some(value); - } - } - } - - let token = token.ok_or_else(|| anyhow::anyhow!("PUSHOVER_TOKEN not found in .env"))?; - let user_key = - user_key.ok_or_else(|| anyhow::anyhow!("PUSHOVER_USER_KEY not found in .env"))?; - - Ok((token, user_key)) - } -} - -#[async_trait] -impl Tool for PushoverTool { - fn name(&self) -> &str { - "pushover" - } - - fn description(&self) -> &str { - "Send a Pushover notification to your device. Requires PUSHOVER_TOKEN and PUSHOVER_USER_KEY in .env file." - } - - fn parameters_schema(&self) -> serde_json::Value { - json!({ - "type": "object", - "properties": { - "message": { - "type": "string", - "description": "The notification message to send" - }, - "title": { - "type": "string", - "description": "Optional notification title" - }, - "priority": { - "type": "integer", - "enum": [-2, -1, 0, 1, 2], - "description": "Message priority: -2 (lowest/silent), -1 (low/no sound), 0 (normal), 1 (high), 2 (emergency/repeating)" - }, - "sound": { - "type": "string", - "description": "Notification sound override (e.g., 'pushover', 'bike', 'bugle', 'cashregister', etc.)" - } - }, - "required": ["message"] - }) - } - - async fn execute(&self, args: serde_json::Value) -> anyhow::Result { - if !self.security.can_act() { - return Ok(ToolResult { - success: false, - output: String::new(), - error: Some("Action blocked: autonomy is read-only".into()), - }); - } - - if !self.security.record_action() { - return Ok(ToolResult { - success: false, - output: String::new(), - error: Some("Action blocked: rate limit exceeded".into()), - }); - } - - let message = args - .get("message") - .and_then(|v| v.as_str()) - .map(str::trim) - .filter(|v| !v.is_empty()) - .ok_or_else(|| anyhow::anyhow!("Missing 'message' parameter"))? - .to_string(); - - let title = args.get("title").and_then(|v| v.as_str()).map(String::from); - - let priority = match args.get("priority").and_then(|v| v.as_i64()) { - Some(value) if (-2..=2).contains(&value) => Some(value), - Some(value) => { - return Ok(ToolResult { - success: false, - output: String::new(), - error: Some(format!( - "Invalid 'priority': {value}. Expected integer in range -2..=2" - )), - }) - } - None => None, - }; - - let sound = args.get("sound").and_then(|v| v.as_str()).map(String::from); - - let (token, user_key) = self.get_credentials()?; - - let mut form = reqwest::multipart::Form::new() - .text("token", token) - .text("user", user_key) - .text("message", message); - - if let Some(title) = title { - form = form.text("title", title); - } - - if let Some(priority) = priority { - form = form.text("priority", priority.to_string()); - } - - if let Some(sound) = sound { - form = form.text("sound", sound); - } - - let client = crate::config::build_runtime_proxy_client_with_timeouts( - "tool.pushover", - PUSHOVER_REQUEST_TIMEOUT_SECS, - 10, - ); - let response = client.post(PUSHOVER_API_URL).multipart(form).send().await?; - - let status = response.status(); - let body = response.text().await.unwrap_or_default(); - - if !status.is_success() { - return Ok(ToolResult { - success: false, - output: body, - error: Some(format!("Pushover API returned status {}", status)), - }); - } - - let api_status = serde_json::from_str::(&body) - .ok() - .and_then(|json| json.get("status").and_then(|value| value.as_i64())); - - if api_status == Some(1) { - Ok(ToolResult { - success: true, - output: format!( - "Pushover notification sent successfully. Response: {}", - body - ), - error: None, - }) - } else { - Ok(ToolResult { - success: false, - output: body, - error: Some("Pushover API returned an application-level error".into()), - }) - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::security::AutonomyLevel; - use std::fs; - use tempfile::TempDir; - - fn test_security(level: AutonomyLevel, max_actions_per_hour: u32) -> Arc { - Arc::new(SecurityPolicy { - autonomy: level, - max_actions_per_hour, - workspace_dir: std::env::temp_dir(), - ..SecurityPolicy::default() - }) - } - - #[test] - fn pushover_tool_name() { - let tool = PushoverTool::new( - test_security(AutonomyLevel::Full, 100), - PathBuf::from("/tmp"), - ); - assert_eq!(tool.name(), "pushover"); - } - - #[test] - fn pushover_tool_description() { - let tool = PushoverTool::new( - test_security(AutonomyLevel::Full, 100), - PathBuf::from("/tmp"), - ); - assert!(!tool.description().is_empty()); - } - - #[test] - fn pushover_tool_has_parameters_schema() { - let tool = PushoverTool::new( - test_security(AutonomyLevel::Full, 100), - PathBuf::from("/tmp"), - ); - let schema = tool.parameters_schema(); - assert_eq!(schema["type"], "object"); - assert!(schema["properties"].get("message").is_some()); - } - - #[test] - fn pushover_tool_requires_message() { - let tool = PushoverTool::new( - test_security(AutonomyLevel::Full, 100), - PathBuf::from("/tmp"), - ); - let schema = tool.parameters_schema(); - let required = schema["required"].as_array().unwrap(); - assert!(required.contains(&serde_json::Value::String("message".to_string()))); - } - - #[test] - fn credentials_parsed_from_env_file() { - let tmp = TempDir::new().unwrap(); - let env_path = tmp.path().join(".env"); - fs::write( - &env_path, - "PUSHOVER_TOKEN=testtoken123\nPUSHOVER_USER_KEY=userkey456\n", - ) - .unwrap(); - - let tool = PushoverTool::new( - test_security(AutonomyLevel::Full, 100), - tmp.path().to_path_buf(), - ); - let result = tool.get_credentials(); - - assert!(result.is_ok()); - let (token, user_key) = result.unwrap(); - assert_eq!(token, "testtoken123"); - assert_eq!(user_key, "userkey456"); - } - - #[test] - fn credentials_fail_without_env_file() { - let tmp = TempDir::new().unwrap(); - let tool = PushoverTool::new( - test_security(AutonomyLevel::Full, 100), - tmp.path().to_path_buf(), - ); - let result = tool.get_credentials(); - - assert!(result.is_err()); - } - - #[test] - fn credentials_fail_without_token() { - let tmp = TempDir::new().unwrap(); - let env_path = tmp.path().join(".env"); - fs::write(&env_path, "PUSHOVER_USER_KEY=userkey456\n").unwrap(); - - let tool = PushoverTool::new( - test_security(AutonomyLevel::Full, 100), - tmp.path().to_path_buf(), - ); - let result = tool.get_credentials(); - - assert!(result.is_err()); - } - - #[test] - fn credentials_fail_without_user_key() { - let tmp = TempDir::new().unwrap(); - let env_path = tmp.path().join(".env"); - fs::write(&env_path, "PUSHOVER_TOKEN=testtoken123\n").unwrap(); - - let tool = PushoverTool::new( - test_security(AutonomyLevel::Full, 100), - tmp.path().to_path_buf(), - ); - let result = tool.get_credentials(); - - assert!(result.is_err()); - } - - #[test] - fn credentials_ignore_comments() { - let tmp = TempDir::new().unwrap(); - let env_path = tmp.path().join(".env"); - fs::write(&env_path, "# This is a comment\nPUSHOVER_TOKEN=realtoken\n# Another comment\nPUSHOVER_USER_KEY=realuser\n").unwrap(); - - let tool = PushoverTool::new( - test_security(AutonomyLevel::Full, 100), - tmp.path().to_path_buf(), - ); - let result = tool.get_credentials(); - - assert!(result.is_ok()); - let (token, user_key) = result.unwrap(); - assert_eq!(token, "realtoken"); - assert_eq!(user_key, "realuser"); - } - - #[test] - fn pushover_tool_supports_priority() { - let tool = PushoverTool::new( - test_security(AutonomyLevel::Full, 100), - PathBuf::from("/tmp"), - ); - let schema = tool.parameters_schema(); - assert!(schema["properties"].get("priority").is_some()); - } - - #[test] - fn pushover_tool_supports_sound() { - let tool = PushoverTool::new( - test_security(AutonomyLevel::Full, 100), - PathBuf::from("/tmp"), - ); - let schema = tool.parameters_schema(); - assert!(schema["properties"].get("sound").is_some()); - } - - #[test] - fn credentials_support_export_and_quoted_values() { - let tmp = TempDir::new().unwrap(); - let env_path = tmp.path().join(".env"); - fs::write( - &env_path, - "export PUSHOVER_TOKEN=\"quotedtoken\"\nPUSHOVER_USER_KEY='quoteduser'\n", - ) - .unwrap(); - - let tool = PushoverTool::new( - test_security(AutonomyLevel::Full, 100), - tmp.path().to_path_buf(), - ); - let result = tool.get_credentials(); - - assert!(result.is_ok()); - let (token, user_key) = result.unwrap(); - assert_eq!(token, "quotedtoken"); - assert_eq!(user_key, "quoteduser"); - } - - #[tokio::test] - async fn execute_blocks_readonly_mode() { - let tool = PushoverTool::new( - test_security(AutonomyLevel::ReadOnly, 100), - PathBuf::from("/tmp"), - ); - - let result = tool.execute(json!({"message": "hello"})).await.unwrap(); - assert!(!result.success); - assert!(result.error.unwrap().contains("read-only")); - } - - #[tokio::test] - async fn execute_blocks_rate_limit() { - let tool = PushoverTool::new(test_security(AutonomyLevel::Full, 0), PathBuf::from("/tmp")); - - let result = tool.execute(json!({"message": "hello"})).await.unwrap(); - assert!(!result.success); - assert!(result.error.unwrap().contains("rate limit")); - } - - #[tokio::test] - async fn execute_rejects_priority_out_of_range() { - let tool = PushoverTool::new( - test_security(AutonomyLevel::Full, 100), - PathBuf::from("/tmp"), - ); - - let result = tool - .execute(json!({"message": "hello", "priority": 5})) - .await - .unwrap(); - - assert!(!result.success); - assert!(result.error.unwrap().contains("-2..=2")); - } -} diff --git a/src/tools/shell.rs b/src/tools/shell.rs index 031ed4b..bafbce9 100644 --- a/src/tools/shell.rs +++ b/src/tools/shell.rs @@ -66,7 +66,10 @@ impl Tool for ShellTool { .and_then(|v| v.as_bool()) .unwrap_or(false); + tracing::debug!(command, approved, "Shell tool invoked"); + if self.security.is_rate_limited() { + tracing::warn!(command, "Shell command rejected: rate limit exceeded"); return Ok(ToolResult { success: false, output: String::new(), @@ -122,9 +125,22 @@ impl Tool for ShellTool { match result { Ok(Ok(output)) => { + let exit_code = output.status.code(); + let success = output.status.success(); + tracing::debug!( + command, + ?exit_code, + success, + stdout_bytes = output.stdout.len(), + stderr_bytes = output.stderr.len(), + "Shell command completed" + ); + let mut stdout = String::from_utf8_lossy(&output.stdout).to_string(); let mut stderr = String::from_utf8_lossy(&output.stderr).to_string(); + tracing::trace!(command, stdout = %stdout, stderr = %stderr, "Shell command output"); + // Truncate output to prevent OOM if stdout.len() > MAX_OUTPUT_BYTES { stdout.truncate(stdout.floor_char_boundary(MAX_OUTPUT_BYTES)); @@ -136,7 +152,7 @@ impl Tool for ShellTool { } Ok(ToolResult { - success: output.status.success(), + success, output: stdout, error: if stderr.is_empty() { None @@ -145,18 +161,28 @@ impl Tool for ShellTool { }, }) } - Ok(Err(e)) => Ok(ToolResult { - success: false, - output: String::new(), - error: Some(format!("Failed to execute command: {e}")), - }), - Err(_) => Ok(ToolResult { - success: false, - output: String::new(), - error: Some(format!( - "Command timed out after {SHELL_TIMEOUT_SECS}s and was killed" - )), - }), + Ok(Err(e)) => { + tracing::warn!(command, error = %e, "Shell command failed to execute"); + Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to execute command: {e}")), + }) + } + Err(_) => { + tracing::warn!( + command, + timeout_secs = SHELL_TIMEOUT_SECS, + "Shell command timed out" + ); + Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Command timed out after {SHELL_TIMEOUT_SECS}s and was killed" + )), + }) + } } } }