zeroclaw/src/gateway/mod.rs
Vernon Stinebaker 40c41cf3d2
feat(discord): add listen_to_bots config and fix model IDs across codebase (#280)
* fix(config): apply env overrides at runtime and fix Docker compose defaults

- Call apply_env_overrides() after Config::load_or_init() in main.rs so
  environment variables (API_KEY, PROVIDER, ZEROCLAW_GATEWAY_PORT, etc.)
  are actually applied at runtime, not just in tests
- Add ZEROCLAW_ALLOW_PUBLIC_BIND env var support for gateway bind policy
- Fix docker-compose.yml: correct volume path (/zeroclaw-data not /data),
  add ZEROCLAW_ALLOW_PUBLIC_BIND=true for container networking, make host
  port configurable via HOST_PORT env var
- Add docker-compose.override.yml to .gitignore for local dev overrides

* feat(discord): add listen_to_bots config and fix model IDs across codebase

Add listen_to_bots field to DiscordConfig so bot messages are processed
when explicitly enabled (defaults to false for backward compat). Remove
ZEROCLAW_MODEL from Dockerfile release stage so config.toml is the
source of truth for model selection. Fix all hardcoded model IDs from
the dated anthropic/claude-sonnet-4-20250514 to the valid OpenRouter
identifier anthropic/claude-sonnet-4.
2026-02-16 02:13:36 -05:00

1176 lines
37 KiB
Rust

//! Axum-based HTTP gateway with proper HTTP/1.1 compliance, body limits, and timeouts.
//!
//! This module replaces the raw TCP implementation with axum for:
//! - Proper HTTP/1.1 parsing and compliance
//! - Content-Length validation (handled by hyper)
//! - Request body size limits (64KB max)
//! - Request timeouts (30s) to prevent slow-loris attacks
//! - Header sanitization (handled by axum/hyper)
use crate::channels::{Channel, WhatsAppChannel};
use crate::config::Config;
use crate::memory::{self, Memory, MemoryCategory};
use crate::providers::{self, Provider};
use crate::security::pairing::{constant_time_eq, is_public_bind, PairingGuard};
use crate::util::truncate_with_ellipsis;
use anyhow::Result;
use axum::{
body::Bytes,
extract::{Query, State},
http::{header, HeaderMap, StatusCode},
response::{IntoResponse, Json},
routing::{get, post},
Router,
};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use tower_http::limit::RequestBodyLimitLayer;
use tower_http::timeout::TimeoutLayer;
use uuid::Uuid;
/// Maximum request body size (64KB) — prevents memory exhaustion
pub const MAX_BODY_SIZE: usize = 65_536;
/// Request timeout (30s) — prevents slow-loris attacks
pub const REQUEST_TIMEOUT_SECS: u64 = 30;
/// Sliding window used by gateway rate limiting.
pub const RATE_LIMIT_WINDOW_SECS: u64 = 60;
fn webhook_memory_key() -> String {
format!("webhook_msg_{}", Uuid::new_v4())
}
fn whatsapp_memory_key(msg: &crate::channels::traits::ChannelMessage) -> String {
format!("whatsapp_{}_{}", msg.sender, msg.id)
}
#[derive(Debug)]
struct SlidingWindowRateLimiter {
limit_per_window: u32,
window: Duration,
requests: Mutex<HashMap<String, Vec<Instant>>>,
}
impl SlidingWindowRateLimiter {
fn new(limit_per_window: u32, window: Duration) -> Self {
Self {
limit_per_window,
window,
requests: Mutex::new(HashMap::new()),
}
}
fn allow(&self, key: &str) -> bool {
if self.limit_per_window == 0 {
return true;
}
let now = Instant::now();
let cutoff = now.checked_sub(self.window).unwrap_or_else(Instant::now);
let mut requests = self
.requests
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let entry = requests.entry(key.to_owned()).or_default();
entry.retain(|instant| *instant > cutoff);
if entry.len() >= self.limit_per_window as usize {
return false;
}
entry.push(now);
true
}
}
#[derive(Debug)]
pub struct GatewayRateLimiter {
pair: SlidingWindowRateLimiter,
webhook: SlidingWindowRateLimiter,
}
impl GatewayRateLimiter {
fn new(pair_per_minute: u32, webhook_per_minute: u32) -> Self {
let window = Duration::from_secs(RATE_LIMIT_WINDOW_SECS);
Self {
pair: SlidingWindowRateLimiter::new(pair_per_minute, window),
webhook: SlidingWindowRateLimiter::new(webhook_per_minute, window),
}
}
fn allow_pair(&self, key: &str) -> bool {
self.pair.allow(key)
}
fn allow_webhook(&self, key: &str) -> bool {
self.webhook.allow(key)
}
}
#[derive(Debug)]
pub struct IdempotencyStore {
ttl: Duration,
keys: Mutex<HashMap<String, Instant>>,
}
impl IdempotencyStore {
fn new(ttl: Duration) -> Self {
Self {
ttl,
keys: Mutex::new(HashMap::new()),
}
}
/// Returns true if this key is new and is now recorded.
fn record_if_new(&self, key: &str) -> bool {
let now = Instant::now();
let mut keys = self
.keys
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
keys.retain(|_, seen_at| now.duration_since(*seen_at) < self.ttl);
if keys.contains_key(key) {
return false;
}
keys.insert(key.to_owned(), now);
true
}
}
fn client_key_from_headers(headers: &HeaderMap) -> String {
for header_name in ["X-Forwarded-For", "X-Real-IP"] {
if let Some(value) = headers.get(header_name).and_then(|v| v.to_str().ok()) {
let first = value.split(',').next().unwrap_or("").trim();
if !first.is_empty() {
return first.to_owned();
}
}
}
"unknown".into()
}
/// Shared state for all axum handlers
#[derive(Clone)]
pub struct AppState {
pub provider: Arc<dyn Provider>,
pub model: String,
pub temperature: f64,
pub mem: Arc<dyn Memory>,
pub auto_save: bool,
pub webhook_secret: Option<Arc<str>>,
pub pairing: Arc<PairingGuard>,
pub rate_limiter: Arc<GatewayRateLimiter>,
pub idempotency_store: Arc<IdempotencyStore>,
pub whatsapp: Option<Arc<WhatsAppChannel>>,
/// `WhatsApp` app secret for webhook signature verification (`X-Hub-Signature-256`)
pub whatsapp_app_secret: Option<Arc<str>>,
}
/// Run the HTTP gateway using axum with proper HTTP/1.1 compliance.
#[allow(clippy::too_many_lines)]
pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
// ── Security: refuse public bind without tunnel or explicit opt-in ──
if is_public_bind(host) && config.tunnel.provider == "none" && !config.gateway.allow_public_bind
{
anyhow::bail!(
"🛑 Refusing to bind to {host} — gateway would be exposed to the internet.\n\
Fix: use --host 127.0.0.1 (default), configure a tunnel, or set\n\
[gateway] allow_public_bind = true in config.toml (NOT recommended)."
);
}
let addr: SocketAddr = format!("{host}:{port}").parse()?;
let listener = tokio::net::TcpListener::bind(addr).await?;
let actual_port = listener.local_addr()?.port();
let display_addr = format!("{host}:{actual_port}");
let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider(
config.default_provider.as_deref().unwrap_or("openrouter"),
config.api_key.as_deref(),
&config.reliability,
)?);
let model = config
.default_model
.clone()
.unwrap_or_else(|| "anthropic/claude-sonnet-4".into());
let temperature = config.default_temperature;
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory(
&config.memory,
&config.workspace_dir,
config.api_key.as_deref(),
)?);
// Extract webhook secret for authentication
let webhook_secret: Option<Arc<str>> = config
.channels_config
.webhook
.as_ref()
.and_then(|w| w.secret.as_deref())
.map(Arc::from);
// WhatsApp channel (if configured)
let whatsapp_channel: Option<Arc<WhatsAppChannel>> =
config.channels_config.whatsapp.as_ref().map(|wa| {
Arc::new(WhatsAppChannel::new(
wa.access_token.clone(),
wa.phone_number_id.clone(),
wa.verify_token.clone(),
wa.allowed_numbers.clone(),
))
});
// WhatsApp app secret for webhook signature verification
// Priority: environment variable > config file
let whatsapp_app_secret: Option<Arc<str>> = std::env::var("ZEROCLAW_WHATSAPP_APP_SECRET")
.ok()
.and_then(|secret| {
let secret = secret.trim();
(!secret.is_empty()).then(|| secret.to_owned())
})
.or_else(|| {
config.channels_config.whatsapp.as_ref().and_then(|wa| {
wa.app_secret
.as_deref()
.map(str::trim)
.filter(|secret| !secret.is_empty())
.map(ToOwned::to_owned)
})
})
.map(Arc::from);
// ── Pairing guard ──────────────────────────────────────
let pairing = Arc::new(PairingGuard::new(
config.gateway.require_pairing,
&config.gateway.paired_tokens,
));
let rate_limiter = Arc::new(GatewayRateLimiter::new(
config.gateway.pair_rate_limit_per_minute,
config.gateway.webhook_rate_limit_per_minute,
));
let idempotency_store = Arc::new(IdempotencyStore::new(Duration::from_secs(
config.gateway.idempotency_ttl_secs.max(1),
)));
// ── Tunnel ────────────────────────────────────────────────
let tunnel = crate::tunnel::create_tunnel(&config.tunnel)?;
let mut tunnel_url: Option<String> = None;
if let Some(ref tun) = tunnel {
println!("🔗 Starting {} tunnel...", tun.name());
match tun.start(host, actual_port).await {
Ok(url) => {
println!("🌐 Tunnel active: {url}");
tunnel_url = Some(url);
}
Err(e) => {
println!("⚠️ Tunnel failed to start: {e}");
println!(" Falling back to local-only mode.");
}
}
}
println!("🦀 ZeroClaw Gateway listening on http://{display_addr}");
if let Some(ref url) = tunnel_url {
println!(" 🌐 Public URL: {url}");
}
println!(" POST /pair — pair a new client (X-Pairing-Code header)");
println!(" POST /webhook — {{\"message\": \"your prompt\"}}");
if whatsapp_channel.is_some() {
println!(" GET /whatsapp — Meta webhook verification");
println!(" POST /whatsapp — WhatsApp message webhook");
}
println!(" GET /health — health check");
if let Some(code) = pairing.pairing_code() {
println!();
println!(" 🔐 PAIRING REQUIRED — use this one-time code:");
println!(" ┌──────────────┐");
println!("{code}");
println!(" └──────────────┘");
println!(" Send: POST /pair with header X-Pairing-Code: {code}");
} else if pairing.require_pairing() {
println!(" 🔒 Pairing: ACTIVE (bearer token required)");
} else {
println!(" ⚠️ Pairing: DISABLED (all requests accepted)");
}
if webhook_secret.is_some() {
println!(" 🔒 Webhook secret: ENABLED");
}
println!(" Press Ctrl+C to stop.\n");
crate::health::mark_component_ok("gateway");
// Build shared state
let state = AppState {
provider,
model,
temperature,
mem,
auto_save: config.memory.auto_save,
webhook_secret,
pairing,
rate_limiter,
idempotency_store,
whatsapp: whatsapp_channel,
whatsapp_app_secret,
};
// Build router with middleware
let app = Router::new()
.route("/health", get(handle_health))
.route("/pair", post(handle_pair))
.route("/webhook", post(handle_webhook))
.route("/whatsapp", get(handle_whatsapp_verify))
.route("/whatsapp", post(handle_whatsapp_message))
.with_state(state)
.layer(RequestBodyLimitLayer::new(MAX_BODY_SIZE))
.layer(TimeoutLayer::with_status_code(
StatusCode::REQUEST_TIMEOUT,
Duration::from_secs(REQUEST_TIMEOUT_SECS),
));
// Run the server
axum::serve(listener, app).await?;
Ok(())
}
// ══════════════════════════════════════════════════════════════════════════════
// AXUM HANDLERS
// ══════════════════════════════════════════════════════════════════════════════
/// GET /health — always public (no secrets leaked)
async fn handle_health(State(state): State<AppState>) -> impl IntoResponse {
let body = serde_json::json!({
"status": "ok",
"paired": state.pairing.is_paired(),
"runtime": crate::health::snapshot_json(),
});
Json(body)
}
/// POST /pair — exchange one-time code for bearer token
async fn handle_pair(State(state): State<AppState>, headers: HeaderMap) -> impl IntoResponse {
let client_key = client_key_from_headers(&headers);
if !state.rate_limiter.allow_pair(&client_key) {
tracing::warn!("/pair rate limit exceeded for key: {client_key}");
let err = serde_json::json!({
"error": "Too many pairing requests. Please retry later.",
"retry_after": RATE_LIMIT_WINDOW_SECS,
});
return (StatusCode::TOO_MANY_REQUESTS, Json(err));
}
let code = headers
.get("X-Pairing-Code")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
match state.pairing.try_pair(code) {
Ok(Some(token)) => {
tracing::info!("🔐 New client paired successfully");
let body = serde_json::json!({
"paired": true,
"token": token,
"message": "Save this token — use it as Authorization: Bearer <token>"
});
(StatusCode::OK, Json(body))
}
Ok(None) => {
tracing::warn!("🔐 Pairing attempt with invalid code");
let err = serde_json::json!({"error": "Invalid pairing code"});
(StatusCode::FORBIDDEN, Json(err))
}
Err(lockout_secs) => {
tracing::warn!(
"🔐 Pairing locked out — too many failed attempts ({lockout_secs}s remaining)"
);
let err = serde_json::json!({
"error": format!("Too many failed attempts. Try again in {lockout_secs}s."),
"retry_after": lockout_secs
});
(StatusCode::TOO_MANY_REQUESTS, Json(err))
}
}
}
/// Webhook request body
#[derive(serde::Deserialize)]
pub struct WebhookBody {
pub message: String,
}
/// POST /webhook — main webhook endpoint
async fn handle_webhook(
State(state): State<AppState>,
headers: HeaderMap,
body: Result<Json<WebhookBody>, axum::extract::rejection::JsonRejection>,
) -> impl IntoResponse {
let client_key = client_key_from_headers(&headers);
if !state.rate_limiter.allow_webhook(&client_key) {
tracing::warn!("/webhook rate limit exceeded for key: {client_key}");
let err = serde_json::json!({
"error": "Too many webhook requests. Please retry later.",
"retry_after": RATE_LIMIT_WINDOW_SECS,
});
return (StatusCode::TOO_MANY_REQUESTS, Json(err));
}
// ── Bearer token auth (pairing) ──
if state.pairing.require_pairing() {
let auth = headers
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let token = auth.strip_prefix("Bearer ").unwrap_or("");
if !state.pairing.is_authenticated(token) {
tracing::warn!("Webhook: rejected — not paired / invalid bearer token");
let err = serde_json::json!({
"error": "Unauthorized — pair first via POST /pair, then send Authorization: Bearer <token>"
});
return (StatusCode::UNAUTHORIZED, Json(err));
}
}
// ── Webhook secret auth (optional, additional layer) ──
if let Some(ref secret) = state.webhook_secret {
let header_val = headers
.get("X-Webhook-Secret")
.and_then(|v| v.to_str().ok());
match header_val {
Some(val) if constant_time_eq(val, secret.as_ref()) => {}
_ => {
tracing::warn!("Webhook: rejected request — invalid or missing X-Webhook-Secret");
let err = serde_json::json!({"error": "Unauthorized — invalid or missing X-Webhook-Secret header"});
return (StatusCode::UNAUTHORIZED, Json(err));
}
}
}
// ── Parse body ──
let Json(webhook_body) = match body {
Ok(b) => b,
Err(e) => {
let err = serde_json::json!({
"error": format!("Invalid JSON: {e}. Expected: {{\"message\": \"...\"}}")
});
return (StatusCode::BAD_REQUEST, Json(err));
}
};
// ── Idempotency (optional) ──
if let Some(idempotency_key) = headers
.get("X-Idempotency-Key")
.and_then(|v| v.to_str().ok())
.map(str::trim)
.filter(|value| !value.is_empty())
{
if !state.idempotency_store.record_if_new(idempotency_key) {
tracing::info!("Webhook duplicate ignored (idempotency key: {idempotency_key})");
let body = serde_json::json!({
"status": "duplicate",
"idempotent": true,
"message": "Request already processed for this idempotency key"
});
return (StatusCode::OK, Json(body));
}
}
let message = &webhook_body.message;
if state.auto_save {
let key = webhook_memory_key();
let _ = state
.mem
.store(&key, message, MemoryCategory::Conversation)
.await;
}
match state
.provider
.chat(message, &state.model, state.temperature)
.await
{
Ok(response) => {
let body = serde_json::json!({"response": response, "model": state.model});
(StatusCode::OK, Json(body))
}
Err(e) => {
tracing::error!(
"Webhook provider error: {}",
providers::sanitize_api_error(&e.to_string())
);
let err = serde_json::json!({"error": "LLM request failed"});
(StatusCode::INTERNAL_SERVER_ERROR, Json(err))
}
}
}
/// `WhatsApp` verification query params
#[derive(serde::Deserialize)]
pub struct WhatsAppVerifyQuery {
#[serde(rename = "hub.mode")]
pub mode: Option<String>,
#[serde(rename = "hub.verify_token")]
pub verify_token: Option<String>,
#[serde(rename = "hub.challenge")]
pub challenge: Option<String>,
}
/// GET /whatsapp — Meta webhook verification
async fn handle_whatsapp_verify(
State(state): State<AppState>,
Query(params): Query<WhatsAppVerifyQuery>,
) -> impl IntoResponse {
let Some(ref wa) = state.whatsapp else {
return (StatusCode::NOT_FOUND, "WhatsApp not configured".to_string());
};
// Verify the token matches (constant-time comparison to prevent timing attacks)
let token_matches = params
.verify_token
.as_deref()
.is_some_and(|t| constant_time_eq(t, wa.verify_token()));
if params.mode.as_deref() == Some("subscribe") && token_matches {
if let Some(ch) = params.challenge {
tracing::info!("WhatsApp webhook verified successfully");
return (StatusCode::OK, ch);
}
return (StatusCode::BAD_REQUEST, "Missing hub.challenge".to_string());
}
tracing::warn!("WhatsApp webhook verification failed — token mismatch");
(StatusCode::FORBIDDEN, "Forbidden".to_string())
}
/// Verify `WhatsApp` webhook signature (`X-Hub-Signature-256`).
/// Returns true if the signature is valid, false otherwise.
/// See: <https://developers.facebook.com/docs/graph-api/webhooks/getting-started#verification-requests>
pub fn verify_whatsapp_signature(app_secret: &str, body: &[u8], signature_header: &str) -> bool {
use hmac::{Hmac, Mac};
use sha2::Sha256;
// Signature format: "sha256=<hex_signature>"
let Some(hex_sig) = signature_header.strip_prefix("sha256=") else {
return false;
};
// Decode hex signature
let Ok(expected) = hex::decode(hex_sig) else {
return false;
};
// Compute HMAC-SHA256
let Ok(mut mac) = Hmac::<Sha256>::new_from_slice(app_secret.as_bytes()) else {
return false;
};
mac.update(body);
// Constant-time comparison
mac.verify_slice(&expected).is_ok()
}
/// POST /whatsapp — incoming message webhook
async fn handle_whatsapp_message(
State(state): State<AppState>,
headers: HeaderMap,
body: Bytes,
) -> impl IntoResponse {
let Some(ref wa) = state.whatsapp else {
return (
StatusCode::NOT_FOUND,
Json(serde_json::json!({"error": "WhatsApp not configured"})),
);
};
// ── Security: Verify X-Hub-Signature-256 if app_secret is configured ──
if let Some(ref app_secret) = state.whatsapp_app_secret {
let signature = headers
.get("X-Hub-Signature-256")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !verify_whatsapp_signature(app_secret, &body, signature) {
tracing::warn!(
"WhatsApp webhook signature verification failed (signature: {})",
if signature.is_empty() {
"missing"
} else {
"invalid"
}
);
return (
StatusCode::UNAUTHORIZED,
Json(serde_json::json!({"error": "Invalid signature"})),
);
}
}
// Parse JSON body
let Ok(payload) = serde_json::from_slice::<serde_json::Value>(&body) else {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({"error": "Invalid JSON payload"})),
);
};
// Parse messages from the webhook payload
let messages = wa.parse_webhook_payload(&payload);
if messages.is_empty() {
// Acknowledge the webhook even if no messages (could be status updates)
return (StatusCode::OK, Json(serde_json::json!({"status": "ok"})));
}
// Process each message
for msg in &messages {
tracing::info!(
"WhatsApp message from {}: {}",
msg.sender,
truncate_with_ellipsis(&msg.content, 50)
);
// Auto-save to memory
if state.auto_save {
let key = whatsapp_memory_key(msg);
let _ = state
.mem
.store(&key, &msg.content, MemoryCategory::Conversation)
.await;
}
// Call the LLM
match state
.provider
.chat(&msg.content, &state.model, state.temperature)
.await
{
Ok(response) => {
// Send reply via WhatsApp
if let Err(e) = wa.send(&response, &msg.sender).await {
tracing::error!("Failed to send WhatsApp reply: {e}");
}
}
Err(e) => {
tracing::error!("LLM error for WhatsApp message: {e:#}");
let _ = wa
.send(
"Sorry, I couldn't process your message right now.",
&msg.sender,
)
.await;
}
}
}
// Acknowledge the webhook
(StatusCode::OK, Json(serde_json::json!({"status": "ok"})))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::channels::traits::ChannelMessage;
use crate::memory::{Memory, MemoryCategory, MemoryEntry};
use crate::providers::Provider;
use async_trait::async_trait;
use axum::http::HeaderValue;
use axum::response::IntoResponse;
use http_body_util::BodyExt;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Mutex;
#[test]
fn security_body_limit_is_64kb() {
assert_eq!(MAX_BODY_SIZE, 65_536);
}
#[test]
fn security_timeout_is_30_seconds() {
assert_eq!(REQUEST_TIMEOUT_SECS, 30);
}
#[test]
fn webhook_body_requires_message_field() {
let valid = r#"{"message": "hello"}"#;
let parsed: Result<WebhookBody, _> = serde_json::from_str(valid);
assert!(parsed.is_ok());
assert_eq!(parsed.unwrap().message, "hello");
let missing = r#"{"other": "field"}"#;
let parsed: Result<WebhookBody, _> = serde_json::from_str(missing);
assert!(parsed.is_err());
}
#[test]
fn whatsapp_query_fields_are_optional() {
let q = WhatsAppVerifyQuery {
mode: None,
verify_token: None,
challenge: None,
};
assert!(q.mode.is_none());
}
#[test]
fn app_state_is_clone() {
fn assert_clone<T: Clone>() {}
assert_clone::<AppState>();
}
#[test]
fn gateway_rate_limiter_blocks_after_limit() {
let limiter = GatewayRateLimiter::new(2, 2);
assert!(limiter.allow_pair("127.0.0.1"));
assert!(limiter.allow_pair("127.0.0.1"));
assert!(!limiter.allow_pair("127.0.0.1"));
}
#[test]
fn idempotency_store_rejects_duplicate_key() {
let store = IdempotencyStore::new(Duration::from_secs(30));
assert!(store.record_if_new("req-1"));
assert!(!store.record_if_new("req-1"));
assert!(store.record_if_new("req-2"));
}
#[test]
fn webhook_memory_key_is_unique() {
let key1 = webhook_memory_key();
let key2 = webhook_memory_key();
assert!(key1.starts_with("webhook_msg_"));
assert!(key2.starts_with("webhook_msg_"));
assert_ne!(key1, key2);
}
#[test]
fn whatsapp_memory_key_includes_sender_and_message_id() {
let msg = ChannelMessage {
id: "wamid-123".into(),
sender: "+1234567890".into(),
content: "hello".into(),
channel: "whatsapp".into(),
timestamp: 1,
};
let key = whatsapp_memory_key(&msg);
assert_eq!(key, "whatsapp_+1234567890_wamid-123");
}
#[derive(Default)]
struct MockMemory;
#[async_trait]
impl Memory for MockMemory {
fn name(&self) -> &str {
"mock"
}
async fn store(
&self,
_key: &str,
_content: &str,
_category: MemoryCategory,
) -> anyhow::Result<()> {
Ok(())
}
async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(Vec::new())
}
async fn get(&self, _key: &str) -> anyhow::Result<Option<MemoryEntry>> {
Ok(None)
}
async fn list(
&self,
_category: Option<&MemoryCategory>,
) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(Vec::new())
}
async fn forget(&self, _key: &str) -> anyhow::Result<bool> {
Ok(false)
}
async fn count(&self) -> anyhow::Result<usize> {
Ok(0)
}
async fn health_check(&self) -> bool {
true
}
}
#[derive(Default)]
struct MockProvider {
calls: AtomicUsize,
}
#[async_trait]
impl Provider for MockProvider {
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
self.calls.fetch_add(1, Ordering::SeqCst);
Ok("ok".into())
}
}
#[derive(Default)]
struct TrackingMemory {
keys: Mutex<Vec<String>>,
}
#[async_trait]
impl Memory for TrackingMemory {
fn name(&self) -> &str {
"tracking"
}
async fn store(
&self,
key: &str,
_content: &str,
_category: MemoryCategory,
) -> anyhow::Result<()> {
self.keys
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.push(key.to_string());
Ok(())
}
async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(Vec::new())
}
async fn get(&self, _key: &str) -> anyhow::Result<Option<MemoryEntry>> {
Ok(None)
}
async fn list(
&self,
_category: Option<&MemoryCategory>,
) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(Vec::new())
}
async fn forget(&self, _key: &str) -> anyhow::Result<bool> {
Ok(false)
}
async fn count(&self) -> anyhow::Result<usize> {
let size = self
.keys
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.len();
Ok(size)
}
async fn health_check(&self) -> bool {
true
}
}
#[tokio::test]
async fn webhook_idempotency_skips_duplicate_provider_calls() {
let provider_impl = Arc::new(MockProvider::default());
let provider: Arc<dyn Provider> = provider_impl.clone();
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
let state = AppState {
provider,
model: "test-model".into(),
temperature: 0.0,
mem: memory,
auto_save: false,
webhook_secret: None,
pairing: Arc::new(PairingGuard::new(false, &[])),
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
whatsapp: None,
whatsapp_app_secret: None,
};
let mut headers = HeaderMap::new();
headers.insert("X-Idempotency-Key", HeaderValue::from_static("abc-123"));
let body = Ok(Json(WebhookBody {
message: "hello".into(),
}));
let first = handle_webhook(State(state.clone()), headers.clone(), body)
.await
.into_response();
assert_eq!(first.status(), StatusCode::OK);
let body = Ok(Json(WebhookBody {
message: "hello".into(),
}));
let second = handle_webhook(State(state), headers, body)
.await
.into_response();
assert_eq!(second.status(), StatusCode::OK);
let payload = second.into_body().collect().await.unwrap().to_bytes();
let parsed: serde_json::Value = serde_json::from_slice(&payload).unwrap();
assert_eq!(parsed["status"], "duplicate");
assert_eq!(parsed["idempotent"], true);
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn webhook_autosave_stores_distinct_keys_per_request() {
let provider_impl = Arc::new(MockProvider::default());
let provider: Arc<dyn Provider> = provider_impl.clone();
let tracking_impl = Arc::new(TrackingMemory::default());
let memory: Arc<dyn Memory> = tracking_impl.clone();
let state = AppState {
provider,
model: "test-model".into(),
temperature: 0.0,
mem: memory,
auto_save: true,
webhook_secret: None,
pairing: Arc::new(PairingGuard::new(false, &[])),
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
whatsapp: None,
whatsapp_app_secret: None,
};
let headers = HeaderMap::new();
let body1 = Ok(Json(WebhookBody {
message: "hello one".into(),
}));
let first = handle_webhook(State(state.clone()), headers.clone(), body1)
.await
.into_response();
assert_eq!(first.status(), StatusCode::OK);
let body2 = Ok(Json(WebhookBody {
message: "hello two".into(),
}));
let second = handle_webhook(State(state), headers, body2)
.await
.into_response();
assert_eq!(second.status(), StatusCode::OK);
let keys = tracking_impl
.keys
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.clone();
assert_eq!(keys.len(), 2);
assert_ne!(keys[0], keys[1]);
assert!(keys[0].starts_with("webhook_msg_"));
assert!(keys[1].starts_with("webhook_msg_"));
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2);
}
// ══════════════════════════════════════════════════════════
// WhatsApp Signature Verification Tests (CWE-345 Prevention)
// ══════════════════════════════════════════════════════════
fn compute_whatsapp_signature_hex(secret: &str, body: &[u8]) -> String {
use hmac::{Hmac, Mac};
use sha2::Sha256;
let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes()).unwrap();
mac.update(body);
hex::encode(mac.finalize().into_bytes())
}
fn compute_whatsapp_signature_header(secret: &str, body: &[u8]) -> String {
format!("sha256={}", compute_whatsapp_signature_hex(secret, body))
}
#[test]
fn whatsapp_signature_valid() {
// Test with known values
let app_secret = "test_secret_key";
let body = b"test body content";
let signature_header = compute_whatsapp_signature_header(app_secret, body);
assert!(verify_whatsapp_signature(
app_secret,
body,
&signature_header
));
}
#[test]
fn whatsapp_signature_invalid_wrong_secret() {
let app_secret = "correct_secret";
let wrong_secret = "wrong_secret";
let body = b"test body content";
let signature_header = compute_whatsapp_signature_header(wrong_secret, body);
assert!(!verify_whatsapp_signature(
app_secret,
body,
&signature_header
));
}
#[test]
fn whatsapp_signature_invalid_wrong_body() {
let app_secret = "test_secret";
let original_body = b"original body";
let tampered_body = b"tampered body";
let signature_header = compute_whatsapp_signature_header(app_secret, original_body);
// Verify with tampered body should fail
assert!(!verify_whatsapp_signature(
app_secret,
tampered_body,
&signature_header
));
}
#[test]
fn whatsapp_signature_missing_prefix() {
let app_secret = "test_secret";
let body = b"test body";
// Signature without "sha256=" prefix
let signature_header = "abc123def456";
assert!(!verify_whatsapp_signature(
app_secret,
body,
signature_header
));
}
#[test]
fn whatsapp_signature_empty_header() {
let app_secret = "test_secret";
let body = b"test body";
assert!(!verify_whatsapp_signature(app_secret, body, ""));
}
#[test]
fn whatsapp_signature_invalid_hex() {
let app_secret = "test_secret";
let body = b"test body";
// Invalid hex characters
let signature_header = "sha256=not_valid_hex_zzz";
assert!(!verify_whatsapp_signature(
app_secret,
body,
signature_header
));
}
#[test]
fn whatsapp_signature_empty_body() {
let app_secret = "test_secret";
let body = b"";
let signature_header = compute_whatsapp_signature_header(app_secret, body);
assert!(verify_whatsapp_signature(
app_secret,
body,
&signature_header
));
}
#[test]
fn whatsapp_signature_unicode_body() {
let app_secret = "test_secret";
let body = "Hello 🦀 世界".as_bytes();
let signature_header = compute_whatsapp_signature_header(app_secret, body);
assert!(verify_whatsapp_signature(
app_secret,
body,
&signature_header
));
}
#[test]
fn whatsapp_signature_json_payload() {
let app_secret = "my_app_secret_from_meta";
let body = br#"{"entry":[{"changes":[{"value":{"messages":[{"from":"1234567890","text":{"body":"Hello"}}]}}]}]}"#;
let signature_header = compute_whatsapp_signature_header(app_secret, body);
assert!(verify_whatsapp_signature(
app_secret,
body,
&signature_header
));
}
#[test]
fn whatsapp_signature_case_sensitive_prefix() {
let app_secret = "test_secret";
let body = b"test body";
let hex_sig = compute_whatsapp_signature_hex(app_secret, body);
// Wrong case prefix should fail
let wrong_prefix = format!("SHA256={hex_sig}");
assert!(!verify_whatsapp_signature(app_secret, body, &wrong_prefix));
// Correct prefix should pass
let correct_prefix = format!("sha256={hex_sig}");
assert!(verify_whatsapp_signature(app_secret, body, &correct_prefix));
}
#[test]
fn whatsapp_signature_truncated_hex() {
let app_secret = "test_secret";
let body = b"test body";
let hex_sig = compute_whatsapp_signature_hex(app_secret, body);
let truncated = &hex_sig[..32]; // Only half the signature
let signature_header = format!("sha256={truncated}");
assert!(!verify_whatsapp_signature(
app_secret,
body,
&signature_header
));
}
#[test]
fn whatsapp_signature_extra_bytes() {
let app_secret = "test_secret";
let body = b"test body";
let hex_sig = compute_whatsapp_signature_hex(app_secret, body);
let extended = format!("{hex_sig}deadbeef");
let signature_header = format!("sha256={extended}");
assert!(!verify_whatsapp_signature(
app_secret,
body,
&signature_header
));
}
}