This commit is contained in:
Danielle Jenkins 2025-03-06 22:49:18 -08:00
commit dcf78edfca
12 changed files with 1158 additions and 0 deletions

156
src/bin/axum_docs.rs Normal file
View file

@ -0,0 +1,156 @@
use axum::{
body::Body,
extract::{Query, State},
http::StatusCode,
response::sse::{Event, Sse},
routing::get,
Router,
};
use futures::{stream::Stream, StreamExt, TryStreamExt};
use mcp_server::{ByteTransport, Server};
use std::collections::HashMap;
use tokio_util::codec::FramedRead;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use anyhow::Result;
use mcp_server::router::RouterService;
use rust_doc_server::{jsonrpc_frame_codec::JsonRpcFrameCodec, DocRouter};
use std::sync::Arc;
use tokio::{
io::{self, AsyncWriteExt},
sync::Mutex,
};
use tracing_subscriber::{self};
type C2SWriter = Arc<Mutex<io::WriteHalf<io::SimplexStream>>>;
type SessionId = Arc<str>;
const BIND_ADDRESS: &str = "127.0.0.1:8080";
#[derive(Clone, Default)]
pub struct App {
txs: Arc<tokio::sync::RwLock<HashMap<SessionId, C2SWriter>>>,
}
impl App {
pub fn new() -> Self {
Self {
txs: Default::default(),
}
}
pub fn router(&self) -> Router {
Router::new()
.route("/sse", get(sse_handler).post(post_event_handler))
.with_state(self.clone())
}
}
fn session_id() -> SessionId {
let id = format!("{:016x}", rand::random::<u128>());
Arc::from(id)
}
#[derive(Debug, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PostEventQuery {
pub session_id: String,
}
async fn post_event_handler(
State(app): State<App>,
Query(PostEventQuery { session_id }): Query<PostEventQuery>,
body: Body,
) -> Result<StatusCode, StatusCode> {
const BODY_BYTES_LIMIT: usize = 1 << 22;
let write_stream = {
let rg = app.txs.read().await;
rg.get(session_id.as_str())
.ok_or(StatusCode::NOT_FOUND)?
.clone()
};
let mut write_stream = write_stream.lock().await;
let mut body = body.into_data_stream();
if let (_, Some(size)) = body.size_hint() {
if size > BODY_BYTES_LIMIT {
return Err(StatusCode::PAYLOAD_TOO_LARGE);
}
}
// calculate the body size
let mut size = 0;
while let Some(chunk) = body.next().await {
let Ok(chunk) = chunk else {
return Err(StatusCode::BAD_REQUEST);
};
size += chunk.len();
if size > BODY_BYTES_LIMIT {
return Err(StatusCode::PAYLOAD_TOO_LARGE);
}
write_stream
.write_all(&chunk)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
}
write_stream
.write_u8(b'\n')
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(StatusCode::ACCEPTED)
}
async fn sse_handler(State(app): State<App>) -> Sse<impl Stream<Item = Result<Event, io::Error>>> {
// it's 4KB
const BUFFER_SIZE: usize = 1 << 12;
let session = session_id();
tracing::info!(%session, "sse connection");
let (c2s_read, c2s_write) = tokio::io::simplex(BUFFER_SIZE);
let (s2c_read, s2c_write) = tokio::io::simplex(BUFFER_SIZE);
app.txs
.write()
.await
.insert(session.clone(), Arc::new(Mutex::new(c2s_write)));
{
let app_clone = app.clone();
let session = session.clone();
tokio::spawn(async move {
let router = RouterService(DocRouter::new());
let server = Server::new(router);
let bytes_transport = ByteTransport::new(c2s_read, s2c_write);
let _result = server
.run(bytes_transport)
.await
.inspect_err(|e| tracing::error!(?e, "server run error"));
app_clone.txs.write().await.remove(&session);
});
}
let stream = futures::stream::once(futures::future::ok(
Event::default()
.event("endpoint")
.data(format!("?sessionId={session}")),
))
.chain(
FramedRead::new(s2c_read, JsonRpcFrameCodec)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
.and_then(move |bytes| match std::str::from_utf8(&bytes) {
Ok(message) => futures::future::ok(Event::default().event("message").data(message)),
Err(e) => futures::future::err(io::Error::new(io::ErrorKind::InvalidData, e)),
}),
);
Sse::new(stream)
}
#[tokio::main]
async fn main() -> io::Result<()> {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| format!("info,{}=debug", env!("CARGO_CRATE_NAME")).into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
let listener = tokio::net::TcpListener::bind(BIND_ADDRESS).await?;
tracing::debug!("Rust Documentation Server listening on {}", listener.local_addr()?);
tracing::info!("Access the Rust Documentation Server at http://{}/sse", BIND_ADDRESS);
axum::serve(listener, App::new().router()).await
}

35
src/bin/doc_server.rs Normal file
View file

@ -0,0 +1,35 @@
use anyhow::Result;
use mcp_server::router::RouterService;
use mcp_server::{ByteTransport, Server};
use rust_doc_server::DocRouter;
use tokio::io::{stdin, stdout};
use tracing_appender::rolling::{RollingFileAppender, Rotation};
use tracing_subscriber::{self, EnvFilter};
#[tokio::main]
async fn main() -> Result<()> {
// Set up file appender for logging
let file_appender = RollingFileAppender::new(Rotation::DAILY, "logs", "doc-server.log");
// Initialize the tracing subscriber with file and stdout logging
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env().add_directive(tracing::Level::INFO.into()))
.with_writer(file_appender)
.with_target(false)
.with_thread_ids(true)
.with_file(true)
.with_line_number(true)
.init();
tracing::info!("Starting MCP documentation server");
// Create an instance of our documentation router
let router = RouterService(DocRouter::new());
// Create and run the server
let server = Server::new(router);
let transport = ByteTransport::new(stdin(), stdout());
tracing::info!("Documentation server initialized and ready to handle requests");
Ok(server.run(transport).await?)
}

340
src/docs.rs Normal file
View file

@ -0,0 +1,340 @@
use std::{future::Future, pin::Pin, sync::Arc};
use mcp_core::{
handler::{PromptError, ResourceError},
prompt::Prompt,
protocol::ServerCapabilities,
Content, Resource, Tool, ToolError,
};
use mcp_server::router::CapabilitiesBuilder;
use reqwest::Client;
use serde_json::{json, Value};
use tokio::sync::Mutex;
// Cache for documentation lookups to avoid repeated requests
#[derive(Clone)]
struct DocCache {
cache: Arc<Mutex<std::collections::HashMap<String, String>>>,
}
impl DocCache {
fn new() -> Self {
Self {
cache: Arc::new(Mutex::new(std::collections::HashMap::new())),
}
}
async fn get(&self, key: &str) -> Option<String> {
let cache = self.cache.lock().await;
cache.get(key).cloned()
}
async fn set(&self, key: String, value: String) {
let mut cache = self.cache.lock().await;
cache.insert(key, value);
}
}
#[derive(Clone)]
pub struct DocRouter {
client: Client,
cache: DocCache,
}
impl DocRouter {
pub fn new() -> Self {
Self {
client: Client::new(),
cache: DocCache::new(),
}
}
// Fetch crate documentation from docs.rs
async fn lookup_crate(&self, crate_name: String, version: Option<String>) -> Result<String, ToolError> {
// Check cache first
let cache_key = if let Some(ver) = &version {
format!("{}:{}", crate_name, ver)
} else {
crate_name.clone()
};
if let Some(doc) = self.cache.get(&cache_key).await {
return Ok(doc);
}
// Construct the docs.rs URL for the crate
let url = if let Some(ver) = version {
format!("https://docs.rs/crate/{}/{}/", crate_name, ver)
} else {
format!("https://docs.rs/crate/{}/", crate_name)
};
// Fetch the documentation page
let response = self.client.get(&url).send().await.map_err(|e| {
ToolError::ExecutionError(format!("Failed to fetch documentation: {}", e))
})?;
if !response.status().is_success() {
return Err(ToolError::ExecutionError(format!(
"Failed to fetch documentation. Status: {}",
response.status()
)));
}
let body = response.text().await.map_err(|e| {
ToolError::ExecutionError(format!("Failed to read response body: {}", e))
})?;
// Cache the result
self.cache.set(cache_key, body.clone()).await;
Ok(body)
}
// Search crates.io for crates matching a query
async fn search_crates(&self, query: String, limit: Option<u32>) -> Result<String, ToolError> {
let limit = limit.unwrap_or(10).min(100); // Cap at 100 results
let url = format!("https://crates.io/api/v1/crates?q={}&per_page={}", query, limit);
let response = self.client.get(&url).send().await.map_err(|e| {
ToolError::ExecutionError(format!("Failed to search crates.io: {}", e))
})?;
if !response.status().is_success() {
return Err(ToolError::ExecutionError(format!(
"Failed to search crates.io. Status: {}",
response.status()
)));
}
let body = response.text().await.map_err(|e| {
ToolError::ExecutionError(format!("Failed to read response body: {}", e))
})?;
Ok(body)
}
// Get documentation for a specific item in a crate
async fn lookup_item(&self, crate_name: String, item_path: String, version: Option<String>) -> Result<String, ToolError> {
// Check cache first
let cache_key = if let Some(ver) = &version {
format!("{}:{}:{}", crate_name, ver, item_path)
} else {
format!("{}:{}", crate_name, item_path)
};
if let Some(doc) = self.cache.get(&cache_key).await {
return Ok(doc);
}
// Construct the docs.rs URL for the specific item
let url = if let Some(ver) = version {
format!("https://docs.rs/{}/{}/{}/", crate_name, ver, item_path.replace("::", "/"))
} else {
format!("https://docs.rs/{}/latest/{}/", crate_name, item_path.replace("::", "/"))
};
// Fetch the documentation page
let response = self.client.get(&url).send().await.map_err(|e| {
ToolError::ExecutionError(format!("Failed to fetch item documentation: {}", e))
})?;
if !response.status().is_success() {
return Err(ToolError::ExecutionError(format!(
"Failed to fetch item documentation. Status: {}",
response.status()
)));
}
let body = response.text().await.map_err(|e| {
ToolError::ExecutionError(format!("Failed to read response body: {}", e))
})?;
// Cache the result
self.cache.set(cache_key, body.clone()).await;
Ok(body)
}
}
impl mcp_server::Router for DocRouter {
fn name(&self) -> String {
"rust-docs".to_string()
}
fn instructions(&self) -> String {
"This server provides tools for looking up Rust crate documentation. \
You can search for crates, lookup documentation for specific crates or \
items within crates. Use these tools to find information about Rust libraries \
you are not familiar with.".to_string()
}
fn capabilities(&self) -> ServerCapabilities {
CapabilitiesBuilder::new()
.with_tools(true)
.with_resources(false, false)
.with_prompts(false)
.build()
}
fn list_tools(&self) -> Vec<Tool> {
vec![
Tool::new(
"lookup_crate".to_string(),
"Look up documentation for a Rust crate".to_string(),
json!({
"type": "object",
"properties": {
"crate_name": {
"type": "string",
"description": "The name of the crate to look up"
},
"version": {
"type": "string",
"description": "The version of the crate (optional, defaults to latest)"
}
},
"required": ["crate_name"]
}),
),
Tool::new(
"search_crates".to_string(),
"Search for Rust crates on crates.io".to_string(),
json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query"
},
"limit": {
"type": "integer",
"description": "Maximum number of results to return (optional, defaults to 10, max 100)"
}
},
"required": ["query"]
}),
),
Tool::new(
"lookup_item".to_string(),
"Look up documentation for a specific item in a Rust crate".to_string(),
json!({
"type": "object",
"properties": {
"crate_name": {
"type": "string",
"description": "The name of the crate"
},
"item_path": {
"type": "string",
"description": "Path to the item (e.g., 'std::vec::Vec')"
},
"version": {
"type": "string",
"description": "The version of the crate (optional, defaults to latest)"
}
},
"required": ["crate_name", "item_path"]
}),
),
]
}
fn call_tool(
&self,
tool_name: &str,
arguments: Value,
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
let this = self.clone();
let tool_name = tool_name.to_string();
let arguments = arguments.clone();
Box::pin(async move {
match tool_name.as_str() {
"lookup_crate" => {
let crate_name = arguments
.get("crate_name")
.and_then(|v| v.as_str())
.ok_or_else(|| ToolError::InvalidParameters("crate_name is required".to_string()))?
.to_string();
let version = arguments
.get("version")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let doc = this.lookup_crate(crate_name, version).await?;
Ok(vec![Content::text(doc)])
}
"search_crates" => {
let query = arguments
.get("query")
.and_then(|v| v.as_str())
.ok_or_else(|| ToolError::InvalidParameters("query is required".to_string()))?
.to_string();
let limit = arguments
.get("limit")
.and_then(|v| v.as_u64())
.map(|v| v as u32);
let results = this.search_crates(query, limit).await?;
Ok(vec![Content::text(results)])
}
"lookup_item" => {
let crate_name = arguments
.get("crate_name")
.and_then(|v| v.as_str())
.ok_or_else(|| ToolError::InvalidParameters("crate_name is required".to_string()))?
.to_string();
let item_path = arguments
.get("item_path")
.and_then(|v| v.as_str())
.ok_or_else(|| ToolError::InvalidParameters("item_path is required".to_string()))?
.to_string();
let version = arguments
.get("version")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let doc = this.lookup_item(crate_name, item_path, version).await?;
Ok(vec![Content::text(doc)])
}
_ => Err(ToolError::NotFound(format!("Tool {} not found", tool_name))),
}
})
}
fn list_resources(&self) -> Vec<Resource> {
vec![]
}
fn read_resource(
&self,
_uri: &str,
) -> Pin<Box<dyn Future<Output = Result<String, ResourceError>> + Send + 'static>> {
Box::pin(async move {
Err(ResourceError::NotFound("Resource not found".to_string()))
})
}
fn list_prompts(&self) -> Vec<Prompt> {
vec![]
}
fn get_prompt(
&self,
prompt_name: &str,
) -> Pin<Box<dyn Future<Output = Result<String, PromptError>> + Send + 'static>> {
let prompt_name = prompt_name.to_string();
Box::pin(async move {
Err(PromptError::NotFound(format!(
"Prompt {} not found",
prompt_name
)))
})
}
}

View file

@ -0,0 +1,24 @@
use tokio_util::codec::Decoder;
#[derive(Default)]
pub struct JsonRpcFrameCodec;
impl Decoder for JsonRpcFrameCodec {
type Item = tokio_util::bytes::Bytes;
type Error = tokio::io::Error;
fn decode(
&mut self,
src: &mut tokio_util::bytes::BytesMut,
) -> Result<Option<Self::Item>, Self::Error> {
if let Some(end) = src
.iter()
.enumerate()
.find_map(|(idx, &b)| (b == b'\n').then_some(idx))
{
let line = src.split_to(end);
let _char_next_line = src.split_to(1);
Ok(Some(line.freeze()))
} else {
Ok(None)
}
}
}

5
src/lib.rs Normal file
View file

@ -0,0 +1,5 @@
pub mod docs;
pub mod jsonrpc_frame_codec;
// Re-export key components for easier access
pub use docs::DocRouter;