From c50785671021b50ffbc6648c06638d76e0c1a471 Mon Sep 17 00:00:00 2001 From: fettpl <38704082+fettpl@users.noreply.github.com> Date: Tue, 17 Feb 2026 22:09:38 +0100 Subject: [PATCH] fix(gateway): harden client identity and bound key stores --- src/config/schema.rs | 36 ++++++ src/gateway/mod.rs | 279 +++++++++++++++++++++++++++++++++++-------- 2 files changed, 267 insertions(+), 48 deletions(-) diff --git a/src/config/schema.rs b/src/config/schema.rs index 0acec6e..bf25866 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -479,9 +479,22 @@ pub struct GatewayConfig { #[serde(default = "default_webhook_rate_limit")] pub webhook_rate_limit_per_minute: u32, + /// Trust proxy-forwarded client IP headers (`X-Forwarded-For`, `X-Real-IP`). + /// Disabled by default; enable only behind a trusted reverse proxy. + #[serde(default)] + pub trust_forwarded_headers: bool, + + /// Maximum distinct client keys tracked by gateway rate limiter maps. + #[serde(default = "default_gateway_rate_limit_max_keys")] + pub rate_limit_max_keys: usize, + /// TTL for webhook idempotency keys. #[serde(default = "default_idempotency_ttl_secs")] pub idempotency_ttl_secs: u64, + + /// Maximum distinct idempotency keys retained in memory. + #[serde(default = "default_gateway_idempotency_max_keys")] + pub idempotency_max_keys: usize, } fn default_gateway_port() -> u16 { @@ -504,6 +517,14 @@ fn default_idempotency_ttl_secs() -> u64 { 300 } +fn default_gateway_rate_limit_max_keys() -> usize { + 10_000 +} + +fn default_gateway_idempotency_max_keys() -> usize { + 10_000 +} + fn default_true() -> bool { true } @@ -518,7 +539,10 @@ impl Default for GatewayConfig { paired_tokens: Vec::new(), pair_rate_limit_per_minute: default_pair_rate_limit(), webhook_rate_limit_per_minute: default_webhook_rate_limit(), + trust_forwarded_headers: false, + rate_limit_max_keys: default_gateway_rate_limit_max_keys(), idempotency_ttl_secs: default_idempotency_ttl_secs(), + idempotency_max_keys: default_gateway_idempotency_max_keys(), } } } @@ -2946,7 +2970,10 @@ channel_id = "C123" ); assert_eq!(g.pair_rate_limit_per_minute, 10); assert_eq!(g.webhook_rate_limit_per_minute, 60); + assert!(!g.trust_forwarded_headers); + assert_eq!(g.rate_limit_max_keys, 10_000); assert_eq!(g.idempotency_ttl_secs, 300); + assert_eq!(g.idempotency_max_keys, 10_000); } #[test] @@ -2974,7 +3001,10 @@ channel_id = "C123" paired_tokens: vec!["zc_test_token".into()], pair_rate_limit_per_minute: 12, webhook_rate_limit_per_minute: 80, + trust_forwarded_headers: true, + rate_limit_max_keys: 2048, idempotency_ttl_secs: 600, + idempotency_max_keys: 4096, }; let toml_str = toml::to_string(&g).unwrap(); let parsed: GatewayConfig = toml::from_str(&toml_str).unwrap(); @@ -2983,7 +3013,10 @@ channel_id = "C123" assert_eq!(parsed.paired_tokens, vec!["zc_test_token"]); assert_eq!(parsed.pair_rate_limit_per_minute, 12); assert_eq!(parsed.webhook_rate_limit_per_minute, 80); + assert!(parsed.trust_forwarded_headers); + assert_eq!(parsed.rate_limit_max_keys, 2048); assert_eq!(parsed.idempotency_ttl_secs, 600); + assert_eq!(parsed.idempotency_max_keys, 4096); } #[test] @@ -3622,6 +3655,9 @@ default_model = "legacy-model" assert!(g.require_pairing); assert!(!g.allow_public_bind); assert!(g.paired_tokens.is_empty()); + assert!(!g.trust_forwarded_headers); + assert_eq!(g.rate_limit_max_keys, 10_000); + assert_eq!(g.idempotency_max_keys, 10_000); } // ── Peripherals config ─────────────────────────────────────── diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 3eb795e..0db6447 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -19,7 +19,7 @@ use crate::util::truncate_with_ellipsis; use anyhow::{Context, Result}; use axum::{ body::Bytes, - extract::{Query, State}, + extract::{ConnectInfo, Query, State}, http::{header, HeaderMap, StatusCode}, response::{IntoResponse, Json}, routing::{get, post}, @@ -27,7 +27,7 @@ use axum::{ }; use parking_lot::Mutex; use std::collections::HashMap; -use std::net::SocketAddr; +use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; use std::time::{Duration, Instant}; use tower_http::limit::RequestBodyLimitLayer; @@ -40,6 +40,10 @@ pub const MAX_BODY_SIZE: usize = 65_536; pub const REQUEST_TIMEOUT_SECS: u64 = 30; /// Sliding window used by gateway rate limiting. pub const RATE_LIMIT_WINDOW_SECS: u64 = 60; +/// Fallback max distinct client keys tracked in gateway rate limiter. +pub const RATE_LIMIT_MAX_KEYS_DEFAULT: usize = 10_000; +/// Fallback max distinct idempotency keys retained in gateway memory. +pub const IDEMPOTENCY_MAX_KEYS_DEFAULT: usize = 10_000; fn webhook_memory_key() -> String { format!("webhook_msg_{}", Uuid::new_v4()) @@ -63,18 +67,27 @@ const RATE_LIMITER_SWEEP_INTERVAL_SECS: u64 = 300; // 5 minutes struct SlidingWindowRateLimiter { limit_per_window: u32, window: Duration, + max_keys: usize, requests: Mutex<(HashMap>, Instant)>, } impl SlidingWindowRateLimiter { - fn new(limit_per_window: u32, window: Duration) -> Self { + fn new(limit_per_window: u32, window: Duration, max_keys: usize) -> Self { Self { limit_per_window, window, + max_keys: max_keys.max(1), requests: Mutex::new((HashMap::new(), Instant::now())), } } + fn prune_stale(requests: &mut HashMap>, cutoff: Instant) { + requests.retain(|_, timestamps| { + timestamps.retain(|t| *t > cutoff); + !timestamps.is_empty() + }); + } + fn allow(&self, key: &str) -> bool { if self.limit_per_window == 0 { return true; @@ -86,15 +99,28 @@ impl SlidingWindowRateLimiter { let mut guard = self.requests.lock(); let (requests, last_sweep) = &mut *guard; - // Periodic sweep: remove IPs with no recent requests + // Periodic sweep: remove keys with no recent requests if last_sweep.elapsed() >= Duration::from_secs(RATE_LIMITER_SWEEP_INTERVAL_SECS) { - requests.retain(|_, timestamps| { - timestamps.retain(|t| *t > cutoff); - !timestamps.is_empty() - }); + Self::prune_stale(requests, cutoff); *last_sweep = now; } + if !requests.contains_key(key) && requests.len() >= self.max_keys { + // Opportunistic stale cleanup before eviction under cardinality pressure. + Self::prune_stale(requests, cutoff); + *last_sweep = now; + + if requests.len() >= self.max_keys { + let evict_key = requests + .iter() + .min_by_key(|(_, timestamps)| timestamps.last().copied().unwrap_or(cutoff)) + .map(|(k, _)| k.clone()); + if let Some(evict_key) = evict_key { + requests.remove(&evict_key); + } + } + } + let entry = requests.entry(key.to_owned()).or_default(); entry.retain(|instant| *instant > cutoff); @@ -114,11 +140,11 @@ pub struct GatewayRateLimiter { } impl GatewayRateLimiter { - fn new(pair_per_minute: u32, webhook_per_minute: u32) -> Self { + fn new(pair_per_minute: u32, webhook_per_minute: u32, max_keys: usize) -> Self { let window = Duration::from_secs(RATE_LIMIT_WINDOW_SECS); Self { - pair: SlidingWindowRateLimiter::new(pair_per_minute, window), - webhook: SlidingWindowRateLimiter::new(webhook_per_minute, window), + pair: SlidingWindowRateLimiter::new(pair_per_minute, window, max_keys), + webhook: SlidingWindowRateLimiter::new(webhook_per_minute, window, max_keys), } } @@ -134,13 +160,15 @@ impl GatewayRateLimiter { #[derive(Debug)] pub struct IdempotencyStore { ttl: Duration, + max_keys: usize, keys: Mutex>, } impl IdempotencyStore { - fn new(ttl: Duration) -> Self { + fn new(ttl: Duration, max_keys: usize) -> Self { Self { ttl, + max_keys: max_keys.max(1), keys: Mutex::new(HashMap::new()), } } @@ -156,21 +184,68 @@ impl IdempotencyStore { return false; } + if keys.len() >= self.max_keys { + let evict_key = keys + .iter() + .min_by_key(|(_, seen_at)| *seen_at) + .map(|(k, _)| k.clone()); + if let Some(evict_key) = evict_key { + keys.remove(&evict_key); + } + } + keys.insert(key.to_owned(), now); true } } -fn client_key_from_headers(headers: &HeaderMap) -> String { - for header_name in ["X-Forwarded-For", "X-Real-IP"] { - if let Some(value) = headers.get(header_name).and_then(|v| v.to_str().ok()) { - let first = value.split(',').next().unwrap_or("").trim(); - if !first.is_empty() { - return first.to_owned(); +fn parse_client_ip(value: &str) -> Option { + let value = value.trim().trim_matches('"').trim(); + if value.is_empty() { + return None; + } + + if let Ok(ip) = value.parse::() { + return Some(ip); + } + + if let Ok(addr) = value.parse::() { + return Some(addr.ip()); + } + + let value = value.trim_matches(['[', ']']); + value.parse::().ok() +} + +fn forwarded_client_ip(headers: &HeaderMap) -> Option { + if let Some(xff) = headers.get("X-Forwarded-For").and_then(|v| v.to_str().ok()) { + for candidate in xff.split(',') { + if let Some(ip) = parse_client_ip(candidate) { + return Some(ip); } } } - "unknown".into() + + headers + .get("X-Real-IP") + .and_then(|v| v.to_str().ok()) + .and_then(parse_client_ip) +} + +fn client_key_from_request( + peer_addr: Option, + headers: &HeaderMap, + trust_forwarded_headers: bool, +) -> String { + if trust_forwarded_headers { + if let Some(ip) = forwarded_client_ip(headers) { + return ip.to_string(); + } + } + + peer_addr + .map(|addr| addr.ip().to_string()) + .unwrap_or_else(|| "unknown".to_string()) } /// Shared state for all axum handlers @@ -185,6 +260,7 @@ pub struct AppState { /// SHA-256 hash of `X-Webhook-Secret` (hex-encoded), never plaintext. pub webhook_secret_hash: Option>, pub pairing: Arc, + pub trust_forwarded_headers: bool, pub rate_limiter: Arc, pub idempotency_store: Arc, pub whatsapp: Option>, @@ -305,10 +381,18 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { let rate_limiter = Arc::new(GatewayRateLimiter::new( config.gateway.pair_rate_limit_per_minute, config.gateway.webhook_rate_limit_per_minute, + config + .gateway + .rate_limit_max_keys + .max(RATE_LIMIT_MAX_KEYS_DEFAULT), + )); + let idempotency_store = Arc::new(IdempotencyStore::new( + Duration::from_secs(config.gateway.idempotency_ttl_secs.max(1)), + config + .gateway + .idempotency_max_keys + .max(IDEMPOTENCY_MAX_KEYS_DEFAULT), )); - let idempotency_store = Arc::new(IdempotencyStore::new(Duration::from_secs( - config.gateway.idempotency_ttl_secs.max(1), - ))); // ── Tunnel ──────────────────────────────────────────────── let tunnel = crate::tunnel::create_tunnel(&config.tunnel)?; @@ -365,6 +449,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { auto_save: config.memory.auto_save, webhook_secret_hash, pairing, + trust_forwarded_headers: config.gateway.trust_forwarded_headers, rate_limiter, idempotency_store, whatsapp: whatsapp_channel, @@ -386,7 +471,11 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { )); // Run the server - axum::serve(listener, app).await?; + axum::serve( + listener, + app.into_make_service_with_connect_info::(), + ) + .await?; Ok(()) } @@ -406,8 +495,13 @@ async fn handle_health(State(state): State) -> impl IntoResponse { } /// POST /pair — exchange one-time code for bearer token -async fn handle_pair(State(state): State, headers: HeaderMap) -> impl IntoResponse { - let client_key = client_key_from_headers(&headers); +async fn handle_pair( + State(state): State, + ConnectInfo(peer_addr): ConnectInfo, + headers: HeaderMap, +) -> impl IntoResponse { + let client_key = + client_key_from_request(Some(peer_addr), &headers, state.trust_forwarded_headers); if !state.rate_limiter.allow_pair(&client_key) { tracing::warn!("/pair rate limit exceeded for key: {client_key}"); let err = serde_json::json!({ @@ -479,10 +573,12 @@ pub struct WebhookBody { /// POST /webhook — main webhook endpoint async fn handle_webhook( State(state): State, + ConnectInfo(peer_addr): ConnectInfo, headers: HeaderMap, body: Result, axum::extract::rejection::JsonRejection>, ) -> impl IntoResponse { - let client_key = client_key_from_headers(&headers); + let client_key = + client_key_from_request(Some(peer_addr), &headers, state.trust_forwarded_headers); if !state.rate_limiter.allow_webhook(&client_key) { tracing::warn!("/webhook rate limit exceeded for key: {client_key}"); let err = serde_json::json!({ @@ -803,7 +899,7 @@ mod tests { #[test] fn gateway_rate_limiter_blocks_after_limit() { - let limiter = GatewayRateLimiter::new(2, 2); + let limiter = GatewayRateLimiter::new(2, 2, 100); assert!(limiter.allow_pair("127.0.0.1")); assert!(limiter.allow_pair("127.0.0.1")); assert!(!limiter.allow_pair("127.0.0.1")); @@ -811,7 +907,7 @@ mod tests { #[test] fn rate_limiter_sweep_removes_stale_entries() { - let limiter = SlidingWindowRateLimiter::new(10, Duration::from_secs(60)); + let limiter = SlidingWindowRateLimiter::new(10, Duration::from_secs(60), 100); // Add entries for multiple IPs assert!(limiter.allow("ip-1")); assert!(limiter.allow("ip-2")); @@ -845,7 +941,7 @@ mod tests { #[test] fn rate_limiter_zero_limit_always_allows() { - let limiter = SlidingWindowRateLimiter::new(0, Duration::from_secs(60)); + let limiter = SlidingWindowRateLimiter::new(0, Duration::from_secs(60), 10); for _ in 0..100 { assert!(limiter.allow("any-key")); } @@ -853,12 +949,77 @@ mod tests { #[test] fn idempotency_store_rejects_duplicate_key() { - let store = IdempotencyStore::new(Duration::from_secs(30)); + let store = IdempotencyStore::new(Duration::from_secs(30), 10); assert!(store.record_if_new("req-1")); assert!(!store.record_if_new("req-1")); assert!(store.record_if_new("req-2")); } + #[test] + fn rate_limiter_bounded_cardinality_evicts_oldest_key() { + let limiter = SlidingWindowRateLimiter::new(5, Duration::from_secs(60), 2); + assert!(limiter.allow("ip-1")); + assert!(limiter.allow("ip-2")); + assert!(limiter.allow("ip-3")); + + let guard = limiter.requests.lock(); + assert_eq!(guard.0.len(), 2); + assert!(guard.0.contains_key("ip-2")); + assert!(guard.0.contains_key("ip-3")); + } + + #[test] + fn idempotency_store_bounded_cardinality_evicts_oldest_key() { + let store = IdempotencyStore::new(Duration::from_secs(300), 2); + assert!(store.record_if_new("k1")); + std::thread::sleep(Duration::from_millis(2)); + assert!(store.record_if_new("k2")); + std::thread::sleep(Duration::from_millis(2)); + assert!(store.record_if_new("k3")); + + let keys = store.keys.lock(); + assert_eq!(keys.len(), 2); + assert!(!keys.contains_key("k1")); + assert!(keys.contains_key("k2")); + assert!(keys.contains_key("k3")); + } + + #[test] + fn client_key_defaults_to_peer_addr_when_untrusted_proxy_mode() { + let peer = SocketAddr::from(([10, 0, 0, 5], 3000)); + let mut headers = HeaderMap::new(); + headers.insert( + "X-Forwarded-For", + HeaderValue::from_static("198.51.100.10, 203.0.113.11"), + ); + + let key = client_key_from_request(Some(peer), &headers, false); + assert_eq!(key, "10.0.0.5"); + } + + #[test] + fn client_key_uses_forwarded_ip_only_in_trusted_proxy_mode() { + let peer = SocketAddr::from(([10, 0, 0, 5], 3000)); + let mut headers = HeaderMap::new(); + headers.insert( + "X-Forwarded-For", + HeaderValue::from_static("198.51.100.10, 203.0.113.11"), + ); + + let key = client_key_from_request(Some(peer), &headers, true); + assert_eq!(key, "198.51.100.10"); + } + + #[test] + fn client_key_falls_back_to_peer_when_forwarded_header_invalid() { + let peer = SocketAddr::from(([10, 0, 0, 5], 3000)); + let mut headers = HeaderMap::new(); + headers.insert("X-Forwarded-For", HeaderValue::from_static("garbage-value")); + + let key = client_key_from_request(Some(peer), &headers, true); + assert_eq!(key, "10.0.0.5"); + } + #[test] fn persist_pairing_tokens_writes_config_tokens() { let temp = tempfile::tempdir().unwrap(); @@ -1040,6 +1201,10 @@ mod tests { } } + fn test_connect_info() -> ConnectInfo { + ConnectInfo(SocketAddr::from(([127, 0, 0, 1], 30_300))) + } + #[tokio::test] async fn webhook_idempotency_skips_duplicate_provider_calls() { let provider_impl = Arc::new(MockProvider::default()); @@ -1055,8 +1220,9 @@ mod tests { auto_save: false, webhook_secret_hash: None, pairing: Arc::new(PairingGuard::new(false, &[])), - rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), - idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), + trust_forwarded_headers: false, + rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100, 100)), + idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300), 1000)), whatsapp: None, whatsapp_app_secret: None, }; @@ -1067,15 +1233,20 @@ mod tests { let body = Ok(Json(WebhookBody { message: "hello".into(), })); - let first = handle_webhook(State(state.clone()), headers.clone(), body) - .await - .into_response(); + let first = handle_webhook( + State(state.clone()), + test_connect_info(), + headers.clone(), + body, + ) + .await + .into_response(); assert_eq!(first.status(), StatusCode::OK); let body = Ok(Json(WebhookBody { message: "hello".into(), })); - let second = handle_webhook(State(state), headers, body) + let second = handle_webhook(State(state), test_connect_info(), headers, body) .await .into_response(); assert_eq!(second.status(), StatusCode::OK); @@ -1104,8 +1275,9 @@ mod tests { auto_save: true, webhook_secret_hash: None, pairing: Arc::new(PairingGuard::new(false, &[])), - rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), - idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), + trust_forwarded_headers: false, + rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100, 100)), + idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300), 1000)), whatsapp: None, whatsapp_app_secret: None, }; @@ -1115,15 +1287,20 @@ mod tests { let body1 = Ok(Json(WebhookBody { message: "hello one".into(), })); - let first = handle_webhook(State(state.clone()), headers.clone(), body1) - .await - .into_response(); + let first = handle_webhook( + State(state.clone()), + test_connect_info(), + headers.clone(), + body1, + ) + .await + .into_response(); assert_eq!(first.status(), StatusCode::OK); let body2 = Ok(Json(WebhookBody { message: "hello two".into(), })); - let second = handle_webhook(State(state), headers, body2) + let second = handle_webhook(State(state), test_connect_info(), headers, body2) .await .into_response(); assert_eq!(second.status(), StatusCode::OK); @@ -1162,14 +1339,16 @@ mod tests { auto_save: false, webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))), pairing: Arc::new(PairingGuard::new(false, &[])), - rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), - idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), + trust_forwarded_headers: false, + rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100, 100)), + idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300), 1000)), whatsapp: None, whatsapp_app_secret: None, }; let response = handle_webhook( State(state), + test_connect_info(), HeaderMap::new(), Ok(Json(WebhookBody { message: "hello".into(), @@ -1197,8 +1376,9 @@ mod tests { auto_save: false, webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))), pairing: Arc::new(PairingGuard::new(false, &[])), - rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), - idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), + trust_forwarded_headers: false, + rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100, 100)), + idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300), 1000)), whatsapp: None, whatsapp_app_secret: None, }; @@ -1208,6 +1388,7 @@ mod tests { let response = handle_webhook( State(state), + test_connect_info(), headers, Ok(Json(WebhookBody { message: "hello".into(), @@ -1235,8 +1416,9 @@ mod tests { auto_save: false, webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))), pairing: Arc::new(PairingGuard::new(false, &[])), - rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), - idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), + trust_forwarded_headers: false, + rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100, 100)), + idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300), 1000)), whatsapp: None, whatsapp_app_secret: None, }; @@ -1246,6 +1428,7 @@ mod tests { let response = handle_webhook( State(state), + test_connect_info(), headers, Ok(Json(WebhookBody { message: "hello".into(),