diff --git a/src/security/policy.rs b/src/security/policy.rs index d2f0d29..debc7f6 100644 --- a/src/security/policy.rs +++ b/src/security/policy.rs @@ -167,45 +167,221 @@ fn skip_env_assignments(s: &str) -> &str { } } -/// Detect a single `&` operator (background/chain). `&&` is allowed. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum QuoteState { + None, + Single, + Double, +} + +/// Split a shell command into sub-commands by unquoted separators. +/// +/// Separators: +/// - `;` and newline +/// - `|` +/// - `&&`, `||` +/// +/// Characters inside single or double quotes are treated as literals, so +/// `sqlite3 db "SELECT 1; SELECT 2;"` remains a single segment. +fn split_unquoted_segments(command: &str) -> Vec { + let mut segments = Vec::new(); + let mut current = String::new(); + let mut quote = QuoteState::None; + let mut escaped = false; + let mut chars = command.chars().peekable(); + + let push_segment = |segments: &mut Vec, current: &mut String| { + let trimmed = current.trim(); + if !trimmed.is_empty() { + segments.push(trimmed.to_string()); + } + current.clear(); + }; + + while let Some(ch) = chars.next() { + match quote { + QuoteState::Single => { + if ch == '\'' { + quote = QuoteState::None; + } + current.push(ch); + } + QuoteState::Double => { + if escaped { + escaped = false; + current.push(ch); + continue; + } + if ch == '\\' { + escaped = true; + current.push(ch); + continue; + } + if ch == '"' { + quote = QuoteState::None; + } + current.push(ch); + } + QuoteState::None => { + if escaped { + escaped = false; + current.push(ch); + continue; + } + if ch == '\\' { + escaped = true; + current.push(ch); + continue; + } + + match ch { + '\'' => { + quote = QuoteState::Single; + current.push(ch); + } + '"' => { + quote = QuoteState::Double; + current.push(ch); + } + ';' | '\n' => push_segment(&mut segments, &mut current), + '|' => { + if chars.next_if_eq(&'|').is_some() { + // Consume full `||`; both characters are separators. + } + push_segment(&mut segments, &mut current); + } + '&' => { + if chars.next_if_eq(&'&').is_some() { + // `&&` is a separator; single `&` is handled separately. + push_segment(&mut segments, &mut current); + } else { + current.push(ch); + } + } + _ => current.push(ch), + } + } + } + } + + let trimmed = current.trim(); + if !trimmed.is_empty() { + segments.push(trimmed.to_string()); + } + + segments +} + +/// Detect a single unquoted `&` operator (background/chain). `&&` is allowed. /// /// We treat any standalone `&` as unsafe in policy validation because it can /// chain hidden sub-commands and escape foreground timeout expectations. -fn contains_single_ampersand(s: &str) -> bool { - let bytes = s.as_bytes(); - for (i, b) in bytes.iter().enumerate() { - if *b != b'&' { - continue; - } - let prev_is_amp = i > 0 && bytes[i - 1] == b'&'; - let next_is_amp = i + 1 < bytes.len() && bytes[i + 1] == b'&'; - if !prev_is_amp && !next_is_amp { - return true; +fn contains_unquoted_single_ampersand(command: &str) -> bool { + let mut quote = QuoteState::None; + let mut escaped = false; + let mut chars = command.chars().peekable(); + + while let Some(ch) = chars.next() { + match quote { + QuoteState::Single => { + if ch == '\'' { + quote = QuoteState::None; + } + } + QuoteState::Double => { + if escaped { + escaped = false; + continue; + } + if ch == '\\' { + escaped = true; + continue; + } + if ch == '"' { + quote = QuoteState::None; + } + } + QuoteState::None => { + if escaped { + escaped = false; + continue; + } + if ch == '\\' { + escaped = true; + continue; + } + match ch { + '\'' => quote = QuoteState::Single, + '"' => quote = QuoteState::Double, + '&' => { + if chars.next_if_eq(&'&').is_none() { + return true; + } + } + _ => {} + } + } } } + + false +} + +/// Detect an unquoted character in a shell command. +fn contains_unquoted_char(command: &str, target: char) -> bool { + let mut quote = QuoteState::None; + let mut escaped = false; + + for ch in command.chars() { + match quote { + QuoteState::Single => { + if ch == '\'' { + quote = QuoteState::None; + } + } + QuoteState::Double => { + if escaped { + escaped = false; + continue; + } + if ch == '\\' { + escaped = true; + continue; + } + if ch == '"' { + quote = QuoteState::None; + continue; + } + } + QuoteState::None => { + if escaped { + escaped = false; + continue; + } + if ch == '\\' { + escaped = true; + continue; + } + match ch { + '\'' => quote = QuoteState::Single, + '"' => quote = QuoteState::Double, + _ if ch == target => return true, + _ => {} + } + } + } + } + false } impl SecurityPolicy { /// Classify command risk. Any high-risk segment marks the whole command high. pub fn command_risk_level(&self, command: &str) -> CommandRiskLevel { - let mut normalized = command.to_string(); - for sep in ["&&", "||"] { - normalized = normalized.replace(sep, "\x00"); - } - for sep in ['\n', ';', '|', '&'] { - normalized = normalized.replace(sep, "\x00"); - } - let mut saw_medium = false; - for segment in normalized.split('\x00') { - let segment = segment.trim(); - if segment.is_empty() { - continue; - } - - let cmd_part = skip_env_assignments(segment); + for segment in split_unquoted_segments(command) { + let cmd_part = skip_env_assignments(&segment); let mut words = cmd_part.split_whitespace(); let Some(base_raw) = words.next() else { continue; @@ -369,8 +545,9 @@ impl SecurityPolicy { return false; } - // Block output redirections — they can write to arbitrary paths - if command.contains('>') { + // Block output redirections (`>`, `>>`) — they can write to arbitrary paths. + // Ignore quoted literals, e.g. `echo "a>b"`. + if contains_unquoted_char(command, '>') { return false; } @@ -385,26 +562,13 @@ impl SecurityPolicy { // Block background command chaining (`&`), which can hide extra // sub-commands and outlive timeout expectations. Keep `&&` allowed. - if contains_single_ampersand(command) { + if contains_unquoted_single_ampersand(command) { return false; } - // Split on command separators and validate each sub-command. - // We collect segments by scanning for separator characters. - let mut normalized = command.to_string(); - for sep in ["&&", "||"] { - normalized = normalized.replace(sep, "\x00"); - } - for sep in ['\n', ';', '|'] { - normalized = normalized.replace(sep, "\x00"); - } - - for segment in normalized.split('\x00') { - let segment = segment.trim(); - if segment.is_empty() { - continue; - } - + // Split on unquoted command separators and validate each sub-command. + let segments = split_unquoted_segments(command); + for segment in &segments { // Strip leading env var assignments (e.g. FOO=bar cmd) let cmd_part = skip_env_assignments(segment); @@ -432,7 +596,7 @@ impl SecurityPolicy { } // At least one command must be present - let has_cmd = normalized.split('\x00').any(|s| { + let has_cmd = segments.iter().any(|s| { let s = skip_env_assignments(s.trim()); s.split_whitespace().next().is_some_and(|w| !w.is_empty()) }); @@ -832,6 +996,19 @@ mod tests { assert!(result.unwrap_err().contains("high-risk")); } + #[test] + fn validate_command_full_mode_skips_medium_risk_approval_gate() { + let p = SecurityPolicy { + autonomy: AutonomyLevel::Full, + require_approval_for_medium_risk: true, + allowed_commands: vec!["touch".into()], + ..SecurityPolicy::default() + }; + + let result = p.validate_command_execution("touch test.txt", false); + assert_eq!(result.unwrap(), CommandRiskLevel::Medium); + } + #[test] fn validate_command_rejects_background_chain_bypass() { let p = default_policy(); @@ -1027,6 +1204,32 @@ mod tests { assert!(!p.is_command_allowed("ls;rm -rf /")); } + #[test] + fn quoted_semicolons_do_not_split_sqlite_command() { + let p = SecurityPolicy { + allowed_commands: vec!["sqlite3".into()], + ..SecurityPolicy::default() + }; + assert!(p.is_command_allowed( + "sqlite3 /tmp/test.db \"CREATE TABLE t(id INT); INSERT INTO t VALUES(1); SELECT * FROM t;\"" + )); + assert_eq!( + p.command_risk_level( + "sqlite3 /tmp/test.db \"CREATE TABLE t(id INT); INSERT INTO t VALUES(1); SELECT * FROM t;\"" + ), + CommandRiskLevel::Low + ); + } + + #[test] + fn unquoted_semicolon_after_quoted_sql_still_splits_commands() { + let p = SecurityPolicy { + allowed_commands: vec!["sqlite3".into()], + ..SecurityPolicy::default() + }; + assert!(!p.is_command_allowed("sqlite3 /tmp/test.db \"SELECT 1;\"; rm -rf /")); + } + #[test] fn command_injection_backtick_blocked() { let p = default_policy(); @@ -1089,6 +1292,13 @@ mod tests { assert!(!p.is_command_allowed("ls >> /tmp/exfil.txt")); } + #[test] + fn quoted_ampersand_and_redirect_literals_are_not_treated_as_operators() { + let p = default_policy(); + assert!(p.is_command_allowed("echo \"A&B\"")); + assert!(p.is_command_allowed("echo \"A>B\"")); + } + #[test] fn command_argument_injection_blocked() { let p = default_policy();