fix(gateway): harden client identity and bound key stores
This commit is contained in:
parent
b1c04d8f88
commit
c507856710
2 changed files with 267 additions and 48 deletions
|
|
@ -479,9 +479,22 @@ pub struct GatewayConfig {
|
||||||
#[serde(default = "default_webhook_rate_limit")]
|
#[serde(default = "default_webhook_rate_limit")]
|
||||||
pub webhook_rate_limit_per_minute: u32,
|
pub webhook_rate_limit_per_minute: u32,
|
||||||
|
|
||||||
|
/// Trust proxy-forwarded client IP headers (`X-Forwarded-For`, `X-Real-IP`).
|
||||||
|
/// Disabled by default; enable only behind a trusted reverse proxy.
|
||||||
|
#[serde(default)]
|
||||||
|
pub trust_forwarded_headers: bool,
|
||||||
|
|
||||||
|
/// Maximum distinct client keys tracked by gateway rate limiter maps.
|
||||||
|
#[serde(default = "default_gateway_rate_limit_max_keys")]
|
||||||
|
pub rate_limit_max_keys: usize,
|
||||||
|
|
||||||
/// TTL for webhook idempotency keys.
|
/// TTL for webhook idempotency keys.
|
||||||
#[serde(default = "default_idempotency_ttl_secs")]
|
#[serde(default = "default_idempotency_ttl_secs")]
|
||||||
pub idempotency_ttl_secs: u64,
|
pub idempotency_ttl_secs: u64,
|
||||||
|
|
||||||
|
/// Maximum distinct idempotency keys retained in memory.
|
||||||
|
#[serde(default = "default_gateway_idempotency_max_keys")]
|
||||||
|
pub idempotency_max_keys: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_gateway_port() -> u16 {
|
fn default_gateway_port() -> u16 {
|
||||||
|
|
@ -504,6 +517,14 @@ fn default_idempotency_ttl_secs() -> u64 {
|
||||||
300
|
300
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn default_gateway_rate_limit_max_keys() -> usize {
|
||||||
|
10_000
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_gateway_idempotency_max_keys() -> usize {
|
||||||
|
10_000
|
||||||
|
}
|
||||||
|
|
||||||
fn default_true() -> bool {
|
fn default_true() -> bool {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
@ -518,7 +539,10 @@ impl Default for GatewayConfig {
|
||||||
paired_tokens: Vec::new(),
|
paired_tokens: Vec::new(),
|
||||||
pair_rate_limit_per_minute: default_pair_rate_limit(),
|
pair_rate_limit_per_minute: default_pair_rate_limit(),
|
||||||
webhook_rate_limit_per_minute: default_webhook_rate_limit(),
|
webhook_rate_limit_per_minute: default_webhook_rate_limit(),
|
||||||
|
trust_forwarded_headers: false,
|
||||||
|
rate_limit_max_keys: default_gateway_rate_limit_max_keys(),
|
||||||
idempotency_ttl_secs: default_idempotency_ttl_secs(),
|
idempotency_ttl_secs: default_idempotency_ttl_secs(),
|
||||||
|
idempotency_max_keys: default_gateway_idempotency_max_keys(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -2946,7 +2970,10 @@ channel_id = "C123"
|
||||||
);
|
);
|
||||||
assert_eq!(g.pair_rate_limit_per_minute, 10);
|
assert_eq!(g.pair_rate_limit_per_minute, 10);
|
||||||
assert_eq!(g.webhook_rate_limit_per_minute, 60);
|
assert_eq!(g.webhook_rate_limit_per_minute, 60);
|
||||||
|
assert!(!g.trust_forwarded_headers);
|
||||||
|
assert_eq!(g.rate_limit_max_keys, 10_000);
|
||||||
assert_eq!(g.idempotency_ttl_secs, 300);
|
assert_eq!(g.idempotency_ttl_secs, 300);
|
||||||
|
assert_eq!(g.idempotency_max_keys, 10_000);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -2974,7 +3001,10 @@ channel_id = "C123"
|
||||||
paired_tokens: vec!["zc_test_token".into()],
|
paired_tokens: vec!["zc_test_token".into()],
|
||||||
pair_rate_limit_per_minute: 12,
|
pair_rate_limit_per_minute: 12,
|
||||||
webhook_rate_limit_per_minute: 80,
|
webhook_rate_limit_per_minute: 80,
|
||||||
|
trust_forwarded_headers: true,
|
||||||
|
rate_limit_max_keys: 2048,
|
||||||
idempotency_ttl_secs: 600,
|
idempotency_ttl_secs: 600,
|
||||||
|
idempotency_max_keys: 4096,
|
||||||
};
|
};
|
||||||
let toml_str = toml::to_string(&g).unwrap();
|
let toml_str = toml::to_string(&g).unwrap();
|
||||||
let parsed: GatewayConfig = toml::from_str(&toml_str).unwrap();
|
let parsed: GatewayConfig = toml::from_str(&toml_str).unwrap();
|
||||||
|
|
@ -2983,7 +3013,10 @@ channel_id = "C123"
|
||||||
assert_eq!(parsed.paired_tokens, vec!["zc_test_token"]);
|
assert_eq!(parsed.paired_tokens, vec!["zc_test_token"]);
|
||||||
assert_eq!(parsed.pair_rate_limit_per_minute, 12);
|
assert_eq!(parsed.pair_rate_limit_per_minute, 12);
|
||||||
assert_eq!(parsed.webhook_rate_limit_per_minute, 80);
|
assert_eq!(parsed.webhook_rate_limit_per_minute, 80);
|
||||||
|
assert!(parsed.trust_forwarded_headers);
|
||||||
|
assert_eq!(parsed.rate_limit_max_keys, 2048);
|
||||||
assert_eq!(parsed.idempotency_ttl_secs, 600);
|
assert_eq!(parsed.idempotency_ttl_secs, 600);
|
||||||
|
assert_eq!(parsed.idempotency_max_keys, 4096);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -3622,6 +3655,9 @@ default_model = "legacy-model"
|
||||||
assert!(g.require_pairing);
|
assert!(g.require_pairing);
|
||||||
assert!(!g.allow_public_bind);
|
assert!(!g.allow_public_bind);
|
||||||
assert!(g.paired_tokens.is_empty());
|
assert!(g.paired_tokens.is_empty());
|
||||||
|
assert!(!g.trust_forwarded_headers);
|
||||||
|
assert_eq!(g.rate_limit_max_keys, 10_000);
|
||||||
|
assert_eq!(g.idempotency_max_keys, 10_000);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Peripherals config ───────────────────────────────────────
|
// ── Peripherals config ───────────────────────────────────────
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ use crate::util::truncate_with_ellipsis;
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use axum::{
|
use axum::{
|
||||||
body::Bytes,
|
body::Bytes,
|
||||||
extract::{Query, State},
|
extract::{ConnectInfo, Query, State},
|
||||||
http::{header, HeaderMap, StatusCode},
|
http::{header, HeaderMap, StatusCode},
|
||||||
response::{IntoResponse, Json},
|
response::{IntoResponse, Json},
|
||||||
routing::{get, post},
|
routing::{get, post},
|
||||||
|
|
@ -27,7 +27,7 @@ use axum::{
|
||||||
};
|
};
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::net::SocketAddr;
|
use std::net::{IpAddr, SocketAddr};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tower_http::limit::RequestBodyLimitLayer;
|
use tower_http::limit::RequestBodyLimitLayer;
|
||||||
|
|
@ -40,6 +40,10 @@ pub const MAX_BODY_SIZE: usize = 65_536;
|
||||||
pub const REQUEST_TIMEOUT_SECS: u64 = 30;
|
pub const REQUEST_TIMEOUT_SECS: u64 = 30;
|
||||||
/// Sliding window used by gateway rate limiting.
|
/// Sliding window used by gateway rate limiting.
|
||||||
pub const RATE_LIMIT_WINDOW_SECS: u64 = 60;
|
pub const RATE_LIMIT_WINDOW_SECS: u64 = 60;
|
||||||
|
/// Fallback max distinct client keys tracked in gateway rate limiter.
|
||||||
|
pub const RATE_LIMIT_MAX_KEYS_DEFAULT: usize = 10_000;
|
||||||
|
/// Fallback max distinct idempotency keys retained in gateway memory.
|
||||||
|
pub const IDEMPOTENCY_MAX_KEYS_DEFAULT: usize = 10_000;
|
||||||
|
|
||||||
fn webhook_memory_key() -> String {
|
fn webhook_memory_key() -> String {
|
||||||
format!("webhook_msg_{}", Uuid::new_v4())
|
format!("webhook_msg_{}", Uuid::new_v4())
|
||||||
|
|
@ -63,18 +67,27 @@ const RATE_LIMITER_SWEEP_INTERVAL_SECS: u64 = 300; // 5 minutes
|
||||||
struct SlidingWindowRateLimiter {
|
struct SlidingWindowRateLimiter {
|
||||||
limit_per_window: u32,
|
limit_per_window: u32,
|
||||||
window: Duration,
|
window: Duration,
|
||||||
|
max_keys: usize,
|
||||||
requests: Mutex<(HashMap<String, Vec<Instant>>, Instant)>,
|
requests: Mutex<(HashMap<String, Vec<Instant>>, Instant)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SlidingWindowRateLimiter {
|
impl SlidingWindowRateLimiter {
|
||||||
fn new(limit_per_window: u32, window: Duration) -> Self {
|
fn new(limit_per_window: u32, window: Duration, max_keys: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
limit_per_window,
|
limit_per_window,
|
||||||
window,
|
window,
|
||||||
|
max_keys: max_keys.max(1),
|
||||||
requests: Mutex::new((HashMap::new(), Instant::now())),
|
requests: Mutex::new((HashMap::new(), Instant::now())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn prune_stale(requests: &mut HashMap<String, Vec<Instant>>, cutoff: Instant) {
|
||||||
|
requests.retain(|_, timestamps| {
|
||||||
|
timestamps.retain(|t| *t > cutoff);
|
||||||
|
!timestamps.is_empty()
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
fn allow(&self, key: &str) -> bool {
|
fn allow(&self, key: &str) -> bool {
|
||||||
if self.limit_per_window == 0 {
|
if self.limit_per_window == 0 {
|
||||||
return true;
|
return true;
|
||||||
|
|
@ -86,15 +99,28 @@ impl SlidingWindowRateLimiter {
|
||||||
let mut guard = self.requests.lock();
|
let mut guard = self.requests.lock();
|
||||||
let (requests, last_sweep) = &mut *guard;
|
let (requests, last_sweep) = &mut *guard;
|
||||||
|
|
||||||
// Periodic sweep: remove IPs with no recent requests
|
// Periodic sweep: remove keys with no recent requests
|
||||||
if last_sweep.elapsed() >= Duration::from_secs(RATE_LIMITER_SWEEP_INTERVAL_SECS) {
|
if last_sweep.elapsed() >= Duration::from_secs(RATE_LIMITER_SWEEP_INTERVAL_SECS) {
|
||||||
requests.retain(|_, timestamps| {
|
Self::prune_stale(requests, cutoff);
|
||||||
timestamps.retain(|t| *t > cutoff);
|
|
||||||
!timestamps.is_empty()
|
|
||||||
});
|
|
||||||
*last_sweep = now;
|
*last_sweep = now;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !requests.contains_key(key) && requests.len() >= self.max_keys {
|
||||||
|
// Opportunistic stale cleanup before eviction under cardinality pressure.
|
||||||
|
Self::prune_stale(requests, cutoff);
|
||||||
|
*last_sweep = now;
|
||||||
|
|
||||||
|
if requests.len() >= self.max_keys {
|
||||||
|
let evict_key = requests
|
||||||
|
.iter()
|
||||||
|
.min_by_key(|(_, timestamps)| timestamps.last().copied().unwrap_or(cutoff))
|
||||||
|
.map(|(k, _)| k.clone());
|
||||||
|
if let Some(evict_key) = evict_key {
|
||||||
|
requests.remove(&evict_key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let entry = requests.entry(key.to_owned()).or_default();
|
let entry = requests.entry(key.to_owned()).or_default();
|
||||||
entry.retain(|instant| *instant > cutoff);
|
entry.retain(|instant| *instant > cutoff);
|
||||||
|
|
||||||
|
|
@ -114,11 +140,11 @@ pub struct GatewayRateLimiter {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GatewayRateLimiter {
|
impl GatewayRateLimiter {
|
||||||
fn new(pair_per_minute: u32, webhook_per_minute: u32) -> Self {
|
fn new(pair_per_minute: u32, webhook_per_minute: u32, max_keys: usize) -> Self {
|
||||||
let window = Duration::from_secs(RATE_LIMIT_WINDOW_SECS);
|
let window = Duration::from_secs(RATE_LIMIT_WINDOW_SECS);
|
||||||
Self {
|
Self {
|
||||||
pair: SlidingWindowRateLimiter::new(pair_per_minute, window),
|
pair: SlidingWindowRateLimiter::new(pair_per_minute, window, max_keys),
|
||||||
webhook: SlidingWindowRateLimiter::new(webhook_per_minute, window),
|
webhook: SlidingWindowRateLimiter::new(webhook_per_minute, window, max_keys),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -134,13 +160,15 @@ impl GatewayRateLimiter {
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct IdempotencyStore {
|
pub struct IdempotencyStore {
|
||||||
ttl: Duration,
|
ttl: Duration,
|
||||||
|
max_keys: usize,
|
||||||
keys: Mutex<HashMap<String, Instant>>,
|
keys: Mutex<HashMap<String, Instant>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IdempotencyStore {
|
impl IdempotencyStore {
|
||||||
fn new(ttl: Duration) -> Self {
|
fn new(ttl: Duration, max_keys: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
ttl,
|
ttl,
|
||||||
|
max_keys: max_keys.max(1),
|
||||||
keys: Mutex::new(HashMap::new()),
|
keys: Mutex::new(HashMap::new()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -156,21 +184,68 @@ impl IdempotencyStore {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if keys.len() >= self.max_keys {
|
||||||
|
let evict_key = keys
|
||||||
|
.iter()
|
||||||
|
.min_by_key(|(_, seen_at)| *seen_at)
|
||||||
|
.map(|(k, _)| k.clone());
|
||||||
|
if let Some(evict_key) = evict_key {
|
||||||
|
keys.remove(&evict_key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
keys.insert(key.to_owned(), now);
|
keys.insert(key.to_owned(), now);
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn client_key_from_headers(headers: &HeaderMap) -> String {
|
fn parse_client_ip(value: &str) -> Option<IpAddr> {
|
||||||
for header_name in ["X-Forwarded-For", "X-Real-IP"] {
|
let value = value.trim().trim_matches('"').trim();
|
||||||
if let Some(value) = headers.get(header_name).and_then(|v| v.to_str().ok()) {
|
if value.is_empty() {
|
||||||
let first = value.split(',').next().unwrap_or("").trim();
|
return None;
|
||||||
if !first.is_empty() {
|
}
|
||||||
return first.to_owned();
|
|
||||||
|
if let Ok(ip) = value.parse::<IpAddr>() {
|
||||||
|
return Some(ip);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(addr) = value.parse::<SocketAddr>() {
|
||||||
|
return Some(addr.ip());
|
||||||
|
}
|
||||||
|
|
||||||
|
let value = value.trim_matches(['[', ']']);
|
||||||
|
value.parse::<IpAddr>().ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forwarded_client_ip(headers: &HeaderMap) -> Option<IpAddr> {
|
||||||
|
if let Some(xff) = headers.get("X-Forwarded-For").and_then(|v| v.to_str().ok()) {
|
||||||
|
for candidate in xff.split(',') {
|
||||||
|
if let Some(ip) = parse_client_ip(candidate) {
|
||||||
|
return Some(ip);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"unknown".into()
|
|
||||||
|
headers
|
||||||
|
.get("X-Real-IP")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.and_then(parse_client_ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn client_key_from_request(
|
||||||
|
peer_addr: Option<SocketAddr>,
|
||||||
|
headers: &HeaderMap,
|
||||||
|
trust_forwarded_headers: bool,
|
||||||
|
) -> String {
|
||||||
|
if trust_forwarded_headers {
|
||||||
|
if let Some(ip) = forwarded_client_ip(headers) {
|
||||||
|
return ip.to_string();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
peer_addr
|
||||||
|
.map(|addr| addr.ip().to_string())
|
||||||
|
.unwrap_or_else(|| "unknown".to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Shared state for all axum handlers
|
/// Shared state for all axum handlers
|
||||||
|
|
@ -185,6 +260,7 @@ pub struct AppState {
|
||||||
/// SHA-256 hash of `X-Webhook-Secret` (hex-encoded), never plaintext.
|
/// SHA-256 hash of `X-Webhook-Secret` (hex-encoded), never plaintext.
|
||||||
pub webhook_secret_hash: Option<Arc<str>>,
|
pub webhook_secret_hash: Option<Arc<str>>,
|
||||||
pub pairing: Arc<PairingGuard>,
|
pub pairing: Arc<PairingGuard>,
|
||||||
|
pub trust_forwarded_headers: bool,
|
||||||
pub rate_limiter: Arc<GatewayRateLimiter>,
|
pub rate_limiter: Arc<GatewayRateLimiter>,
|
||||||
pub idempotency_store: Arc<IdempotencyStore>,
|
pub idempotency_store: Arc<IdempotencyStore>,
|
||||||
pub whatsapp: Option<Arc<WhatsAppChannel>>,
|
pub whatsapp: Option<Arc<WhatsAppChannel>>,
|
||||||
|
|
@ -305,10 +381,18 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||||
let rate_limiter = Arc::new(GatewayRateLimiter::new(
|
let rate_limiter = Arc::new(GatewayRateLimiter::new(
|
||||||
config.gateway.pair_rate_limit_per_minute,
|
config.gateway.pair_rate_limit_per_minute,
|
||||||
config.gateway.webhook_rate_limit_per_minute,
|
config.gateway.webhook_rate_limit_per_minute,
|
||||||
|
config
|
||||||
|
.gateway
|
||||||
|
.rate_limit_max_keys
|
||||||
|
.max(RATE_LIMIT_MAX_KEYS_DEFAULT),
|
||||||
|
));
|
||||||
|
let idempotency_store = Arc::new(IdempotencyStore::new(
|
||||||
|
Duration::from_secs(config.gateway.idempotency_ttl_secs.max(1)),
|
||||||
|
config
|
||||||
|
.gateway
|
||||||
|
.idempotency_max_keys
|
||||||
|
.max(IDEMPOTENCY_MAX_KEYS_DEFAULT),
|
||||||
));
|
));
|
||||||
let idempotency_store = Arc::new(IdempotencyStore::new(Duration::from_secs(
|
|
||||||
config.gateway.idempotency_ttl_secs.max(1),
|
|
||||||
)));
|
|
||||||
|
|
||||||
// ── Tunnel ────────────────────────────────────────────────
|
// ── Tunnel ────────────────────────────────────────────────
|
||||||
let tunnel = crate::tunnel::create_tunnel(&config.tunnel)?;
|
let tunnel = crate::tunnel::create_tunnel(&config.tunnel)?;
|
||||||
|
|
@ -365,6 +449,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||||
auto_save: config.memory.auto_save,
|
auto_save: config.memory.auto_save,
|
||||||
webhook_secret_hash,
|
webhook_secret_hash,
|
||||||
pairing,
|
pairing,
|
||||||
|
trust_forwarded_headers: config.gateway.trust_forwarded_headers,
|
||||||
rate_limiter,
|
rate_limiter,
|
||||||
idempotency_store,
|
idempotency_store,
|
||||||
whatsapp: whatsapp_channel,
|
whatsapp: whatsapp_channel,
|
||||||
|
|
@ -386,7 +471,11 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||||
));
|
));
|
||||||
|
|
||||||
// Run the server
|
// Run the server
|
||||||
axum::serve(listener, app).await?;
|
axum::serve(
|
||||||
|
listener,
|
||||||
|
app.into_make_service_with_connect_info::<SocketAddr>(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
@ -406,8 +495,13 @@ async fn handle_health(State(state): State<AppState>) -> impl IntoResponse {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// POST /pair — exchange one-time code for bearer token
|
/// POST /pair — exchange one-time code for bearer token
|
||||||
async fn handle_pair(State(state): State<AppState>, headers: HeaderMap) -> impl IntoResponse {
|
async fn handle_pair(
|
||||||
let client_key = client_key_from_headers(&headers);
|
State(state): State<AppState>,
|
||||||
|
ConnectInfo(peer_addr): ConnectInfo<SocketAddr>,
|
||||||
|
headers: HeaderMap,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
let client_key =
|
||||||
|
client_key_from_request(Some(peer_addr), &headers, state.trust_forwarded_headers);
|
||||||
if !state.rate_limiter.allow_pair(&client_key) {
|
if !state.rate_limiter.allow_pair(&client_key) {
|
||||||
tracing::warn!("/pair rate limit exceeded for key: {client_key}");
|
tracing::warn!("/pair rate limit exceeded for key: {client_key}");
|
||||||
let err = serde_json::json!({
|
let err = serde_json::json!({
|
||||||
|
|
@ -479,10 +573,12 @@ pub struct WebhookBody {
|
||||||
/// POST /webhook — main webhook endpoint
|
/// POST /webhook — main webhook endpoint
|
||||||
async fn handle_webhook(
|
async fn handle_webhook(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
|
ConnectInfo(peer_addr): ConnectInfo<SocketAddr>,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
body: Result<Json<WebhookBody>, axum::extract::rejection::JsonRejection>,
|
body: Result<Json<WebhookBody>, axum::extract::rejection::JsonRejection>,
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
let client_key = client_key_from_headers(&headers);
|
let client_key =
|
||||||
|
client_key_from_request(Some(peer_addr), &headers, state.trust_forwarded_headers);
|
||||||
if !state.rate_limiter.allow_webhook(&client_key) {
|
if !state.rate_limiter.allow_webhook(&client_key) {
|
||||||
tracing::warn!("/webhook rate limit exceeded for key: {client_key}");
|
tracing::warn!("/webhook rate limit exceeded for key: {client_key}");
|
||||||
let err = serde_json::json!({
|
let err = serde_json::json!({
|
||||||
|
|
@ -803,7 +899,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn gateway_rate_limiter_blocks_after_limit() {
|
fn gateway_rate_limiter_blocks_after_limit() {
|
||||||
let limiter = GatewayRateLimiter::new(2, 2);
|
let limiter = GatewayRateLimiter::new(2, 2, 100);
|
||||||
assert!(limiter.allow_pair("127.0.0.1"));
|
assert!(limiter.allow_pair("127.0.0.1"));
|
||||||
assert!(limiter.allow_pair("127.0.0.1"));
|
assert!(limiter.allow_pair("127.0.0.1"));
|
||||||
assert!(!limiter.allow_pair("127.0.0.1"));
|
assert!(!limiter.allow_pair("127.0.0.1"));
|
||||||
|
|
@ -811,7 +907,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn rate_limiter_sweep_removes_stale_entries() {
|
fn rate_limiter_sweep_removes_stale_entries() {
|
||||||
let limiter = SlidingWindowRateLimiter::new(10, Duration::from_secs(60));
|
let limiter = SlidingWindowRateLimiter::new(10, Duration::from_secs(60), 100);
|
||||||
// Add entries for multiple IPs
|
// Add entries for multiple IPs
|
||||||
assert!(limiter.allow("ip-1"));
|
assert!(limiter.allow("ip-1"));
|
||||||
assert!(limiter.allow("ip-2"));
|
assert!(limiter.allow("ip-2"));
|
||||||
|
|
@ -845,7 +941,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn rate_limiter_zero_limit_always_allows() {
|
fn rate_limiter_zero_limit_always_allows() {
|
||||||
let limiter = SlidingWindowRateLimiter::new(0, Duration::from_secs(60));
|
let limiter = SlidingWindowRateLimiter::new(0, Duration::from_secs(60), 10);
|
||||||
for _ in 0..100 {
|
for _ in 0..100 {
|
||||||
assert!(limiter.allow("any-key"));
|
assert!(limiter.allow("any-key"));
|
||||||
}
|
}
|
||||||
|
|
@ -853,12 +949,77 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn idempotency_store_rejects_duplicate_key() {
|
fn idempotency_store_rejects_duplicate_key() {
|
||||||
let store = IdempotencyStore::new(Duration::from_secs(30));
|
let store = IdempotencyStore::new(Duration::from_secs(30), 10);
|
||||||
assert!(store.record_if_new("req-1"));
|
assert!(store.record_if_new("req-1"));
|
||||||
assert!(!store.record_if_new("req-1"));
|
assert!(!store.record_if_new("req-1"));
|
||||||
assert!(store.record_if_new("req-2"));
|
assert!(store.record_if_new("req-2"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rate_limiter_bounded_cardinality_evicts_oldest_key() {
|
||||||
|
let limiter = SlidingWindowRateLimiter::new(5, Duration::from_secs(60), 2);
|
||||||
|
assert!(limiter.allow("ip-1"));
|
||||||
|
assert!(limiter.allow("ip-2"));
|
||||||
|
assert!(limiter.allow("ip-3"));
|
||||||
|
|
||||||
|
let guard = limiter.requests.lock();
|
||||||
|
assert_eq!(guard.0.len(), 2);
|
||||||
|
assert!(guard.0.contains_key("ip-2"));
|
||||||
|
assert!(guard.0.contains_key("ip-3"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn idempotency_store_bounded_cardinality_evicts_oldest_key() {
|
||||||
|
let store = IdempotencyStore::new(Duration::from_secs(300), 2);
|
||||||
|
assert!(store.record_if_new("k1"));
|
||||||
|
std::thread::sleep(Duration::from_millis(2));
|
||||||
|
assert!(store.record_if_new("k2"));
|
||||||
|
std::thread::sleep(Duration::from_millis(2));
|
||||||
|
assert!(store.record_if_new("k3"));
|
||||||
|
|
||||||
|
let keys = store.keys.lock();
|
||||||
|
assert_eq!(keys.len(), 2);
|
||||||
|
assert!(!keys.contains_key("k1"));
|
||||||
|
assert!(keys.contains_key("k2"));
|
||||||
|
assert!(keys.contains_key("k3"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn client_key_defaults_to_peer_addr_when_untrusted_proxy_mode() {
|
||||||
|
let peer = SocketAddr::from(([10, 0, 0, 5], 3000));
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert(
|
||||||
|
"X-Forwarded-For",
|
||||||
|
HeaderValue::from_static("198.51.100.10, 203.0.113.11"),
|
||||||
|
);
|
||||||
|
|
||||||
|
let key = client_key_from_request(Some(peer), &headers, false);
|
||||||
|
assert_eq!(key, "10.0.0.5");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn client_key_uses_forwarded_ip_only_in_trusted_proxy_mode() {
|
||||||
|
let peer = SocketAddr::from(([10, 0, 0, 5], 3000));
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert(
|
||||||
|
"X-Forwarded-For",
|
||||||
|
HeaderValue::from_static("198.51.100.10, 203.0.113.11"),
|
||||||
|
);
|
||||||
|
|
||||||
|
let key = client_key_from_request(Some(peer), &headers, true);
|
||||||
|
assert_eq!(key, "198.51.100.10");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn client_key_falls_back_to_peer_when_forwarded_header_invalid() {
|
||||||
|
let peer = SocketAddr::from(([10, 0, 0, 5], 3000));
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert("X-Forwarded-For", HeaderValue::from_static("garbage-value"));
|
||||||
|
|
||||||
|
let key = client_key_from_request(Some(peer), &headers, true);
|
||||||
|
assert_eq!(key, "10.0.0.5");
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn persist_pairing_tokens_writes_config_tokens() {
|
fn persist_pairing_tokens_writes_config_tokens() {
|
||||||
let temp = tempfile::tempdir().unwrap();
|
let temp = tempfile::tempdir().unwrap();
|
||||||
|
|
@ -1040,6 +1201,10 @@ mod tests {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn test_connect_info() -> ConnectInfo<SocketAddr> {
|
||||||
|
ConnectInfo(SocketAddr::from(([127, 0, 0, 1], 30_300)))
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn webhook_idempotency_skips_duplicate_provider_calls() {
|
async fn webhook_idempotency_skips_duplicate_provider_calls() {
|
||||||
let provider_impl = Arc::new(MockProvider::default());
|
let provider_impl = Arc::new(MockProvider::default());
|
||||||
|
|
@ -1055,8 +1220,9 @@ mod tests {
|
||||||
auto_save: false,
|
auto_save: false,
|
||||||
webhook_secret_hash: None,
|
webhook_secret_hash: None,
|
||||||
pairing: Arc::new(PairingGuard::new(false, &[])),
|
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||||
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
trust_forwarded_headers: false,
|
||||||
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100, 100)),
|
||||||
|
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300), 1000)),
|
||||||
whatsapp: None,
|
whatsapp: None,
|
||||||
whatsapp_app_secret: None,
|
whatsapp_app_secret: None,
|
||||||
};
|
};
|
||||||
|
|
@ -1067,15 +1233,20 @@ mod tests {
|
||||||
let body = Ok(Json(WebhookBody {
|
let body = Ok(Json(WebhookBody {
|
||||||
message: "hello".into(),
|
message: "hello".into(),
|
||||||
}));
|
}));
|
||||||
let first = handle_webhook(State(state.clone()), headers.clone(), body)
|
let first = handle_webhook(
|
||||||
.await
|
State(state.clone()),
|
||||||
.into_response();
|
test_connect_info(),
|
||||||
|
headers.clone(),
|
||||||
|
body,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.into_response();
|
||||||
assert_eq!(first.status(), StatusCode::OK);
|
assert_eq!(first.status(), StatusCode::OK);
|
||||||
|
|
||||||
let body = Ok(Json(WebhookBody {
|
let body = Ok(Json(WebhookBody {
|
||||||
message: "hello".into(),
|
message: "hello".into(),
|
||||||
}));
|
}));
|
||||||
let second = handle_webhook(State(state), headers, body)
|
let second = handle_webhook(State(state), test_connect_info(), headers, body)
|
||||||
.await
|
.await
|
||||||
.into_response();
|
.into_response();
|
||||||
assert_eq!(second.status(), StatusCode::OK);
|
assert_eq!(second.status(), StatusCode::OK);
|
||||||
|
|
@ -1104,8 +1275,9 @@ mod tests {
|
||||||
auto_save: true,
|
auto_save: true,
|
||||||
webhook_secret_hash: None,
|
webhook_secret_hash: None,
|
||||||
pairing: Arc::new(PairingGuard::new(false, &[])),
|
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||||
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
trust_forwarded_headers: false,
|
||||||
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100, 100)),
|
||||||
|
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300), 1000)),
|
||||||
whatsapp: None,
|
whatsapp: None,
|
||||||
whatsapp_app_secret: None,
|
whatsapp_app_secret: None,
|
||||||
};
|
};
|
||||||
|
|
@ -1115,15 +1287,20 @@ mod tests {
|
||||||
let body1 = Ok(Json(WebhookBody {
|
let body1 = Ok(Json(WebhookBody {
|
||||||
message: "hello one".into(),
|
message: "hello one".into(),
|
||||||
}));
|
}));
|
||||||
let first = handle_webhook(State(state.clone()), headers.clone(), body1)
|
let first = handle_webhook(
|
||||||
.await
|
State(state.clone()),
|
||||||
.into_response();
|
test_connect_info(),
|
||||||
|
headers.clone(),
|
||||||
|
body1,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.into_response();
|
||||||
assert_eq!(first.status(), StatusCode::OK);
|
assert_eq!(first.status(), StatusCode::OK);
|
||||||
|
|
||||||
let body2 = Ok(Json(WebhookBody {
|
let body2 = Ok(Json(WebhookBody {
|
||||||
message: "hello two".into(),
|
message: "hello two".into(),
|
||||||
}));
|
}));
|
||||||
let second = handle_webhook(State(state), headers, body2)
|
let second = handle_webhook(State(state), test_connect_info(), headers, body2)
|
||||||
.await
|
.await
|
||||||
.into_response();
|
.into_response();
|
||||||
assert_eq!(second.status(), StatusCode::OK);
|
assert_eq!(second.status(), StatusCode::OK);
|
||||||
|
|
@ -1162,14 +1339,16 @@ mod tests {
|
||||||
auto_save: false,
|
auto_save: false,
|
||||||
webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
|
webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
|
||||||
pairing: Arc::new(PairingGuard::new(false, &[])),
|
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||||
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
trust_forwarded_headers: false,
|
||||||
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100, 100)),
|
||||||
|
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300), 1000)),
|
||||||
whatsapp: None,
|
whatsapp: None,
|
||||||
whatsapp_app_secret: None,
|
whatsapp_app_secret: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let response = handle_webhook(
|
let response = handle_webhook(
|
||||||
State(state),
|
State(state),
|
||||||
|
test_connect_info(),
|
||||||
HeaderMap::new(),
|
HeaderMap::new(),
|
||||||
Ok(Json(WebhookBody {
|
Ok(Json(WebhookBody {
|
||||||
message: "hello".into(),
|
message: "hello".into(),
|
||||||
|
|
@ -1197,8 +1376,9 @@ mod tests {
|
||||||
auto_save: false,
|
auto_save: false,
|
||||||
webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
|
webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
|
||||||
pairing: Arc::new(PairingGuard::new(false, &[])),
|
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||||
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
trust_forwarded_headers: false,
|
||||||
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100, 100)),
|
||||||
|
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300), 1000)),
|
||||||
whatsapp: None,
|
whatsapp: None,
|
||||||
whatsapp_app_secret: None,
|
whatsapp_app_secret: None,
|
||||||
};
|
};
|
||||||
|
|
@ -1208,6 +1388,7 @@ mod tests {
|
||||||
|
|
||||||
let response = handle_webhook(
|
let response = handle_webhook(
|
||||||
State(state),
|
State(state),
|
||||||
|
test_connect_info(),
|
||||||
headers,
|
headers,
|
||||||
Ok(Json(WebhookBody {
|
Ok(Json(WebhookBody {
|
||||||
message: "hello".into(),
|
message: "hello".into(),
|
||||||
|
|
@ -1235,8 +1416,9 @@ mod tests {
|
||||||
auto_save: false,
|
auto_save: false,
|
||||||
webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
|
webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
|
||||||
pairing: Arc::new(PairingGuard::new(false, &[])),
|
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||||
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
trust_forwarded_headers: false,
|
||||||
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100, 100)),
|
||||||
|
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300), 1000)),
|
||||||
whatsapp: None,
|
whatsapp: None,
|
||||||
whatsapp_app_secret: None,
|
whatsapp_app_secret: None,
|
||||||
};
|
};
|
||||||
|
|
@ -1246,6 +1428,7 @@ mod tests {
|
||||||
|
|
||||||
let response = handle_webhook(
|
let response = handle_webhook(
|
||||||
State(state),
|
State(state),
|
||||||
|
test_connect_info(),
|
||||||
headers,
|
headers,
|
||||||
Ok(Json(WebhookBody {
|
Ok(Json(WebhookBody {
|
||||||
message: "hello".into(),
|
message: "hello".into(),
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue