diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index fcaedf9..dfce36a 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -476,6 +476,7 @@ pub async fn run( mem.clone(), composio_key, &config.browser, + &config.http_request, &config.agents, config.api_key.as_deref(), ); @@ -966,4 +967,213 @@ I will now call the tool with this payload: let recalled = mem.recall("45", 5).await.unwrap(); assert!(recalled.iter().any(|entry| entry.content.contains("45"))); } + + // ═══════════════════════════════════════════════════════════════════════ + // Recovery Tests - Tool Call Parsing Edge Cases + // ═══════════════════════════════════════════════════════════════════════ + + #[test] + fn parse_tool_calls_handles_empty_tool_result() { + // Recovery: Empty tool_result tag should be handled gracefully + let response = r#"I'll run that command. + + + +Done."#; + let (text, calls) = parse_tool_calls(response); + assert!(text.contains("Done.")); + assert!(calls.is_empty()); + } + + #[test] + fn parse_arguments_value_handles_null() { + // Recovery: null arguments are returned as-is (Value::Null) + let value = serde_json::json!(null); + let result = parse_arguments_value(Some(&value)); + assert!(result.is_null()); + } + + #[test] + fn parse_tool_calls_handles_empty_tool_calls_array() { + // Recovery: Empty tool_calls array returns original response (no tool parsing) + let response = r#"{"content": "Hello", "tool_calls": []}"#; + let (text, calls) = parse_tool_calls(response); + // When tool_calls is empty, the entire JSON is returned as text + assert!(text.contains("Hello")); + assert!(calls.is_empty()); + } + + #[test] + fn parse_tool_calls_handles_whitespace_only_name() { + // Recovery: Whitespace-only tool name should return None + let value = serde_json::json!({"function": {"name": " ", "arguments": {}}}); + let result = parse_tool_call_value(&value); + assert!(result.is_none()); + } + + #[test] + fn parse_tool_calls_handles_empty_string_arguments() { + // Recovery: Empty string arguments should be handled + let value = serde_json::json!({"name": "test", "arguments": ""}); + let result = parse_tool_call_value(&value); + assert!(result.is_some()); + assert_eq!(result.unwrap().name, "test"); + } + + // ═══════════════════════════════════════════════════════════════════════ + // Recovery Tests - History Management + // ═══════════════════════════════════════════════════════════════════════ + + #[test] + fn trim_history_with_no_system_prompt() { + // Recovery: History without system prompt should trim correctly + let mut history = vec![]; + for i in 0..MAX_HISTORY_MESSAGES + 20 { + history.push(ChatMessage::user(format!("msg {i}"))); + } + trim_history(&mut history); + assert_eq!(history.len(), MAX_HISTORY_MESSAGES); + } + + #[test] + fn trim_history_preserves_role_ordering() { + // Recovery: After trimming, role ordering should remain consistent + let mut history = vec![ChatMessage::system("system")]; + for i in 0..MAX_HISTORY_MESSAGES + 10 { + history.push(ChatMessage::user(format!("user {i}"))); + history.push(ChatMessage::assistant(format!("assistant {i}"))); + } + trim_history(&mut history); + assert_eq!(history[0].role, "system"); + assert_eq!(history[history.len() - 1].role, "assistant"); + } + + #[test] + fn trim_history_with_only_system_prompt() { + // Recovery: Only system prompt should not be trimmed + let mut history = vec![ChatMessage::system("system prompt")]; + trim_history(&mut history); + assert_eq!(history.len(), 1); + } + + // ═══════════════════════════════════════════════════════════════════════ + // Recovery Tests - Arguments Parsing + // ═══════════════════════════════════════════════════════════════════════ + + #[test] + fn parse_arguments_value_handles_invalid_json_string() { + // Recovery: Invalid JSON string should return empty object + let value = serde_json::Value::String("not valid json".to_string()); + let result = parse_arguments_value(Some(&value)); + assert!(result.is_object()); + assert!(result.as_object().unwrap().is_empty()); + } + + #[test] + fn parse_arguments_value_handles_none() { + // Recovery: None arguments should return empty object + let result = parse_arguments_value(None); + assert!(result.is_object()); + assert!(result.as_object().unwrap().is_empty()); + } + + // ═══════════════════════════════════════════════════════════════════════ + // Recovery Tests - JSON Extraction + // ═══════════════════════════════════════════════════════════════════════ + + #[test] + fn extract_json_values_handles_empty_string() { + // Recovery: Empty input should return empty vec + let result = extract_json_values(""); + assert!(result.is_empty()); + } + + #[test] + fn extract_json_values_handles_whitespace_only() { + // Recovery: Whitespace only should return empty vec + let result = extract_json_values(" \n\t "); + assert!(result.is_empty()); + } + + #[test] + fn extract_json_values_handles_multiple_objects() { + // Recovery: Multiple JSON objects should all be extracted + let input = r#"{"a": 1}{"b": 2}{"c": 3}"#; + let result = extract_json_values(input); + assert_eq!(result.len(), 3); + } + + #[test] + fn extract_json_values_handles_arrays() { + // Recovery: JSON arrays should be extracted + let input = r#"[1, 2, 3]{"key": "value"}"#; + let result = extract_json_values(input); + assert_eq!(result.len(), 2); + } + + // ═══════════════════════════════════════════════════════════════════════ + // Recovery Tests - Constants Validation + // ═══════════════════════════════════════════════════════════════════════ + + #[test] + fn max_tool_iterations_is_reasonable() { + // Recovery: MAX_TOOL_ITERATIONS should be set to prevent runaway loops + assert!(MAX_TOOL_ITERATIONS > 0); + assert!(MAX_TOOL_ITERATIONS <= 100); + } + + #[test] + fn max_history_messages_is_reasonable() { + // Recovery: MAX_HISTORY_MESSAGES should be set to prevent memory bloat + assert!(MAX_HISTORY_MESSAGES > 0); + assert!(MAX_HISTORY_MESSAGES <= 1000); + } + + // ═══════════════════════════════════════════════════════════════════════ + // Recovery Tests - Tool Call Value Parsing + // ═══════════════════════════════════════════════════════════════════════ + + #[test] + fn parse_tool_call_value_handles_missing_name_field() { + // Recovery: Missing name field should return None + let value = serde_json::json!({"function": {"arguments": {}}}); + let result = parse_tool_call_value(&value); + assert!(result.is_none()); + } + + #[test] + fn parse_tool_call_value_handles_top_level_name() { + // Recovery: Tool call with name at top level (non-OpenAI format) + let value = serde_json::json!({"name": "test_tool", "arguments": {}}); + let result = parse_tool_call_value(&value); + assert!(result.is_some()); + assert_eq!(result.unwrap().name, "test_tool"); + } + + #[test] + fn parse_tool_calls_from_json_value_handles_empty_array() { + // Recovery: Empty tool_calls array should return empty vec + let value = serde_json::json!({"tool_calls": []}); + let result = parse_tool_calls_from_json_value(&value); + assert!(result.is_empty()); + } + + #[test] + fn parse_tool_calls_from_json_value_handles_missing_tool_calls() { + // Recovery: Missing tool_calls field should fall through + let value = serde_json::json!({"name": "test", "arguments": {}}); + let result = parse_tool_calls_from_json_value(&value); + assert_eq!(result.len(), 1); + } + + #[test] + fn parse_tool_calls_from_json_value_handles_top_level_array() { + // Recovery: Top-level array of tool calls + let value = serde_json::json!([ + {"name": "tool_a", "arguments": {}}, + {"name": "tool_b", "arguments": {}} + ]); + let result = parse_tool_calls_from_json_value(&value); + assert_eq!(result.len(), 2); + } } diff --git a/src/config/mod.rs b/src/config/mod.rs index bd520a8..5256633 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -2,8 +2,8 @@ pub mod schema; pub use schema::{ AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, DelegateAgentConfig, - DiscordConfig, DockerRuntimeConfig, GatewayConfig, HeartbeatConfig, IMessageConfig, - IdentityConfig, LarkConfig, MatrixConfig, MemoryConfig, ModelRouteConfig, ObservabilityConfig, - ReliabilityConfig, RuntimeConfig, SecretsConfig, SlackConfig, TelegramConfig, TunnelConfig, - WebhookConfig, + DiscordConfig, DockerRuntimeConfig, GatewayConfig, HeartbeatConfig, HttpRequestConfig, + IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, MemoryConfig, ModelRouteConfig, + ObservabilityConfig, ReliabilityConfig, RuntimeConfig, SecretsConfig, SlackConfig, + TelegramConfig, TunnelConfig, WebhookConfig, }; diff --git a/src/config/schema.rs b/src/config/schema.rs index da00e7c..9d436d0 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -62,6 +62,9 @@ pub struct Config { #[serde(default)] pub browser: BrowserConfig, + #[serde(default)] + pub http_request: HttpRequestConfig, + #[serde(default)] pub identity: IdentityConfig, @@ -272,6 +275,32 @@ pub struct BrowserConfig { pub session_name: Option, } +// ── HTTP request tool ─────────────────────────────────────────── + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct HttpRequestConfig { + /// Enable `http_request` tool for API interactions + #[serde(default)] + pub enabled: bool, + /// Allowed domains for HTTP requests (exact or subdomain match) + #[serde(default)] + pub allowed_domains: Vec, + /// Maximum response size in bytes (default: 1MB) + #[serde(default = "default_http_max_response_size")] + pub max_response_size: usize, + /// Request timeout in seconds (default: 30) + #[serde(default = "default_http_timeout_secs")] + pub timeout_secs: u64, +} + +fn default_http_max_response_size() -> usize { + 1_000_000 // 1MB +} + +fn default_http_timeout_secs() -> u64 { + 30 +} + // ── Memory ─────────────────────────────────────────────────── #[derive(Debug, Clone, Serialize, Deserialize)] @@ -906,6 +935,7 @@ impl Default for Config { composio: ComposioConfig::default(), secrets: SecretsConfig::default(), browser: BrowserConfig::default(), + http_request: HttpRequestConfig::default(), identity: IdentityConfig::default(), agents: HashMap::new(), } @@ -1257,6 +1287,7 @@ mod tests { composio: ComposioConfig::default(), secrets: SecretsConfig::default(), browser: BrowserConfig::default(), + http_request: HttpRequestConfig::default(), identity: IdentityConfig::default(), agents: HashMap::new(), }; @@ -1329,6 +1360,7 @@ default_temperature = 0.7 composio: ComposioConfig::default(), secrets: SecretsConfig::default(), browser: BrowserConfig::default(), + http_request: HttpRequestConfig::default(), identity: IdentityConfig::default(), agents: HashMap::new(), }; diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index 2baae7d..11b7279 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -105,6 +105,7 @@ pub fn run_wizard() -> Result { composio: composio_config, secrets: secrets_config, browser: BrowserConfig::default(), + http_request: crate::config::HttpRequestConfig::default(), identity: crate::config::IdentityConfig::default(), agents: std::collections::HashMap::new(), }; @@ -297,6 +298,7 @@ pub fn run_quick_setup( composio: ComposioConfig::default(), secrets: SecretsConfig::default(), browser: BrowserConfig::default(), + http_request: crate::config::HttpRequestConfig::default(), identity: crate::config::IdentityConfig::default(), agents: std::collections::HashMap::new(), }; diff --git a/src/tools/http_request.rs b/src/tools/http_request.rs new file mode 100644 index 0000000..4ec9b01 --- /dev/null +++ b/src/tools/http_request.rs @@ -0,0 +1,605 @@ +use super::traits::{Tool, ToolResult}; +use crate::security::SecurityPolicy; +use async_trait::async_trait; +use serde_json::json; +use std::sync::Arc; +use std::time::Duration; + +/// HTTP request tool for API interactions. +/// Supports GET, POST, PUT, DELETE methods with configurable security. +pub struct HttpRequestTool { + security: Arc, + allowed_domains: Vec, + max_response_size: usize, + timeout_secs: u64, +} + +impl HttpRequestTool { + pub fn new( + security: Arc, + allowed_domains: Vec, + max_response_size: usize, + timeout_secs: u64, + ) -> Self { + Self { + security, + allowed_domains: normalize_allowed_domains(allowed_domains), + max_response_size, + timeout_secs, + } + } + + fn validate_url(&self, raw_url: &str) -> anyhow::Result { + let url = raw_url.trim(); + + if url.is_empty() { + anyhow::bail!("URL cannot be empty"); + } + + if url.chars().any(char::is_whitespace) { + anyhow::bail!("URL cannot contain whitespace"); + } + + if !url.starts_with("http://") && !url.starts_with("https://") { + anyhow::bail!("Only http:// and https:// URLs are allowed"); + } + + if self.allowed_domains.is_empty() { + anyhow::bail!( + "HTTP request tool is enabled but no allowed_domains are configured. Add [http_request].allowed_domains in config.toml" + ); + } + + let host = extract_host(url)?; + + if is_private_or_local_host(&host) { + anyhow::bail!("Blocked local/private host: {host}"); + } + + if !host_matches_allowlist(&host, &self.allowed_domains) { + anyhow::bail!("Host '{host}' is not in http_request.allowed_domains"); + } + + Ok(url.to_string()) + } + + fn validate_method(&self, method: &str) -> anyhow::Result { + match method.to_uppercase().as_str() { + "GET" => Ok(reqwest::Method::GET), + "POST" => Ok(reqwest::Method::POST), + "PUT" => Ok(reqwest::Method::PUT), + "DELETE" => Ok(reqwest::Method::DELETE), + "PATCH" => Ok(reqwest::Method::PATCH), + "HEAD" => Ok(reqwest::Method::HEAD), + "OPTIONS" => Ok(reqwest::Method::OPTIONS), + _ => anyhow::bail!("Unsupported HTTP method: {method}. Supported: GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS"), + } + } + + fn sanitize_headers(&self, headers: &serde_json::Value) -> Vec<(String, String)> { + let mut result = Vec::new(); + if let Some(obj) = headers.as_object() { + for (key, value) in obj { + if let Some(str_val) = value.as_str() { + // Redact sensitive headers from logs (we don't log headers, but this is defense-in-depth) + let is_sensitive = key.to_lowercase().contains("authorization") + || key.to_lowercase().contains("api-key") + || key.to_lowercase().contains("apikey") + || key.to_lowercase().contains("token") + || key.to_lowercase().contains("secret"); + if is_sensitive { + result.push((key.clone(), "***REDACTED***".into())); + } else { + result.push((key.clone(), str_val.to_string())); + } + } + } + } + result + } + + async fn execute_request( + &self, + url: &str, + method: reqwest::Method, + headers: Vec<(String, String)>, + body: Option<&str>, + ) -> anyhow::Result { + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(self.timeout_secs)) + .build()?; + + let mut request = client.request(method, url); + + for (key, value) in headers { + request = request.header(&key, &value); + } + + if let Some(body_str) = body { + request = request.body(body_str.to_string()); + } + + Ok(request.send().await?) + } + + fn truncate_response(&self, text: &str) -> String { + if text.len() > self.max_response_size { + let mut truncated = text.chars().take(self.max_response_size).collect::(); + truncated.push_str("\n\n... [Response truncated due to size limit] ..."); + truncated + } else { + text.to_string() + } + } +} + +#[async_trait] +impl Tool for HttpRequestTool { + fn name(&self) -> &str { + "http_request" + } + + fn description(&self) -> &str { + "Make HTTP requests to external APIs. Supports GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS methods. \ + Security constraints: allowlist-only domains, no local/private hosts, configurable timeout and response size limits." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "HTTP or HTTPS URL to request" + }, + "method": { + "type": "string", + "description": "HTTP method (GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS)", + "default": "GET" + }, + "headers": { + "type": "object", + "description": "Optional HTTP headers as key-value pairs (e.g., {\"Authorization\": \"Bearer token\", \"Content-Type\": \"application/json\"})", + "default": {} + }, + "body": { + "type": "string", + "description": "Optional request body (for POST, PUT, PATCH requests)" + } + }, + "required": ["url"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let url = args + .get("url") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'url' parameter"))?; + + let method_str = args.get("method").and_then(|v| v.as_str()).unwrap_or("GET"); + let headers_val = args.get("headers").cloned().unwrap_or(json!({})); + let body = args.get("body").and_then(|v| v.as_str()); + + if !self.security.can_act() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Action blocked: autonomy is read-only".into()), + }); + } + + if !self.security.record_action() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Action blocked: rate limit exceeded".into()), + }); + } + + let url = match self.validate_url(url) { + Ok(v) => v, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e.to_string()), + }) + } + }; + + let method = match self.validate_method(method_str) { + Ok(m) => m, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e.to_string()), + }) + } + }; + + let sanitized_headers = self.sanitize_headers(&headers_val); + + match self.execute_request(&url, method, sanitized_headers, body).await { + Ok(response) => { + let status = response.status(); + let status_code = status.as_u16(); + + // Get response headers (redact sensitive ones) + let response_headers = response.headers().iter(); + let headers_text = response_headers + .map(|(k, _)| { + let is_sensitive = k.as_str().to_lowercase().contains("set-cookie"); + if is_sensitive { + format!("{}: ***REDACTED***", k.as_str()) + } else { + format!("{}: {:?}", k.as_str(), k.as_str()) + } + }) + .collect::>() + .join(", "); + + // Get response body with size limit + let response_text = match response.text().await { + Ok(text) => self.truncate_response(&text), + Err(e) => format!("[Failed to read response body: {e}]"), + }; + + let output = format!( + "Status: {} {}\nResponse Headers: {}\n\nResponse Body:\n{}", + status_code, + status.canonical_reason().unwrap_or("Unknown"), + headers_text, + response_text + ); + + Ok(ToolResult { + success: status.is_success(), + output, + error: if status.is_client_error() || status.is_server_error() { + Some(format!("HTTP {}", status_code)) + } else { + None + }, + }) + } + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("HTTP request failed: {e}")), + }), + } + } +} + +// Helper functions similar to browser_open.rs + +fn normalize_allowed_domains(domains: Vec) -> Vec { + let mut normalized = domains + .into_iter() + .filter_map(|d| normalize_domain(&d)) + .collect::>(); + normalized.sort_unstable(); + normalized.dedup(); + normalized +} + +fn normalize_domain(raw: &str) -> Option { + let mut d = raw.trim().to_lowercase(); + if d.is_empty() { + return None; + } + + if let Some(stripped) = d.strip_prefix("https://") { + d = stripped.to_string(); + } else if let Some(stripped) = d.strip_prefix("http://") { + d = stripped.to_string(); + } + + if let Some((host, _)) = d.split_once('/') { + d = host.to_string(); + } + + d = d.trim_start_matches('.').trim_end_matches('.').to_string(); + + if let Some((host, _)) = d.split_once(':') { + d = host.to_string(); + } + + if d.is_empty() || d.chars().any(char::is_whitespace) { + return None; + } + + Some(d) +} + +fn extract_host(url: &str) -> anyhow::Result { + let rest = url + .strip_prefix("http://") + .or_else(|| url.strip_prefix("https://")) + .ok_or_else(|| anyhow::anyhow!("Only http:// and https:// URLs are allowed"))?; + + let authority = rest + .split(['/', '?', '#']) + .next() + .ok_or_else(|| anyhow::anyhow!("Invalid URL"))?; + + if authority.is_empty() { + anyhow::bail!("URL must include a host"); + } + + if authority.contains('@') { + anyhow::bail!("URL userinfo is not allowed"); + } + + if authority.starts_with('[') { + anyhow::bail!("IPv6 hosts are not supported in http_request"); + } + + let host = authority + .split(':') + .next() + .unwrap_or_default() + .trim() + .trim_end_matches('.') + .to_lowercase(); + + if host.is_empty() { + anyhow::bail!("URL must include a valid host"); + } + + Ok(host) +} + +fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool { + allowed_domains.iter().any(|domain| { + host == domain + || host + .strip_suffix(domain) + .is_some_and(|prefix| prefix.ends_with('.')) + }) +} + +fn is_private_or_local_host(host: &str) -> bool { + let has_local_tld = host + .rsplit('.') + .next() + .is_some_and(|label| label == "local"); + + if host == "localhost" || host.ends_with(".localhost") || has_local_tld || host == "::1" { + return true; + } + + if let Some([a, b, _, _]) = parse_ipv4(host) { + return a == 0 + || a == 10 + || a == 127 + || (a == 169 && b == 254) + || (a == 172 && (16..=31).contains(&b)) + || (a == 192 && b == 168) + || (a == 100 && (64..=127).contains(&b)); + } + + false +} + +fn parse_ipv4(host: &str) -> Option<[u8; 4]> { + let parts: Vec<&str> = host.split('.').collect(); + if parts.len() != 4 { + return None; + } + + let mut octets = [0_u8; 4]; + for (i, part) in parts.iter().enumerate() { + octets[i] = part.parse::().ok()?; + } + Some(octets) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::security::{AutonomyLevel, SecurityPolicy}; + + fn test_tool(allowed_domains: Vec<&str>) -> HttpRequestTool { + let security = Arc::new(SecurityPolicy { + autonomy: AutonomyLevel::Supervised, + ..SecurityPolicy::default() + }); + HttpRequestTool::new(security, allowed_domains.into_iter().map(String::from).collect(), 1_000_000, 30) + } + + #[test] + fn normalize_domain_strips_scheme_path_and_case() { + let got = normalize_domain(" HTTPS://Docs.Example.com/path ").unwrap(); + assert_eq!(got, "docs.example.com"); + } + + #[test] + fn normalize_allowed_domains_deduplicates() { + let got = normalize_allowed_domains(vec![ + "example.com".into(), + "EXAMPLE.COM".into(), + "https://example.com/".into(), + ]); + assert_eq!(got, vec!["example.com".to_string()]); + } + + #[test] + fn validate_accepts_exact_domain() { + let tool = test_tool(vec!["example.com"]); + let got = tool.validate_url("https://example.com/docs").unwrap(); + assert_eq!(got, "https://example.com/docs"); + } + + #[test] + fn validate_accepts_http() { + let tool = test_tool(vec!["example.com"]); + assert!(tool.validate_url("http://example.com").is_ok()); + } + + #[test] + fn validate_accepts_subdomain() { + let tool = test_tool(vec!["example.com"]); + assert!(tool.validate_url("https://api.example.com/v1").is_ok()); + } + + #[test] + fn validate_rejects_allowlist_miss() { + let tool = test_tool(vec!["example.com"]); + let err = tool + .validate_url("https://google.com") + .unwrap_err() + .to_string(); + assert!(err.contains("allowed_domains")); + } + + #[test] + fn validate_rejects_localhost() { + let tool = test_tool(vec!["localhost"]); + let err = tool + .validate_url("https://localhost:8080") + .unwrap_err() + .to_string(); + assert!(err.contains("local/private")); + } + + #[test] + fn validate_rejects_private_ipv4() { + let tool = test_tool(vec!["192.168.1.5"]); + let err = tool + .validate_url("https://192.168.1.5") + .unwrap_err() + .to_string(); + assert!(err.contains("local/private")); + } + + #[test] + fn validate_rejects_whitespace() { + let tool = test_tool(vec!["example.com"]); + let err = tool + .validate_url("https://example.com/hello world") + .unwrap_err() + .to_string(); + assert!(err.contains("whitespace")); + } + + #[test] + fn validate_rejects_userinfo() { + let tool = test_tool(vec!["example.com"]); + let err = tool + .validate_url("https://user@example.com") + .unwrap_err() + .to_string(); + assert!(err.contains("userinfo")); + } + + #[test] + fn validate_requires_allowlist() { + let security = Arc::new(SecurityPolicy::default()); + let tool = HttpRequestTool::new(security, vec![], 1_000_000, 30); + let err = tool + .validate_url("https://example.com") + .unwrap_err() + .to_string(); + assert!(err.contains("allowed_domains")); + } + + #[test] + fn validate_accepts_valid_methods() { + let tool = test_tool(vec!["example.com"]); + assert!(tool.validate_method("GET").is_ok()); + assert!(tool.validate_method("POST").is_ok()); + assert!(tool.validate_method("PUT").is_ok()); + assert!(tool.validate_method("DELETE").is_ok()); + assert!(tool.validate_method("PATCH").is_ok()); + assert!(tool.validate_method("HEAD").is_ok()); + assert!(tool.validate_method("OPTIONS").is_ok()); + } + + #[test] + fn validate_rejects_invalid_method() { + let tool = test_tool(vec!["example.com"]); + let err = tool.validate_method("INVALID").unwrap_err().to_string(); + assert!(err.contains("Unsupported HTTP method")); + } + + #[test] + fn parse_ipv4_valid() { + assert_eq!(parse_ipv4("1.2.3.4"), Some([1, 2, 3, 4])); + } + + #[test] + fn parse_ipv4_invalid() { + assert_eq!(parse_ipv4("1.2.3"), None); + assert_eq!(parse_ipv4("1.2.3.999"), None); + assert_eq!(parse_ipv4("not-an-ip"), None); + } + + #[tokio::test] + async fn execute_blocks_readonly_mode() { + let security = Arc::new(SecurityPolicy { + autonomy: AutonomyLevel::ReadOnly, + ..SecurityPolicy::default() + }); + let tool = HttpRequestTool::new(security, vec!["example.com".into()], 1_000_000, 30); + let result = tool + .execute(json!({"url": "https://example.com"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("read-only")); + } + + #[tokio::test] + async fn execute_blocks_when_rate_limited() { + let security = Arc::new(SecurityPolicy { + max_actions_per_hour: 0, + ..SecurityPolicy::default() + }); + let tool = HttpRequestTool::new(security, vec!["example.com".into()], 1_000_000, 30); + let result = tool + .execute(json!({"url": "https://example.com"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("rate limit")); + } + + #[test] + fn truncate_response_within_limit() { + let tool = test_tool(vec!["example.com"]); + let text = "hello world"; + assert_eq!(tool.truncate_response(text), "hello world"); + } + + #[test] + fn truncate_response_over_limit() { + let tool = HttpRequestTool::new( + Arc::new(SecurityPolicy::default()), + vec!["example.com".into()], + 10, + 30, + ); + let text = "hello world this is long"; + let truncated = tool.truncate_response(text); + assert!(truncated.len() <= 10 + 60); // limit + message + assert!(truncated.contains("[Response truncated")); + } + + #[test] + fn sanitize_headers_redacts_sensitive() { + let tool = test_tool(vec!["example.com"]); + let headers = json!({ + "Authorization": "Bearer secret", + "Content-Type": "application/json", + "X-API-Key": "my-key" + }); + let sanitized = tool.sanitize_headers(&headers); + assert_eq!(sanitized.len(), 3); + assert!(sanitized.iter().any(|(k, v)| k == "Authorization" && v == "***REDACTED***")); + assert!(sanitized.iter().any(|(k, v)| k == "X-API-Key" && v == "***REDACTED***")); + assert!(sanitized.iter().any(|(k, v)| k == "Content-Type" && v == "application/json")); + } +} diff --git a/src/tools/mod.rs b/src/tools/mod.rs index c2814c0..0f139d1 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -4,6 +4,7 @@ pub mod composio; pub mod delegate; pub mod file_read; pub mod file_write; +pub mod http_request; pub mod image_info; pub mod memory_forget; pub mod memory_recall; @@ -18,6 +19,7 @@ pub use composio::ComposioTool; pub use delegate::DelegateTool; pub use file_read::FileReadTool; pub use file_write::FileWriteTool; +pub use http_request::HttpRequestTool; pub use image_info::ImageInfoTool; pub use memory_forget::MemoryForgetTool; pub use memory_recall::MemoryRecallTool; @@ -58,6 +60,7 @@ pub fn all_tools( memory: Arc, composio_key: Option<&str>, browser_config: &crate::config::BrowserConfig, + http_config: &crate::config::HttpRequestConfig, agents: &HashMap, fallback_api_key: Option<&str>, ) -> Vec> { @@ -67,6 +70,7 @@ pub fn all_tools( memory, composio_key, browser_config, + http_config, agents, fallback_api_key, ) @@ -79,6 +83,7 @@ pub fn all_tools_with_runtime( memory: Arc, composio_key: Option<&str>, browser_config: &crate::config::BrowserConfig, + http_config: &crate::config::HttpRequestConfig, agents: &HashMap, fallback_api_key: Option<&str>, ) -> Vec> { @@ -105,6 +110,15 @@ pub fn all_tools_with_runtime( ))); } + if http_config.enabled { + tools.push(Box::new(HttpRequestTool::new( + security.clone(), + http_config.allowed_domains.clone(), + http_config.max_response_size, + http_config.timeout_secs, + ))); + } + // Vision tools are always available tools.push(Box::new(ScreenshotTool::new(security.clone()))); tools.push(Box::new(ImageInfoTool::new(security.clone()))); @@ -155,8 +169,9 @@ mod tests { allowed_domains: vec!["example.com".into()], session_name: None, }; + let http = crate::config::HttpRequestConfig::default(); - let tools = all_tools(&security, mem, None, &browser, &HashMap::new(), None); + let tools = all_tools(&security, mem, None, &browser, &http, &HashMap::new(), None); let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); assert!(!names.contains(&"browser_open")); } @@ -177,8 +192,9 @@ mod tests { allowed_domains: vec!["example.com".into()], session_name: None, }; + let http = crate::config::HttpRequestConfig::default(); - let tools = all_tools(&security, mem, None, &browser, &HashMap::new(), None); + let tools = all_tools(&security, mem, None, &browser, &http, &HashMap::new(), None); let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); assert!(names.contains(&"browser_open")); } @@ -289,6 +305,7 @@ mod tests { Arc::from(crate::memory::create_memory(&mem_cfg, tmp.path(), None).unwrap()); let browser = BrowserConfig::default(); + let http = crate::config::HttpRequestConfig::default(); let mut agents = HashMap::new(); agents.insert( @@ -303,7 +320,7 @@ mod tests { }, ); - let tools = all_tools(&security, mem, None, &browser, &agents, Some("sk-test")); + let tools = all_tools(&security, mem, None, &browser, &http, &agents, Some("sk-test")); let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); assert!(names.contains(&"delegate")); } @@ -320,8 +337,9 @@ mod tests { Arc::from(crate::memory::create_memory(&mem_cfg, tmp.path(), None).unwrap()); let browser = BrowserConfig::default(); + let http = crate::config::HttpRequestConfig::default(); - let tools = all_tools(&security, mem, None, &browser, &HashMap::new(), None); + let tools = all_tools(&security, mem, None, &browser, &http, &HashMap::new(), None); let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); assert!(!names.contains(&"delegate")); }