Skip to content

Commit dbf6e3d

Browse files
committed
server: Add error enum while deal server info
1. wrap the error type for more standardized 2. add more information in error for debug trace 3. wrap helper func for more user-friendly code Signed-off-by: jokemanfire <hu.dingyang@zte.com.cn>
1 parent ecff1eb commit dbf6e3d

File tree

1 file changed

+90
-46
lines changed

1 file changed

+90
-46
lines changed

crates/rmcp/src/service/server.rs

Lines changed: 90 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,7 @@
11
use crate::model::{
2-
CancelledNotification, CancelledNotificationParam, ClientInfo, ClientNotification,
3-
ClientRequest, ClientResult, CreateMessageRequest, CreateMessageRequestParam,
4-
CreateMessageResult, ListRootsRequest, ListRootsResult, LoggingMessageNotification,
5-
LoggingMessageNotificationParam, ProgressNotification, ProgressNotificationParam,
6-
PromptListChangedNotification, ResourceListChangedNotification, ResourceUpdatedNotification,
7-
ResourceUpdatedNotificationParam, ServerInfo, ServerMessage, ServerNotification, ServerRequest,
8-
ServerResult, ToolListChangedNotification,
2+
CancelledNotification, CancelledNotificationParam, ClientInfo, ClientJsonRpcMessage, ClientMessage, ClientNotification, ClientRequest, ClientResult, CreateMessageRequest, CreateMessageRequestParam, CreateMessageResult, ListRootsRequest, ListRootsResult, LoggingMessageNotification, LoggingMessageNotificationParam, ProgressNotification, ProgressNotificationParam, PromptListChangedNotification, ResourceListChangedNotification, ResourceUpdatedNotification, ResourceUpdatedNotificationParam, ServerInfo, ServerMessage, ServerNotification, ServerRequest, ServerResult, ToolListChangedNotification
93
};
10-
4+
use thiserror::Error;
115
use super::*;
126
use futures::{SinkExt, StreamExt};
137

@@ -26,6 +20,24 @@ impl ServiceRole for RoleServer {
2620
const IS_CLIENT: bool = false;
2721
}
2822

