Refactor
This commit is contained in:
		
							parent
							
								
									5fe28f7fe7
								
							
						
					
					
						commit
						4171c923e6
					
				
					 4 changed files with 0 additions and 222 deletions
				
			
		|  | @ -1,139 +0,0 @@ | |||
| use axum::{ | ||||
|     body::Body, | ||||
|     extract::{Query, State}, | ||||
|     http::StatusCode, | ||||
|     response::sse::{Event, Sse}, | ||||
|     routing::get, | ||||
|     Router, | ||||
| }; | ||||
| 
 | ||||
| #[cfg(test)] | ||||
| mod tests; | ||||
| use futures::{stream::Stream, StreamExt, TryStreamExt}; | ||||
| use mcp_server::{ByteTransport, Server}; | ||||
| use std::collections::HashMap; | ||||
| use tokio_util::codec::FramedRead; | ||||
| 
 | ||||
| use anyhow::Result; | ||||
| use mcp_server::router::RouterService; | ||||
| use crate::{jsonrpc_frame_codec::JsonRpcFrameCodec, DocRouter}; | ||||
| use std::sync::Arc; | ||||
| use tokio::{ | ||||
|     io::{self, AsyncWriteExt}, | ||||
|     sync::Mutex, | ||||
| }; | ||||
| 
 | ||||
| type C2SWriter = Arc<Mutex<io::WriteHalf<io::SimplexStream>>>; | ||||
| type SessionId = Arc<str>; | ||||
| 
 | ||||
| #[derive(Clone, Default)] | ||||
| pub struct App { | ||||
|     txs: Arc<tokio::sync::RwLock<HashMap<SessionId, C2SWriter>>>, | ||||
| } | ||||
| 
 | ||||
| impl App { | ||||
|     pub fn new() -> Self { | ||||
|         Self { | ||||
|             txs: Default::default(), | ||||
|         } | ||||
|     } | ||||
|     pub fn router(&self) -> Router { | ||||
|         Router::new() | ||||
|             .route("/sse", get(sse_handler).post(post_event_handler)) | ||||
|             .with_state(self.clone()) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| fn session_id() -> SessionId { | ||||
|     let id = format!("{:016x}", rand::random::<u128>()); | ||||
|     Arc::from(id) | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, serde::Deserialize)] | ||||
| #[serde(rename_all = "camelCase")] | ||||
| pub struct PostEventQuery { | ||||
|     pub session_id: String, | ||||
| } | ||||
| 
 | ||||
| async fn post_event_handler( | ||||
|     State(app): State<App>, | ||||
|     Query(PostEventQuery { session_id }): Query<PostEventQuery>, | ||||
|     body: Body, | ||||
| ) -> Result<StatusCode, StatusCode> { | ||||
|     const BODY_BYTES_LIMIT: usize = 1 << 22; | ||||
|     let write_stream = { | ||||
|         let rg = app.txs.read().await; | ||||
|         rg.get(session_id.as_str()) | ||||
|             .ok_or(StatusCode::NOT_FOUND)? | ||||
|             .clone() | ||||
|     }; | ||||
|     let mut write_stream = write_stream.lock().await; | ||||
|     let mut body = body.into_data_stream(); | ||||
|     if let (_, Some(size)) = body.size_hint() { | ||||
|         if size > BODY_BYTES_LIMIT { | ||||
|             return Err(StatusCode::PAYLOAD_TOO_LARGE); | ||||
|         } | ||||
|     } | ||||
|     // calculate the body size
 | ||||
|     let mut size = 0; | ||||
|     while let Some(chunk) = body.next().await { | ||||
|         let Ok(chunk) = chunk else { | ||||
|             return Err(StatusCode::BAD_REQUEST); | ||||
|         }; | ||||
|         size += chunk.len(); | ||||
|         if size > BODY_BYTES_LIMIT { | ||||
|             return Err(StatusCode::PAYLOAD_TOO_LARGE); | ||||
|         } | ||||
|         write_stream | ||||
|             .write_all(&chunk) | ||||
|             .await | ||||
|             .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; | ||||
|     } | ||||
|     write_stream | ||||
|         .write_u8(b'\n') | ||||
|         .await | ||||
|         .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; | ||||
|     Ok(StatusCode::ACCEPTED) | ||||
| } | ||||
| 
 | ||||
| async fn sse_handler(State(app): State<App>) -> Sse<impl Stream<Item = Result<Event, io::Error>>> { | ||||
|     // it's 4KB
 | ||||
|     const BUFFER_SIZE: usize = 1 << 12; | ||||
|     let session = session_id(); | ||||
|     tracing::info!(%session, "sse connection"); | ||||
|     let (c2s_read, c2s_write) = tokio::io::simplex(BUFFER_SIZE); | ||||
|     let (s2c_read, s2c_write) = tokio::io::simplex(BUFFER_SIZE); | ||||
|     app.txs | ||||
|         .write() | ||||
|         .await | ||||
|         .insert(session.clone(), Arc::new(Mutex::new(c2s_write))); | ||||
|     { | ||||
|         let app_clone = app.clone(); | ||||
|         let session = session.clone(); | ||||
|         tokio::spawn(async move { | ||||
|             let router = RouterService(DocRouter::new()); | ||||
|             let server = Server::new(router); | ||||
|             let bytes_transport = ByteTransport::new(c2s_read, s2c_write); | ||||
|             let _result = server | ||||
|                 .run(bytes_transport) | ||||
|                 .await | ||||
|                 .inspect_err(|e| tracing::error!(?e, "server run error")); | ||||
|             app_clone.txs.write().await.remove(&session); | ||||
|         }); | ||||
|     } | ||||
| 
 | ||||
|     let stream = futures::stream::once(futures::future::ok( | ||||
|         Event::default() | ||||
|             .event("endpoint") | ||||
|             .data(format!("?sessionId={session}")), | ||||
|     )) | ||||
|     .chain( | ||||
|         FramedRead::new(s2c_read, JsonRpcFrameCodec) | ||||
|             .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) | ||||
|             .and_then(move |bytes| match std::str::from_utf8(bytes.as_ref()) { | ||||
|                 Ok(message) => futures::future::ok(Event::default().event("message").data(message)), | ||||
|                 Err(e) => futures::future::err(io::Error::new(io::ErrorKind::InvalidData, e)), | ||||
|             }), | ||||
|     ); | ||||
|     Sse::new(stream) | ||||
| } | ||||
|  | @ -1,63 +0,0 @@ | |||
| use super::*; | ||||
| use axum::{ | ||||
|     body::Body, | ||||
|     http::{Method, Request}, | ||||
| }; | ||||
| use tokio::sync::RwLock; | ||||
| // Comment out tower imports for now, as we'll handle router testing differently
 | ||||
| // use tower::Service; 
 | ||||
| // use tower::util::ServiceExt;
 | ||||
| 
 | ||||
| // Helper function to create an App with an empty state
 | ||||
| fn create_test_app() -> App { | ||||
|     App { | ||||
|         txs: Arc::new(RwLock::new(HashMap::new())), | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| #[tokio::test] | ||||
| async fn test_app_initialization() { | ||||
|     let app = App::new(); | ||||
|     // App should be created with an empty hashmap
 | ||||
|     assert_eq!(app.txs.read().await.len(), 0); | ||||
| } | ||||
| 
 | ||||
| #[tokio::test] | ||||
| async fn test_router_setup() { | ||||
|     let app = App::new(); | ||||
|     let _router = app.router(); | ||||
|     
 | ||||
|     // Check if the router is constructed properly
 | ||||
|     // This is a basic test to ensure the router is created without panics
 | ||||
|     // Just check that the router exists, no need to invoke methods
 | ||||
|     assert!(true); | ||||
| } | ||||
| 
 | ||||
| #[tokio::test] | ||||
| async fn test_session_id_generation() { | ||||
|     // Generate two session IDs and ensure they're different
 | ||||
|     let id1 = session_id(); | ||||
|     let id2 = session_id(); | ||||
|     
 | ||||
|     assert_ne!(id1, id2); | ||||
|     assert_eq!(id1.len(), 32); // Should be 32 hex chars
 | ||||
| } | ||||
| 
 | ||||
| #[tokio::test] | ||||
| async fn test_post_event_handler_not_found() { | ||||
|     let app = create_test_app(); | ||||
|     let _router = app.router(); | ||||
|     
 | ||||
|     // Create a request with a session ID that doesn't exist
 | ||||
|     let _request = Request::builder() | ||||
|         .method(Method::POST) | ||||
|         .uri("/sse?sessionId=nonexistent") | ||||
|         .body(Body::empty()) | ||||
|         .unwrap(); | ||||
|     
 | ||||
|     // Since we can't use oneshot without tower imports, 
 | ||||
|     // we'll skip the actual request handling for now
 | ||||
|     
 | ||||
|     // Just check that the handler would have been called
 | ||||
|     assert!(true); | ||||
| } | ||||
|  | @ -1,4 +0,0 @@ | |||
| pub mod http_sse_server; | ||||
| 
 | ||||
| #[cfg(test)] | ||||
| mod tests; | ||||
|  | @ -1,16 +0,0 @@ | |||
| use mcp_server::router::RouterService; | ||||
| use mcp_server::{Server}; | ||||
| use crate::DocRouter; | ||||
| use mcp_server::Router; | ||||
| 
 | ||||
| #[tokio::test] | ||||
| async fn test_server_initialization() { | ||||
|     // Basic test to ensure the server initializes properly
 | ||||
|     let doc_router = DocRouter::new(); | ||||
|     let router_name = doc_router.name(); | ||||
|     let router = RouterService(doc_router); | ||||
|     let _server = Server::new(router); | ||||
|     
 | ||||
|     // Server should be created successfully without panics
 | ||||
|     assert!(router_name.contains("rust-docs")); | ||||
| } | ||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Danielle Jenkins
						Danielle Jenkins