Skip to content

Commit 83ad262

Browse files
authored
feat(validation): support validation for subtypes of ServerRequest enum. (#81)
Co-authored-by: zhongyi51 <zhongyi51@users.noreply.github.com>
1 parent cbb55b0 commit 83ad262

File tree

2 files changed

+121
-0
lines changed

2 files changed

+121
-0
lines changed

src/generated_schema/2025_06_18/mcp_schema.rs

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,6 +1312,8 @@ impl ::std::convert::From<EmbeddedResource> for ContentBlock {
13121312
/// </details>
13131313
#[derive(::serde::Deserialize, ::serde::Serialize, Clone, Debug)]
13141314
pub struct CreateMessageRequest {
1315+
// This field requires custom deserialization for validation.
1316+
#[serde(deserialize_with = "server_request_method_validation::deserialize_CreateMessageRequest_method")]
13151317
method: ::std::string::String,
13161318
pub params: CreateMessageRequestParams,
13171319
}
@@ -1625,6 +1627,8 @@ pub struct Cursor(pub ::std::string::String);
16251627
/// </details>
16261628
#[derive(::serde::Deserialize, ::serde::Serialize, Clone, Debug)]
16271629
pub struct ElicitRequest {
1630+
// This field requires custom deserialization for validation.
1631+
#[serde(deserialize_with = "server_request_method_validation::deserialize_ElicitRequest_method")]
16281632
method: ::std::string::String,
16291633
pub params: ElicitRequestParams,
16301634
}
@@ -3302,6 +3306,8 @@ structure or access specific locations that the client has permission to read fr
33023306
/// </details>
33033307
#[derive(::serde::Deserialize, ::serde::Serialize, Clone, Debug)]
33043308
pub struct ListRootsRequest {
3309+
// This field requires custom deserialization for validation.
3310+
#[serde(deserialize_with = "server_request_method_validation::deserialize_ListRootsRequest_method")]
33053311
method: ::std::string::String,
33063312
#[serde(default, skip_serializing_if = "::std::option::Option::is_none")]
33073313
pub params: ::std::option::Option<ListRootsRequestParams>,
@@ -4040,6 +4046,8 @@ pub struct PaginatedResult {
40404046
/// </details>
40414047
#[derive(::serde::Deserialize, ::serde::Serialize, Clone, Debug)]
40424048
pub struct PingRequest {
4049+
// This field requires custom deserialization for validation.
4050+
#[serde(deserialize_with = "server_request_method_validation::deserialize_PingRequest_method")]
40434051
method: ::std::string::String,
40444052
#[serde(default, skip_serializing_if = "::std::option::Option::is_none")]
40454053
pub params: ::std::option::Option<PingRequestParams>,
@@ -7167,6 +7175,112 @@ impl ServerNotification {
71677175
}
71687176
}
71697177
}
7178+
7179+
// Custom module for deserialization function to prevent name conflicts.
7180+
mod server_request_method_validation{
7181+
7182+
// Custom deserialization function, following the `deserialize_#StructName_#FieldName` format.
7183+
#[allow(non_snake_case)]
7184+
pub(super) fn deserialize_PingRequest_method<'de, D>(
7185+
deserializer: D,
7186+
) -> std::result::Result<String, D::Error>
7187+
where
7188+
D: serde::de::Deserializer<'de>,
7189+
{
7190+
let value = serde::Deserialize::deserialize(deserializer)?;
7191+
// The expected constant value.
7192+
let expected = "ping";
7193+
7194+
// Validate the deserialized value.
7195+
if value == expected {
7196+
Ok(value)
7197+
} else {
7198+
// The error message with format
7199+
// "Expected field `#FieldName` in struct `#StructName` as const value '{}', but got '{}'"
7200+
Err(serde::de::Error::custom(format!(
7201+
"Expected field `method` in struct `PingRequest` as const value '{}', but got '{}'",
7202+
expected, value
7203+
)))
7204+
}
7205+
}
7206+
7207+
// Custom deserialization function, following the `deserialize_#StructName_#FieldName` format.
7208+
#[allow(non_snake_case)]
7209+
pub(super) fn deserialize_CreateMessageRequest_method<'de, D>(
7210+
deserializer: D,
7211+
) -> std::result::Result<String, D::Error>
7212+
where
7213+
D: serde::de::Deserializer<'de>,
7214+
{
7215+
let value = serde::Deserialize::deserialize(deserializer)?;
7216+
// The expected constant value.
7217+
let expected = "sampling/createMessage";
7218+
7219+
// Validate the deserialized value.
7220+
if value == expected {
7221+
Ok(value)
7222+
} else {
7223+
// The error message with format
7224+
// "Expected field `#FieldName` in struct `#StructName` as const value '{}', but got '{}'"
7225+
Err(serde::de::Error::custom(format!(
7226+
"Expected field `method` in struct `CreateMessageRequest` as const value '{}', but got '{}'",
7227+
expected, value
7228+
)))
7229+
}
7230+
}
7231+
7232+
// Custom deserialization function, following the `deserialize_#StructName_#FieldName` format.
7233+
#[allow(non_snake_case)]
7234+
pub(super) fn deserialize_ListRootsRequest_method<'de, D>(
7235+
deserializer: D,
7236+
) -> std::result::Result<String, D::Error>
7237+
where
7238+
D: serde::de::Deserializer<'de>,
7239+
{
7240+
let value = serde::Deserialize::deserialize(deserializer)?;
7241+
// The expected constant value.
7242+
let expected = "roots/list";
7243+
7244+
// Validate the deserialized value.
7245+
if value == expected {
7246+
Ok(value)
7247+
} else {
7248+
// The error message with format
7249+
// "Expected field `#FieldName` in struct `#StructName` as const value '{}', but got '{}'"
7250+
Err(serde::de::Error::custom(format!(
7251+
"Expected field `method` in struct `ListRootsRequest` as const value '{}', but got '{}'",
7252+
expected, value
7253+
)))
7254+
}
7255+
}
7256+
7257+
// Custom deserialization function, following the `deserialize_#StructName_#FieldName` format.
7258+
#[allow(non_snake_case)]
7259+
pub(super) fn deserialize_ElicitRequest_method<'de, D>(
7260+
deserializer: D,
7261+
) -> std::result::Result<String, D::Error>
7262+
where
7263+
D: serde::de::Deserializer<'de>,
7264+
{
7265+
let value = serde::Deserialize::deserialize(deserializer)?;
7266+
// The expected constant value.
7267+
let expected = "elicitation/create";
7268+
7269+
// Validate the deserialized value.
7270+
if value == expected {
7271+
Ok(value)
7272+
} else {
7273+
// The error message with format
7274+
// "Expected field `#FieldName` in struct `#StructName` as const value '{}', but got '{}'"
7275+
Err(serde::de::Error::custom(format!(
7276+
"Expected field `method` in struct `ElicitRequest` as const value '{}', but got '{}'",
7277+
expected, value
7278+
)))
7279+
}
7280+
}
7281+
7282+
}
7283+
71707284
#[deprecated(since = "0.3.0", note = "Use `RpcError` instead.")]
71717285
pub type JsonrpcErrorError = RpcError;
71727286
#[deprecated(since = "0.7.0", note = "Use `ElicitRequestedSchema` instead.")]

src/generated_schema/2025_06_18/schema_utils.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4453,5 +4453,12 @@ mod tests {
44534453
// default
44544454
let result = detect_message_type(&json!({}));
44554455
assert!(matches!(result, MessageTypes::Request));
4456+
4457+
// assert method type validation
4458+
let should_err:std::result::Result<PingRequest,_> = serde_json::from_value(json!({
4459+
"method":"wrong_method",
4460+
"params":null
4461+
}));
4462+
assert!(should_err.is_err());
44564463
}
44574464
}

0 commit comments

Comments
 (0)