Combine into one single binary
This commit is contained in:
parent
06514ed935
commit
37e50029cb
10 changed files with 561 additions and 148 deletions
|
@ -1,146 +1,13 @@
|
|||
use axum::{
|
||||
body::Body,
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
response::sse::{Event, Sse},
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
use futures::{stream::Stream, StreamExt, TryStreamExt};
|
||||
use mcp_server::{ByteTransport, Server};
|
||||
use std::collections::HashMap;
|
||||
use tokio_util::codec::FramedRead;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
use anyhow::Result;
|
||||
use mcp_server::router::RouterService;
|
||||
use cratedocs_mcp::{jsonrpc_frame_codec::JsonRpcFrameCodec, DocRouter};
|
||||
use std::sync::Arc;
|
||||
use tokio::{
|
||||
io::{self, AsyncWriteExt},
|
||||
sync::Mutex,
|
||||
};
|
||||
use tracing_subscriber::{self};
|
||||
|
||||
type C2SWriter = Arc<Mutex<io::WriteHalf<io::SimplexStream>>>;
|
||||
type SessionId = Arc<str>;
|
||||
use std::net::SocketAddr;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
use cratedocs_mcp::server::axum_docs::App;
|
||||
|
||||
const BIND_ADDRESS: &str = "127.0.0.1:8080";
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct App {
|
||||
txs: Arc<tokio::sync::RwLock<HashMap<SessionId, C2SWriter>>>,
|
||||
}
|
||||
|
||||
impl App {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
txs: Default::default(),
|
||||
}
|
||||
}
|
||||
pub fn router(&self) -> Router {
|
||||
Router::new()
|
||||
.route("/sse", get(sse_handler).post(post_event_handler))
|
||||
.with_state(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
fn session_id() -> SessionId {
|
||||
let id = format!("{:016x}", rand::random::<u128>());
|
||||
Arc::from(id)
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PostEventQuery {
|
||||
pub session_id: String,
|
||||
}
|
||||
|
||||
async fn post_event_handler(
|
||||
State(app): State<App>,
|
||||
Query(PostEventQuery { session_id }): Query<PostEventQuery>,
|
||||
body: Body,
|
||||
) -> Result<StatusCode, StatusCode> {
|
||||
const BODY_BYTES_LIMIT: usize = 1 << 22;
|
||||
let write_stream = {
|
||||
let rg = app.txs.read().await;
|
||||
rg.get(session_id.as_str())
|
||||
.ok_or(StatusCode::NOT_FOUND)?
|
||||
.clone()
|
||||
};
|
||||
let mut write_stream = write_stream.lock().await;
|
||||
let mut body = body.into_data_stream();
|
||||
if let (_, Some(size)) = body.size_hint() {
|
||||
if size > BODY_BYTES_LIMIT {
|
||||
return Err(StatusCode::PAYLOAD_TOO_LARGE);
|
||||
}
|
||||
}
|
||||
// calculate the body size
|
||||
let mut size = 0;
|
||||
while let Some(chunk) = body.next().await {
|
||||
let Ok(chunk) = chunk else {
|
||||
return Err(StatusCode::BAD_REQUEST);
|
||||
};
|
||||
size += chunk.len();
|
||||
if size > BODY_BYTES_LIMIT {
|
||||
return Err(StatusCode::PAYLOAD_TOO_LARGE);
|
||||
}
|
||||
write_stream
|
||||
.write_all(&chunk)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
}
|
||||
write_stream
|
||||
.write_u8(b'\n')
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
Ok(StatusCode::ACCEPTED)
|
||||
}
|
||||
|
||||
async fn sse_handler(State(app): State<App>) -> Sse<impl Stream<Item = Result<Event, io::Error>>> {
|
||||
// it's 4KB
|
||||
const BUFFER_SIZE: usize = 1 << 12;
|
||||
let session = session_id();
|
||||
tracing::info!(%session, "sse connection");
|
||||
let (c2s_read, c2s_write) = tokio::io::simplex(BUFFER_SIZE);
|
||||
let (s2c_read, s2c_write) = tokio::io::simplex(BUFFER_SIZE);
|
||||
app.txs
|
||||
.write()
|
||||
.await
|
||||
.insert(session.clone(), Arc::new(Mutex::new(c2s_write)));
|
||||
{
|
||||
let app_clone = app.clone();
|
||||
let session = session.clone();
|
||||
tokio::spawn(async move {
|
||||
let router = RouterService(DocRouter::new());
|
||||
let server = Server::new(router);
|
||||
let bytes_transport = ByteTransport::new(c2s_read, s2c_write);
|
||||
let _result = server
|
||||
.run(bytes_transport)
|
||||
.await
|
||||
.inspect_err(|e| tracing::error!(?e, "server run error"));
|
||||
app_clone.txs.write().await.remove(&session);
|
||||
});
|
||||
}
|
||||
|
||||
let stream = futures::stream::once(futures::future::ok(
|
||||
Event::default()
|
||||
.event("endpoint")
|
||||
.data(format!("?sessionId={session}")),
|
||||
))
|
||||
.chain(
|
||||
FramedRead::new(s2c_read, JsonRpcFrameCodec)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
|
||||
.and_then(move |bytes| match std::str::from_utf8(bytes.as_ref()) {
|
||||
Ok(message) => futures::future::ok(Event::default().event("message").data(message)),
|
||||
Err(e) => futures::future::err(io::Error::new(io::ErrorKind::InvalidData, e)),
|
||||
}),
|
||||
);
|
||||
Sse::new(stream)
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> io::Result<()> {
|
||||
async fn main() -> Result<()> {
|
||||
// Setup tracing
|
||||
tracing_subscriber::registry()
|
||||
.with(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
|
@ -148,9 +15,17 @@ async fn main() -> io::Result<()> {
|
|||
)
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
let listener = tokio::net::TcpListener::bind(BIND_ADDRESS).await?;
|
||||
|
||||
// Parse socket address
|
||||
let addr: SocketAddr = BIND_ADDRESS.parse()?;
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
|
||||
tracing::debug!("Rust Documentation Server listening on {}", listener.local_addr()?);
|
||||
tracing::info!("Access the Rust Documentation Server at http://{}/sse", BIND_ADDRESS);
|
||||
axum::serve(listener, App::new().router()).await
|
||||
|
||||
// Create app and run server
|
||||
let app = App::new();
|
||||
axum::serve(listener, app.router()).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
136
src/bin/axum_docs/mod.rs
Normal file
136
src/bin/axum_docs/mod.rs
Normal file
|
@ -0,0 +1,136 @@
|
|||
use axum::{
|
||||
body::Body,
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
response::sse::{Event, Sse},
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
use futures::{stream::Stream, StreamExt, TryStreamExt};
|
||||
use mcp_server::{ByteTransport, Server};
|
||||
use std::collections::HashMap;
|
||||
use tokio_util::codec::FramedRead;
|
||||
|
||||
use anyhow::Result;
|
||||
use mcp_server::router::RouterService;
|
||||
use cratedocs_mcp::{jsonrpc_frame_codec::JsonRpcFrameCodec, DocRouter};
|
||||
use std::sync::Arc;
|
||||
use tokio::{
|
||||
io::{self, AsyncWriteExt},
|
||||
sync::Mutex,
|
||||
};
|
||||
|
||||
type C2SWriter = Arc<Mutex<io::WriteHalf<io::SimplexStream>>>;
|
||||
type SessionId = Arc<str>;
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct App {
|
||||
txs: Arc<tokio::sync::RwLock<HashMap<SessionId, C2SWriter>>>,
|
||||
}
|
||||
|
||||
impl App {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
txs: Default::default(),
|
||||
}
|
||||
}
|
||||
pub fn router(&self) -> Router {
|
||||
Router::new()
|
||||
.route("/sse", get(sse_handler).post(post_event_handler))
|
||||
.with_state(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
fn session_id() -> SessionId {
|
||||
let id = format!("{:016x}", rand::random::<u128>());
|
||||
Arc::from(id)
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PostEventQuery {
|
||||
pub session_id: String,
|
||||
}
|
||||
|
||||
async fn post_event_handler(
|
||||
State(app): State<App>,
|
||||
Query(PostEventQuery { session_id }): Query<PostEventQuery>,
|
||||
body: Body,
|
||||
) -> Result<StatusCode, StatusCode> {
|
||||
const BODY_BYTES_LIMIT: usize = 1 << 22;
|
||||
let write_stream = {
|
||||
let rg = app.txs.read().await;
|
||||
rg.get(session_id.as_str())
|
||||
.ok_or(StatusCode::NOT_FOUND)?
|
||||
.clone()
|
||||
};
|
||||
let mut write_stream = write_stream.lock().await;
|
||||
let mut body = body.into_data_stream();
|
||||
if let (_, Some(size)) = body.size_hint() {
|
||||
if size > BODY_BYTES_LIMIT {
|
||||
return Err(StatusCode::PAYLOAD_TOO_LARGE);
|
||||
}
|
||||
}
|
||||
// calculate the body size
|
||||
let mut size = 0;
|
||||
while let Some(chunk) = body.next().await {
|
||||
let Ok(chunk) = chunk else {
|
||||
return Err(StatusCode::BAD_REQUEST);
|
||||
};
|
||||
size += chunk.len();
|
||||
if size > BODY_BYTES_LIMIT {
|
||||
return Err(StatusCode::PAYLOAD_TOO_LARGE);
|
||||
}
|
||||
write_stream
|
||||
.write_all(&chunk)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
}
|
||||
write_stream
|
||||
.write_u8(b'\n')
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
Ok(StatusCode::ACCEPTED)
|
||||
}
|
||||
|
||||
async fn sse_handler(State(app): State<App>) -> Sse<impl Stream<Item = Result<Event, io::Error>>> {
|
||||
// it's 4KB
|
||||
const BUFFER_SIZE: usize = 1 << 12;
|
||||
let session = session_id();
|
||||
tracing::info!(%session, "sse connection");
|
||||
let (c2s_read, c2s_write) = tokio::io::simplex(BUFFER_SIZE);
|
||||
let (s2c_read, s2c_write) = tokio::io::simplex(BUFFER_SIZE);
|
||||
app.txs
|
||||
.write()
|
||||
.await
|
||||
.insert(session.clone(), Arc::new(Mutex::new(c2s_write)));
|
||||
{
|
||||
let app_clone = app.clone();
|
||||
let session = session.clone();
|
||||
tokio::spawn(async move {
|
||||
let router = RouterService(DocRouter::new());
|
||||
let server = Server::new(router);
|
||||
let bytes_transport = ByteTransport::new(c2s_read, s2c_write);
|
||||
let _result = server
|
||||
.run(bytes_transport)
|
||||
.await
|
||||
.inspect_err(|e| tracing::error!(?e, "server run error"));
|
||||
app_clone.txs.write().await.remove(&session);
|
||||
});
|
||||
}
|
||||
|
||||
let stream = futures::stream::once(futures::future::ok(
|
||||
Event::default()
|
||||
.event("endpoint")
|
||||
.data(format!("?sessionId={session}")),
|
||||
))
|
||||
.chain(
|
||||
FramedRead::new(s2c_read, JsonRpcFrameCodec)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
|
||||
.and_then(move |bytes| match std::str::from_utf8(bytes.as_ref()) {
|
||||
Ok(message) => futures::future::ok(Event::default().event("message").data(message)),
|
||||
Err(e) => futures::future::err(io::Error::new(io::ErrorKind::InvalidData, e)),
|
||||
}),
|
||||
);
|
||||
Sse::new(stream)
|
||||
}
|
102
src/bin/cratedocs.rs
Normal file
102
src/bin/cratedocs.rs
Normal file
|
@ -0,0 +1,102 @@
|
|||
use anyhow::Result;
|
||||
use clap::{Parser, Subcommand};
|
||||
use cratedocs_mcp::DocRouter;
|
||||
use mcp_server::router::RouterService;
|
||||
use mcp_server::{ByteTransport, Server};
|
||||
use std::net::SocketAddr;
|
||||
use tokio::io::{stdin, stdout};
|
||||
use tracing_appender::rolling::{RollingFileAppender, Rotation};
|
||||
use tracing_subscriber::{self, EnvFilter, layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
#[command(propagate_version = true)]
|
||||
struct Cli {
|
||||
#[command(subcommand)]
|
||||
command: Commands,
|
||||
}
|
||||
|
||||
#[derive(Subcommand)]
|
||||
enum Commands {
|
||||
/// Run the server in stdin/stdout mode
|
||||
Stdio {
|
||||
/// Enable debug logging
|
||||
#[arg(short, long)]
|
||||
debug: bool,
|
||||
},
|
||||
/// Run the server with HTTP/SSE interface
|
||||
Http {
|
||||
/// Address to bind the HTTP server to
|
||||
#[arg(short, long, default_value = "127.0.0.1:8080")]
|
||||
address: String,
|
||||
|
||||
/// Enable debug logging
|
||||
#[arg(short, long)]
|
||||
debug: bool,
|
||||
},
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
let cli = Cli::parse();
|
||||
|
||||
match cli.command {
|
||||
Commands::Stdio { debug } => run_stdio_server(debug).await,
|
||||
Commands::Http { address, debug } => run_http_server(address, debug).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_stdio_server(debug: bool) -> Result<()> {
|
||||
// Set up file appender for logging
|
||||
let file_appender = RollingFileAppender::new(Rotation::DAILY, "logs", "doc-server.log");
|
||||
|
||||
// Initialize the tracing subscriber with file logging
|
||||
let level = if debug { tracing::Level::DEBUG } else { tracing::Level::INFO };
|
||||
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env().add_directive(level.into()))
|
||||
.with_writer(file_appender)
|
||||
.with_target(false)
|
||||
.with_thread_ids(true)
|
||||
.with_file(true)
|
||||
.with_line_number(true)
|
||||
.init();
|
||||
|
||||
tracing::info!("Starting MCP documentation server in STDIN/STDOUT mode");
|
||||
|
||||
// Create an instance of our documentation router
|
||||
let router = RouterService(DocRouter::new());
|
||||
|
||||
// Create and run the server
|
||||
let server = Server::new(router);
|
||||
let transport = ByteTransport::new(stdin(), stdout());
|
||||
|
||||
tracing::info!("Documentation server initialized and ready to handle requests");
|
||||
Ok(server.run(transport).await?)
|
||||
}
|
||||
|
||||
async fn run_http_server(address: String, debug: bool) -> Result<()> {
|
||||
// Setup tracing
|
||||
let level = if debug { "debug" } else { "info" };
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| format!("{},{}", level, env!("CARGO_CRATE_NAME")).into()),
|
||||
)
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
|
||||
// Parse socket address
|
||||
let addr: SocketAddr = address.parse()?;
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
|
||||
tracing::debug!("Rust Documentation Server listening on {}", listener.local_addr()?);
|
||||
tracing::info!("Access the Rust Documentation Server at http://{}/sse", addr);
|
||||
|
||||
// Create app and run server
|
||||
let app = cratedocs_mcp::server::axum_docs::App::new();
|
||||
axum::serve(listener, app.router()).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
pub mod docs;
|
||||
pub mod jsonrpc_frame_codec;
|
||||
pub mod server;
|
||||
|
||||
// Re-export key components for easier access
|
||||
pub use docs::DocRouter;
|
136
src/server/axum_docs/mod.rs
Normal file
136
src/server/axum_docs/mod.rs
Normal file
|
@ -0,0 +1,136 @@
|
|||
use axum::{
|
||||
body::Body,
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
response::sse::{Event, Sse},
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
use futures::{stream::Stream, StreamExt, TryStreamExt};
|
||||
use mcp_server::{ByteTransport, Server};
|
||||
use std::collections::HashMap;
|
||||
use tokio_util::codec::FramedRead;
|
||||
|
||||
use anyhow::Result;
|
||||
use mcp_server::router::RouterService;
|
||||
use crate::{jsonrpc_frame_codec::JsonRpcFrameCodec, DocRouter};
|
||||
use std::sync::Arc;
|
||||
use tokio::{
|
||||
io::{self, AsyncWriteExt},
|
||||
sync::Mutex,
|
||||
};
|
||||
|
||||
type C2SWriter = Arc<Mutex<io::WriteHalf<io::SimplexStream>>>;
|
||||
type SessionId = Arc<str>;
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct App {
|
||||
txs: Arc<tokio::sync::RwLock<HashMap<SessionId, C2SWriter>>>,
|
||||
}
|
||||
|
||||
impl App {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
txs: Default::default(),
|
||||
}
|
||||
}
|
||||
pub fn router(&self) -> Router {
|
||||
Router::new()
|
||||
.route("/sse", get(sse_handler).post(post_event_handler))
|
||||
.with_state(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
fn session_id() -> SessionId {
|
||||
let id = format!("{:016x}", rand::random::<u128>());
|
||||
Arc::from(id)
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PostEventQuery {
|
||||
pub session_id: String,
|
||||
}
|
||||
|
||||
async fn post_event_handler(
|
||||
State(app): State<App>,
|
||||
Query(PostEventQuery { session_id }): Query<PostEventQuery>,
|
||||
body: Body,
|
||||
) -> Result<StatusCode, StatusCode> {
|
||||
const BODY_BYTES_LIMIT: usize = 1 << 22;
|
||||
let write_stream = {
|
||||
let rg = app.txs.read().await;
|
||||
rg.get(session_id.as_str())
|
||||
.ok_or(StatusCode::NOT_FOUND)?
|
||||
.clone()
|
||||
};
|
||||
let mut write_stream = write_stream.lock().await;
|
||||
let mut body = body.into_data_stream();
|
||||
if let (_, Some(size)) = body.size_hint() {
|
||||
if size > BODY_BYTES_LIMIT {
|
||||
return Err(StatusCode::PAYLOAD_TOO_LARGE);
|
||||
}
|
||||
}
|
||||
// calculate the body size
|
||||
let mut size = 0;
|
||||
while let Some(chunk) = body.next().await {
|
||||
let Ok(chunk) = chunk else {
|
||||
return Err(StatusCode::BAD_REQUEST);
|
||||
};
|
||||
size += chunk.len();
|
||||
if size > BODY_BYTES_LIMIT {
|
||||
return Err(StatusCode::PAYLOAD_TOO_LARGE);
|
||||
}
|
||||
write_stream
|
||||
.write_all(&chunk)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
}
|
||||
write_stream
|
||||
.write_u8(b'\n')
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
Ok(StatusCode::ACCEPTED)
|
||||
}
|
||||
|
||||
async fn sse_handler(State(app): State<App>) -> Sse<impl Stream<Item = Result<Event, io::Error>>> {
|
||||
// it's 4KB
|
||||
const BUFFER_SIZE: usize = 1 << 12;
|
||||
let session = session_id();
|
||||
tracing::info!(%session, "sse connection");
|
||||
let (c2s_read, c2s_write) = tokio::io::simplex(BUFFER_SIZE);
|
||||
let (s2c_read, s2c_write) = tokio::io::simplex(BUFFER_SIZE);
|
||||
app.txs
|
||||
.write()
|
||||
.await
|
||||
.insert(session.clone(), Arc::new(Mutex::new(c2s_write)));
|
||||
{
|
||||
let app_clone = app.clone();
|
||||
let session = session.clone();
|
||||
tokio::spawn(async move {
|
||||
let router = RouterService(DocRouter::new());
|
||||
let server = Server::new(router);
|
||||
let bytes_transport = ByteTransport::new(c2s_read, s2c_write);
|
||||
let _result = server
|
||||
.run(bytes_transport)
|
||||
.await
|
||||
.inspect_err(|e| tracing::error!(?e, "server run error"));
|
||||
app_clone.txs.write().await.remove(&session);
|
||||
});
|
||||
}
|
||||
|
||||
let stream = futures::stream::once(futures::future::ok(
|
||||
Event::default()
|
||||
.event("endpoint")
|
||||
.data(format!("?sessionId={session}")),
|
||||
))
|
||||
.chain(
|
||||
FramedRead::new(s2c_read, JsonRpcFrameCodec)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
|
||||
.and_then(move |bytes| match std::str::from_utf8(bytes.as_ref()) {
|
||||
Ok(message) => futures::future::ok(Event::default().event("message").data(message)),
|
||||
Err(e) => futures::future::err(io::Error::new(io::ErrorKind::InvalidData, e)),
|
||||
}),
|
||||
);
|
||||
Sse::new(stream)
|
||||
}
|
1
src/server/mod.rs
Normal file
1
src/server/mod.rs
Normal file
|
@ -0,0 +1 @@
|
|||
pub mod axum_docs;
|
Loading…
Add table
Add a link
Reference in a new issue