23+
/// It represents the error that may occur when serving the server.
24+
///
25+
/// if you want to handle the error, you can use `serve_server_with_ct` or `serve_server` with `Result<RunningService<RoleServer, S>, ServerError>`
26+
#[derive(Error, Debug)]
27+
pub enum ServerError {
28+
#[error("expect initialize request, but received: {0:?}")]
29+
ExpectedInitRequest(Option<ClientMessage>),
30+
31+
#[error("expect initialize notification, but received: {0:?}")]
32+
ExpectedInitNotification(Option<ClientMessage>),
33+
34+
#[error("connection closed: {0}")]
35+
ConnectionClosed(String),
36+
37+
#[error("IO error: {0}")]
38+
Io(#[from] std::io::Error),
39+
}
40+
2941
pub type ClientSink = Peer<RoleServer>;
3042

3143
impl<S: Service<RoleServer>> ServiceExt<RoleServer> for S {
@@ -55,6 +67,46 @@ where
5567
serve_server_with_ct(service, transport, CancellationToken::new()).await
5668
}
5769

70+
/// Helper function to get the next message from the stream
71+
async fn expect_next_message<S>(
72+
stream: &mut S,
73+
context: &str
74+
) -> Result<ClientMessage, ServerError>
75+
where S: StreamExt<Item = ClientJsonRpcMessage> + Unpin
76+
{
77+
Ok(stream
78+
.next()
79+
.await
80+
.ok_or_else(|| ServerError::ConnectionClosed(context.to_string()))?
81+
.into_message())
82+
}
83+
84+
/// Helper function to expect a request from the stream
85+
async fn expect_request<S>(
86+
stream: &mut S,
87+
context: &str
88+
) -> Result<(ClientRequest, RequestId), ServerError>
89+
where S: StreamExt<Item = ClientJsonRpcMessage> + Unpin
90+
{
91+
let msg = expect_next_message(stream, context).await?;
92+
let msg_clone = msg.clone();
93+
msg.into_request()
94+
.ok_or_else(|| ServerError::ExpectedInitRequest(Some(msg_clone)))
95+
}
96+
97+
/// Helper function to expect a notification from the stream
98+
async fn expect_notification<S>(
99+
stream: &mut S,
100+
context: &str
101+
) -> Result<ClientNotification, ServerError>
102+
where S: StreamExt<Item = ClientJsonRpcMessage> + Unpin
103+
{
104+
let msg = expect_next_message(stream, context).await?;
105+
let msg_clone = msg.clone();
106+
msg.into_notification()
107+
.ok_or_else(|| ServerError::ExpectedInitNotification(Some(msg_clone)))
108+
}
109+
58110
pub async fn serve_server_with_ct<S, T, E, A>(
59111
service: S,
60112
transport: T,
@@ -70,54 +122,46 @@ where
70122
let mut stream = Box::pin(stream);
71123
let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();
72124

73-
// service
74-
let (request, id) = stream
75-
.next()
76-
.await
77-
.ok_or(std::io::Error::new(
78-
std::io::ErrorKind::UnexpectedEof,
79-
"expect initialize request",
80-
))?
81-
.into_message()
82-
.into_request()
83-
.ok_or(std::io::Error::new(
84-
std::io::ErrorKind::InvalidData,
85-
"expect initialize request",
86-
))?;
125+
// Convert ServerError to std::io::Error, then to E
126+
let handle_server_error = |e: ServerError| -> E {
127+
match e {
128+
ServerError::Io(io_err) => io_err.into(),
129+
other => std::io::Error::new(
130+
std::io::ErrorKind::Other,
131+
format!("{}", other)
132+
).into()
133+
}
134+
};
135+
136+
// Get initialize request
137+
let (request, id) = expect_request(&mut stream, "initialize request")
138+
.await.map_err(handle_server_error)?;
139+
87140
let ClientRequest::InitializeRequest(peer_info) = request else {
88-
return Err(std::io::Error::new(
89-
std::io::ErrorKind::InvalidData,
90-
"expect initialize request",
91-
)
92-
.into());
141+
return Err(handle_server_error(ServerError::ExpectedInitRequest(
142+
Some(ClientMessage::Request(request, id))
143+
)));
93144
};
145+
146+
// Send initialize response
94147
let init_response = service.get_info();
95148
sink.send(
96149
ServerMessage::Response(ServerResult::InitializeResult(init_response), id)
97150
.into_json_rpc_message(),
98151
)
99152
.await?;
100-
// waiting for notification
101-
let notification = stream
102-
.next()
103-
.await
104-
.ok_or(std::io::Error::new(
105-
std::io::ErrorKind::UnexpectedEof,
106-
"expect initialize notification",
107-
))?
108-
.into_message()
109-
.into_notification()
110-
.ok_or(std::io::Error::new(
111-
std::io::ErrorKind::InvalidData,
112-
"expect initialize notification",
113-
))?;
153+
154+
// Wait for initialize notification
155+
let notification = expect_notification(&mut stream, "initialize notification")
156+
.await.map_err(handle_server_error)?;
157+
114158
let ClientNotification::InitializedNotification(_) = notification else {
115-
return Err(std::io::Error::new(
116-
std::io::ErrorKind::InvalidData,
117-
"expect initialize notification",
118-
)
119-
.into());
159+
return Err(handle_server_error(ServerError::ExpectedInitNotification(
160+
Some(ClientMessage::Notification(notification))
161+
)));
120162
};
163+
164+
// Continue processing service
121165
serve_inner(service, (sink, stream), peer_info.params, id_provider, ct).await
122166
}
123167

0 commit comments

Comments
 (0)