Skip to content

Add support for DNS rebinding protections #284

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.modelcontextprotocol.server.transport;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

Expand Down Expand Up @@ -110,6 +111,11 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv
*/
private volatile boolean isClosing = false;

/**
* DNS rebinding protection configuration.
*/
private final DnsRebindingProtection dnsRebindingProtection;

/**
* Constructs a new WebFlux SSE server transport provider instance with the default
* SSE endpoint.
Expand All @@ -118,8 +124,10 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
* messages. This endpoint will be communicated to clients during SSE connection
* setup. Must not be null.
* @deprecated Use {@link #builder()} instead.
* @throws IllegalArgumentException if either parameter is null
*/
@Deprecated
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) {
this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT);
}
Expand All @@ -131,10 +139,12 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
* messages. This endpoint will be communicated to clients during SSE connection
* setup. Must not be null.
* @deprecated Use {@link #builder()} instead.
* @throws IllegalArgumentException if either parameter is null
*/
@Deprecated
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) {
this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint);
this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint, null);
}

/**
Expand All @@ -145,10 +155,31 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
* messages. This endpoint will be communicated to clients during SSE connection
* setup. Must not be null.
* @deprecated Use {@link #builder()} instead.
* @throws IllegalArgumentException if either parameter is null
*/
@Deprecated
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint) {
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null);
}

/**
* Constructs a new WebFlux SSE server transport provider instance with optional DNS
* rebinding protection.
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
* of MCP messages. Must not be null.
* @param baseUrl webflux message base path
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
* messages. This endpoint will be communicated to clients during SSE connection
* setup. Must not be null.
* @param sseEndpoint The endpoint URI where clients establish their SSE connections.
* @param dnsRebindingProtection The DNS rebinding protection configuration (may be
* null).
* @throws IllegalArgumentException if required parameters are null
*/
private WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint, DnsRebindingProtection dnsRebindingProtection) {
Assert.notNull(objectMapper, "ObjectMapper must not be null");
Assert.notNull(baseUrl, "Message base path must not be null");
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
Expand All @@ -158,6 +189,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU
this.baseUrl = baseUrl;
this.messageEndpoint = messageEndpoint;
this.sseEndpoint = sseEndpoint;
this.dnsRebindingProtection = dnsRebindingProtection;
this.routerFunction = RouterFunctions.route()
.GET(this.sseEndpoint, this::handleSseConnection)
.POST(this.messageEndpoint, this::handleMessage)
Expand Down Expand Up @@ -256,6 +288,12 @@ private Mono<ServerResponse> handleSseConnection(ServerRequest request) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
}

// Validate headers
Mono<ServerResponse> validationError = validateDnsRebindingProtection(request);
if (validationError != null) {
return validationError;
}

return ServerResponse.ok()
.contentType(MediaType.TEXT_EVENT_STREAM)
.body(Flux.<ServerSentEvent<?>>create(sink -> {
Expand Down Expand Up @@ -300,6 +338,19 @@ private Mono<ServerResponse> handleMessage(ServerRequest request) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down");
}

// Always validate Content-Type for POST requests
String contentType = request.headers().contentType().map(MediaType::toString).orElse(null);
if (contentType == null || !contentType.toLowerCase().startsWith("application/json")) {
logger.warn("Invalid Content-Type header: '{}'", contentType);
return ServerResponse.badRequest().bodyValue(new McpError("Content-Type must be application/json"));
}

// Validate headers for POST requests if DNS rebinding protection is configured
Mono<ServerResponse> validationError = validateDnsRebindingProtection(request);
if (validationError != null) {
return validationError;
}

if (request.queryParam("sessionId").isEmpty()) {
return ServerResponse.badRequest().bodyValue(new McpError("Session ID missing in message endpoint"));
}
Expand Down Expand Up @@ -397,6 +448,8 @@ public static class Builder {

private String sseEndpoint = DEFAULT_SSE_ENDPOINT;

private DnsRebindingProtection dnsRebindingProtection;

/**
* Sets the ObjectMapper to use for JSON serialization/deserialization of MCP
* messages.
Expand Down Expand Up @@ -447,6 +500,22 @@ public Builder sseEndpoint(String sseEndpoint) {
return this;
}

/**
* Sets the DNS rebinding protection configuration.
* <p>
* When set, this configuration will be used to create a header validator that
* enforces DNS rebinding protection rules. This will override any previously set
* header validator.
* @param config The DNS rebinding protection configuration
* @return this builder instance
* @throws IllegalArgumentException if config is null
*/
public Builder dnsRebindingProtection(DnsRebindingProtection config) {
Assert.notNull(config, "DNS rebinding protection config must not be null");
this.dnsRebindingProtection = config;
return this;
}

/**
* Builds a new instance of {@link WebFluxSseServerTransportProvider} with the
* configured settings.
Expand All @@ -457,9 +526,30 @@ public WebFluxSseServerTransportProvider build() {
Assert.notNull(objectMapper, "ObjectMapper must be set");
Assert.notNull(messageEndpoint, "Message endpoint must be set");

return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint);
return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint,
dnsRebindingProtection);
}

}

/**
* Validates DNS rebinding protection for the given request.
* @param request The incoming server request
* @return A ServerResponse with forbidden status if validation fails, or null if
* validation passes
*/
private Mono<ServerResponse> validateDnsRebindingProtection(ServerRequest request) {
if (dnsRebindingProtection != null) {
String hostHeader = request.headers().asHttpHeaders().getFirst("Host");
String originHeader = request.headers().asHttpHeaders().getFirst("Origin");
if (!dnsRebindingProtection.isValid(hostHeader, originHeader)) {
logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader,
originHeader);
return ServerResponse.status(HttpStatus.FORBIDDEN)
.bodyValue("DNS rebinding protection validation failed");
}
}
return null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import java.io.IOException;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
Expand Down Expand Up @@ -107,6 +108,11 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi
*/
private volatile boolean isClosing = false;

/**
* DNS rebinding protection configuration.
*/
private final DnsRebindingProtection dnsRebindingProtection;

/**
* Constructs a new WebMvcSseServerTransportProvider instance with the default SSE
* endpoint.
Expand All @@ -115,8 +121,10 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
* messages via HTTP POST. This endpoint will be communicated to clients through the
* SSE connection's initial endpoint event.
* @deprecated Use {@link #builder()} instead.
* @throws IllegalArgumentException if either objectMapper or messageEndpoint is null
*/
@Deprecated
public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) {
this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT);
}
Expand All @@ -129,10 +137,12 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag
* messages via HTTP POST. This endpoint will be communicated to clients through the
* SSE connection's initial endpoint event.
* @param sseEndpoint The endpoint URI where clients establish their SSE connections.
* @deprecated Use {@link #builder()} instead.
* @throws IllegalArgumentException if any parameter is null
*/
@Deprecated
public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) {
this(objectMapper, "", messageEndpoint, sseEndpoint);
this(objectMapper, "", messageEndpoint, sseEndpoint, null);
}

/**
Expand All @@ -145,10 +155,32 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag
* messages via HTTP POST. This endpoint will be communicated to clients through the
* SSE connection's initial endpoint event.
* @param sseEndpoint The endpoint URI where clients establish their SSE connections.
* @deprecated Use {@link #builder()} instead.
* @throws IllegalArgumentException if any parameter is null
*/
@Deprecated
public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint) {
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null);
}

