diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index baf66fc..f9f5b6e 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -80,11 +80,14 @@ async fn gateway_agent_reply(state: &AppState, message: &str) -> Result Ok(normalize_gateway_reply(reply)) } +/// How often the rate limiter sweeps stale IP entries from its map. +const RATE_LIMITER_SWEEP_INTERVAL_SECS: u64 = 300; // 5 minutes + #[derive(Debug)] struct SlidingWindowRateLimiter { limit_per_window: u32, window: Duration, - requests: Mutex>>, + requests: Mutex<(HashMap>, Instant)>, } impl SlidingWindowRateLimiter { @@ -92,7 +95,7 @@ impl SlidingWindowRateLimiter { Self { limit_per_window, window, - requests: Mutex::new(HashMap::new()), + requests: Mutex::new((HashMap::new(), Instant::now())), } } @@ -104,10 +107,20 @@ impl SlidingWindowRateLimiter { let now = Instant::now(); let cutoff = now.checked_sub(self.window).unwrap_or_else(Instant::now); - let mut requests = self + let mut guard = self .requests .lock() .unwrap_or_else(std::sync::PoisonError::into_inner); + let (requests, last_sweep) = &mut *guard; + + // Periodic sweep: remove IPs with no recent requests + if last_sweep.elapsed() >= Duration::from_secs(RATE_LIMITER_SWEEP_INTERVAL_SECS) { + requests.retain(|_, timestamps| { + timestamps.retain(|t| *t > cutoff); + !timestamps.is_empty() + }); + *last_sweep = now; + } let entry = requests.entry(key.to_owned()).or_default(); entry.retain(|instant| *instant > cutoff); @@ -813,6 +826,55 @@ mod tests { assert!(!limiter.allow_pair("127.0.0.1")); } + #[test] + fn rate_limiter_sweep_removes_stale_entries() { + let limiter = SlidingWindowRateLimiter::new(10, Duration::from_secs(60)); + // Add entries for multiple IPs + assert!(limiter.allow("ip-1")); + assert!(limiter.allow("ip-2")); + assert!(limiter.allow("ip-3")); + + { + let guard = limiter + .requests + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + assert_eq!(guard.0.len(), 3); + } + + // Force a sweep by backdating last_sweep + { + let mut guard = limiter + .requests + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + guard.1 = Instant::now() - Duration::from_secs(RATE_LIMITER_SWEEP_INTERVAL_SECS + 1); + // Clear timestamps for ip-2 and ip-3 to simulate stale entries + guard.0.get_mut("ip-2").unwrap().clear(); + guard.0.get_mut("ip-3").unwrap().clear(); + } + + // Next allow() call should trigger sweep and remove stale entries + assert!(limiter.allow("ip-1")); + + { + let guard = limiter + .requests + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + assert_eq!(guard.0.len(), 1, "Stale entries should have been swept"); + assert!(guard.0.contains_key("ip-1")); + } + } + + #[test] + fn rate_limiter_zero_limit_always_allows() { + let limiter = SlidingWindowRateLimiter::new(0, Duration::from_secs(60)); + for _ in 0..100 { + assert!(limiter.allow("any-key")); + } + } + #[test] fn idempotency_store_rejects_duplicate_key() { let store = IdempotencyStore::new(Duration::from_secs(30));