|
1 |
| -use async_trait::async_trait; |
2 |
| -use tokio_stream::StreamExt; |
| 1 | +#[cfg(feature = "hyper-server")] |
| 2 | +pub mod test_server_common { |
| 3 | + use async_trait::async_trait; |
| 4 | + use tokio_stream::StreamExt; |
3 | 5 |
|
4 |
| -use rust_mcp_schema::{ |
5 |
| - Implementation, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, |
6 |
| - LATEST_PROTOCOL_VERSION, |
7 |
| -}; |
8 |
| -use rust_mcp_sdk::{ |
9 |
| - mcp_server::{hyper_server, HyperServer, HyperServerOptions, IdGenerator, ServerHandler}, |
10 |
| - McpServer, SessionId, |
11 |
| -}; |
12 |
| -use std::sync::RwLock; |
13 |
| -use std::time::Duration; |
14 |
| -use tokio::time::timeout; |
| 6 | + use rust_mcp_schema::{ |
| 7 | + Implementation, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, |
| 8 | + LATEST_PROTOCOL_VERSION, |
| 9 | + }; |
| 10 | + use rust_mcp_sdk::{ |
| 11 | + mcp_server::{hyper_server, HyperServer, HyperServerOptions, IdGenerator, ServerHandler}, |
| 12 | + McpServer, SessionId, |
| 13 | + }; |
| 14 | + use std::sync::RwLock; |
| 15 | + use std::time::Duration; |
| 16 | + use tokio::time::timeout; |
15 | 17 |
|
16 |
| -pub const INITIALIZE_REQUEST: &str = r#"{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"protocolVersion":"2.0","capabilities":{"sampling":{},"roots":{"listChanged":true}},"clientInfo":{"name":"reqwest-test","version":"0.1.0"}}}"#; |
17 |
| -pub const PING_REQUEST: &str = r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#; |
| 18 | + pub const INITIALIZE_REQUEST: &str = r#"{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"protocolVersion":"2.0","capabilities":{"sampling":{},"roots":{"listChanged":true}},"clientInfo":{"name":"reqwest-test","version":"0.1.0"}}}"#; |
| 19 | + pub const PING_REQUEST: &str = r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#; |
18 | 20 |
|
19 |
| -pub fn test_server_details() -> InitializeResult { |
20 |
| - InitializeResult { |
21 |
| - // server name and version |
22 |
| - server_info: Implementation { |
23 |
| - name: "Test MCP Server".to_string(), |
24 |
| - version: "0.1.0".to_string(), |
25 |
| - }, |
26 |
| - capabilities: ServerCapabilities { |
27 |
| - // indicates that server support mcp tools |
28 |
| - tools: Some(ServerCapabilitiesTools { list_changed: None }), |
29 |
| - ..Default::default() // Using default values for other fields |
30 |
| - }, |
31 |
| - meta: None, |
32 |
| - instructions: Some("server instructions...".to_string()), |
33 |
| - protocol_version: LATEST_PROTOCOL_VERSION.to_string(), |
| 21 | + pub fn test_server_details() -> InitializeResult { |
| 22 | + InitializeResult { |
| 23 | + // server name and version |
| 24 | + server_info: Implementation { |
| 25 | + name: "Test MCP Server".to_string(), |
| 26 | + version: "0.1.0".to_string(), |
| 27 | + }, |
| 28 | + capabilities: ServerCapabilities { |
| 29 | + // indicates that server support mcp tools |
| 30 | + tools: Some(ServerCapabilitiesTools { list_changed: None }), |
| 31 | + ..Default::default() // Using default values for other fields |
| 32 | + }, |
| 33 | + meta: None, |
| 34 | + instructions: Some("server instructions...".to_string()), |
| 35 | + protocol_version: LATEST_PROTOCOL_VERSION.to_string(), |
| 36 | + } |
34 | 37 | }
|
35 |
| -} |
36 | 38 |
|
37 |
| -pub struct TestServerHandler; |
| 39 | + pub struct TestServerHandler; |
38 | 40 |
|
39 |
| -#[async_trait] |
40 |
| -impl ServerHandler for TestServerHandler { |
41 |
| - async fn on_server_started(&self, runtime: &dyn McpServer) { |
42 |
| - let _ = runtime |
43 |
| - .stderr_message("Server started successfully".into()) |
44 |
| - .await; |
| 41 | + #[async_trait] |
| 42 | + impl ServerHandler for TestServerHandler { |
| 43 | + async fn on_server_started(&self, runtime: &dyn McpServer) { |
| 44 | + let _ = runtime |
| 45 | + .stderr_message("Server started successfully".into()) |
| 46 | + .await; |
| 47 | + } |
45 | 48 | }
|
46 |
| -} |
47 | 49 |
|
48 |
| -pub fn create_test_server(options: HyperServerOptions) -> HyperServer { |
49 |
| - hyper_server::create_server(test_server_details(), TestServerHandler {}, options) |
50 |
| -} |
| 50 | + pub fn create_test_server(options: HyperServerOptions) -> HyperServer { |
| 51 | + hyper_server::create_server(test_server_details(), TestServerHandler {}, options) |
| 52 | + } |
51 | 53 |
|
52 |
| -// Tests the session ID generator, ensuring it returns a sequence of predefined session IDs. |
53 |
| -pub struct TestIdGenerator { |
54 |
| - constant_ids: Vec<SessionId>, |
55 |
| - generated: RwLock<usize>, |
56 |
| -} |
| 54 | + // Tests the session ID generator, ensuring it returns a sequence of predefined session IDs. |
| 55 | + pub struct TestIdGenerator { |
| 56 | + constant_ids: Vec<SessionId>, |
| 57 | + generated: RwLock<usize>, |
| 58 | + } |
57 | 59 |
|
58 |
| -impl TestIdGenerator { |
59 |
| - pub fn new(constant_ids: Vec<SessionId>) -> Self { |
60 |
| - TestIdGenerator { |
61 |
| - constant_ids, |
62 |
| - generated: RwLock::new(0), |
| 60 | + impl TestIdGenerator { |
| 61 | + pub fn new(constant_ids: Vec<SessionId>) -> Self { |
| 62 | + TestIdGenerator { |
| 63 | + constant_ids, |
| 64 | + generated: RwLock::new(0), |
| 65 | + } |
63 | 66 | }
|
64 | 67 | }
|
65 |
| -} |
66 | 68 |
|
67 |
| -impl IdGenerator for TestIdGenerator { |
68 |
| - fn generate(&self) -> SessionId { |
69 |
| - let mut lock = self.generated.write().unwrap(); |
70 |
| - *lock += 1; |
71 |
| - if *lock > self.constant_ids.len() { |
72 |
| - *lock = 1; |
| 69 | + impl IdGenerator for TestIdGenerator { |
| 70 | + fn generate(&self) -> SessionId { |
| 71 | + let mut lock = self.generated.write().unwrap(); |
| 72 | + *lock += 1; |
| 73 | + if *lock > self.constant_ids.len() { |
| 74 | + *lock = 1; |
| 75 | + } |
| 76 | + self.constant_ids[*lock - 1].to_owned() |
73 | 77 | }
|
74 |
| - self.constant_ids[*lock - 1].to_owned() |
75 | 78 | }
|
76 |
| -} |
77 | 79 |
|
78 |
| -pub async fn collect_sse_lines( |
79 |
| - response: reqwest::Response, |
80 |
| - line_count: usize, |
81 |
| - read_timeout: Duration, |
82 |
| -) -> Result<Vec<String>, Box<dyn std::error::Error>> { |
83 |
| - let mut collected_lines = Vec::new(); |
84 |
| - let mut stream = response.bytes_stream(); |
| 80 | + pub async fn collect_sse_lines( |
| 81 | + response: reqwest::Response, |
| 82 | + line_count: usize, |
| 83 | + read_timeout: Duration, |
| 84 | + ) -> Result<Vec<String>, Box<dyn std::error::Error>> { |
| 85 | + let mut collected_lines = Vec::new(); |
| 86 | + let mut stream = response.bytes_stream(); |
85 | 87 |
|
86 |
| - let result = timeout(read_timeout, async { |
87 |
| - while let Some(chunk) = stream.next().await { |
88 |
| - let chunk = chunk.map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?; |
89 |
| - let chunk_str = String::from_utf8_lossy(&chunk); |
| 88 | + let result = timeout(read_timeout, async { |
| 89 | + while let Some(chunk) = stream.next().await { |
| 90 | + let chunk = chunk.map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?; |
| 91 | + let chunk_str = String::from_utf8_lossy(&chunk); |
90 | 92 |
|
91 |
| - // Split the chunk into lines |
92 |
| - let lines: Vec<&str> = chunk_str.lines().collect(); |
| 93 | + // Split the chunk into lines |
| 94 | + let lines: Vec<&str> = chunk_str.lines().collect(); |
93 | 95 |
|
94 |
| - // Add each line to the collected_lines vector |
95 |
| - for line in lines { |
96 |
| - collected_lines.push(line.to_string()); |
| 96 | + // Add each line to the collected_lines vector |
| 97 | + for line in lines { |
| 98 | + collected_lines.push(line.to_string()); |
97 | 99 |
|
98 |
| - // Check if we have collected 5 lines |
99 |
| - if collected_lines.len() >= line_count { |
100 |
| - return Ok(collected_lines); |
| 100 | + // Check if we have collected 5 lines |
| 101 | + if collected_lines.len() >= line_count { |
| 102 | + return Ok(collected_lines); |
| 103 | + } |
101 | 104 | }
|
102 | 105 | }
|
103 |
| - } |
104 |
| - // If the stream ends before collecting 5 lines, return what we have |
105 |
| - Ok(collected_lines) |
106 |
| - }) |
107 |
| - .await; |
| 106 | + // If the stream ends before collecting 5 lines, return what we have |
| 107 | + Ok(collected_lines) |
| 108 | + }) |
| 109 | + .await; |
108 | 110 |
|
109 |
| - // Handle timeout or stream result |
110 |
| - match result { |
111 |
| - Ok(Ok(lines)) => Ok(lines), |
112 |
| - Ok(Err(e)) => Err(e), |
113 |
| - Err(_) => Err(Box::new(std::io::Error::new( |
114 |
| - std::io::ErrorKind::TimedOut, |
115 |
| - "Timed out waiting for 5 lines", |
116 |
| - ))), |
| 111 | + // Handle timeout or stream result |
| 112 | + match result { |
| 113 | + Ok(Ok(lines)) => Ok(lines), |
| 114 | + Ok(Err(e)) => Err(e), |
| 115 | + Err(_) => Err(Box::new(std::io::Error::new( |
| 116 | + std::io::ErrorKind::TimedOut, |
| 117 | + "Timed out waiting for 5 lines", |
| 118 | + ))), |
| 119 | + } |
117 | 120 | }
|
118 | 121 | }
|
0 commit comments