diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index af7cf91..eac6571 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -21,8 +21,9 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { ); } - let addr = format!("{host}:{port}"); - let listener = TcpListener::bind(&addr).await?; + let listener = TcpListener::bind(format!("{host}:{port}")).await?; + let actual_port = listener.local_addr()?.port(); + let addr = format!("{host}:{actual_port}"); let provider: Arc = Arc::from(providers::create_provider( config.default_provider.as_deref().unwrap_or("openrouter"), @@ -59,7 +60,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { if let Some(ref tun) = tunnel { println!("🔗 Starting {} tunnel...", tun.name()); - match tun.start(host, port).await { + match tun.start(host, actual_port).await { Ok(url) => { println!("🌐 Tunnel active: {url}"); tunnel_url = Some(url); @@ -312,6 +313,111 @@ async fn send_response( #[cfg(test)] mod tests { use super::*; + use tokio::net::TcpListener as TokioListener; + + // ── Port allocation tests ──────────────────────────────── + + #[tokio::test] + async fn port_zero_binds_to_random_port() { + let listener = TokioListener::bind("127.0.0.1:0").await.unwrap(); + let actual = listener.local_addr().unwrap().port(); + assert_ne!(actual, 0, "OS must assign a non-zero port"); + assert!(actual > 0, "Actual port must be positive"); + } + + #[tokio::test] + async fn port_zero_assigns_different_ports() { + let l1 = TokioListener::bind("127.0.0.1:0").await.unwrap(); + let l2 = TokioListener::bind("127.0.0.1:0").await.unwrap(); + let p1 = l1.local_addr().unwrap().port(); + let p2 = l2.local_addr().unwrap().port(); + assert_ne!(p1, p2, "Two port-0 binds should get different ports"); + } + + #[tokio::test] + async fn port_zero_assigns_high_port() { + let listener = TokioListener::bind("127.0.0.1:0").await.unwrap(); + let actual = listener.local_addr().unwrap().port(); + // OS typically assigns ephemeral ports >= 1024 + assert!( + actual >= 1024, + "Random port {actual} should be >= 1024 (unprivileged)" + ); + } + + #[tokio::test] + async fn specific_port_binds_exactly() { + // Find a free port first via port 0, then rebind to it + let tmp = TokioListener::bind("127.0.0.1:0").await.unwrap(); + let free_port = tmp.local_addr().unwrap().port(); + drop(tmp); + + let listener = TokioListener::bind(format!("127.0.0.1:{free_port}")) + .await + .unwrap(); + let actual = listener.local_addr().unwrap().port(); + assert_eq!(actual, free_port, "Specific port bind must match exactly"); + } + + #[tokio::test] + async fn actual_port_matches_addr_format() { + let listener = TokioListener::bind("127.0.0.1:0").await.unwrap(); + let actual_port = listener.local_addr().unwrap().port(); + let addr = format!("127.0.0.1:{actual_port}"); + assert!( + addr.starts_with("127.0.0.1:"), + "Addr format must include host" + ); + assert!( + !addr.ends_with(":0"), + "Addr must not contain port 0 after binding" + ); + } + + #[tokio::test] + async fn port_zero_listener_accepts_connections() { + let listener = TokioListener::bind("127.0.0.1:0").await.unwrap(); + let actual_port = listener.local_addr().unwrap().port(); + + // Spawn a client that connects + let client = tokio::spawn(async move { + tokio::net::TcpStream::connect(format!("127.0.0.1:{actual_port}")) + .await + .unwrap() + }); + + // Accept the connection + let (stream, _peer) = listener.accept().await.unwrap(); + assert!(stream.peer_addr().is_ok()); + client.await.unwrap(); + } + + #[tokio::test] + async fn duplicate_specific_port_fails() { + let l1 = TokioListener::bind("127.0.0.1:0").await.unwrap(); + let port = l1.local_addr().unwrap().port(); + // Try to bind the same port while l1 is still alive + let result = TokioListener::bind(format!("127.0.0.1:{port}")).await; + assert!(result.is_err(), "Binding an already-used port must fail"); + } + + #[tokio::test] + async fn tunnel_gets_actual_port_not_zero() { + // Simulate what run_gateway does: bind port 0, extract actual port + let port: u16 = 0; + let host = "127.0.0.1"; + let listener = TokioListener::bind(format!("{host}:{port}")).await.unwrap(); + let actual_port = listener.local_addr().unwrap().port(); + + // This is the port that would be passed to tun.start(host, actual_port) + assert_ne!(actual_port, 0, "Tunnel must receive actual port, not 0"); + assert!( + actual_port >= 1024, + "Tunnel port {actual_port} must be unprivileged" + ); + } + + // ── extract_header tests ───────────────────────────────── #[test] fn extract_header_finds_value() { diff --git a/src/main.rs b/src/main.rs index 2a02f27..1c66172 100644 --- a/src/main.rs +++ b/src/main.rs @@ -69,7 +69,7 @@ enum Commands { /// Start the gateway server (webhooks, websockets) Gateway { - /// Port to listen on + /// Port to listen on (use 0 for random available port) #[arg(short, long, default_value = "8080")] port: u16, @@ -234,9 +234,11 @@ async fn main() -> Result<()> { } => agent::run(config, message, provider, model, temperature).await, Commands::Gateway { port, host } => { - info!("🚀 Starting ZeroClaw Gateway on {host}:{port}"); - info!("POST http://{host}:{port}/webhook — send JSON messages"); - info!("GET http://{host}:{port}/health — health check"); + if port == 0 { + info!("🚀 Starting ZeroClaw Gateway on {host} (random port)"); + } else { + info!("🚀 Starting ZeroClaw Gateway on {host}:{port}"); + } gateway::run_gateway(&host, port, config).await }