fix(security): parse shell separators only when unquoted

This commit is contained in:
Chummy 2026-02-19 16:15:18 +08:00
parent a0098de28c
commit 67466254f0

View file

@ -167,45 +167,221 @@ fn skip_env_assignments(s: &str) -> &str {
}
}
/// Detect a single `&` operator (background/chain). `&&` is allowed.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum QuoteState {
None,
Single,
Double,
}
/// Split a shell command into sub-commands by unquoted separators.
///
/// Separators:
/// - `;` and newline
/// - `|`
/// - `&&`, `||`
///
/// Characters inside single or double quotes are treated as literals, so
/// `sqlite3 db "SELECT 1; SELECT 2;"` remains a single segment.
fn split_unquoted_segments(command: &str) -> Vec<String> {
let mut segments = Vec::new();
let mut current = String::new();
let mut quote = QuoteState::None;
let mut escaped = false;
let mut chars = command.chars().peekable();
let push_segment = |segments: &mut Vec<String>, current: &mut String| {
let trimmed = current.trim();
if !trimmed.is_empty() {
segments.push(trimmed.to_string());
}
current.clear();
};
while let Some(ch) = chars.next() {
match quote {
QuoteState::Single => {
if ch == '\'' {
quote = QuoteState::None;
}
current.push(ch);
}
QuoteState::Double => {
if escaped {
escaped = false;
current.push(ch);
continue;
}
if ch == '\\' {
escaped = true;
current.push(ch);
continue;
}
if ch == '"' {
quote = QuoteState::None;
}
current.push(ch);
}
QuoteState::None => {
if escaped {
escaped = false;
current.push(ch);
continue;
}
if ch == '\\' {
escaped = true;
current.push(ch);
continue;
}
match ch {
'\'' => {
quote = QuoteState::Single;
current.push(ch);
}
'"' => {
quote = QuoteState::Double;
current.push(ch);
}
';' | '\n' => push_segment(&mut segments, &mut current),
'|' => {
if chars.next_if_eq(&'|').is_some() {
// Consume full `||`; both characters are separators.
}
push_segment(&mut segments, &mut current);
}
'&' => {
if chars.next_if_eq(&'&').is_some() {
// `&&` is a separator; single `&` is handled separately.
push_segment(&mut segments, &mut current);
} else {
current.push(ch);
}
}
_ => current.push(ch),
}
}
}
}
let trimmed = current.trim();
if !trimmed.is_empty() {
segments.push(trimmed.to_string());
}
segments
}
/// Detect a single unquoted `&` operator (background/chain). `&&` is allowed.
///
/// We treat any standalone `&` as unsafe in policy validation because it can
/// chain hidden sub-commands and escape foreground timeout expectations.
fn contains_single_ampersand(s: &str) -> bool {
let bytes = s.as_bytes();
for (i, b) in bytes.iter().enumerate() {
if *b != b'&' {
fn contains_unquoted_single_ampersand(command: &str) -> bool {
let mut quote = QuoteState::None;
let mut escaped = false;
let mut chars = command.chars().peekable();
while let Some(ch) = chars.next() {
match quote {
QuoteState::Single => {
if ch == '\'' {
quote = QuoteState::None;
}
}
QuoteState::Double => {
if escaped {
escaped = false;
continue;
}
let prev_is_amp = i > 0 && bytes[i - 1] == b'&';
let next_is_amp = i + 1 < bytes.len() && bytes[i + 1] == b'&';
if !prev_is_amp && !next_is_amp {
if ch == '\\' {
escaped = true;
continue;
}
if ch == '"' {
quote = QuoteState::None;
}
}
QuoteState::None => {
if escaped {
escaped = false;
continue;
}
if ch == '\\' {
escaped = true;
continue;
}
match ch {
'\'' => quote = QuoteState::Single,
'"' => quote = QuoteState::Double,
'&' => {
if chars.next_if_eq(&'&').is_none() {
return true;
}
}
_ => {}
}
}
}
}
false
}
/// Detect an unquoted character in a shell command.
fn contains_unquoted_char(command: &str, target: char) -> bool {
let mut quote = QuoteState::None;
let mut escaped = false;
for ch in command.chars() {
match quote {
QuoteState::Single => {
if ch == '\'' {
quote = QuoteState::None;
}
}
QuoteState::Double => {
if escaped {
escaped = false;
continue;
}
if ch == '\\' {
escaped = true;
continue;
}
if ch == '"' {
quote = QuoteState::None;
continue;
}
}
QuoteState::None => {
if escaped {
escaped = false;
continue;
}
if ch == '\\' {
escaped = true;
continue;
}
match ch {
'\'' => quote = QuoteState::Single,
'"' => quote = QuoteState::Double,
_ if ch == target => return true,
_ => {}
}
}
}
}
false
}
impl SecurityPolicy {
/// Classify command risk. Any high-risk segment marks the whole command high.
pub fn command_risk_level(&self, command: &str) -> CommandRiskLevel {
let mut normalized = command.to_string();
for sep in ["&&", "||"] {
normalized = normalized.replace(sep, "\x00");
}
for sep in ['\n', ';', '|', '&'] {
normalized = normalized.replace(sep, "\x00");
}
let mut saw_medium = false;
for segment in normalized.split('\x00') {
let segment = segment.trim();
if segment.is_empty() {
continue;
}
let cmd_part = skip_env_assignments(segment);
for segment in split_unquoted_segments(command) {
let cmd_part = skip_env_assignments(&segment);
let mut words = cmd_part.split_whitespace();
let Some(base_raw) = words.next() else {
continue;
@ -369,8 +545,9 @@ impl SecurityPolicy {
return false;
}
// Block output redirections — they can write to arbitrary paths
if command.contains('>') {
// Block output redirections (`>`, `>>`) — they can write to arbitrary paths.
// Ignore quoted literals, e.g. `echo "a>b"`.
if contains_unquoted_char(command, '>') {
return false;
}
@ -385,26 +562,13 @@ impl SecurityPolicy {
// Block background command chaining (`&`), which can hide extra
// sub-commands and outlive timeout expectations. Keep `&&` allowed.
if contains_single_ampersand(command) {
if contains_unquoted_single_ampersand(command) {
return false;
}
// Split on command separators and validate each sub-command.
// We collect segments by scanning for separator characters.
let mut normalized = command.to_string();
for sep in ["&&", "||"] {
normalized = normalized.replace(sep, "\x00");
}
for sep in ['\n', ';', '|'] {
normalized = normalized.replace(sep, "\x00");
}
for segment in normalized.split('\x00') {
let segment = segment.trim();
if segment.is_empty() {
continue;
}
// Split on unquoted command separators and validate each sub-command.
let segments = split_unquoted_segments(command);
for segment in &segments {
// Strip leading env var assignments (e.g. FOO=bar cmd)
let cmd_part = skip_env_assignments(segment);
@ -432,7 +596,7 @@ impl SecurityPolicy {
}
// At least one command must be present
let has_cmd = normalized.split('\x00').any(|s| {
let has_cmd = segments.iter().any(|s| {
let s = skip_env_assignments(s.trim());
s.split_whitespace().next().is_some_and(|w| !w.is_empty())
});
@ -832,6 +996,19 @@ mod tests {
assert!(result.unwrap_err().contains("high-risk"));
}
#[test]
fn validate_command_full_mode_skips_medium_risk_approval_gate() {
let p = SecurityPolicy {
autonomy: AutonomyLevel::Full,
require_approval_for_medium_risk: true,
allowed_commands: vec!["touch".into()],
..SecurityPolicy::default()
};
let result = p.validate_command_execution("touch test.txt", false);
assert_eq!(result.unwrap(), CommandRiskLevel::Medium);
}
#[test]
fn validate_command_rejects_background_chain_bypass() {
let p = default_policy();
@ -1027,6 +1204,32 @@ mod tests {
assert!(!p.is_command_allowed("ls;rm -rf /"));
}
#[test]
fn quoted_semicolons_do_not_split_sqlite_command() {
let p = SecurityPolicy {
allowed_commands: vec!["sqlite3".into()],
..SecurityPolicy::default()
};
assert!(p.is_command_allowed(
"sqlite3 /tmp/test.db \"CREATE TABLE t(id INT); INSERT INTO t VALUES(1); SELECT * FROM t;\""
));
assert_eq!(
p.command_risk_level(
"sqlite3 /tmp/test.db \"CREATE TABLE t(id INT); INSERT INTO t VALUES(1); SELECT * FROM t;\""
),
CommandRiskLevel::Low
);
}
#[test]
fn unquoted_semicolon_after_quoted_sql_still_splits_commands() {
let p = SecurityPolicy {
allowed_commands: vec!["sqlite3".into()],
..SecurityPolicy::default()
};
assert!(!p.is_command_allowed("sqlite3 /tmp/test.db \"SELECT 1;\"; rm -rf /"));
}
#[test]
fn command_injection_backtick_blocked() {
let p = default_policy();
@ -1089,6 +1292,13 @@ mod tests {
assert!(!p.is_command_allowed("ls >> /tmp/exfil.txt"));
}
#[test]
fn quoted_ampersand_and_redirect_literals_are_not_treated_as_operators() {
let p = default_policy();
assert!(p.is_command_allowed("echo \"A&B\""));
assert!(p.is_command_allowed("echo \"A>B\""));
}
#[test]
fn command_argument_injection_blocked() {
let p = default_policy();