diff --git a/src/channels/telegram.rs b/src/channels/telegram.rs index fa7f130..7602690 100644 --- a/src/channels/telegram.rs +++ b/src/channels/telegram.rs @@ -629,7 +629,7 @@ impl TelegramChannel { if let Some(code) = Self::extract_bind_code(text) { if let Some(pairing) = self.pairing.as_ref() { - match pairing.try_pair(code).await { + match pairing.try_pair(code, &chat_id).await { Ok(Some(_token)) => { let bind_identity = normalized_sender_id.clone().or_else(|| { if normalized_username.is_empty() || normalized_username == "unknown" { diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 2f56909..b42b4e1 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -610,7 +610,7 @@ async fn handle_pair( .and_then(|v| v.to_str().ok()) .unwrap_or(""); - match state.pairing.try_pair(code).await { + match state.pairing.try_pair(code, &rate_key).await { Ok(Some(token)) => { tracing::info!("🔐 New client paired successfully"); if let Err(err) = persist_pairing_tokens(state.config.clone(), &state.pairing).await { @@ -1457,7 +1457,7 @@ mod tests { let guard = PairingGuard::new(true, &[]); let code = guard.pairing_code().unwrap(); - let token = guard.try_pair(&code).await.unwrap().unwrap(); + let token = guard.try_pair(&code, "test_client").await.unwrap().unwrap(); assert!(guard.is_authenticated(&token)); let shared_config = Arc::new(Mutex::new(config)); diff --git a/src/security/pairing.rs b/src/security/pairing.rs index c6a14e5..b772b38 100644 --- a/src/security/pairing.rs +++ b/src/security/pairing.rs @@ -10,7 +10,7 @@ use parking_lot::Mutex; use sha2::{Digest, Sha256}; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; use std::time::Instant; @@ -18,6 +18,8 @@ use std::time::Instant; const MAX_PAIR_ATTEMPTS: u32 = 5; /// Lockout duration after too many failed pairing attempts. const PAIR_LOCKOUT_SECS: u64 = 300; // 5 minutes +/// Maximum number of tracked client entries to bound memory usage. +const MAX_TRACKED_CLIENTS: usize = 1024; /// Manages pairing state for the gateway. /// @@ -33,8 +35,8 @@ pub struct PairingGuard { pairing_code: Arc>>, /// Set of SHA-256 hashed bearer tokens (persisted across restarts). paired_tokens: Arc>>, - /// Brute-force protection: failed attempt counter + lockout time. - failed_attempts: Arc)>>, + /// Brute-force protection: per-client failed attempt counter + lockout time. + failed_attempts: Arc)>>>, } impl PairingGuard { @@ -66,7 +68,7 @@ impl PairingGuard { require_pairing, pairing_code: Arc::new(Mutex::new(code)), paired_tokens: Arc::new(Mutex::new(tokens)), - failed_attempts: Arc::new(Mutex::new((0, None))), + failed_attempts: Arc::new(Mutex::new(HashMap::new())), } } @@ -80,11 +82,11 @@ impl PairingGuard { self.require_pairing } - fn try_pair_blocking(&self, code: &str) -> Result, u64> { - // Check brute force lockout + fn try_pair_blocking(&self, code: &str, client_id: &str) -> Result, u64> { + // Check brute force lockout for this specific client { let attempts = self.failed_attempts.lock(); - if let (count, Some(locked_at)) = &*attempts { + if let Some((count, Some(locked_at))) = attempts.get(client_id) { if *count >= MAX_PAIR_ATTEMPTS { let elapsed = locked_at.elapsed().as_secs(); if elapsed < PAIR_LOCKOUT_SECS { @@ -98,10 +100,10 @@ impl PairingGuard { let mut pairing_code = self.pairing_code.lock(); if let Some(ref expected) = *pairing_code { if constant_time_eq(code.trim(), expected.trim()) { - // Reset failed attempts on success + // Reset failed attempts for this client on success { let mut attempts = self.failed_attempts.lock(); - *attempts = (0, None); + attempts.remove(client_id); } let token = generate_token(); let mut tokens = self.paired_tokens.lock(); @@ -115,12 +117,29 @@ impl PairingGuard { } } - // Increment failed attempts + // Increment failed attempts for this client { let mut attempts = self.failed_attempts.lock(); - attempts.0 += 1; - if attempts.0 >= MAX_PAIR_ATTEMPTS { - attempts.1 = Some(Instant::now()); + + // Evict expired entries when approaching the bound + if attempts.len() >= MAX_TRACKED_CLIENTS { + attempts.retain(|_, (_, locked_at)| { + locked_at + .map(|t| t.elapsed().as_secs() < PAIR_LOCKOUT_SECS) + .unwrap_or(true) + }); + } + + let entry = attempts.entry(client_id.to_string()).or_insert((0, None)); + // Reset if previous lockout has expired + if let Some(locked_at) = entry.1 { + if locked_at.elapsed().as_secs() >= PAIR_LOCKOUT_SECS { + *entry = (0, None); + } + } + entry.0 += 1; + if entry.0 >= MAX_PAIR_ATTEMPTS { + entry.1 = Some(Instant::now()); } } @@ -129,11 +148,13 @@ impl PairingGuard { /// Attempt to pair with the given code. Returns a bearer token on success. /// Returns `Err(lockout_seconds)` if locked out due to brute force. - pub async fn try_pair(&self, code: &str) -> Result, u64> { + /// `client_id` identifies the client for per-client lockout accounting. + pub async fn try_pair(&self, code: &str, client_id: &str) -> Result, u64> { let this = self.clone(); let code = code.to_string(); + let client_id = client_id.to_string(); // TODO: make this function the main one without spawning a task - let handle = tokio::task::spawn_blocking(move || this.try_pair_blocking(&code)); + let handle = tokio::task::spawn_blocking(move || this.try_pair_blocking(&code, &client_id)); handle .await @@ -273,7 +294,7 @@ mod tests { async fn try_pair_correct_code() { let guard = PairingGuard::new(true, &[]); let code = guard.pairing_code().unwrap().to_string(); - let token = guard.try_pair(&code).await.unwrap(); + let token = guard.try_pair(&code, "test_client").await.unwrap(); assert!(token.is_some()); assert!(token.unwrap().starts_with("zc_")); assert!(guard.is_paired()); @@ -282,7 +303,7 @@ mod tests { #[test] async fn try_pair_wrong_code() { let guard = PairingGuard::new(true, &[]); - let result = guard.try_pair("000000").await.unwrap(); + let result = guard.try_pair("000000", "test_client").await.unwrap(); // Might succeed if code happens to be 000000, but extremely unlikely // Just check it returns Ok(None) normally let _ = result; @@ -291,7 +312,7 @@ mod tests { #[test] async fn try_pair_empty_code() { let guard = PairingGuard::new(true, &[]); - assert!(guard.try_pair("").await.unwrap().is_none()); + assert!(guard.try_pair("", "test_client").await.unwrap().is_none()); } #[test] @@ -339,7 +360,7 @@ mod tests { async fn pair_then_authenticate() { let guard = PairingGuard::new(true, &[]); let code = guard.pairing_code().unwrap().to_string(); - let token = guard.try_pair(&code).await.unwrap().unwrap(); + let token = guard.try_pair(&code, "test_client").await.unwrap().unwrap(); assert!(guard.is_authenticated(&token)); assert!(!guard.is_authenticated("wrong")); } @@ -450,13 +471,14 @@ mod tests { #[test] async fn brute_force_lockout_after_max_attempts() { let guard = PairingGuard::new(true, &[]); + let client = "attacker_client"; // Exhaust all attempts with wrong codes for i in 0..MAX_PAIR_ATTEMPTS { - let result = guard.try_pair(&format!("wrong_{i}")).await; + let result = guard.try_pair(&format!("wrong_{i}"), client).await; assert!(result.is_ok(), "Attempt {i} should not be locked out yet"); } // Next attempt should be locked out - let result = guard.try_pair("another_wrong").await; + let result = guard.try_pair("another_wrong", client).await; assert!( result.is_err(), "Should be locked out after {MAX_PAIR_ATTEMPTS} attempts" @@ -473,26 +495,49 @@ mod tests { async fn correct_code_resets_failed_attempts() { let guard = PairingGuard::new(true, &[]); let code = guard.pairing_code().unwrap().to_string(); + let client = "test_client"; // Fail a few times for _ in 0..3 { - let _ = guard.try_pair("wrong").await; + let _ = guard.try_pair("wrong", client).await; } // Correct code should still work (under MAX_PAIR_ATTEMPTS) - let result = guard.try_pair(&code).await.unwrap(); + let result = guard.try_pair(&code, client).await.unwrap(); assert!(result.is_some(), "Correct code should work before lockout"); } #[test] async fn lockout_returns_remaining_seconds() { let guard = PairingGuard::new(true, &[]); + let client = "test_client"; for _ in 0..MAX_PAIR_ATTEMPTS { - let _ = guard.try_pair("wrong").await; + let _ = guard.try_pair("wrong", client).await; } - let err = guard.try_pair("wrong").await.unwrap_err(); + let err = guard.try_pair("wrong", client).await.unwrap_err(); // Should be close to PAIR_LOCKOUT_SECS (within a second) assert!( err >= PAIR_LOCKOUT_SECS - 1, "Remaining lockout should be ~{PAIR_LOCKOUT_SECS}s, got {err}s" ); } + + #[test] + async fn lockout_is_per_client() { + let guard = PairingGuard::new(true, &[]); + let attacker = "attacker_ip"; + let legitimate = "legitimate_ip"; + + // Attacker exhausts attempts + for i in 0..MAX_PAIR_ATTEMPTS { + let _ = guard.try_pair(&format!("wrong_{i}"), attacker).await; + } + // Attacker is locked out + assert!(guard.try_pair("wrong", attacker).await.is_err()); + + // Legitimate client is NOT locked out + let result = guard.try_pair("wrong", legitimate).await; + assert!( + result.is_ok(), + "Legitimate client should not be locked out by attacker" + ); + } }