Skip to content

Implement DNS Rebinding Protections per spec #565

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 27, 2025
Merged
22 changes: 21 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,11 @@ app.post('/mcp', async (req, res) => {
onsessioninitialized: (sessionId) => {
// Store the transport by session ID
transports[sessionId] = transport;
}
},
// DNS rebinding protection is disabled by default for backwards compatibility. If you are running this server
// locally, make sure to set:
// enableDnsRebindingProtection: true,
// allowedHosts: ['127.0.0.1'],
});

// Clean up transport when closed
Expand Down Expand Up @@ -596,6 +600,22 @@ This stateless approach is useful for:
- RESTful scenarios where each request is independent
- Horizontally scaled deployments without shared session state

#### DNS Rebinding Protection

The Streamable HTTP transport includes DNS rebinding protection to prevent security vulnerabilities. By default, this protection is **disabled** for backwards compatibility.

**Important**: If you are running this server locally, enable DNS rebinding protection:

```typescript
const transport = new StreamableHTTPServerTransport({
sessionIdGenerator: () => randomUUID(),
enableDnsRebindingProtection: true,

allowedHosts: ['127.0.0.1', ...],
allowedOrigins: ['https://yourdomain.com', 'https://www.yourdomain.com']
});
```

### Testing and Debugging

To test your server, you can use the [MCP Inspector](https://github.com/modelcontextprotocol/inspector). See its README for more information.
Expand Down
262 changes: 261 additions & 1 deletion src/server/sse.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -453,4 +453,264 @@ describe('SSEServerTransport', () => {
expect.stringContaining(`data: /messages?sessionId=${transport.sessionId}`));
});
});
});

describe('DNS rebinding protection', () => {
beforeEach(() => {
jest.clearAllMocks();
});

describe('Host header validation', () => {
it('should accept requests with allowed host headers', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes, {
allowedHosts: ['localhost:3000', 'example.com'],
enableDnsRebindingProtection: true,
});
await transport.start();

const mockReq = createMockRequest({
headers: {
host: 'localhost:3000',
'content-type': 'application/json',
}
});
const mockHandleRes = createMockResponse();

await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202);
expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted');
});

it('should reject requests with disallowed host headers', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes, {
allowedHosts: ['localhost:3000'],
enableDnsRebindingProtection: true,
});
await transport.start();

const mockReq = createMockRequest({
headers: {
host: 'evil.com',
'content-type': 'application/json',
}
});
const mockHandleRes = createMockResponse();

await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403);
expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Host header: evil.com');
});

it('should reject requests without host header when allowedHosts is configured', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes, {
allowedHosts: ['localhost:3000'],
enableDnsRebindingProtection: true,
});
await transport.start();

const mockReq = createMockRequest({
headers: {
'content-type': 'application/json',
}
});
const mockHandleRes = createMockResponse();

await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403);
expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Host header: undefined');
});
});

describe('Origin header validation', () => {
it('should accept requests with allowed origin headers', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes, {
allowedOrigins: ['http://localhost:3000', 'https://example.com'],
enableDnsRebindingProtection: true,
});
await transport.start();

const mockReq = createMockRequest({
headers: {
origin: 'http://localhost:3000',
'content-type': 'application/json',
}
});
const mockHandleRes = createMockResponse();

await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202);
expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted');
});

it('should reject requests with disallowed origin headers', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes, {
allowedOrigins: ['http://localhost:3000'],
enableDnsRebindingProtection: true,
});
await transport.start();

const mockReq = createMockRequest({
headers: {
origin: 'http://evil.com',
'content-type': 'application/json',
}
});
const mockHandleRes = createMockResponse();

await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403);
expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Origin header: http://evil.com');
});
});

describe('Content-Type validation', () => {
it('should accept requests with application/json content-type', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes);
await transport.start();

const mockReq = createMockRequest({
headers: {
'content-type': 'application/json',
}
});
const mockHandleRes = createMockResponse();

await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202);
expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted');
});

it('should accept requests with application/json with charset', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes);
await transport.start();

const mockReq = createMockRequest({
headers: {
'content-type': 'application/json; charset=utf-8',
}
});
const mockHandleRes = createMockResponse();

await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202);
expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted');
});

it('should reject requests with non-application/json content-type when protection is enabled', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes);
await transport.start();

const mockReq = createMockRequest({
headers: {
'content-type': 'text/plain',
}
});
const mockHandleRes = createMockResponse();

await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes.writeHead).toHaveBeenCalledWith(400);
expect(mockHandleRes.end).toHaveBeenCalledWith('Error: Unsupported content-type: text/plain');
});
});

describe('enableDnsRebindingProtection option', () => {
it('should skip all validations when enableDnsRebindingProtection is false', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes, {
allowedHosts: ['localhost:3000'],
allowedOrigins: ['http://localhost:3000'],
enableDnsRebindingProtection: false,
});
await transport.start();

const mockReq = createMockRequest({
headers: {
host: 'evil.com',
origin: 'http://evil.com',
'content-type': 'text/plain',
}
});
const mockHandleRes = createMockResponse();

await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });

