From 37e50029cb30098474f57cfb54b5c0d229ebc425 Mon Sep 17 00:00:00 2001 From: Danielle Jenkins Date: Thu, 6 Mar 2025 23:15:54 -0800 Subject: [PATCH] Combine into one single binary --- Cargo.lock | 121 ++++++++++++++++++++++++++++ Cargo.toml | 8 +- README.md | 28 +++++-- docs/usage.md | 21 +++++ src/bin/axum_docs.rs | 155 ++++-------------------------------- src/bin/axum_docs/mod.rs | 136 +++++++++++++++++++++++++++++++ src/bin/cratedocs.rs | 102 ++++++++++++++++++++++++ src/lib.rs | 1 + src/server/axum_docs/mod.rs | 136 +++++++++++++++++++++++++++++++ src/server/mod.rs | 1 + 10 files changed, 561 insertions(+), 148 deletions(-) create mode 100644 src/bin/axum_docs/mod.rs create mode 100644 src/bin/cratedocs.rs create mode 100644 src/server/axum_docs/mod.rs create mode 100644 src/server/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 1d6c1f4..4e3aa28 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -41,6 +41,56 @@ dependencies = [ "libc", ] +[[package]] +name = "anstream" +version = "0.6.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" + +[[package]] +name = "anstyle-parse" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" +dependencies = [ + "windows-sys 0.59.0", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e" +dependencies = [ + "anstyle", + "once_cell", + "windows-sys 0.59.0", +] + [[package]] name = "anyhow" version = "1.0.97" @@ -211,6 +261,52 @@ dependencies = [ "windows-link", ] +[[package]] +name = "clap" +version = "4.5.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "027bb0d98429ae334a8698531da7077bdf906419543a35a55c2cb1b66437d767" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.5.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5589e0cba072e0f3d23791efac0fd8627b49c829c196a492e88168e6a669d863" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.5.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf4ced95c6f4a675af3da73304b9ac4ed991640c36374e4b46795c49e17cf1ed" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "clap_lex" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" + +[[package]] +name = "colorchoice" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" + [[package]] name = "convert_case" version = "0.6.0" @@ -242,6 +338,7 @@ version = "0.1.0" dependencies = [ "anyhow", "axum", + "clap", "futures", "mcp-core", "mcp-macros", @@ -502,6 +599,12 @@ version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + [[package]] name = "http" version = "0.2.12" @@ -820,6 +923,12 @@ version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" +[[package]] +name = "is_terminal_polyfill" +version = "1.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" + [[package]] name = "itoa" version = "1.0.15" @@ -1530,6 +1639,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + [[package]] name = "syn" version = "2.0.99" @@ -1878,6 +1993,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + [[package]] name = "valuable" version = "0.1.1" diff --git a/Cargo.toml b/Cargo.toml index d094bec..aa984e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,8 +37,14 @@ tracing-appender = "0.2" anyhow = "1.0" futures = "0.3" rand = "0.8" +clap = { version = "4.4", features = ["derive"] } -# For examples +# Main binary with subcommands +[[bin]] +name = "cratedocs" +path = "src/bin/cratedocs.rs" + +# Keep existing binaries for backward compatibility [[bin]] name = "doc-server" path = "src/bin/doc_server.rs" diff --git a/README.md b/README.md index bf75435..917939c 100644 --- a/README.md +++ b/README.md @@ -18,25 +18,39 @@ cargo build --release ## Running the Server -There are two ways to run the documentation server: +There are multiple ways to run the documentation server: -### STDIN/STDOUT Mode +### Using the Unified CLI -This mode is useful for integrating with LLM clients that communicate via standard input/output: +The unified command-line interface provides subcommands for all server modes: ```bash -cargo run --bin doc-server +# Run in STDIN/STDOUT mode +cargo run --bin cratedocs stdio + +# Run in HTTP/SSE mode (default address: 127.0.0.1:8080) +cargo run --bin cratedocs http + +# Run in HTTP/SSE mode with custom address +cargo run --bin cratedocs http --address 0.0.0.0:3000 + +# Enable debug logging +cargo run --bin cratedocs http --debug ``` -### HTTP/SSE Mode +### Legacy Commands -This mode exposes an HTTP endpoint that uses Server-Sent Events (SSE) for communication: +For backward compatibility, you can still use the original binaries: ```bash +# STDIN/STDOUT Mode +cargo run --bin doc-server + +# HTTP/SSE Mode cargo run --bin axum-docs ``` -By default, the server will listen on `http://127.0.0.1:8080/sse`. +By default, the HTTP server will listen on `http://127.0.0.1:8080/sse`. ## Available Tools diff --git a/docs/usage.md b/docs/usage.md index 4b79350..160da56 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -88,6 +88,27 @@ async function callTool(toolName, args) { callTool('search_crates', { query: 'async runtime', limit: 5 }); ``` +## Using the CLI + +The CrateDocs MCP server can be started using the unified CLI: + +```bash +# Show help +cargo run --bin cratedocs -- --help + +# Run in STDIN/STDOUT mode +cargo run --bin cratedocs stdio + +# Run in HTTP/SSE mode with default settings +cargo run --bin cratedocs http + +# Run HTTP server on custom address and port +cargo run --bin cratedocs http --address 0.0.0.0:3000 + +# Enable debug logging +cargo run --bin cratedocs http --debug +``` + ## Example Workflows ### Helping an LLM Understand a New Crate diff --git a/src/bin/axum_docs.rs b/src/bin/axum_docs.rs index 0faac2f..a0f78e6 100644 --- a/src/bin/axum_docs.rs +++ b/src/bin/axum_docs.rs @@ -1,146 +1,13 @@ -use axum::{ - body::Body, - extract::{Query, State}, - http::StatusCode, - response::sse::{Event, Sse}, - routing::get, - Router, -}; -use futures::{stream::Stream, StreamExt, TryStreamExt}; -use mcp_server::{ByteTransport, Server}; -use std::collections::HashMap; -use tokio_util::codec::FramedRead; -use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; - use anyhow::Result; -use mcp_server::router::RouterService; -use cratedocs_mcp::{jsonrpc_frame_codec::JsonRpcFrameCodec, DocRouter}; -use std::sync::Arc; -use tokio::{ - io::{self, AsyncWriteExt}, - sync::Mutex, -}; -use tracing_subscriber::{self}; - -type C2SWriter = Arc>>; -type SessionId = Arc; +use std::net::SocketAddr; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; +use cratedocs_mcp::server::axum_docs::App; const BIND_ADDRESS: &str = "127.0.0.1:8080"; -#[derive(Clone, Default)] -pub struct App { - txs: Arc>>, -} - -impl App { - pub fn new() -> Self { - Self { - txs: Default::default(), - } - } - pub fn router(&self) -> Router { - Router::new() - .route("/sse", get(sse_handler).post(post_event_handler)) - .with_state(self.clone()) - } -} - -fn session_id() -> SessionId { - let id = format!("{:016x}", rand::random::()); - Arc::from(id) -} - -#[derive(Debug, serde::Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct PostEventQuery { - pub session_id: String, -} - -async fn post_event_handler( - State(app): State, - Query(PostEventQuery { session_id }): Query, - body: Body, -) -> Result { - const BODY_BYTES_LIMIT: usize = 1 << 22; - let write_stream = { - let rg = app.txs.read().await; - rg.get(session_id.as_str()) - .ok_or(StatusCode::NOT_FOUND)? - .clone() - }; - let mut write_stream = write_stream.lock().await; - let mut body = body.into_data_stream(); - if let (_, Some(size)) = body.size_hint() { - if size > BODY_BYTES_LIMIT { - return Err(StatusCode::PAYLOAD_TOO_LARGE); - } - } - // calculate the body size - let mut size = 0; - while let Some(chunk) = body.next().await { - let Ok(chunk) = chunk else { - return Err(StatusCode::BAD_REQUEST); - }; - size += chunk.len(); - if size > BODY_BYTES_LIMIT { - return Err(StatusCode::PAYLOAD_TOO_LARGE); - } - write_stream - .write_all(&chunk) - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - } - write_stream - .write_u8(b'\n') - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - Ok(StatusCode::ACCEPTED) -} - -async fn sse_handler(State(app): State) -> Sse>> { - // it's 4KB - const BUFFER_SIZE: usize = 1 << 12; - let session = session_id(); - tracing::info!(%session, "sse connection"); - let (c2s_read, c2s_write) = tokio::io::simplex(BUFFER_SIZE); - let (s2c_read, s2c_write) = tokio::io::simplex(BUFFER_SIZE); - app.txs - .write() - .await - .insert(session.clone(), Arc::new(Mutex::new(c2s_write))); - { - let app_clone = app.clone(); - let session = session.clone(); - tokio::spawn(async move { - let router = RouterService(DocRouter::new()); - let server = Server::new(router); - let bytes_transport = ByteTransport::new(c2s_read, s2c_write); - let _result = server - .run(bytes_transport) - .await - .inspect_err(|e| tracing::error!(?e, "server run error")); - app_clone.txs.write().await.remove(&session); - }); - } - - let stream = futures::stream::once(futures::future::ok( - Event::default() - .event("endpoint") - .data(format!("?sessionId={session}")), - )) - .chain( - FramedRead::new(s2c_read, JsonRpcFrameCodec) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) - .and_then(move |bytes| match std::str::from_utf8(bytes.as_ref()) { - Ok(message) => futures::future::ok(Event::default().event("message").data(message)), - Err(e) => futures::future::err(io::Error::new(io::ErrorKind::InvalidData, e)), - }), - ); - Sse::new(stream) -} - #[tokio::main] -async fn main() -> io::Result<()> { +async fn main() -> Result<()> { + // Setup tracing tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() @@ -148,9 +15,17 @@ async fn main() -> io::Result<()> { ) .with(tracing_subscriber::fmt::layer()) .init(); - let listener = tokio::net::TcpListener::bind(BIND_ADDRESS).await?; + + // Parse socket address + let addr: SocketAddr = BIND_ADDRESS.parse()?; + let listener = tokio::net::TcpListener::bind(addr).await?; tracing::debug!("Rust Documentation Server listening on {}", listener.local_addr()?); tracing::info!("Access the Rust Documentation Server at http://{}/sse", BIND_ADDRESS); - axum::serve(listener, App::new().router()).await + + // Create app and run server + let app = App::new(); + axum::serve(listener, app.router()).await?; + + Ok(()) } \ No newline at end of file diff --git a/src/bin/axum_docs/mod.rs b/src/bin/axum_docs/mod.rs new file mode 100644 index 0000000..de8635f --- /dev/null +++ b/src/bin/axum_docs/mod.rs @@ -0,0 +1,136 @@ +use axum::{ + body::Body, + extract::{Query, State}, + http::StatusCode, + response::sse::{Event, Sse}, + routing::get, + Router, +}; +use futures::{stream::Stream, StreamExt, TryStreamExt}; +use mcp_server::{ByteTransport, Server}; +use std::collections::HashMap; +use tokio_util::codec::FramedRead; + +use anyhow::Result; +use mcp_server::router::RouterService; +use cratedocs_mcp::{jsonrpc_frame_codec::JsonRpcFrameCodec, DocRouter}; +use std::sync::Arc; +use tokio::{ + io::{self, AsyncWriteExt}, + sync::Mutex, +}; + +type C2SWriter = Arc>>; +type SessionId = Arc; + +#[derive(Clone, Default)] +pub struct App { + txs: Arc>>, +} + +impl App { + pub fn new() -> Self { + Self { + txs: Default::default(), + } + } + pub fn router(&self) -> Router { + Router::new() + .route("/sse", get(sse_handler).post(post_event_handler)) + .with_state(self.clone()) + } +} + +fn session_id() -> SessionId { + let id = format!("{:016x}", rand::random::()); + Arc::from(id) +} + +#[derive(Debug, serde::Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PostEventQuery { + pub session_id: String, +} + +async fn post_event_handler( + State(app): State, + Query(PostEventQuery { session_id }): Query, + body: Body, +) -> Result { + const BODY_BYTES_LIMIT: usize = 1 << 22; + let write_stream = { + let rg = app.txs.read().await; + rg.get(session_id.as_str()) + .ok_or(StatusCode::NOT_FOUND)? + .clone() + }; + let mut write_stream = write_stream.lock().await; + let mut body = body.into_data_stream(); + if let (_, Some(size)) = body.size_hint() { + if size > BODY_BYTES_LIMIT { + return Err(StatusCode::PAYLOAD_TOO_LARGE); + } + } + // calculate the body size + let mut size = 0; + while let Some(chunk) = body.next().await { + let Ok(chunk) = chunk else { + return Err(StatusCode::BAD_REQUEST); + }; + size += chunk.len(); + if size > BODY_BYTES_LIMIT { + return Err(StatusCode::PAYLOAD_TOO_LARGE); + } + write_stream + .write_all(&chunk) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + } + write_stream + .write_u8(b'\n') + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + Ok(StatusCode::ACCEPTED) +} + +async fn sse_handler(State(app): State) -> Sse>> { + // it's 4KB + const BUFFER_SIZE: usize = 1 << 12; + let session = session_id(); + tracing::info!(%session, "sse connection"); + let (c2s_read, c2s_write) = tokio::io::simplex(BUFFER_SIZE); + let (s2c_read, s2c_write) = tokio::io::simplex(BUFFER_SIZE); + app.txs + .write() + .await + .insert(session.clone(), Arc::new(Mutex::new(c2s_write))); + { + let app_clone = app.clone(); + let session = session.clone(); + tokio::spawn(async move { + let router = RouterService(DocRouter::new()); + let server = Server::new(router); + let bytes_transport = ByteTransport::new(c2s_read, s2c_write); + let _result = server + .run(bytes_transport) + .await + .inspect_err(|e| tracing::error!(?e, "server run error")); + app_clone.txs.write().await.remove(&session); + }); + } + + let stream = futures::stream::once(futures::future::ok( + Event::default() + .event("endpoint") + .data(format!("?sessionId={session}")), + )) + .chain( + FramedRead::new(s2c_read, JsonRpcFrameCodec) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) + .and_then(move |bytes| match std::str::from_utf8(bytes.as_ref()) { + Ok(message) => futures::future::ok(Event::default().event("message").data(message)), + Err(e) => futures::future::err(io::Error::new(io::ErrorKind::InvalidData, e)), + }), + ); + Sse::new(stream) +} \ No newline at end of file diff --git a/src/bin/cratedocs.rs b/src/bin/cratedocs.rs new file mode 100644 index 0000000..a866528 --- /dev/null +++ b/src/bin/cratedocs.rs @@ -0,0 +1,102 @@ +use anyhow::Result; +use clap::{Parser, Subcommand}; +use cratedocs_mcp::DocRouter; +use mcp_server::router::RouterService; +use mcp_server::{ByteTransport, Server}; +use std::net::SocketAddr; +use tokio::io::{stdin, stdout}; +use tracing_appender::rolling::{RollingFileAppender, Rotation}; +use tracing_subscriber::{self, EnvFilter, layer::SubscriberExt, util::SubscriberInitExt}; + +#[derive(Parser)] +#[command(author, version, about, long_about = None)] +#[command(propagate_version = true)] +struct Cli { + #[command(subcommand)] + command: Commands, +} + +#[derive(Subcommand)] +enum Commands { + /// Run the server in stdin/stdout mode + Stdio { + /// Enable debug logging + #[arg(short, long)] + debug: bool, + }, + /// Run the server with HTTP/SSE interface + Http { + /// Address to bind the HTTP server to + #[arg(short, long, default_value = "127.0.0.1:8080")] + address: String, + + /// Enable debug logging + #[arg(short, long)] + debug: bool, + }, +} + +#[tokio::main] +async fn main() -> Result<()> { + let cli = Cli::parse(); + + match cli.command { + Commands::Stdio { debug } => run_stdio_server(debug).await, + Commands::Http { address, debug } => run_http_server(address, debug).await, + } +} + +async fn run_stdio_server(debug: bool) -> Result<()> { + // Set up file appender for logging + let file_appender = RollingFileAppender::new(Rotation::DAILY, "logs", "doc-server.log"); + + // Initialize the tracing subscriber with file logging + let level = if debug { tracing::Level::DEBUG } else { tracing::Level::INFO }; + + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env().add_directive(level.into())) + .with_writer(file_appender) + .with_target(false) + .with_thread_ids(true) + .with_file(true) + .with_line_number(true) + .init(); + + tracing::info!("Starting MCP documentation server in STDIN/STDOUT mode"); + + // Create an instance of our documentation router + let router = RouterService(DocRouter::new()); + + // Create and run the server + let server = Server::new(router); + let transport = ByteTransport::new(stdin(), stdout()); + + tracing::info!("Documentation server initialized and ready to handle requests"); + Ok(server.run(transport).await?) +} + +async fn run_http_server(address: String, debug: bool) -> Result<()> { + // Setup tracing + let level = if debug { "debug" } else { "info" }; + + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| format!("{},{}", level, env!("CARGO_CRATE_NAME")).into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + // Parse socket address + let addr: SocketAddr = address.parse()?; + let listener = tokio::net::TcpListener::bind(addr).await?; + + tracing::debug!("Rust Documentation Server listening on {}", listener.local_addr()?); + tracing::info!("Access the Rust Documentation Server at http://{}/sse", addr); + + // Create app and run server + let app = cratedocs_mcp::server::axum_docs::App::new(); + axum::serve(listener, app.router()).await?; + + Ok(()) +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 63b6ad5..34ea461 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ pub mod docs; pub mod jsonrpc_frame_codec; +pub mod server; // Re-export key components for easier access pub use docs::DocRouter; \ No newline at end of file diff --git a/src/server/axum_docs/mod.rs b/src/server/axum_docs/mod.rs new file mode 100644 index 0000000..4401b1f --- /dev/null +++ b/src/server/axum_docs/mod.rs @@ -0,0 +1,136 @@ +use axum::{ + body::Body, + extract::{Query, State}, + http::StatusCode, + response::sse::{Event, Sse}, + routing::get, + Router, +}; +use futures::{stream::Stream, StreamExt, TryStreamExt}; +use mcp_server::{ByteTransport, Server}; +use std::collections::HashMap; +use tokio_util::codec::FramedRead; + +use anyhow::Result; +use mcp_server::router::RouterService; +use crate::{jsonrpc_frame_codec::JsonRpcFrameCodec, DocRouter}; +use std::sync::Arc; +use tokio::{ + io::{self, AsyncWriteExt}, + sync::Mutex, +}; + +type C2SWriter = Arc>>; +type SessionId = Arc; + +#[derive(Clone, Default)] +pub struct App { + txs: Arc>>, +} + +impl App { + pub fn new() -> Self { + Self { + txs: Default::default(), + } + } + pub fn router(&self) -> Router { + Router::new() + .route("/sse", get(sse_handler).post(post_event_handler)) + .with_state(self.clone()) + } +} + +fn session_id() -> SessionId { + let id = format!("{:016x}", rand::random::()); + Arc::from(id) +} + +#[derive(Debug, serde::Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PostEventQuery { + pub session_id: String, +} + +async fn post_event_handler( + State(app): State, + Query(PostEventQuery { session_id }): Query, + body: Body, +) -> Result { + const BODY_BYTES_LIMIT: usize = 1 << 22; + let write_stream = { + let rg = app.txs.read().await; + rg.get(session_id.as_str()) + .ok_or(StatusCode::NOT_FOUND)? + .clone() + }; + let mut write_stream = write_stream.lock().await; + let mut body = body.into_data_stream(); + if let (_, Some(size)) = body.size_hint() { + if size > BODY_BYTES_LIMIT { + return Err(StatusCode::PAYLOAD_TOO_LARGE); + } + } + // calculate the body size + let mut size = 0; + while let Some(chunk) = body.next().await { + let Ok(chunk) = chunk else { + return Err(StatusCode::BAD_REQUEST); + }; + size += chunk.len(); + if size > BODY_BYTES_LIMIT { + return Err(StatusCode::PAYLOAD_TOO_LARGE); + } + write_stream + .write_all(&chunk) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + } + write_stream + .write_u8(b'\n') + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + Ok(StatusCode::ACCEPTED) +} + +async fn sse_handler(State(app): State) -> Sse>> { + // it's 4KB + const BUFFER_SIZE: usize = 1 << 12; + let session = session_id(); + tracing::info!(%session, "sse connection"); + let (c2s_read, c2s_write) = tokio::io::simplex(BUFFER_SIZE); + let (s2c_read, s2c_write) = tokio::io::simplex(BUFFER_SIZE); + app.txs + .write() + .await + .insert(session.clone(), Arc::new(Mutex::new(c2s_write))); + { + let app_clone = app.clone(); + let session = session.clone(); + tokio::spawn(async move { + let router = RouterService(DocRouter::new()); + let server = Server::new(router); + let bytes_transport = ByteTransport::new(c2s_read, s2c_write); + let _result = server + .run(bytes_transport) + .await + .inspect_err(|e| tracing::error!(?e, "server run error")); + app_clone.txs.write().await.remove(&session); + }); + } + + let stream = futures::stream::once(futures::future::ok( + Event::default() + .event("endpoint") + .data(format!("?sessionId={session}")), + )) + .chain( + FramedRead::new(s2c_read, JsonRpcFrameCodec) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) + .and_then(move |bytes| match std::str::from_utf8(bytes.as_ref()) { + Ok(message) => futures::future::ok(Event::default().event("message").data(message)), + Err(e) => futures::future::err(io::Error::new(io::ErrorKind::InvalidData, e)), + }), + ); + Sse::new(stream) +} \ No newline at end of file diff --git a/src/server/mod.rs b/src/server/mod.rs new file mode 100644 index 0000000..1204c0e --- /dev/null +++ b/src/server/mod.rs @@ -0,0 +1 @@ +pub mod axum_docs; \ No newline at end of file