Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import securityHeadersMiddleware from '../securityHeadersMiddleware';

describe('securityHeadersMiddleware', () => {
let req, res, next;

beforeEach(() => {
req = {
headers: {},
};
res = {
setHeader: () => {},
};
next = jest.fn();
});

afterEach(() => {
jest.clearAllMocks();
});

it('should block requests from different origins', () => {
req.headers.origin = 'https://example.com';
const middleware = securityHeadersMiddleware({});
middleware(req, res, next);
expect(next).toHaveBeenCalledWith(expect.any(Error));
});

it('should allow requests from localhost', () => {
req.headers.origin = 'http://localhost:3000';
const middleware = securityHeadersMiddleware({});
middleware(req, res, next);
expect(next).toHaveBeenCalled();
});

it('should allow requests from devtools', () => {
req.headers.origin = 'devtools://devtools';
const middleware = securityHeadersMiddleware({});
middleware(req, res, next);
expect(next).toHaveBeenCalled();
});

it('should allow requests from custom host if provided in options', () => {
req.headers.origin = 'http://customhost.com';
const middleware = securityHeadersMiddleware({host: 'customhost.com'});
middleware(req, res, next);
expect(next).toHaveBeenCalled();
});

it('should block requests from custom host if provided in options but not matching', () => {
req.headers.origin = 'http://anotherhost.com';
const middleware = securityHeadersMiddleware({host: 'customhost.com'});
middleware(req, res, next);
expect(next).toHaveBeenCalledWith(expect.any(Error));
});
});
2 changes: 1 addition & 1 deletion packages/cli-server-api/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ export function createDevServerMiddleware(options: MiddlewareOptions) {
const eventsSocketEndpoint = createEventsSocketEndpoint(broadcast);

const middleware = connect()
.use(securityHeadersMiddleware)
.use(securityHeadersMiddleware(options))
// @ts-ignore compression and connect types mismatch
.use(compression())
.use(nocache())
Expand Down
55 changes: 35 additions & 20 deletions packages/cli-server-api/src/securityHeadersMiddleware.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,44 @@
*/
import http from 'http';

export default function securityHeadersMiddleware(
type MiddlewareOptions = {
host?: string;
};

type MiddlewareFn = (
req: http.IncomingMessage,
res: http.ServerResponse,
next: (err?: any) => void,
) {
// Block any cross origin request.
if (
typeof req.headers.origin === 'string' &&
!req.headers.origin.match(/^https?:\/\/localhost:/) &&
!req.headers.origin.startsWith('devtools://devtools')
) {
next(
new Error(
'Unauthorized request from ' +
req.headers.origin +
'. This may happen because of a conflicting browser extension. Please try to disable it and try again.',
),
);
return;
}
) => void;

export default function securityHeadersMiddleware(
options: MiddlewareOptions,
): MiddlewareFn {
return (
req: http.IncomingMessage,
res: http.ServerResponse,
next: (err?: any) => void,
) => {
const host = options.host ? options.host : 'localhost';
// Block any cross origin request.
if (
typeof req.headers.origin === 'string' &&
!req.headers.origin.match(new RegExp('^https?://' + host + ':')) &&
!req.headers.origin.startsWith('devtools://devtools')
) {
next(
new Error(
'Unauthorized request from ' +
req.headers.origin +
'. This may happen because of a conflicting browser extension. Please try to disable it and try again.',
),
);
return;
}

// Block MIME-type sniffing.
res.setHeader('X-Content-Type-Options', 'nosniff');
// Block MIME-type sniffing.
res.setHeader('X-Content-Type-Options', 'nosniff');

next();
next();
};
}