Skip to content

Commit f6c3fde

Browse files
committed
DNS rebinding protection: check host header
Signed-off-by: Daniel Garnier-Moiroux <git@garnier.wf>
1 parent b57b126 commit f6c3fde

File tree

5 files changed

+551
-119
lines changed

5 files changed

+551
-119
lines changed

mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java

Lines changed: 83 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212

1313
/**
1414
* Default implementation of {@link ServerTransportSecurityValidator} that validates the
15-
* Origin header against a list of allowed origins.
15+
* Origin and Host headers against lists of allowed values.
1616
*
1717
* <p>
18-
* Supports exact matches and wildcard port patterns (e.g., "http://example.com:*").
18+
* Supports exact matches and wildcard port patterns (e.g., "http://example.com:*" for
19+
* origins, "example.com:*" for hosts).
1920
*
2021
* @author Daniel Garnier-Moiroux
2122
* @see ServerTransportSecurityValidator
@@ -25,29 +26,49 @@ public final class DefaultServerTransportSecurityValidator implements ServerTran
2526

2627
private static final String ORIGIN_HEADER = "Origin";
2728

29+
private static final String HOST_HEADER = "Host";
30+
2831
private final List<String> allowedOrigins;
2932

33+
private final List<String> allowedHosts;
34+
3035
/**
31-
* Creates a new validator with the specified allowed origins.
36+
* Creates a new validator with the specified allowed origins and hosts.
3237
* @param allowedOrigins List of allowed origin patterns. Supports exact matches
3338
* (e.g., "http://example.com:8080") and wildcard ports (e.g., "http://example.com:*")
39+
* @param allowedHosts List of allowed host patterns. Supports exact matches (e.g.,
40+
* "example.com:8080") and wildcard ports (e.g., "example.com:*")
3441
*/
35-
private DefaultServerTransportSecurityValidator(List<String> allowedOrigins) {
42+
private DefaultServerTransportSecurityValidator(List<String> allowedOrigins, List<String> allowedHosts) {
3643
Assert.notNull(allowedOrigins, "allowedOrigins must not be null");
44+
Assert.notNull(allowedHosts, "allowedHosts must not be null");
3745
this.allowedOrigins = allowedOrigins;
46+
this.allowedHosts = allowedHosts;
3847
}
3948

4049
@Override
4150
public void validateHeaders(Map<String, List<String>> headers) throws ServerTransportSecurityException {
51+
boolean missingHost = true;
4252
for (Map.Entry<String, List<String>> entry : headers.entrySet()) {
4353
if (ORIGIN_HEADER.equalsIgnoreCase(entry.getKey())) {
4454
List<String> values = entry.getValue();
45-
if (values != null && !values.isEmpty()) {
46-
validateOrigin(values.get(0));
55+
if (values == null || values.isEmpty()) {
56+
throw new ServerTransportSecurityException(403, "Invalid Origin header");
57+
}
58+
validateOrigin(values.get(0));
59+
}
60+
else if (HOST_HEADER.equalsIgnoreCase(entry.getKey())) {
61+
missingHost = false;
62+
List<String> values = entry.getValue();
63+
if (values == null || values.isEmpty()) {
64+
throw new ServerTransportSecurityException(421, "Invalid Host header");
4765
}
48-
break;
66+
validateHost(values.get(0));
4967
}
5068
}
69+
if (!allowedHosts.isEmpty() && missingHost) {
70+
throw new ServerTransportSecurityException(421, "Invalid Host header");
71+
}
5172
}
5273

5374
/**
@@ -79,6 +100,37 @@ else if (allowed.endsWith(":*")) {
79100
throw new ServerTransportSecurityException(403, "Invalid Origin header");
80101
}
81102

103+
/**
104+
* Validates a single host value against the allowed hosts.
105+
* @param host The host header value, or null if not present
106+
* @throws ServerTransportSecurityException if the host is not allowed
107+
*/
108+
private void validateHost(String host) throws ServerTransportSecurityException {
109+
if (allowedHosts.isEmpty()) {
110+
return;
111+
}
112+
113+
// Host is required
114+
if (host == null || host.isBlank()) {
115+
throw new ServerTransportSecurityException(421, "Invalid Host header");
116+
}
117+
118+
for (String allowed : allowedHosts) {
119+
if (allowed.equals(host)) {
120+
return;
121+
}
122+
else if (allowed.endsWith(":*")) {
123+
// Wildcard port pattern: "example.com:*"
124+
String baseHost = allowed.substring(0, allowed.length() - 2);
125+
if (host.equals(baseHost) || host.startsWith(baseHost + ":")) {
126+
return;
127+
}
128+
}
129+
}
130+
131+
throw new ServerTransportSecurityException(421, "Invalid Host header");
132+
}
133+
82134
/**
83135
* Creates a new builder for constructing a DefaultServerTransportSecurityValidator.
84136
* @return A new builder instance
@@ -94,6 +146,8 @@ public static class Builder {
94146

95147
private final List<String> allowedOrigins = new ArrayList<>();
96148

149+
private final List<String> allowedHosts = new ArrayList<>();
150+
97151
/**
98152
* Adds an allowed origin pattern.
99153
* @param origin The origin to allow (e.g., "http://localhost:8080" or
@@ -116,12 +170,33 @@ public Builder allowedOrigins(List<String> origins) {
116170
return this;
117171
}
118172

173+
/**
174+
* Adds an allowed host pattern.
175+
* @param host The host to allow (e.g., "localhost:8080" or "example.com:*")
176+
* @return this builder instance
177+
*/
178+
public Builder allowedHost(String host) {
179+
this.allowedHosts.add(host);
180+
return this;
181+
}
182+
183+
/**
184+
* Adds multiple allowed host patterns.
185+
* @param hosts The hosts to allow
186+
* @return this builder instance
187+
*/
188+
public Builder allowedHosts(List<String> hosts) {
189+
Assert.notNull(hosts, "hosts must not be null");
190+
this.allowedHosts.addAll(hosts);
191+
return this;
192+
}
193+
119194
/**
120195
* Builds the validator instance.
121196
* @return A new DefaultServerTransportSecurityValidator
122197
*/
123198
public DefaultServerTransportSecurityValidator build() {
124-
return new DefaultServerTransportSecurityValidator(allowedOrigins);
199+
return new DefaultServerTransportSecurityValidator(allowedOrigins, allowedHosts);
125200
}
126201

127202
}

0 commit comments

Comments
 (0)