// Should pass even with invalid headers because protection is disabled
expect(mockHandleRes.writeHead).toHaveBeenCalledWith(400);
// The error should be from content-type parsing, not DNS rebinding protection
expect(mockHandleRes.end).toHaveBeenCalledWith('Error: Unsupported content-type: text/plain');
});
});

describe('Combined validations', () => {
it('should validate both host and origin when both are configured', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes, {
allowedHosts: ['localhost:3000'],
allowedOrigins: ['http://localhost:3000'],
enableDnsRebindingProtection: true,
});
await transport.start();

// Valid host, invalid origin
const mockReq1 = createMockRequest({
headers: {
host: 'localhost:3000',
origin: 'http://evil.com',
'content-type': 'application/json',
}
});
const mockHandleRes1 = createMockResponse();

await transport.handlePostMessage(mockReq1, mockHandleRes1, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes1.writeHead).toHaveBeenCalledWith(403);
expect(mockHandleRes1.end).toHaveBeenCalledWith('Invalid Origin header: http://evil.com');

// Invalid host, valid origin
const mockReq2 = createMockRequest({
headers: {
host: 'evil.com',
origin: 'http://localhost:3000',
'content-type': 'application/json',
}
});
const mockHandleRes2 = createMockResponse();

await transport.handlePostMessage(mockReq2, mockHandleRes2, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes2.writeHead).toHaveBeenCalledWith(403);
expect(mockHandleRes2.end).toHaveBeenCalledWith('Invalid Host header: evil.com');

// Both valid
const mockReq3 = createMockRequest({
headers: {
host: 'localhost:3000',
origin: 'http://localhost:3000',
'content-type': 'application/json',
}
});
const mockHandleRes3 = createMockResponse();

await transport.handlePostMessage(mockReq3, mockHandleRes3, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes3.writeHead).toHaveBeenCalledWith(202);
expect(mockHandleRes3.end).toHaveBeenCalledWith('Accepted');
});
});
});
});
64 changes: 64 additions & 0 deletions src/server/sse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,29 @@ import { URL } from 'url';

const MAXIMUM_MESSAGE_SIZE = "4mb";

/**
* Configuration options for SSEServerTransport.
*/
export interface SSEServerTransportOptions {
/**
* List of allowed host header values for DNS rebinding protection.
* If not specified, host validation is disabled.
*/
allowedHosts?: string[];

/**
* List of allowed origin header values for DNS rebinding protection.
* If not specified, origin validation is disabled.
*/
allowedOrigins?: string[];

/**
* Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured).
* Default is false for backwards compatibility.
*/
enableDnsRebindingProtection?: boolean;
}

/**
* Server transport for SSE: this will send messages over an SSE connection and receive messages from HTTP POST requests.
*
Expand All @@ -17,6 +40,7 @@ const MAXIMUM_MESSAGE_SIZE = "4mb";
export class SSEServerTransport implements Transport {
private _sseResponse?: ServerResponse;
private _sessionId: string;
private _options: SSEServerTransportOptions;
onclose?: () => void;
onerror?: (error: Error) => void;
onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void;
Expand All @@ -27,8 +51,39 @@ export class SSEServerTransport implements Transport {
constructor(
private _endpoint: string,
private res: ServerResponse,
options?: SSEServerTransportOptions,
) {
this._sessionId = randomUUID();
this._options = options || {enableDnsRebindingProtection: false};
}

/**
* Validates request headers for DNS rebinding protection.
* @returns Error message if validation fails, undefined if validation passes.
*/
private validateRequestHeaders(req: IncomingMessage): string | undefined {
// Skip validation if protection is not enabled
if (!this._options.enableDnsRebindingProtection) {
return undefined;
}

// Validate Host header if allowedHosts is configured
if (this._options.allowedHosts && this._options.allowedHosts.length > 0) {
const hostHeader = req.headers.host;
if (!hostHeader || !this._options.allowedHosts.includes(hostHeader)) {
return `Invalid Host header: ${hostHeader}`;
}
}

// Validate Origin header if allowedOrigins is configured
if (this._options.allowedOrigins && this._options.allowedOrigins.length > 0) {
const originHeader = req.headers.origin;
if (!originHeader || !this._options.allowedOrigins.includes(originHeader)) {
return `Invalid Origin header: ${originHeader}`;
}
}

return undefined;
}

/**
Expand Down Expand Up @@ -85,6 +140,15 @@ export class SSEServerTransport implements Transport {
res.writeHead(500).end(message);
throw new Error(message);
}

// Validate request headers for DNS rebinding protection
const validationError = this.validateRequestHeaders(req);
if (validationError) {
res.writeHead(403).end(validationError);
this.onerror?.(new Error(validationError));
return;
}

const authInfo: AuthInfo | undefined = req.auth;
const requestInfo: RequestInfo = { headers: req.headers };

Expand Down
Loading
Loading