chore: Remove more blocking io calls
This commit is contained in:
parent
1aec9ad9c0
commit
f1ca73d3d2
14 changed files with 427 additions and 357 deletions
|
|
@ -11,6 +11,7 @@
|
|||
use parking_lot::Mutex;
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
/// Maximum failed pairing attempts before lockout.
|
||||
|
|
@ -23,16 +24,17 @@ const PAIR_LOCKOUT_SECS: u64 = 300; // 5 minutes
|
|||
/// Bearer tokens are stored as SHA-256 hashes to prevent plaintext exposure
|
||||
/// in config files. When a new token is generated, the plaintext is returned
|
||||
/// to the client once, and only the hash is retained.
|
||||
#[derive(Debug)]
|
||||
// TODO: I've just made this work with parking_lot but it should use either flume or tokio's async mutexes
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PairingGuard {
|
||||
/// Whether pairing is required at all.
|
||||
require_pairing: bool,
|
||||
/// One-time pairing code (generated on startup, consumed on first pair).
|
||||
pairing_code: Mutex<Option<String>>,
|
||||
pairing_code: Arc<Mutex<Option<String>>>,
|
||||
/// Set of SHA-256 hashed bearer tokens (persisted across restarts).
|
||||
paired_tokens: Mutex<HashSet<String>>,
|
||||
paired_tokens: Arc<Mutex<HashSet<String>>>,
|
||||
/// Brute-force protection: failed attempt counter + lockout time.
|
||||
failed_attempts: Mutex<(u32, Option<Instant>)>,
|
||||
failed_attempts: Arc<Mutex<(u32, Option<Instant>)>>,
|
||||
}
|
||||
|
||||
impl PairingGuard {
|
||||
|
|
@ -62,9 +64,9 @@ impl PairingGuard {
|
|||
};
|
||||
Self {
|
||||
require_pairing,
|
||||
pairing_code: Mutex::new(code),
|
||||
paired_tokens: Mutex::new(tokens),
|
||||
failed_attempts: Mutex::new((0, None)),
|
||||
pairing_code: Arc::new(Mutex::new(code)),
|
||||
paired_tokens: Arc::new(Mutex::new(tokens)),
|
||||
failed_attempts: Arc::new(Mutex::new((0, None))),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -78,9 +80,7 @@ impl PairingGuard {
|
|||
self.require_pairing
|
||||
}
|
||||
|
||||
/// Attempt to pair with the given code. Returns a bearer token on success.
|
||||
/// Returns `Err(lockout_seconds)` if locked out due to brute force.
|
||||
pub fn try_pair(&self, code: &str) -> Result<Option<String>, u64> {
|
||||
fn try_pair_blocking(&self, code: &str) -> Result<Option<String>, u64> {
|
||||
// Check brute force lockout
|
||||
{
|
||||
let attempts = self.failed_attempts.lock();
|
||||
|
|
@ -127,6 +127,19 @@ impl PairingGuard {
|
|||
Ok(None)
|
||||
}
|
||||
|
||||
/// Attempt to pair with the given code. Returns a bearer token on success.
|
||||
/// Returns `Err(lockout_seconds)` if locked out due to brute force.
|
||||
pub async fn try_pair(&self, code: &str) -> Result<Option<String>, u64> {
|
||||
let this = self.clone();
|
||||
let code = code.to_string();
|
||||
// TODO: make this function the main one without spawning a task
|
||||
let handle = tokio::task::spawn_blocking(move || this.try_pair_blocking(&code));
|
||||
|
||||
handle
|
||||
.await
|
||||
.expect("failed to spawn blocking task this should not happen")
|
||||
}
|
||||
|
||||
/// Check if a bearer token is valid (compares against stored hashes).
|
||||
pub fn is_authenticated(&self, token: &str) -> bool {
|
||||
if !self.require_pairing {
|
||||
|
|
@ -232,63 +245,64 @@ pub fn is_public_bind(host: &str) -> bool {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::test;
|
||||
|
||||
// ── PairingGuard ─────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn new_guard_generates_code_when_no_tokens() {
|
||||
async fn new_guard_generates_code_when_no_tokens() {
|
||||
let guard = PairingGuard::new(true, &[]);
|
||||
assert!(guard.pairing_code().is_some());
|
||||
assert!(!guard.is_paired());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_guard_no_code_when_tokens_exist() {
|
||||
async fn new_guard_no_code_when_tokens_exist() {
|
||||
let guard = PairingGuard::new(true, &["zc_existing".into()]);
|
||||
assert!(guard.pairing_code().is_none());
|
||||
assert!(guard.is_paired());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_guard_no_code_when_pairing_disabled() {
|
||||
async fn new_guard_no_code_when_pairing_disabled() {
|
||||
let guard = PairingGuard::new(false, &[]);
|
||||
assert!(guard.pairing_code().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn try_pair_correct_code() {
|
||||
async fn try_pair_correct_code() {
|
||||
let guard = PairingGuard::new(true, &[]);
|
||||
let code = guard.pairing_code().unwrap().to_string();
|
||||
let token = guard.try_pair(&code).unwrap();
|
||||
let token = guard.try_pair(&code).await.unwrap();
|
||||
assert!(token.is_some());
|
||||
assert!(token.unwrap().starts_with("zc_"));
|
||||
assert!(guard.is_paired());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn try_pair_wrong_code() {
|
||||
async fn try_pair_wrong_code() {
|
||||
let guard = PairingGuard::new(true, &[]);
|
||||
let result = guard.try_pair("000000").unwrap();
|
||||
let result = guard.try_pair("000000").await.unwrap();
|
||||
// Might succeed if code happens to be 000000, but extremely unlikely
|
||||
// Just check it returns Ok(None) normally
|
||||
let _ = result;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn try_pair_empty_code() {
|
||||
async fn try_pair_empty_code() {
|
||||
let guard = PairingGuard::new(true, &[]);
|
||||
assert!(guard.try_pair("").unwrap().is_none());
|
||||
assert!(guard.try_pair("").await.unwrap().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authenticated_with_valid_token() {
|
||||
async fn is_authenticated_with_valid_token() {
|
||||
// Pass plaintext token — PairingGuard hashes it on load
|
||||
let guard = PairingGuard::new(true, &["zc_valid".into()]);
|
||||
assert!(guard.is_authenticated("zc_valid"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authenticated_with_prehashed_token() {
|
||||
async fn is_authenticated_with_prehashed_token() {
|
||||
// Pass an already-hashed token (64 hex chars)
|
||||
let hashed = hash_token("zc_valid");
|
||||
let guard = PairingGuard::new(true, &[hashed]);
|
||||
|
|
@ -296,20 +310,20 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn is_authenticated_with_invalid_token() {
|
||||
async fn is_authenticated_with_invalid_token() {
|
||||
let guard = PairingGuard::new(true, &["zc_valid".into()]);
|
||||
assert!(!guard.is_authenticated("zc_invalid"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authenticated_when_pairing_disabled() {
|
||||
async fn is_authenticated_when_pairing_disabled() {
|
||||
let guard = PairingGuard::new(false, &[]);
|
||||
assert!(guard.is_authenticated("anything"));
|
||||
assert!(guard.is_authenticated(""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tokens_returns_hashes() {
|
||||
async fn tokens_returns_hashes() {
|
||||
let guard = PairingGuard::new(true, &["zc_a".into(), "zc_b".into()]);
|
||||
let tokens = guard.tokens();
|
||||
assert_eq!(tokens.len(), 2);
|
||||
|
|
@ -322,10 +336,10 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn pair_then_authenticate() {
|
||||
async fn pair_then_authenticate() {
|
||||
let guard = PairingGuard::new(true, &[]);
|
||||
let code = guard.pairing_code().unwrap().to_string();
|
||||
let token = guard.try_pair(&code).unwrap().unwrap();
|
||||
let token = guard.try_pair(&code).await.unwrap().unwrap();
|
||||
assert!(guard.is_authenticated(&token));
|
||||
assert!(!guard.is_authenticated("wrong"));
|
||||
}
|
||||
|
|
@ -333,24 +347,24 @@ mod tests {
|
|||
// ── Token hashing ────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn hash_token_produces_64_hex_chars() {
|
||||
async fn hash_token_produces_64_hex_chars() {
|
||||
let hash = hash_token("zc_test_token");
|
||||
assert_eq!(hash.len(), 64);
|
||||
assert!(hash.chars().all(|c| c.is_ascii_hexdigit()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_token_is_deterministic() {
|
||||
async fn hash_token_is_deterministic() {
|
||||
assert_eq!(hash_token("zc_abc"), hash_token("zc_abc"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_token_differs_for_different_inputs() {
|
||||
async fn hash_token_differs_for_different_inputs() {
|
||||
assert_ne!(hash_token("zc_a"), hash_token("zc_b"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_token_hash_detects_hash_vs_plaintext() {
|
||||
async fn is_token_hash_detects_hash_vs_plaintext() {
|
||||
assert!(is_token_hash(&hash_token("zc_test")));
|
||||
assert!(!is_token_hash("zc_test_token"));
|
||||
assert!(!is_token_hash("too_short"));
|
||||
|
|
@ -360,7 +374,7 @@ mod tests {
|
|||
// ── is_public_bind ───────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn localhost_variants_not_public() {
|
||||
async fn localhost_variants_not_public() {
|
||||
assert!(!is_public_bind("127.0.0.1"));
|
||||
assert!(!is_public_bind("localhost"));
|
||||
assert!(!is_public_bind("::1"));
|
||||
|
|
@ -368,12 +382,12 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn zero_zero_is_public() {
|
||||
async fn zero_zero_is_public() {
|
||||
assert!(is_public_bind("0.0.0.0"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn real_ip_is_public() {
|
||||
async fn real_ip_is_public() {
|
||||
assert!(is_public_bind("192.168.1.100"));
|
||||
assert!(is_public_bind("10.0.0.1"));
|
||||
}
|
||||
|
|
@ -381,13 +395,13 @@ mod tests {
|
|||
// ── constant_time_eq ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn constant_time_eq_same() {
|
||||
async fn constant_time_eq_same() {
|
||||
assert!(constant_time_eq("abc", "abc"));
|
||||
assert!(constant_time_eq("", ""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn constant_time_eq_different() {
|
||||
async fn constant_time_eq_different() {
|
||||
assert!(!constant_time_eq("abc", "abd"));
|
||||
assert!(!constant_time_eq("abc", "ab"));
|
||||
assert!(!constant_time_eq("a", ""));
|
||||
|
|
@ -396,14 +410,14 @@ mod tests {
|
|||
// ── generate helpers ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn generate_code_is_6_digits() {
|
||||
async fn generate_code_is_6_digits() {
|
||||
let code = generate_code();
|
||||
assert_eq!(code.len(), 6);
|
||||
assert!(code.chars().all(|c| c.is_ascii_digit()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_code_is_not_deterministic() {
|
||||
async fn generate_code_is_not_deterministic() {
|
||||
// Two codes should differ with overwhelming probability. We try
|
||||
// multiple pairs so a single 1-in-10^6 collision doesn't cause
|
||||
// a flaky CI failure. All 10 pairs colliding is ~1-in-10^60.
|
||||
|
|
@ -416,7 +430,7 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn generate_token_has_prefix_and_hex_payload() {
|
||||
async fn generate_token_has_prefix_and_hex_payload() {
|
||||
let token = generate_token();
|
||||
let payload = token
|
||||
.strip_prefix("zc_")
|
||||
|
|
@ -434,15 +448,15 @@ mod tests {
|
|||
// ── Brute force protection ───────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn brute_force_lockout_after_max_attempts() {
|
||||
async fn brute_force_lockout_after_max_attempts() {
|
||||
let guard = PairingGuard::new(true, &[]);
|
||||
// Exhaust all attempts with wrong codes
|
||||
for i in 0..MAX_PAIR_ATTEMPTS {
|
||||
let result = guard.try_pair(&format!("wrong_{i}"));
|
||||
let result = guard.try_pair(&format!("wrong_{i}")).await;
|
||||
assert!(result.is_ok(), "Attempt {i} should not be locked out yet");
|
||||
}
|
||||
// Next attempt should be locked out
|
||||
let result = guard.try_pair("another_wrong");
|
||||
let result = guard.try_pair("another_wrong").await;
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"Should be locked out after {MAX_PAIR_ATTEMPTS} attempts"
|
||||
|
|
@ -456,25 +470,25 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn correct_code_resets_failed_attempts() {
|
||||
async fn correct_code_resets_failed_attempts() {
|
||||
let guard = PairingGuard::new(true, &[]);
|
||||
let code = guard.pairing_code().unwrap().to_string();
|
||||
// Fail a few times
|
||||
for _ in 0..3 {
|
||||
let _ = guard.try_pair("wrong");
|
||||
let _ = guard.try_pair("wrong").await;
|
||||
}
|
||||
// Correct code should still work (under MAX_PAIR_ATTEMPTS)
|
||||
let result = guard.try_pair(&code).unwrap();
|
||||
let result = guard.try_pair(&code).await.unwrap();
|
||||
assert!(result.is_some(), "Correct code should work before lockout");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lockout_returns_remaining_seconds() {
|
||||
async fn lockout_returns_remaining_seconds() {
|
||||
let guard = PairingGuard::new(true, &[]);
|
||||
for _ in 0..MAX_PAIR_ATTEMPTS {
|
||||
let _ = guard.try_pair("wrong");
|
||||
let _ = guard.try_pair("wrong").await;
|
||||
}
|
||||
let err = guard.try_pair("wrong").unwrap_err();
|
||||
let err = guard.try_pair("wrong").await.unwrap_err();
|
||||
// Should be close to PAIR_LOCKOUT_SECS (within a second)
|
||||
assert!(
|
||||
err >= PAIR_LOCKOUT_SECS - 1,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue