From 5fe28f7fe73c43f6e78e3207479aad9a99cc20dc Mon Sep 17 00:00:00 2001 From: Danielle Jenkins Date: Wed, 12 Mar 2025 15:12:37 -0700 Subject: [PATCH] Refactor --- Cargo.toml | 2 +- src/bin/axum_docs/mod.rs | 136 --------- src/bin/{axum_docs.rs => sse_server.rs} | 0 src/docs.rs | 349 ------------------------ src/docs/tests.rs | 152 ----------- src/jsonrpc_frame_codec.rs | 27 -- src/jsonrpc_frame_codec/tests.rs | 72 ----- 7 files changed, 1 insertion(+), 737 deletions(-) delete mode 100644 src/bin/axum_docs/mod.rs rename src/bin/{axum_docs.rs => sse_server.rs} (100%) delete mode 100644 src/docs.rs delete mode 100644 src/docs/tests.rs delete mode 100644 src/jsonrpc_frame_codec.rs delete mode 100644 src/jsonrpc_frame_codec/tests.rs diff --git a/Cargo.toml b/Cargo.toml index 9944d5e..ef7d38c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,4 +53,4 @@ path = "src/bin/stdio_server.rs" [[bin]] name = "http-sse-server" -path = "src/bin/axum_docs.rs" +path = "src/bin/sse_server.rs" diff --git a/src/bin/axum_docs/mod.rs b/src/bin/axum_docs/mod.rs deleted file mode 100644 index 73d637b..0000000 --- a/src/bin/axum_docs/mod.rs +++ /dev/null @@ -1,136 +0,0 @@ -use axum::{ - body::Body, - extract::{Query, State}, - http::StatusCode, - response::sse::{Event, Sse}, - routing::get, - Router, -}; -use futures::{stream::Stream, StreamExt, TryStreamExt}; -use mcp_server::{ByteTransport, Server}; -use std::collections::HashMap; -use tokio_util::codec::FramedRead; - -use anyhow::Result; -use mcp_server::router::RouterService; -use cratedocs_mcp::{transport::jsonrpc_frame_codec::JsonRpcFrameCodec, tools::DocRouter}; -use std::sync::Arc; -use tokio::{ - io::{self, AsyncWriteExt}, - sync::Mutex, -}; - -type C2SWriter = Arc>>; -type SessionId = Arc; - -#[derive(Clone, Default)] -pub struct App { - txs: Arc>>, -} - -impl App { - pub fn new() -> Self { - Self { - txs: Default::default(), - } - } - pub fn router(&self) -> Router { - Router::new() - .route("/sse", get(sse_handler).post(post_event_handler)) - .with_state(self.clone()) - } -} - -fn session_id() -> SessionId { - let id = format!("{:016x}", rand::random::()); - Arc::from(id) -} - -#[derive(Debug, serde::Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct PostEventQuery { - pub session_id: String, -} - -async fn post_event_handler( - State(app): State, - Query(PostEventQuery { session_id }): Query, - body: Body, -) -> Result { - const BODY_BYTES_LIMIT: usize = 1 << 22; - let write_stream = { - let rg = app.txs.read().await; - rg.get(session_id.as_str()) - .ok_or(StatusCode::NOT_FOUND)? - .clone() - }; - let mut write_stream = write_stream.lock().await; - let mut body = body.into_data_stream(); - if let (_, Some(size)) = body.size_hint() { - if size > BODY_BYTES_LIMIT { - return Err(StatusCode::PAYLOAD_TOO_LARGE); - } - } - // calculate the body size - let mut size = 0; - while let Some(chunk) = body.next().await { - let Ok(chunk) = chunk else { - return Err(StatusCode::BAD_REQUEST); - }; - size += chunk.len(); - if size > BODY_BYTES_LIMIT { - return Err(StatusCode::PAYLOAD_TOO_LARGE); - } - write_stream - .write_all(&chunk) - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - } - write_stream - .write_u8(b'\n') - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - Ok(StatusCode::ACCEPTED) -} - -async fn sse_handler(State(app): State) -> Sse>> { - // it's 4KB - const BUFFER_SIZE: usize = 1 << 12; - let session = session_id(); - tracing::info!(%session, "sse connection"); - let (c2s_read, c2s_write) = tokio::io::simplex(BUFFER_SIZE); - let (s2c_read, s2c_write) = tokio::io::simplex(BUFFER_SIZE); - app.txs - .write() - .await - .insert(session.clone(), Arc::new(Mutex::new(c2s_write))); - { - let app_clone = app.clone(); - let session = session.clone(); - tokio::spawn(async move { - let router = RouterService(DocRouter::new()); - let server = Server::new(router); - let bytes_transport = ByteTransport::new(c2s_read, s2c_write); - let _result = server - .run(bytes_transport) - .await - .inspect_err(|e| tracing::error!(?e, "server run error")); - app_clone.txs.write().await.remove(&session); - }); - } - - let stream = futures::stream::once(futures::future::ok( - Event::default() - .event("endpoint") - .data(format!("?sessionId={session}")), - )) - .chain( - FramedRead::new(s2c_read, JsonRpcFrameCodec) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) - .and_then(move |bytes| match std::str::from_utf8(bytes.as_ref()) { - Ok(message) => futures::future::ok(Event::default().event("message").data(message)), - Err(e) => futures::future::err(io::Error::new(io::ErrorKind::InvalidData, e)), - }), - ); - Sse::new(stream) -} \ No newline at end of file diff --git a/src/bin/axum_docs.rs b/src/bin/sse_server.rs similarity index 100% rename from src/bin/axum_docs.rs rename to src/bin/sse_server.rs diff --git a/src/docs.rs b/src/docs.rs deleted file mode 100644 index 4c183bd..0000000 --- a/src/docs.rs +++ /dev/null @@ -1,349 +0,0 @@ -use std::{future::Future, pin::Pin, sync::Arc}; - -use mcp_core::{ - handler::{PromptError, ResourceError}, - prompt::Prompt, - protocol::ServerCapabilities, - Content, Resource, Tool, ToolError, -}; -use mcp_server::router::CapabilitiesBuilder; -use reqwest::Client; -use serde_json::{json, Value}; -use tokio::sync::Mutex; - -#[cfg(test)] -mod tests; - -// Cache for documentation lookups to avoid repeated requests -#[derive(Clone)] -struct DocCache { - cache: Arc>>, -} - -impl DocCache { - fn new() -> Self { - Self { - cache: Arc::new(Mutex::new(std::collections::HashMap::new())), - } - } - - async fn get(&self, key: &str) -> Option { - let cache = self.cache.lock().await; - cache.get(key).cloned() - } - - async fn set(&self, key: String, value: String) { - let mut cache = self.cache.lock().await; - cache.insert(key, value); - } -} - -#[derive(Clone)] -pub struct DocRouter { - client: Client, - cache: DocCache, -} - -impl Default for DocRouter { - fn default() -> Self { - Self::new() - } -} - -impl DocRouter { - pub fn new() -> Self { - Self { - client: Client::new(), - cache: DocCache::new(), - } - } - - // Fetch crate documentation from docs.rs - async fn lookup_crate(&self, crate_name: String, version: Option) -> Result { - // Check cache first - let cache_key = if let Some(ver) = &version { - format!("{}:{}", crate_name, ver) - } else { - crate_name.clone() - }; - - if let Some(doc) = self.cache.get(&cache_key).await { - return Ok(doc); - } - - // Construct the docs.rs URL for the crate - let url = if let Some(ver) = version { - format!("https://docs.rs/crate/{}/{}/", crate_name, ver) - } else { - format!("https://docs.rs/crate/{}/", crate_name) - }; - - // Fetch the documentation page - let response = self.client.get(&url).send().await.map_err(|e| { - ToolError::ExecutionError(format!("Failed to fetch documentation: {}", e)) - })?; - - if !response.status().is_success() { - return Err(ToolError::ExecutionError(format!( - "Failed to fetch documentation. Status: {}", - response.status() - ))); - } - - let body = response.text().await.map_err(|e| { - ToolError::ExecutionError(format!("Failed to read response body: {}", e)) - })?; - - // Cache the result - self.cache.set(cache_key, body.clone()).await; - - Ok(body) - } - - // Search crates.io for crates matching a query - async fn search_crates(&self, query: String, limit: Option) -> Result { - let limit = limit.unwrap_or(10).min(100); // Cap at 100 results - - let url = format!("https://crates.io/api/v1/crates?q={}&per_page={}", query, limit); - - let response = self.client.get(&url).send().await.map_err(|e| { - ToolError::ExecutionError(format!("Failed to search crates.io: {}", e)) - })?; - - if !response.status().is_success() { - return Err(ToolError::ExecutionError(format!( - "Failed to search crates.io. Status: {}", - response.status() - ))); - } - - let body = response.text().await.map_err(|e| { - ToolError::ExecutionError(format!("Failed to read response body: {}", e)) - })?; - - Ok(body) - } - - // Get documentation for a specific item in a crate - async fn lookup_item(&self, crate_name: String, item_path: String, version: Option) -> Result { - // Check cache first - let cache_key = if let Some(ver) = &version { - format!("{}:{}:{}", crate_name, ver, item_path) - } else { - format!("{}:{}", crate_name, item_path) - }; - - if let Some(doc) = self.cache.get(&cache_key).await { - return Ok(doc); - } - - // Construct the docs.rs URL for the specific item - let url = if let Some(ver) = version { - format!("https://docs.rs/{}/{}/{}/", crate_name, ver, item_path.replace("::", "/")) - } else { - format!("https://docs.rs/{}/latest/{}/", crate_name, item_path.replace("::", "/")) - }; - - // Fetch the documentation page - let response = self.client.get(&url).send().await.map_err(|e| { - ToolError::ExecutionError(format!("Failed to fetch item documentation: {}", e)) - })?; - - if !response.status().is_success() { - return Err(ToolError::ExecutionError(format!( - "Failed to fetch item documentation. Status: {}", - response.status() - ))); - } - - let body = response.text().await.map_err(|e| { - ToolError::ExecutionError(format!("Failed to read response body: {}", e)) - })?; - - // Cache the result - self.cache.set(cache_key, body.clone()).await; - - Ok(body) - } -} - -impl mcp_server::Router for DocRouter { - fn name(&self) -> String { - "rust-docs".to_string() - } - - fn instructions(&self) -> String { - "This server provides tools for looking up Rust crate documentation. \ - You can search for crates, lookup documentation for specific crates or \ - items within crates. Use these tools to find information about Rust libraries \ - you are not familiar with.".to_string() - } - - fn capabilities(&self) -> ServerCapabilities { - CapabilitiesBuilder::new() - .with_tools(true) - .with_resources(false, false) - .with_prompts(false) - .build() - } - - fn list_tools(&self) -> Vec { - vec![ - Tool::new( - "lookup_crate".to_string(), - "Look up documentation for a Rust crate".to_string(), - json!({ - "type": "object", - "properties": { - "crate_name": { - "type": "string", - "description": "The name of the crate to look up" - }, - "version": { - "type": "string", - "description": "The version of the crate (optional, defaults to latest)" - } - }, - "required": ["crate_name"] - }), - ), - Tool::new( - "search_crates".to_string(), - "Search for Rust crates on crates.io".to_string(), - json!({ - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "The search query" - }, - "limit": { - "type": "integer", - "description": "Maximum number of results to return (optional, defaults to 10, max 100)" - } - }, - "required": ["query"] - }), - ), - Tool::new( - "lookup_item".to_string(), - "Look up documentation for a specific item in a Rust crate".to_string(), - json!({ - "type": "object", - "properties": { - "crate_name": { - "type": "string", - "description": "The name of the crate" - }, - "item_path": { - "type": "string", - "description": "Path to the item (e.g., 'std::vec::Vec')" - }, - "version": { - "type": "string", - "description": "The version of the crate (optional, defaults to latest)" - } - }, - "required": ["crate_name", "item_path"] - }), - ), - ] - } - - fn call_tool( - &self, - tool_name: &str, - arguments: Value, - ) -> Pin, ToolError>> + Send + 'static>> { - let this = self.clone(); - let tool_name = tool_name.to_string(); - let arguments = arguments.clone(); - - Box::pin(async move { - match tool_name.as_str() { - "lookup_crate" => { - let crate_name = arguments - .get("crate_name") - .and_then(|v| v.as_str()) - .ok_or_else(|| ToolError::InvalidParameters("crate_name is required".to_string()))? - .to_string(); - - let version = arguments - .get("version") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - - let doc = this.lookup_crate(crate_name, version).await?; - Ok(vec![Content::text(doc)]) - } - "search_crates" => { - let query = arguments - .get("query") - .and_then(|v| v.as_str()) - .ok_or_else(|| ToolError::InvalidParameters("query is required".to_string()))? - .to_string(); - - let limit = arguments - .get("limit") - .and_then(|v| v.as_u64()) - .map(|v| v as u32); - - let results = this.search_crates(query, limit).await?; - Ok(vec![Content::text(results)]) - } - "lookup_item" => { - let crate_name = arguments - .get("crate_name") - .and_then(|v| v.as_str()) - .ok_or_else(|| ToolError::InvalidParameters("crate_name is required".to_string()))? - .to_string(); - - let item_path = arguments - .get("item_path") - .and_then(|v| v.as_str()) - .ok_or_else(|| ToolError::InvalidParameters("item_path is required".to_string()))? - .to_string(); - - let version = arguments - .get("version") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - - let doc = this.lookup_item(crate_name, item_path, version).await?; - Ok(vec![Content::text(doc)]) - } - _ => Err(ToolError::NotFound(format!("Tool {} not found", tool_name))), - } - }) - } - - fn list_resources(&self) -> Vec { - vec![] - } - - fn read_resource( - &self, - _uri: &str, - ) -> Pin> + Send + 'static>> { - Box::pin(async move { - Err(ResourceError::NotFound("Resource not found".to_string())) - }) - } - - fn list_prompts(&self) -> Vec { - vec![] - } - - fn get_prompt( - &self, - prompt_name: &str, - ) -> Pin> + Send + 'static>> { - let prompt_name = prompt_name.to_string(); - Box::pin(async move { - Err(PromptError::NotFound(format!( - "Prompt {} not found", - prompt_name - ))) - }) - } -} \ No newline at end of file diff --git a/src/docs/tests.rs b/src/docs/tests.rs deleted file mode 100644 index 6aa2fe1..0000000 --- a/src/docs/tests.rs +++ /dev/null @@ -1,152 +0,0 @@ -use super::*; -use mcp_core::{Content, ToolError}; -use mcp_server::Router; -use serde_json::json; - -#[tokio::test] -async fn test_doc_cache() { - let cache = DocCache::new(); - - // Initial get should return None - let result = cache.get("test_key").await; - assert_eq!(result, None); - - // Set and get should return the value - cache.set("test_key".to_string(), "test_value".to_string()).await; - let result = cache.get("test_key").await; - assert_eq!(result, Some("test_value".to_string())); -} - -#[tokio::test] -async fn test_router_capabilities() { - let router = DocRouter::new(); - - // Test basic properties - assert_eq!(router.name(), "rust-docs"); - assert!(router.instructions().contains("documentation")); - - // Test capabilities - let capabilities = router.capabilities(); - assert!(capabilities.tools.is_some()); - // Only assert that tools are supported, skip resources checks since they might be configured - // differently depending on the SDK version -} - -#[tokio::test] -async fn test_list_tools() { - let router = DocRouter::new(); - let tools = router.list_tools(); - - // Should have exactly 3 tools - assert_eq!(tools.len(), 3); - - // Check tool names - let tool_names: Vec = tools.iter().map(|t| t.name.clone()).collect(); - assert!(tool_names.contains(&"lookup_crate".to_string())); - assert!(tool_names.contains(&"search_crates".to_string())); - assert!(tool_names.contains(&"lookup_item".to_string())); -} - -#[tokio::test] -async fn test_invalid_tool_call() { - let router = DocRouter::new(); - let result = router.call_tool("invalid_tool", json!({})).await; - - // Should return NotFound error - assert!(matches!(result, Err(ToolError::NotFound(_)))); - if let Err(ToolError::NotFound(msg)) = result { - assert!(msg.contains("invalid_tool")); - } -} - -#[tokio::test] -async fn test_lookup_crate_missing_parameter() { - let router = DocRouter::new(); - let result = router.call_tool("lookup_crate", json!({})).await; - - // Should return InvalidParameters error - assert!(matches!(result, Err(ToolError::InvalidParameters(_)))); -} - -#[tokio::test] -async fn test_search_crates_missing_parameter() { - let router = DocRouter::new(); - let result = router.call_tool("search_crates", json!({})).await; - - // Should return InvalidParameters error - assert!(matches!(result, Err(ToolError::InvalidParameters(_)))); -} - -#[tokio::test] -async fn test_lookup_item_missing_parameters() { - let router = DocRouter::new(); - - // Missing both parameters - let result = router.call_tool("lookup_item", json!({})).await; - assert!(matches!(result, Err(ToolError::InvalidParameters(_)))); - - // Missing item_path - let result = router.call_tool("lookup_item", json!({ - "crate_name": "tokio" - })).await; - assert!(matches!(result, Err(ToolError::InvalidParameters(_)))); -} - -// Requires network access, can be marked as ignored if needed -#[tokio::test] -#[ignore = "Requires network access"] -async fn test_lookup_crate_integration() { - let router = DocRouter::new(); - let result = router.call_tool("lookup_crate", json!({ - "crate_name": "serde" - })).await; - - assert!(result.is_ok()); - let contents = result.unwrap(); - assert_eq!(contents.len(), 1); - if let Content::Text(text) = &contents[0] { - assert!(text.text.contains("serde")); - } else { - panic!("Expected text content"); - } -} - -// Requires network access, can be marked as ignored if needed -#[tokio::test] -#[ignore = "Requires network access"] -async fn test_search_crates_integration() { - let router = DocRouter::new(); - let result = router.call_tool("search_crates", json!({ - "query": "json", - "limit": 5 - })).await; - - assert!(result.is_ok()); - let contents = result.unwrap(); - assert_eq!(contents.len(), 1); - if let Content::Text(text) = &contents[0] { - assert!(text.text.contains("crates")); - } else { - panic!("Expected text content"); - } -} - -// Requires network access, can be marked as ignored if needed -#[tokio::test] -#[ignore = "Requires network access"] -async fn test_lookup_item_integration() { - let router = DocRouter::new(); - let result = router.call_tool("lookup_item", json!({ - "crate_name": "serde", - "item_path": "ser::Serializer" - })).await; - - assert!(result.is_ok()); - let contents = result.unwrap(); - assert_eq!(contents.len(), 1); - if let Content::Text(text) = &contents[0] { - assert!(text.text.contains("Serializer")); - } else { - panic!("Expected text content"); - } -} \ No newline at end of file diff --git a/src/jsonrpc_frame_codec.rs b/src/jsonrpc_frame_codec.rs deleted file mode 100644 index a799b28..0000000 --- a/src/jsonrpc_frame_codec.rs +++ /dev/null @@ -1,27 +0,0 @@ -use tokio_util::codec::Decoder; - -#[derive(Default)] -pub struct JsonRpcFrameCodec; - -#[cfg(test)] -mod tests; -impl Decoder for JsonRpcFrameCodec { - type Item = tokio_util::bytes::Bytes; - type Error = tokio::io::Error; - fn decode( - &mut self, - src: &mut tokio_util::bytes::BytesMut, - ) -> Result, Self::Error> { - if let Some(end) = src - .iter() - .enumerate() - .find_map(|(idx, &b)| (b == b'\n').then_some(idx)) - { - let line = src.split_to(end); - let _char_next_line = src.split_to(1); - Ok(Some(line.freeze())) - } else { - Ok(None) - } - } -} \ No newline at end of file diff --git a/src/jsonrpc_frame_codec/tests.rs b/src/jsonrpc_frame_codec/tests.rs deleted file mode 100644 index f5b2910..0000000 --- a/src/jsonrpc_frame_codec/tests.rs +++ /dev/null @@ -1,72 +0,0 @@ -use super::*; -use tokio_util::bytes::BytesMut; - -#[test] -fn test_decode_single_line() { - let mut codec = JsonRpcFrameCodec::default(); - let mut buffer = BytesMut::from(r#"{"jsonrpc":"2.0","method":"test"}"#); - buffer.extend_from_slice(b"\n"); - - let result = codec.decode(&mut buffer).unwrap(); - - // Should decode successfully - assert!(result.is_some()); - let bytes = result.unwrap(); - assert_eq!(bytes, r#"{"jsonrpc":"2.0","method":"test"}"#); - - // Buffer should be empty after decoding - assert_eq!(buffer.len(), 0); -} - -#[test] -fn test_decode_incomplete_frame() { - let mut codec = JsonRpcFrameCodec::default(); - let mut buffer = BytesMut::from(r#"{"jsonrpc":"2.0","method":"test""#); - - // Should return None when no newline is found - let result = codec.decode(&mut buffer).unwrap(); - assert!(result.is_none()); - - // Buffer should still contain the incomplete frame - assert_eq!(buffer.len(), 32); -} - -#[test] -fn test_decode_multiple_frames() { - let mut codec = JsonRpcFrameCodec::default(); - let json1 = r#"{"jsonrpc":"2.0","method":"test1"}"#; - let json2 = r#"{"jsonrpc":"2.0","method":"test2"}"#; - - let mut buffer = BytesMut::new(); - buffer.extend_from_slice(json1.as_bytes()); - buffer.extend_from_slice(b"\n"); - buffer.extend_from_slice(json2.as_bytes()); - buffer.extend_from_slice(b"\n"); - - // First decode should return the first frame - let result1 = codec.decode(&mut buffer).unwrap(); - assert!(result1.is_some()); - assert_eq!(result1.unwrap(), json1); - - // Second decode should return the second frame - let result2 = codec.decode(&mut buffer).unwrap(); - assert!(result2.is_some()); - assert_eq!(result2.unwrap(), json2); - - // Buffer should be empty after decoding both frames - assert_eq!(buffer.len(), 0); -} - -#[test] -fn test_decode_empty_line() { - let mut codec = JsonRpcFrameCodec::default(); - let mut buffer = BytesMut::from("\n"); - - // Should return an empty frame - let result = codec.decode(&mut buffer).unwrap(); - assert!(result.is_some()); - assert_eq!(result.unwrap().len(), 0); - - // Buffer should be empty - assert_eq!(buffer.len(), 0); -} \ No newline at end of file