feat(config): make config writes atomic with rollback-safe replacement (#190)
* feat(runtime): add Docker runtime MVP and runtime-aware command builder * feat(security): add shell risk classification, approval gates, and action throttling * feat(gateway): add per-endpoint rate limiting and webhook idempotency * feat(config): make config writes atomic with rollback-safe replacement --------- Co-authored-by: chumyin <chumyin@users.noreply.github.com>
This commit is contained in:
parent
f1e3b1166d
commit
b0e1e32819
11 changed files with 1202 additions and 67 deletions
|
|
@ -22,9 +22,10 @@ use axum::{
|
|||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, Instant};
|
||||
use tower_http::limit::RequestBodyLimitLayer;
|
||||
use tower_http::timeout::TimeoutLayer;
|
||||
|
||||
|
|
@ -32,6 +33,118 @@ use tower_http::timeout::TimeoutLayer;
|
|||
pub const MAX_BODY_SIZE: usize = 65_536;
|
||||
/// Request timeout (30s) — prevents slow-loris attacks
|
||||
pub const REQUEST_TIMEOUT_SECS: u64 = 30;
|
||||
/// Sliding window used by gateway rate limiting.
|
||||
pub const RATE_LIMIT_WINDOW_SECS: u64 = 60;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct SlidingWindowRateLimiter {
|
||||
limit_per_window: u32,
|
||||
window: Duration,
|
||||
requests: Mutex<HashMap<String, Vec<Instant>>>,
|
||||
}
|
||||
|
||||
impl SlidingWindowRateLimiter {
|
||||
fn new(limit_per_window: u32, window: Duration) -> Self {
|
||||
Self {
|
||||
limit_per_window,
|
||||
window,
|
||||
requests: Mutex::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
fn allow(&self, key: &str) -> bool {
|
||||
if self.limit_per_window == 0 {
|
||||
return true;
|
||||
}
|
||||
|
||||
let now = Instant::now();
|
||||
let cutoff = now.checked_sub(self.window).unwrap_or_else(Instant::now);
|
||||
|
||||
let mut requests = self
|
||||
.requests
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
|
||||
let entry = requests.entry(key.to_owned()).or_default();
|
||||
entry.retain(|instant| *instant > cutoff);
|
||||
|
||||
if entry.len() >= self.limit_per_window as usize {
|
||||
return false;
|
||||
}
|
||||
|
||||
entry.push(now);
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct GatewayRateLimiter {
|
||||
pair: SlidingWindowRateLimiter,
|
||||
webhook: SlidingWindowRateLimiter,
|
||||
}
|
||||
|
||||
impl GatewayRateLimiter {
|
||||
fn new(pair_per_minute: u32, webhook_per_minute: u32) -> Self {
|
||||
let window = Duration::from_secs(RATE_LIMIT_WINDOW_SECS);
|
||||
Self {
|
||||
pair: SlidingWindowRateLimiter::new(pair_per_minute, window),
|
||||
webhook: SlidingWindowRateLimiter::new(webhook_per_minute, window),
|
||||
}
|
||||
}
|
||||
|
||||
fn allow_pair(&self, key: &str) -> bool {
|
||||
self.pair.allow(key)
|
||||
}
|
||||
|
||||
fn allow_webhook(&self, key: &str) -> bool {
|
||||
self.webhook.allow(key)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct IdempotencyStore {
|
||||
ttl: Duration,
|
||||
keys: Mutex<HashMap<String, Instant>>,
|
||||
}
|
||||
|
||||
impl IdempotencyStore {
|
||||
fn new(ttl: Duration) -> Self {
|
||||
Self {
|
||||
ttl,
|
||||
keys: Mutex::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if this key is new and is now recorded.
|
||||
fn record_if_new(&self, key: &str) -> bool {
|
||||
let now = Instant::now();
|
||||
let mut keys = self
|
||||
.keys
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
|
||||
keys.retain(|_, seen_at| now.duration_since(*seen_at) < self.ttl);
|
||||
|
||||
if keys.contains_key(key) {
|
||||
return false;
|
||||
}
|
||||
|
||||
keys.insert(key.to_owned(), now);
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
fn client_key_from_headers(headers: &HeaderMap) -> String {
|
||||
for header_name in ["X-Forwarded-For", "X-Real-IP"] {
|
||||
if let Some(value) = headers.get(header_name).and_then(|v| v.to_str().ok()) {
|
||||
let first = value.split(',').next().unwrap_or("").trim();
|
||||
if !first.is_empty() {
|
||||
return first.to_owned();
|
||||
}
|
||||
}
|
||||
}
|
||||
"unknown".into()
|
||||
}
|
||||
|
||||
/// Shared state for all axum handlers
|
||||
#[derive(Clone)]
|
||||
|
|
@ -43,6 +156,8 @@ pub struct AppState {
|
|||
pub auto_save: bool,
|
||||
pub webhook_secret: Option<Arc<str>>,
|
||||
pub pairing: Arc<PairingGuard>,
|
||||
pub rate_limiter: Arc<GatewayRateLimiter>,
|
||||
pub idempotency_store: Arc<IdempotencyStore>,
|
||||
pub whatsapp: Option<Arc<WhatsAppChannel>>,
|
||||
/// `WhatsApp` app secret for webhook signature verification (`X-Hub-Signature-256`)
|
||||
pub whatsapp_app_secret: Option<Arc<str>>,
|
||||
|
|
@ -66,17 +181,15 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
|||
let actual_port = listener.local_addr()?.port();
|
||||
let display_addr = format!("{host}:{actual_port}");
|
||||
|
||||
let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider(
|
||||
config.default_provider.as_deref().unwrap_or("openrouter"),
|
||||
config.api_key.as_deref(),
|
||||
&config.reliability,
|
||||
)?);
|
||||
let model = config
|
||||
.default_model
|
||||
.clone()
|
||||
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into());
|
||||
let provider: Arc<dyn Provider> = Arc::from(providers::create_routed_provider(
|
||||
config.default_provider.as_deref().unwrap_or("openrouter"),
|
||||
config.api_key.as_deref(),
|
||||
&config.reliability,
|
||||
&config.model_routes,
|
||||
&model,
|
||||
)?);
|
||||
let temperature = config.default_temperature;
|
||||
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory(
|
||||
&config.memory,
|
||||
|
|
@ -127,6 +240,13 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
|||
config.gateway.require_pairing,
|
||||
&config.gateway.paired_tokens,
|
||||
));
|
||||
let rate_limiter = Arc::new(GatewayRateLimiter::new(
|
||||
config.gateway.pair_rate_limit_per_minute,
|
||||
config.gateway.webhook_rate_limit_per_minute,
|
||||
));
|
||||
let idempotency_store = Arc::new(IdempotencyStore::new(Duration::from_secs(
|
||||
config.gateway.idempotency_ttl_secs.max(1),
|
||||
)));
|
||||
|
||||
// ── Tunnel ────────────────────────────────────────────────
|
||||
let tunnel = crate::tunnel::create_tunnel(&config.tunnel)?;
|
||||
|
|
@ -185,6 +305,8 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
|||
auto_save: config.memory.auto_save,
|
||||
webhook_secret,
|
||||
pairing,
|
||||
rate_limiter,
|
||||
idempotency_store,
|
||||
whatsapp: whatsapp_channel,
|
||||
whatsapp_app_secret,
|
||||
};
|
||||
|
|
@ -225,6 +347,16 @@ async fn handle_health(State(state): State<AppState>) -> impl IntoResponse {
|
|||
|
||||
/// POST /pair — exchange one-time code for bearer token
|
||||
async fn handle_pair(State(state): State<AppState>, headers: HeaderMap) -> impl IntoResponse {
|
||||
let client_key = client_key_from_headers(&headers);
|
||||
if !state.rate_limiter.allow_pair(&client_key) {
|
||||
tracing::warn!("/pair rate limit exceeded for key: {client_key}");
|
||||
let err = serde_json::json!({
|
||||
"error": "Too many pairing requests. Please retry later.",
|
||||
"retry_after": RATE_LIMIT_WINDOW_SECS,
|
||||
});
|
||||
return (StatusCode::TOO_MANY_REQUESTS, Json(err));
|
||||
}
|
||||
|
||||
let code = headers
|
||||
.get("X-Pairing-Code")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
|
|
@ -270,6 +402,16 @@ async fn handle_webhook(
|
|||
headers: HeaderMap,
|
||||
body: Result<Json<WebhookBody>, axum::extract::rejection::JsonRejection>,
|
||||
) -> impl IntoResponse {
|
||||
let client_key = client_key_from_headers(&headers);
|
||||
if !state.rate_limiter.allow_webhook(&client_key) {
|
||||
tracing::warn!("/webhook rate limit exceeded for key: {client_key}");
|
||||
let err = serde_json::json!({
|
||||
"error": "Too many webhook requests. Please retry later.",
|
||||
"retry_after": RATE_LIMIT_WINDOW_SECS,
|
||||
});
|
||||
return (StatusCode::TOO_MANY_REQUESTS, Json(err));
|
||||
}
|
||||
|
||||
// ── Bearer token auth (pairing) ──
|
||||
if state.pairing.require_pairing() {
|
||||
let auth = headers
|
||||
|
|
@ -312,6 +454,24 @@ async fn handle_webhook(
|
|||
}
|
||||
};
|
||||
|
||||
// ── Idempotency (optional) ──
|
||||
if let Some(idempotency_key) = headers
|
||||
.get("X-Idempotency-Key")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
{
|
||||
if !state.idempotency_store.record_if_new(idempotency_key) {
|
||||
tracing::info!("Webhook duplicate ignored (idempotency key: {idempotency_key})");
|
||||
let body = serde_json::json!({
|
||||
"status": "duplicate",
|
||||
"idempotent": true,
|
||||
"message": "Request already processed for this idempotency key"
|
||||
});
|
||||
return (StatusCode::OK, Json(body));
|
||||
}
|
||||
}
|
||||
|
||||
let message = &webhook_body.message;
|
||||
|
||||
if state.auto_save {
|
||||
|
|
@ -508,6 +668,13 @@ async fn handle_whatsapp_message(
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::{Memory, MemoryCategory, MemoryEntry};
|
||||
use crate::providers::Provider;
|
||||
use async_trait::async_trait;
|
||||
use axum::http::HeaderValue;
|
||||
use axum::response::IntoResponse;
|
||||
use http_body_util::BodyExt;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
#[test]
|
||||
fn security_body_limit_is_64kb() {
|
||||
|
|
@ -547,6 +714,133 @@ mod tests {
|
|||
assert_clone::<AppState>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gateway_rate_limiter_blocks_after_limit() {
|
||||
let limiter = GatewayRateLimiter::new(2, 2);
|
||||
assert!(limiter.allow_pair("127.0.0.1"));
|
||||
assert!(limiter.allow_pair("127.0.0.1"));
|
||||
assert!(!limiter.allow_pair("127.0.0.1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn idempotency_store_rejects_duplicate_key() {
|
||||
let store = IdempotencyStore::new(Duration::from_secs(30));
|
||||
assert!(store.record_if_new("req-1"));
|
||||
assert!(!store.record_if_new("req-1"));
|
||||
assert!(store.record_if_new("req-2"));
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct MockMemory;
|
||||
|
||||
#[async_trait]
|
||||
impl Memory for MockMemory {
|
||||
fn name(&self) -> &str {
|
||||
"mock"
|
||||
}
|
||||
|
||||
async fn store(
|
||||
&self,
|
||||
_key: &str,
|
||||
_content: &str,
|
||||
_category: MemoryCategory,
|
||||
) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
async fn get(&self, _key: &str) -> anyhow::Result<Option<MemoryEntry>> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn list(
|
||||
&self,
|
||||
_category: Option<&MemoryCategory>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
async fn forget(&self, _key: &str) -> anyhow::Result<bool> {
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn count(&self) -> anyhow::Result<usize> {
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct MockProvider {
|
||||
calls: AtomicUsize,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for MockProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
_system_prompt: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
self.calls.fetch_add(1, Ordering::SeqCst);
|
||||
Ok("ok".into())
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn webhook_idempotency_skips_duplicate_provider_calls() {
|
||||
let provider_impl = Arc::new(MockProvider::default());
|
||||
let provider: Arc<dyn Provider> = provider_impl.clone();
|
||||
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
|
||||
|
||||
let state = AppState {
|
||||
provider,
|
||||
model: "test-model".into(),
|
||||
temperature: 0.0,
|
||||
mem: memory,
|
||||
auto_save: false,
|
||||
webhook_secret: None,
|
||||
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
||||
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
||||
whatsapp: None,
|
||||
whatsapp_app_secret: None,
|
||||
};
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("X-Idempotency-Key", HeaderValue::from_static("abc-123"));
|
||||
|
||||
let body = Ok(Json(WebhookBody {
|
||||
message: "hello".into(),
|
||||
}));
|
||||
let first = handle_webhook(State(state.clone()), headers.clone(), body)
|
||||
.await
|
||||
.into_response();
|
||||
assert_eq!(first.status(), StatusCode::OK);
|
||||
|
||||
let body = Ok(Json(WebhookBody {
|
||||
message: "hello".into(),
|
||||
}));
|
||||
let second = handle_webhook(State(state), headers, body)
|
||||
.await
|
||||
.into_response();
|
||||
assert_eq!(second.status(), StatusCode::OK);
|
||||
|
||||
let payload = second.into_body().collect().await.unwrap().to_bytes();
|
||||
let parsed: serde_json::Value = serde_json::from_slice(&payload).unwrap();
|
||||
assert_eq!(parsed["status"], "duplicate");
|
||||
assert_eq!(parsed["idempotent"], true);
|
||||
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
// ══════════════════════════════════════════════════════════
|
||||
// WhatsApp Signature Verification Tests (CWE-345 Prevention)
|
||||
// ══════════════════════════════════════════════════════════
|
||||
|
|
@ -572,7 +866,11 @@ mod tests {
|
|||
|
||||
let signature_header = compute_whatsapp_signature_header(app_secret, body);
|
||||
|
||||
assert!(verify_whatsapp_signature(app_secret, body, &signature_header));
|
||||
assert!(verify_whatsapp_signature(
|
||||
app_secret,
|
||||
body,
|
||||
&signature_header
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -583,7 +881,11 @@ mod tests {
|
|||
|
||||
let signature_header = compute_whatsapp_signature_header(wrong_secret, body);
|
||||
|
||||
assert!(!verify_whatsapp_signature(app_secret, body, &signature_header));
|
||||
assert!(!verify_whatsapp_signature(
|
||||
app_secret,
|
||||
body,
|
||||
&signature_header
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -610,7 +912,11 @@ mod tests {
|
|||
// Signature without "sha256=" prefix
|
||||
let signature_header = "abc123def456";
|
||||
|
||||
assert!(!verify_whatsapp_signature(app_secret, body, signature_header));
|
||||
assert!(!verify_whatsapp_signature(
|
||||
app_secret,
|
||||
body,
|
||||
signature_header
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -643,7 +949,11 @@ mod tests {
|
|||
|
||||
let signature_header = compute_whatsapp_signature_header(app_secret, body);
|
||||
|
||||
assert!(verify_whatsapp_signature(app_secret, body, &signature_header));
|
||||
assert!(verify_whatsapp_signature(
|
||||
app_secret,
|
||||
body,
|
||||
&signature_header
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -653,7 +963,11 @@ mod tests {
|
|||
|
||||
let signature_header = compute_whatsapp_signature_header(app_secret, body);
|
||||
|
||||
assert!(verify_whatsapp_signature(app_secret, body, &signature_header));
|
||||
assert!(verify_whatsapp_signature(
|
||||
app_secret,
|
||||
body,
|
||||
&signature_header
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -663,7 +977,11 @@ mod tests {
|
|||
|
||||
let signature_header = compute_whatsapp_signature_header(app_secret, body);
|
||||
|
||||
assert!(verify_whatsapp_signature(app_secret, body, &signature_header));
|
||||
assert!(verify_whatsapp_signature(
|
||||
app_secret,
|
||||
body,
|
||||
&signature_header
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue