11use crate :: model:: {
2- CancelledNotification , CancelledNotificationParam , ClientInfo , ClientNotification ,
3- ClientRequest , ClientResult , CreateMessageRequest , CreateMessageRequestParam ,
4- CreateMessageResult , ListRootsRequest , ListRootsResult , LoggingMessageNotification ,
5- LoggingMessageNotificationParam , ProgressNotification , ProgressNotificationParam ,
6- PromptListChangedNotification , ResourceListChangedNotification , ResourceUpdatedNotification ,
7- ResourceUpdatedNotificationParam , ServerInfo , ServerMessage , ServerNotification , ServerRequest ,
8- ServerResult , ToolListChangedNotification ,
2+ CancelledNotification , CancelledNotificationParam , ClientInfo , ClientJsonRpcMessage , ClientMessage , ClientNotification , ClientRequest , ClientResult , CreateMessageRequest , CreateMessageRequestParam , CreateMessageResult , ListRootsRequest , ListRootsResult , LoggingMessageNotification , LoggingMessageNotificationParam , ProgressNotification , ProgressNotificationParam , PromptListChangedNotification , ResourceListChangedNotification , ResourceUpdatedNotification , ResourceUpdatedNotificationParam , ServerInfo , ServerMessage , ServerNotification , ServerRequest , ServerResult , ToolListChangedNotification
93} ;
10-
4+ use thiserror :: Error ;
115use super :: * ;
126use futures:: { SinkExt , StreamExt } ;
137
@@ -26,6 +20,24 @@ impl ServiceRole for RoleServer {
2620 const IS_CLIENT : bool = false ;
2721}
2822
23+ /// It represents the error that may occur when serving the server.
24+ ///
25+ /// if you want to handle the error, you can use `serve_server_with_ct` or `serve_server` with `Result<RunningService<RoleServer, S>, ServerError>`
26+ #[ derive( Error , Debug ) ]
27+ pub enum ServerError {
28+ #[ error( "expect initialize request, but received: {0:?}" ) ]
29+ ExpectedInitRequest ( Option < ClientMessage > ) ,
30+
31+ #[ error( "expect initialize notification, but received: {0:?}" ) ]
32+ ExpectedInitNotification ( Option < ClientMessage > ) ,
33+
34+ #[ error( "connection closed: {0}" ) ]
35+ ConnectionClosed ( String ) ,
36+
37+ #[ error( "IO error: {0}" ) ]
38+ Io ( #[ from] std:: io:: Error ) ,
39+ }
40+
2941pub type ClientSink = Peer < RoleServer > ;
3042
3143impl < S : Service < RoleServer > > ServiceExt < RoleServer > for S {
5567 serve_server_with_ct ( service, transport, CancellationToken :: new ( ) ) . await
5668}
5769
70+ /// Helper function to get the next message from the stream
71+ async fn expect_next_message < S > (
72+ stream : & mut S ,
73+ context : & str
74+ ) -> Result < ClientMessage , ServerError >
75+ where S : StreamExt < Item = ClientJsonRpcMessage > + Unpin
76+ {
77+ Ok ( stream
78+ . next ( )
79+ . await
80+ . ok_or_else ( || ServerError :: ConnectionClosed ( context. to_string ( ) ) ) ?
81+ . into_message ( ) )
82+ }
83+
84+ /// Helper function to expect a request from the stream
85+ async fn expect_request < S > (
86+ stream : & mut S ,
87+ context : & str
88+ ) -> Result < ( ClientRequest , RequestId ) , ServerError >
89+ where S : StreamExt < Item = ClientJsonRpcMessage > + Unpin
90+ {
91+ let msg = expect_next_message ( stream, context) . await ?;
92+ let msg_clone = msg. clone ( ) ;
93+ msg. into_request ( )
94+ . ok_or_else ( || ServerError :: ExpectedInitRequest ( Some ( msg_clone) ) )
95+ }
96+
97+ /// Helper function to expect a notification from the stream
98+ async fn expect_notification < S > (
99+ stream : & mut S ,
100+ context : & str
101+ ) -> Result < ClientNotification , ServerError >
102+ where S : StreamExt < Item = ClientJsonRpcMessage > + Unpin
103+ {
104+ let msg = expect_next_message ( stream, context) . await ?;
105+ let msg_clone = msg. clone ( ) ;
106+ msg. into_notification ( )
107+ . ok_or_else ( || ServerError :: ExpectedInitNotification ( Some ( msg_clone) ) )
108+ }
109+
58110pub async fn serve_server_with_ct < S , T , E , A > (
59111 service : S ,
60112 transport : T ,
@@ -70,54 +122,46 @@ where
70122 let mut stream = Box :: pin ( stream) ;
71123 let id_provider = <Arc < AtomicU32RequestIdProvider > >:: default ( ) ;
72124
73- // service
74- let ( request, id) = stream
75- . next ( )
76- . await
77- . ok_or ( std:: io:: Error :: new (
78- std:: io:: ErrorKind :: UnexpectedEof ,
79- "expect initialize request" ,
80- ) ) ?
81- . into_message ( )
82- . into_request ( )
83- . ok_or ( std:: io:: Error :: new (
84- std:: io:: ErrorKind :: InvalidData ,
85- "expect initialize request" ,
86- ) ) ?;
125+ // Convert ServerError to std::io::Error, then to E
126+ let handle_server_error = |e : ServerError | -> E {
127+ match e {
128+ ServerError :: Io ( io_err) => io_err. into ( ) ,
129+ other => std:: io:: Error :: new (
130+ std:: io:: ErrorKind :: Other ,
131+ format ! ( "{}" , other)
132+ ) . into ( )
133+ }
134+ } ;
135+
136+ // Get initialize request
137+ let ( request, id) = expect_request ( & mut stream, "initialize request" )
138+ . await . map_err ( handle_server_error) ?;
139+
87140 let ClientRequest :: InitializeRequest ( peer_info) = request else {
88- return Err ( std:: io:: Error :: new (
89- std:: io:: ErrorKind :: InvalidData ,
90- "expect initialize request" ,
91- )
92- . into ( ) ) ;
141+ return Err ( handle_server_error ( ServerError :: ExpectedInitRequest (
142+ Some ( ClientMessage :: Request ( request, id) )
143+ ) ) ) ;
93144 } ;
145+
146+ // Send initialize response
94147 let init_response = service. get_info ( ) ;
95148 sink. send (
96149 ServerMessage :: Response ( ServerResult :: InitializeResult ( init_response) , id)
97150 . into_json_rpc_message ( ) ,
98151 )
99152 . await ?;
100- // waiting for notification
101- let notification = stream
102- . next ( )
103- . await
104- . ok_or ( std:: io:: Error :: new (
105- std:: io:: ErrorKind :: UnexpectedEof ,
106- "expect initialize notification" ,
107- ) ) ?
108- . into_message ( )
109- . into_notification ( )
110- . ok_or ( std:: io:: Error :: new (
111- std:: io:: ErrorKind :: InvalidData ,
112- "expect initialize notification" ,
113- ) ) ?;
153+
154+ // Wait for initialize notification
155+ let notification = expect_notification ( & mut stream, "initialize notification" )
156+ . await . map_err ( handle_server_error) ?;
157+
114158 let ClientNotification :: InitializedNotification ( _) = notification else {
115- return Err ( std:: io:: Error :: new (
116- std:: io:: ErrorKind :: InvalidData ,
117- "expect initialize notification" ,
118- )
119- . into ( ) ) ;
159+ return Err ( handle_server_error ( ServerError :: ExpectedInitNotification (
160+ Some ( ClientMessage :: Notification ( notification) )
161+ ) ) ) ;
120162 } ;
163+
164+ // Continue processing service
121165 serve_inner ( service, ( sink, stream) , peer_info. params , id_provider, ct) . await
122166}
123167
0 commit comments