Merge pull request #951 from zeroclaw-labs/fix/per-client-pairing-lockout

fix(security): change pairing lockout to per-client accounting
This commit is contained in:
Alex Gorevski 2026-02-19 11:26:46 -08:00 committed by GitHub
commit 77609777ab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 73 additions and 28 deletions

View file

@ -626,7 +626,7 @@ impl TelegramChannel {
if let Some(code) = Self::extract_bind_code(text) { if let Some(code) = Self::extract_bind_code(text) {
if let Some(pairing) = self.pairing.as_ref() { if let Some(pairing) = self.pairing.as_ref() {
match pairing.try_pair(code).await { match pairing.try_pair(code, &chat_id).await {
Ok(Some(_token)) => { Ok(Some(_token)) => {
let bind_identity = normalized_sender_id.clone().or_else(|| { let bind_identity = normalized_sender_id.clone().or_else(|| {
if normalized_username.is_empty() || normalized_username == "unknown" { if normalized_username.is_empty() || normalized_username == "unknown" {

View file

@ -610,7 +610,7 @@ async fn handle_pair(
.and_then(|v| v.to_str().ok()) .and_then(|v| v.to_str().ok())
.unwrap_or(""); .unwrap_or("");
match state.pairing.try_pair(code).await { match state.pairing.try_pair(code, &rate_key).await {
Ok(Some(token)) => { Ok(Some(token)) => {
tracing::info!("🔐 New client paired successfully"); tracing::info!("🔐 New client paired successfully");
if let Err(err) = persist_pairing_tokens(state.config.clone(), &state.pairing).await { if let Err(err) = persist_pairing_tokens(state.config.clone(), &state.pairing).await {
@ -1457,7 +1457,7 @@ mod tests {
let guard = PairingGuard::new(true, &[]); let guard = PairingGuard::new(true, &[]);
let code = guard.pairing_code().unwrap(); let code = guard.pairing_code().unwrap();
let token = guard.try_pair(&code).await.unwrap().unwrap(); let token = guard.try_pair(&code, "test_client").await.unwrap().unwrap();
assert!(guard.is_authenticated(&token)); assert!(guard.is_authenticated(&token));
let shared_config = Arc::new(Mutex::new(config)); let shared_config = Arc::new(Mutex::new(config));

View file

@ -10,7 +10,7 @@
use parking_lot::Mutex; use parking_lot::Mutex;
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
use std::collections::HashSet; use std::collections::{HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
@ -18,6 +18,8 @@ use std::time::Instant;
const MAX_PAIR_ATTEMPTS: u32 = 5; const MAX_PAIR_ATTEMPTS: u32 = 5;
/// Lockout duration after too many failed pairing attempts. /// Lockout duration after too many failed pairing attempts.
const PAIR_LOCKOUT_SECS: u64 = 300; // 5 minutes const PAIR_LOCKOUT_SECS: u64 = 300; // 5 minutes
/// Maximum number of tracked client entries to bound memory usage.
const MAX_TRACKED_CLIENTS: usize = 1024;
/// Manages pairing state for the gateway. /// Manages pairing state for the gateway.
/// ///
@ -33,8 +35,8 @@ pub struct PairingGuard {
pairing_code: Arc<Mutex<Option<String>>>, pairing_code: Arc<Mutex<Option<String>>>,
/// Set of SHA-256 hashed bearer tokens (persisted across restarts). /// Set of SHA-256 hashed bearer tokens (persisted across restarts).
paired_tokens: Arc<Mutex<HashSet<String>>>, paired_tokens: Arc<Mutex<HashSet<String>>>,
/// Brute-force protection: failed attempt counter + lockout time. /// Brute-force protection: per-client failed attempt counter + lockout time.
failed_attempts: Arc<Mutex<(u32, Option<Instant>)>>, failed_attempts: Arc<Mutex<HashMap<String, (u32, Option<Instant>)>>>,
} }
impl PairingGuard { impl PairingGuard {
@ -66,7 +68,7 @@ impl PairingGuard {
require_pairing, require_pairing,
pairing_code: Arc::new(Mutex::new(code)), pairing_code: Arc::new(Mutex::new(code)),
paired_tokens: Arc::new(Mutex::new(tokens)), paired_tokens: Arc::new(Mutex::new(tokens)),
failed_attempts: Arc::new(Mutex::new((0, None))), failed_attempts: Arc::new(Mutex::new(HashMap::new())),
} }
} }
@ -80,11 +82,11 @@ impl PairingGuard {
self.require_pairing self.require_pairing
} }
fn try_pair_blocking(&self, code: &str) -> Result<Option<String>, u64> { fn try_pair_blocking(&self, code: &str, client_id: &str) -> Result<Option<String>, u64> {
// Check brute force lockout // Check brute force lockout for this specific client
{ {
let attempts = self.failed_attempts.lock(); let attempts = self.failed_attempts.lock();
if let (count, Some(locked_at)) = &*attempts { if let Some((count, Some(locked_at))) = attempts.get(client_id) {
if *count >= MAX_PAIR_ATTEMPTS { if *count >= MAX_PAIR_ATTEMPTS {
let elapsed = locked_at.elapsed().as_secs(); let elapsed = locked_at.elapsed().as_secs();
if elapsed < PAIR_LOCKOUT_SECS { if elapsed < PAIR_LOCKOUT_SECS {
@ -98,10 +100,10 @@ impl PairingGuard {
let mut pairing_code = self.pairing_code.lock(); let mut pairing_code = self.pairing_code.lock();
if let Some(ref expected) = *pairing_code { if let Some(ref expected) = *pairing_code {
if constant_time_eq(code.trim(), expected.trim()) { if constant_time_eq(code.trim(), expected.trim()) {
// Reset failed attempts on success // Reset failed attempts for this client on success
{ {
let mut attempts = self.failed_attempts.lock(); let mut attempts = self.failed_attempts.lock();
*attempts = (0, None); attempts.remove(client_id);
} }
let token = generate_token(); let token = generate_token();
let mut tokens = self.paired_tokens.lock(); let mut tokens = self.paired_tokens.lock();
@ -115,12 +117,29 @@ impl PairingGuard {
} }
} }
// Increment failed attempts // Increment failed attempts for this client
{ {
let mut attempts = self.failed_attempts.lock(); let mut attempts = self.failed_attempts.lock();
attempts.0 += 1;
if attempts.0 >= MAX_PAIR_ATTEMPTS { // Evict expired entries when approaching the bound
attempts.1 = Some(Instant::now()); if attempts.len() >= MAX_TRACKED_CLIENTS {
attempts.retain(|_, (_, locked_at)| {
locked_at
.map(|t| t.elapsed().as_secs() < PAIR_LOCKOUT_SECS)
.unwrap_or(true)
});
}
let entry = attempts.entry(client_id.to_string()).or_insert((0, None));
// Reset if previous lockout has expired
if let Some(locked_at) = entry.1 {
if locked_at.elapsed().as_secs() >= PAIR_LOCKOUT_SECS {
*entry = (0, None);
}
}
entry.0 += 1;
if entry.0 >= MAX_PAIR_ATTEMPTS {
entry.1 = Some(Instant::now());
} }
} }
@ -129,11 +148,13 @@ impl PairingGuard {
/// Attempt to pair with the given code. Returns a bearer token on success. /// Attempt to pair with the given code. Returns a bearer token on success.
/// Returns `Err(lockout_seconds)` if locked out due to brute force. /// Returns `Err(lockout_seconds)` if locked out due to brute force.
pub async fn try_pair(&self, code: &str) -> Result<Option<String>, u64> { /// `client_id` identifies the client for per-client lockout accounting.
pub async fn try_pair(&self, code: &str, client_id: &str) -> Result<Option<String>, u64> {
let this = self.clone(); let this = self.clone();
let code = code.to_string(); let code = code.to_string();
let client_id = client_id.to_string();
// TODO: make this function the main one without spawning a task // TODO: make this function the main one without spawning a task
let handle = tokio::task::spawn_blocking(move || this.try_pair_blocking(&code)); let handle = tokio::task::spawn_blocking(move || this.try_pair_blocking(&code, &client_id));
handle handle
.await .await
@ -273,7 +294,7 @@ mod tests {
async fn try_pair_correct_code() { async fn try_pair_correct_code() {
let guard = PairingGuard::new(true, &[]); let guard = PairingGuard::new(true, &[]);
let code = guard.pairing_code().unwrap().to_string(); let code = guard.pairing_code().unwrap().to_string();
let token = guard.try_pair(&code).await.unwrap(); let token = guard.try_pair(&code, "test_client").await.unwrap();
assert!(token.is_some()); assert!(token.is_some());
assert!(token.unwrap().starts_with("zc_")); assert!(token.unwrap().starts_with("zc_"));
assert!(guard.is_paired()); assert!(guard.is_paired());
@ -282,7 +303,7 @@ mod tests {
#[test] #[test]
async fn try_pair_wrong_code() { async fn try_pair_wrong_code() {
let guard = PairingGuard::new(true, &[]); let guard = PairingGuard::new(true, &[]);
let result = guard.try_pair("000000").await.unwrap(); let result = guard.try_pair("000000", "test_client").await.unwrap();
// Might succeed if code happens to be 000000, but extremely unlikely // Might succeed if code happens to be 000000, but extremely unlikely
// Just check it returns Ok(None) normally // Just check it returns Ok(None) normally
let _ = result; let _ = result;
@ -291,7 +312,7 @@ mod tests {
#[test] #[test]
async fn try_pair_empty_code() { async fn try_pair_empty_code() {
let guard = PairingGuard::new(true, &[]); let guard = PairingGuard::new(true, &[]);
assert!(guard.try_pair("").await.unwrap().is_none()); assert!(guard.try_pair("", "test_client").await.unwrap().is_none());
} }
#[test] #[test]
@ -339,7 +360,7 @@ mod tests {
async fn pair_then_authenticate() { async fn pair_then_authenticate() {
let guard = PairingGuard::new(true, &[]); let guard = PairingGuard::new(true, &[]);
let code = guard.pairing_code().unwrap().to_string(); let code = guard.pairing_code().unwrap().to_string();
let token = guard.try_pair(&code).await.unwrap().unwrap(); let token = guard.try_pair(&code, "test_client").await.unwrap().unwrap();
assert!(guard.is_authenticated(&token)); assert!(guard.is_authenticated(&token));
assert!(!guard.is_authenticated("wrong")); assert!(!guard.is_authenticated("wrong"));
} }
@ -450,13 +471,14 @@ mod tests {
#[test] #[test]
async fn brute_force_lockout_after_max_attempts() { async fn brute_force_lockout_after_max_attempts() {
let guard = PairingGuard::new(true, &[]); let guard = PairingGuard::new(true, &[]);
let client = "attacker_client";
// Exhaust all attempts with wrong codes // Exhaust all attempts with wrong codes
for i in 0..MAX_PAIR_ATTEMPTS { for i in 0..MAX_PAIR_ATTEMPTS {
let result = guard.try_pair(&format!("wrong_{i}")).await; let result = guard.try_pair(&format!("wrong_{i}"), client).await;
assert!(result.is_ok(), "Attempt {i} should not be locked out yet"); assert!(result.is_ok(), "Attempt {i} should not be locked out yet");
} }
// Next attempt should be locked out // Next attempt should be locked out
let result = guard.try_pair("another_wrong").await; let result = guard.try_pair("another_wrong", client).await;
assert!( assert!(
result.is_err(), result.is_err(),
"Should be locked out after {MAX_PAIR_ATTEMPTS} attempts" "Should be locked out after {MAX_PAIR_ATTEMPTS} attempts"
@ -473,26 +495,49 @@ mod tests {
async fn correct_code_resets_failed_attempts() { async fn correct_code_resets_failed_attempts() {
let guard = PairingGuard::new(true, &[]); let guard = PairingGuard::new(true, &[]);
let code = guard.pairing_code().unwrap().to_string(); let code = guard.pairing_code().unwrap().to_string();
let client = "test_client";
// Fail a few times // Fail a few times
for _ in 0..3 { for _ in 0..3 {
let _ = guard.try_pair("wrong").await; let _ = guard.try_pair("wrong", client).await;
} }
// Correct code should still work (under MAX_PAIR_ATTEMPTS) // Correct code should still work (under MAX_PAIR_ATTEMPTS)
let result = guard.try_pair(&code).await.unwrap(); let result = guard.try_pair(&code, client).await.unwrap();
assert!(result.is_some(), "Correct code should work before lockout"); assert!(result.is_some(), "Correct code should work before lockout");
} }
#[test] #[test]
async fn lockout_returns_remaining_seconds() { async fn lockout_returns_remaining_seconds() {
let guard = PairingGuard::new(true, &[]); let guard = PairingGuard::new(true, &[]);
let client = "test_client";
for _ in 0..MAX_PAIR_ATTEMPTS { for _ in 0..MAX_PAIR_ATTEMPTS {
let _ = guard.try_pair("wrong").await; let _ = guard.try_pair("wrong", client).await;
} }
let err = guard.try_pair("wrong").await.unwrap_err(); let err = guard.try_pair("wrong", client).await.unwrap_err();
// Should be close to PAIR_LOCKOUT_SECS (within a second) // Should be close to PAIR_LOCKOUT_SECS (within a second)
assert!( assert!(
err >= PAIR_LOCKOUT_SECS - 1, err >= PAIR_LOCKOUT_SECS - 1,
"Remaining lockout should be ~{PAIR_LOCKOUT_SECS}s, got {err}s" "Remaining lockout should be ~{PAIR_LOCKOUT_SECS}s, got {err}s"
); );
} }
#[test]
async fn lockout_is_per_client() {
let guard = PairingGuard::new(true, &[]);
let attacker = "attacker_ip";
let legitimate = "legitimate_ip";
// Attacker exhausts attempts
for i in 0..MAX_PAIR_ATTEMPTS {
let _ = guard.try_pair(&format!("wrong_{i}"), attacker).await;
}
// Attacker is locked out
assert!(guard.try_pair("wrong", attacker).await.is_err());
// Legitimate client is NOT locked out
let result = guard.try_pair("wrong", legitimate).await;
assert!(
result.is_ok(),
"Legitimate client should not be locked out by attacker"
);
}
} }