/**
* Constructs a new WebMvcSseServerTransportProvider instance with DNS rebinding
* protection.
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
* of messages.
* @param baseUrl The base URL for the message endpoint, used to construct the full
* endpoint URL for clients.
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
* messages via HTTP POST. This endpoint will be communicated to clients through the
* SSE connection's initial endpoint event.
* @param sseEndpoint The endpoint URI where clients establish their SSE connections.
* @param dnsRebindingProtection The DNS rebinding protection configuration (may be
* null).
* @throws IllegalArgumentException if any required parameter is null
*/
private WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint, DnsRebindingProtection dnsRebindingProtection) {
Assert.notNull(objectMapper, "ObjectMapper must not be null");
Assert.notNull(baseUrl, "Message base URL must not be null");
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
Expand All @@ -158,6 +190,7 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUr
this.baseUrl = baseUrl;
this.messageEndpoint = messageEndpoint;
this.sseEndpoint = sseEndpoint;
this.dnsRebindingProtection = dnsRebindingProtection;
this.routerFunction = RouterFunctions.route()
.GET(this.sseEndpoint, this::handleSseConnection)
.POST(this.messageEndpoint, this::handleMessage)
Expand Down Expand Up @@ -247,6 +280,12 @@ private ServerResponse handleSseConnection(ServerRequest request) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down");
}

// Validate headers
ServerResponse validationError = validateDnsRebindingProtection(request);
if (validationError != null) {
return validationError;
}

String sessionId = UUID.randomUUID().toString();
logger.debug("Creating new SSE connection for session: {}", sessionId);

Expand Down Expand Up @@ -300,6 +339,19 @@ private ServerResponse handleMessage(ServerRequest request) {
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down");
}

// Always validate Content-Type for POST requests
String contentType = request.headers().asHttpHeaders().getFirst("Content-Type");
if (contentType == null || !contentType.toLowerCase().startsWith("application/json")) {
logger.warn("Invalid Content-Type header: '{}'", contentType);
return ServerResponse.badRequest().body(new McpError("Content-Type must be application/json"));
}

// Validate headers for POST requests if DNS rebinding protection is configured
ServerResponse validationError = validateDnsRebindingProtection(request);
if (validationError != null) {
return validationError;
}

if (request.param("sessionId").isEmpty()) {
return ServerResponse.badRequest().body(new McpError("Session ID missing in message endpoint"));
}
Expand Down Expand Up @@ -417,4 +469,23 @@ public void close() {

}

/**
* Validates DNS rebinding protection for the given request.
* @param request The incoming server request
* @return A ServerResponse with forbidden status if validation fails, or null if
* validation passes
*/
private ServerResponse validateDnsRebindingProtection(ServerRequest request) {
if (dnsRebindingProtection != null) {
String hostHeader = request.headers().asHttpHeaders().getFirst("Host");
String originHeader = request.headers().asHttpHeaders().getFirst("Origin");
if (!dnsRebindingProtection.isValid(hostHeader, originHeader)) {
logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader,
originHeader);
return ServerResponse.status(HttpStatus.FORBIDDEN).body("DNS rebinding protection validation failed");
}
}
return null;
}

}
Loading