feat(config): make config writes atomic with rollback-safe replacement (#190)

* feat(runtime): add Docker runtime MVP and runtime-aware command builder

* feat(security): add shell risk classification, approval gates, and action throttling

* feat(gateway): add per-endpoint rate limiting and webhook idempotency

* feat(config): make config writes atomic with rollback-safe replacement

---------

Co-authored-by: chumyin <chumyin@users.noreply.github.com>
This commit is contained in:
Chummy 2026-02-16 01:18:45 +08:00 committed by GitHub
parent f1e3b1166d
commit b0e1e32819
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 1202 additions and 67 deletions

View file

@ -22,9 +22,10 @@ use axum::{
routing::{get, post},
Router,
};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use tower_http::limit::RequestBodyLimitLayer;
use tower_http::timeout::TimeoutLayer;
@ -32,6 +33,118 @@ use tower_http::timeout::TimeoutLayer;
pub const MAX_BODY_SIZE: usize = 65_536;
/// Request timeout (30s) — prevents slow-loris attacks
pub const REQUEST_TIMEOUT_SECS: u64 = 30;
/// Sliding window used by gateway rate limiting.
pub const RATE_LIMIT_WINDOW_SECS: u64 = 60;
#[derive(Debug)]
struct SlidingWindowRateLimiter {
limit_per_window: u32,
window: Duration,
requests: Mutex<HashMap<String, Vec<Instant>>>,
}
impl SlidingWindowRateLimiter {
fn new(limit_per_window: u32, window: Duration) -> Self {
Self {
limit_per_window,
window,
requests: Mutex::new(HashMap::new()),
}
}
fn allow(&self, key: &str) -> bool {
if self.limit_per_window == 0 {
return true;
}
let now = Instant::now();
let cutoff = now.checked_sub(self.window).unwrap_or_else(Instant::now);
let mut requests = self
.requests
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let entry = requests.entry(key.to_owned()).or_default();
entry.retain(|instant| *instant > cutoff);
if entry.len() >= self.limit_per_window as usize {
return false;
}
entry.push(now);
true
}
}
#[derive(Debug)]
pub struct GatewayRateLimiter {
pair: SlidingWindowRateLimiter,
webhook: SlidingWindowRateLimiter,
}
impl GatewayRateLimiter {
fn new(pair_per_minute: u32, webhook_per_minute: u32) -> Self {
let window = Duration::from_secs(RATE_LIMIT_WINDOW_SECS);
Self {
pair: SlidingWindowRateLimiter::new(pair_per_minute, window),
webhook: SlidingWindowRateLimiter::new(webhook_per_minute, window),
}
}
fn allow_pair(&self, key: &str) -> bool {
self.pair.allow(key)
}
fn allow_webhook(&self, key: &str) -> bool {
self.webhook.allow(key)
}
}
#[derive(Debug)]
pub struct IdempotencyStore {
ttl: Duration,
keys: Mutex<HashMap<String, Instant>>,
}
impl IdempotencyStore {
fn new(ttl: Duration) -> Self {
Self {
ttl,
keys: Mutex::new(HashMap::new()),
}
}
/// Returns true if this key is new and is now recorded.
fn record_if_new(&self, key: &str) -> bool {
let now = Instant::now();
let mut keys = self
.keys
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
keys.retain(|_, seen_at| now.duration_since(*seen_at) < self.ttl);
if keys.contains_key(key) {
return false;
}
keys.insert(key.to_owned(), now);
true
}
}
fn client_key_from_headers(headers: &HeaderMap) -> String {
for header_name in ["X-Forwarded-For", "X-Real-IP"] {
if let Some(value) = headers.get(header_name).and_then(|v| v.to_str().ok()) {
let first = value.split(',').next().unwrap_or("").trim();
if !first.is_empty() {
return first.to_owned();
}
}
}
"unknown".into()
}
/// Shared state for all axum handlers
#[derive(Clone)]
@ -43,6 +156,8 @@ pub struct AppState {
pub auto_save: bool,
pub webhook_secret: Option<Arc<str>>,
pub pairing: Arc<PairingGuard>,
pub rate_limiter: Arc<GatewayRateLimiter>,
pub idempotency_store: Arc<IdempotencyStore>,
pub whatsapp: Option<Arc<WhatsAppChannel>>,
/// `WhatsApp` app secret for webhook signature verification (`X-Hub-Signature-256`)
pub whatsapp_app_secret: Option<Arc<str>>,
@ -66,17 +181,15 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
let actual_port = listener.local_addr()?.port();
let display_addr = format!("{host}:{actual_port}");
let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider(
config.default_provider.as_deref().unwrap_or("openrouter"),
config.api_key.as_deref(),
&config.reliability,
)?);
let model = config
.default_model
.clone()
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into());
let provider: Arc<dyn Provider> = Arc::from(providers::create_routed_provider(
config.default_provider.as_deref().unwrap_or("openrouter"),
config.api_key.as_deref(),
&config.reliability,
&config.model_routes,
&model,
)?);
let temperature = config.default_temperature;
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory(
&config.memory,
@ -127,6 +240,13 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
config.gateway.require_pairing,
&config.gateway.paired_tokens,
));
let rate_limiter = Arc::new(GatewayRateLimiter::new(
config.gateway.pair_rate_limit_per_minute,
config.gateway.webhook_rate_limit_per_minute,
));
let idempotency_store = Arc::new(IdempotencyStore::new(Duration::from_secs(
config.gateway.idempotency_ttl_secs.max(1),
)));
// ── Tunnel ────────────────────────────────────────────────
let tunnel = crate::tunnel::create_tunnel(&config.tunnel)?;
@ -185,6 +305,8 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
auto_save: config.memory.auto_save,
webhook_secret,
pairing,
rate_limiter,
idempotency_store,
whatsapp: whatsapp_channel,
whatsapp_app_secret,
};
@ -225,6 +347,16 @@ async fn handle_health(State(state): State<AppState>) -> impl IntoResponse {
/// POST /pair — exchange one-time code for bearer token
async fn handle_pair(State(state): State<AppState>, headers: HeaderMap) -> impl IntoResponse {
let client_key = client_key_from_headers(&headers);
if !state.rate_limiter.allow_pair(&client_key) {
tracing::warn!("/pair rate limit exceeded for key: {client_key}");
let err = serde_json::json!({
"error": "Too many pairing requests. Please retry later.",
"retry_after": RATE_LIMIT_WINDOW_SECS,
});
return (StatusCode::TOO_MANY_REQUESTS, Json(err));
}
let code = headers
.get("X-Pairing-Code")
.and_then(|v| v.to_str().ok())
@ -270,6 +402,16 @@ async fn handle_webhook(
headers: HeaderMap,
body: Result<Json<WebhookBody>, axum::extract::rejection::JsonRejection>,
) -> impl IntoResponse {
let client_key = client_key_from_headers(&headers);
if !state.rate_limiter.allow_webhook(&client_key) {
tracing::warn!("/webhook rate limit exceeded for key: {client_key}");
let err = serde_json::json!({
"error": "Too many webhook requests. Please retry later.",
"retry_after": RATE_LIMIT_WINDOW_SECS,
});
return (StatusCode::TOO_MANY_REQUESTS, Json(err));
}
// ── Bearer token auth (pairing) ──
if state.pairing.require_pairing() {
let auth = headers
@ -312,6 +454,24 @@ async fn handle_webhook(
}
};
// ── Idempotency (optional) ──
if let Some(idempotency_key) = headers
.get("X-Idempotency-Key")
.and_then(|v| v.to_str().ok())
.map(str::trim)
.filter(|value| !value.is_empty())
{
if !state.idempotency_store.record_if_new(idempotency_key) {
tracing::info!("Webhook duplicate ignored (idempotency key: {idempotency_key})");
let body = serde_json::json!({
"status": "duplicate",
"idempotent": true,
"message": "Request already processed for this idempotency key"
});
return (StatusCode::OK, Json(body));
}
}
let message = &webhook_body.message;
if state.auto_save {
@ -508,6 +668,13 @@ async fn handle_whatsapp_message(
#[cfg(test)]
mod tests {
use super::*;
use crate::memory::{Memory, MemoryCategory, MemoryEntry};
use crate::providers::Provider;
use async_trait::async_trait;
use axum::http::HeaderValue;
use axum::response::IntoResponse;
use http_body_util::BodyExt;
use std::sync::atomic::{AtomicUsize, Ordering};
#[test]
fn security_body_limit_is_64kb() {
@ -547,6 +714,133 @@ mod tests {
assert_clone::<AppState>();
}
#[test]
fn gateway_rate_limiter_blocks_after_limit() {
let limiter = GatewayRateLimiter::new(2, 2);
assert!(limiter.allow_pair("127.0.0.1"));
assert!(limiter.allow_pair("127.0.0.1"));
assert!(!limiter.allow_pair("127.0.0.1"));
}
#[test]
fn idempotency_store_rejects_duplicate_key() {
let store = IdempotencyStore::new(Duration::from_secs(30));
assert!(store.record_if_new("req-1"));
assert!(!store.record_if_new("req-1"));
assert!(store.record_if_new("req-2"));
}
#[derive(Default)]
struct MockMemory;
#[async_trait]
impl Memory for MockMemory {
fn name(&self) -> &str {
"mock"
}
async fn store(
&self,
_key: &str,
_content: &str,
_category: MemoryCategory,
) -> anyhow::Result<()> {
Ok(())
}
async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(Vec::new())
}
async fn get(&self, _key: &str) -> anyhow::Result<Option<MemoryEntry>> {
Ok(None)
}
async fn list(
&self,
_category: Option<&MemoryCategory>,
) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(Vec::new())
}
async fn forget(&self, _key: &str) -> anyhow::Result<bool> {
Ok(false)
}
async fn count(&self) -> anyhow::Result<usize> {
Ok(0)
}
async fn health_check(&self) -> bool {
true
}
}
#[derive(Default)]
struct MockProvider {
calls: AtomicUsize,
}
#[async_trait]
impl Provider for MockProvider {
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
self.calls.fetch_add(1, Ordering::SeqCst);
Ok("ok".into())
}
}
#[tokio::test]
async fn webhook_idempotency_skips_duplicate_provider_calls() {
let provider_impl = Arc::new(MockProvider::default());
let provider: Arc<dyn Provider> = provider_impl.clone();
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
let state = AppState {
provider,
model: "test-model".into(),
temperature: 0.0,
mem: memory,
auto_save: false,
webhook_secret: None,
pairing: Arc::new(PairingGuard::new(false, &[])),
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
whatsapp: None,
whatsapp_app_secret: None,
};
let mut headers = HeaderMap::new();
headers.insert("X-Idempotency-Key", HeaderValue::from_static("abc-123"));
let body = Ok(Json(WebhookBody {
message: "hello".into(),
}));
let first = handle_webhook(State(state.clone()), headers.clone(), body)
.await
.into_response();
assert_eq!(first.status(), StatusCode::OK);
let body = Ok(Json(WebhookBody {
message: "hello".into(),
}));
let second = handle_webhook(State(state), headers, body)
.await
.into_response();
assert_eq!(second.status(), StatusCode::OK);
let payload = second.into_body().collect().await.unwrap().to_bytes();
let parsed: serde_json::Value = serde_json::from_slice(&payload).unwrap();
assert_eq!(parsed["status"], "duplicate");
assert_eq!(parsed["idempotent"], true);
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 1);
}
// ══════════════════════════════════════════════════════════
// WhatsApp Signature Verification Tests (CWE-345 Prevention)
// ══════════════════════════════════════════════════════════
@ -572,7 +866,11 @@ mod tests {
let signature_header = compute_whatsapp_signature_header(app_secret, body);
assert!(verify_whatsapp_signature(app_secret, body, &signature_header));
assert!(verify_whatsapp_signature(
app_secret,
body,
&signature_header
));
}
#[test]
@ -583,7 +881,11 @@ mod tests {
let signature_header = compute_whatsapp_signature_header(wrong_secret, body);
assert!(!verify_whatsapp_signature(app_secret, body, &signature_header));
assert!(!verify_whatsapp_signature(
app_secret,
body,
&signature_header
));
}
#[test]
@ -610,7 +912,11 @@ mod tests {
// Signature without "sha256=" prefix
let signature_header = "abc123def456";
assert!(!verify_whatsapp_signature(app_secret, body, signature_header));
assert!(!verify_whatsapp_signature(
app_secret,
body,
signature_header
));
}
#[test]
@ -643,7 +949,11 @@ mod tests {
let signature_header = compute_whatsapp_signature_header(app_secret, body);
assert!(verify_whatsapp_signature(app_secret, body, &signature_header));
assert!(verify_whatsapp_signature(
app_secret,
body,
&signature_header
));
}
#[test]
@ -653,7 +963,11 @@ mod tests {
let signature_header = compute_whatsapp_signature_header(app_secret, body);
assert!(verify_whatsapp_signature(app_secret, body, &signature_header));
assert!(verify_whatsapp_signature(
app_secret,
body,
&signature_header
));
}
#[test]
@ -663,7 +977,11 @@ mod tests {
let signature_header = compute_whatsapp_signature_header(app_secret, body);
assert!(verify_whatsapp_signature(app_secret, body, &signature_header));
assert!(verify_whatsapp_signature(
app_secret,
body,
&signature_header
));
}
#[test]