Skip to content

chore: refactor server, prepare for browser reuse #490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 11 additions & 17 deletions src/connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,22 @@
* limitations under the License.
*/

import { Server } from '@modelcontextprotocol/sdk/server/index.js';
import { Server as McpServer } from '@modelcontextprotocol/sdk/server/index.js';
import { CallToolRequestSchema, ListToolsRequestSchema, Tool as McpTool } from '@modelcontextprotocol/sdk/types.js';
import { zodToJsonSchema } from 'zod-to-json-schema';

import { Context, packageJSON } from './context.js';
import { Context } from './context.js';
import { snapshotTools, visionTools } from './tools.js';
import { packageJSON } from './package.js';

import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
import { FullConfig } from './config.js';

export async function createConnection(config: FullConfig): Promise<Connection> {
export function createConnection(config: FullConfig): Connection {
const allTools = config.vision ? visionTools : snapshotTools;
const tools = allTools.filter(tool => !config.capabilities || tool.capability === 'core' || config.capabilities.includes(tool.capability));

const context = new Context(tools, config);
const server = new Server({ name: 'Playwright', version: packageJSON.version }, {
const server = new McpServer({ name: 'Playwright', version: packageJSON.version }, {
capabilities: {
tools: {},
}
Expand Down Expand Up @@ -74,25 +74,19 @@ export async function createConnection(config: FullConfig): Promise<Connection>
}
});

const connection = new Connection(server, context);
return connection;
return new Connection(server, context);
}

export class Connection {
readonly server: Server;
readonly server: McpServer;
readonly context: Context;

constructor(server: Server, context: Context) {
constructor(server: McpServer, context: Context) {
this.server = server;
this.context = context;
}

async connect(transport: Transport) {
await this.server.connect(transport);
await new Promise<void>(resolve => {
this.server.oninitialized = () => resolve();
});
this.context.clientVersion = this.server.getClientVersion();
this.server.oninitialized = () => {
this.context.clientVersion = this.server.getClientVersion();
};
}

async close() {
Expand Down
4 changes: 0 additions & 4 deletions src/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
*/

import fs from 'node:fs';
import url from 'node:url';
import os from 'node:os';
import path from 'node:path';

Expand Down Expand Up @@ -416,6 +415,3 @@ async function createUserDataDir(browserConfig: FullConfig['browser']) {
await fs.promises.mkdir(result, { recursive: true });
return result;
}

const __filename = url.fileURLToPath(import.meta.url);
export const packageJSON = JSON.parse(fs.readFileSync(path.join(path.dirname(__filename), '..', 'package.json'), 'utf8'));
22 changes: 22 additions & 0 deletions src/package.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/**
* Copyright (c) Microsoft Corporation.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import fs from 'node:fs';
import url from 'node:url';
import path from 'node:path';

const __filename = url.fileURLToPath(import.meta.url);
export const packageJSON = JSON.parse(fs.readFileSync(path.join(path.dirname(__filename), '..', 'package.json'), 'utf8'));
32 changes: 9 additions & 23 deletions src/program.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@
*/

import { program } from 'commander';

import { startHttpTransport, startStdioTransport } from './transport.js';
import { resolveCLIConfig } from './config.js';
// @ts-ignore
import { startTraceViewerServer } from 'playwright-core/lib/server';

import type { Connection } from './connection.js';
import { packageJSON } from './context.js';
import { startHttpTransport, startStdioTransport } from './transport.js';
import { resolveCLIConfig } from './config.js';
import { Server } from './server.js';
import { packageJSON } from './package.js';

program
.version('Version ' + packageJSON.version)
Expand Down Expand Up @@ -54,13 +53,13 @@ program
.option('--vision', 'Run server that uses screenshots (Aria snapshots are used by default)')
.action(async options => {
const config = await resolveCLIConfig(options);
const connectionList: Connection[] = [];
setupExitWatchdog(connectionList);
const server = new Server(config);
server.setupExitWatchdog();

if (options.port)
startHttpTransport(config, +options.port, options.host, connectionList);
startHttpTransport(server, +options.port, options.host);
else
await startStdioTransport(config, connectionList);
await startStdioTransport(server);

if (config.saveTrace) {
const server = await startTraceViewerServer();
Expand All @@ -71,21 +70,8 @@ program
}
});

function setupExitWatchdog(connectionList: Connection[]) {
const handleExit = async () => {
setTimeout(() => process.exit(0), 15000);
for (const connection of connectionList)
await connection.close();
process.exit(0);
};

process.stdin.on('close', handleExit);
process.on('SIGINT', handleExit);
process.on('SIGTERM', handleExit);
}

function semicolonSeparatedList(value: string): string[] {
return value.split(';').map(v => v.trim());
}

program.parse(process.argv);
void program.parseAsync(process.argv);
49 changes: 49 additions & 0 deletions src/server.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/**
* Copyright (c) Microsoft Corporation.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { createConnection } from './connection.js';

import type { FullConfig } from './config.js';
import type { Connection } from './connection.js';
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';

export class Server {
readonly config: FullConfig;
private _connectionList: Connection[] = [];

constructor(config: FullConfig) {
this.config = config;
}

async createConnection(transport: Transport): Promise<Connection> {
const connection = createConnection(this.config);
this._connectionList.push(connection);
await connection.server.connect(transport);
return connection;
}

setupExitWatchdog() {
const handleExit = async () => {
setTimeout(() => process.exit(0), 15000);
await Promise.all(this._connectionList.map(connection => connection.close()));
process.exit(0);
};

process.stdin.on('close', handleExit);
process.on('SIGINT', handleExit);
process.on('SIGTERM', handleExit);
}
}
39 changes: 13 additions & 26 deletions src/transport.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,13 @@ import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js';
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js';

import { createConnection } from './connection.js';
import type { Server } from './server.js';

import type { Connection } from './connection.js';
import type { FullConfig } from './config.js';

export async function startStdioTransport(config: FullConfig, connectionList: Connection[]) {
const connection = await createConnection(config);
await connection.connect(new StdioServerTransport());
connectionList.push(connection);
export async function startStdioTransport(server: Server) {
await server.createConnection(new StdioServerTransport());
}

async function handleSSE(config: FullConfig, req: http.IncomingMessage, res: http.ServerResponse, url: URL, sessions: Map<string, SSEServerTransport>, connectionList: Connection[]) {
async function handleSSE(server: Server, req: http.IncomingMessage, res: http.ServerResponse, url: URL, sessions: Map<string, SSEServerTransport>) {
if (req.method === 'POST') {
const sessionId = url.searchParams.get('sessionId');
if (!sessionId) {
Expand All @@ -51,15 +46,11 @@ async function handleSSE(config: FullConfig, req: http.IncomingMessage, res: htt
} else if (req.method === 'GET') {
const transport = new SSEServerTransport('/sse', res);
sessions.set(transport.sessionId, transport);
const connection = await createConnection(config);
await connection.connect(transport);
connectionList.push(connection);
const connection = await server.createConnection(transport);
res.on('close', () => {
sessions.delete(transport.sessionId);
connection.close().catch(e => {
// eslint-disable-next-line no-console
console.error(e);
});
// eslint-disable-next-line no-console
void connection.close().catch(e => console.error(e));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe await?

});
return;
}
Expand All @@ -68,7 +59,7 @@ async function handleSSE(config: FullConfig, req: http.IncomingMessage, res: htt
res.end('Method not allowed');
}

async function handleStreamable(config: FullConfig, req: http.IncomingMessage, res: http.ServerResponse, sessions: Map<string, StreamableHTTPServerTransport>, connectionList: Connection[]) {
async function handleStreamable(server: Server, req: http.IncomingMessage, res: http.ServerResponse, sessions: Map<string, StreamableHTTPServerTransport>) {
const sessionId = req.headers['mcp-session-id'] as string | undefined;
if (sessionId) {
const transport = sessions.get(sessionId);
Expand All @@ -91,28 +82,24 @@ async function handleStreamable(config: FullConfig, req: http.IncomingMessage, r
if (transport.sessionId)
sessions.delete(transport.sessionId);
};
const connection = await createConnection(config);
connectionList.push(connection);
await Promise.all([
connection.connect(transport),
transport.handleRequest(req, res),
]);
await server.createConnection(transport);
await transport.handleRequest(req, res);
return;
}

res.statusCode = 400;
res.end('Invalid request');
}

export function startHttpTransport(config: FullConfig, port: number, hostname: string | undefined, connectionList: Connection[]) {
export function startHttpTransport(server: Server, port: number, hostname: string | undefined) {
const sseSessions = new Map<string, SSEServerTransport>();
const streamableSessions = new Map<string, StreamableHTTPServerTransport>();
const httpServer = http.createServer(async (req, res) => {
const url = new URL(`http://localhost${req.url}`);
if (url.pathname.startsWith('/mcp'))
await handleStreamable(config, req, res, streamableSessions, connectionList);
await handleStreamable(server, req, res, streamableSessions);
else
await handleSSE(config, req, res, url, sseSessions, connectionList);
await handleSSE(server, req, res, url, sseSessions);
});
httpServer.listen(port, hostname, () => {
const address = httpServer.address();
Expand Down
49 changes: 1 addition & 48 deletions tests/sse.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,12 @@
*/

import url from 'node:url';
import http from 'node:http';
import { spawn } from 'node:child_process';
import path from 'node:path';
import type { AddressInfo } from 'node:net';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js';
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import { Client } from '@modelcontextprotocol/sdk/client/index.js';

import { createConnection } from '@playwright/mcp';

import { test as baseTest, expect } from './fixtures.js';

// NOTE: Can be removed when we drop Node.js 18 support and changed to import.meta.filename.
Expand Down Expand Up @@ -55,6 +50,7 @@ test('sse transport', async ({ serverEndpoint }) => {
const client = new Client({ name: 'test', version: '1.0.0' });
await client.connect(transport);
await client.ping();
await client.close();
});

test('streamable http transport', async ({ serverEndpoint }) => {
Expand All @@ -64,46 +60,3 @@ test('streamable http transport', async ({ serverEndpoint }) => {
await client.ping();
expect(transport.sessionId, 'has session support').toBeDefined();
});

test('sse transport via public API', async ({ server }, testInfo) => {
const userDataDir = testInfo.outputPath('user-data-dir');
const sessions = new Map<string, SSEServerTransport>();
const mcpServer = http.createServer(async (req, res) => {
if (req.method === 'GET') {
const connection = await createConnection({
browser: {
userDataDir,
launchOptions: { headless: true }
},
});
const transport = new SSEServerTransport('/sse', res);
sessions.set(transport.sessionId, transport);
await connection.connect(transport);
} else if (req.method === 'POST') {
const url = new URL(`http://localhost${req.url}`);
const sessionId = url.searchParams.get('sessionId');
if (!sessionId) {
res.statusCode = 400;
return res.end('Missing sessionId');
}
const transport = sessions.get(sessionId);
if (!transport) {
res.statusCode = 404;
return res.end('Session not found');
}
void transport.handlePostMessage(req, res);
}
});
await new Promise<void>(resolve => mcpServer.listen(0, () => resolve()));
const serverUrl = `http://localhost:${(mcpServer.address() as AddressInfo).port}/sse`;
const transport = new SSEClientTransport(new URL(serverUrl));
const client = new Client({ name: 'test', version: '1.0.0' });
await client.connect(transport);
await client.ping();
expect(await client.callTool({
name: 'browser_navigate',
arguments: { url: server.HELLO_WORLD },
})).toContainTextContent(`- generic [ref=e1]: Hello, world!`);
await client.close();
mcpServer.close();
});