Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 115 additions & 1 deletion src/client/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,66 @@ import {
SUPPORTED_PROTOCOL_VERSIONS,
type SubscribeRequest,
type Tool,
type UnsubscribeRequest
type UnsubscribeRequest,
ElicitResultSchema,
ElicitRequestSchema
} from '../types.js';
import { AjvJsonSchemaValidator } from '../validation/ajv-provider.js';
import type { JsonSchemaType, JsonSchemaValidator, jsonSchemaValidator } from '../validation/types.js';
import { ZodLiteral, ZodObject, z } from 'zod';
import type { RequestHandlerExtra } from '../shared/protocol.js';

function applyElicitationDefaults(schema: JsonSchemaType | undefined, data: unknown): void {
if (!schema || data === null || typeof data !== 'object') return;

// Handle object properties
if (schema.type === 'object' && schema.properties && typeof schema.properties === 'object') {
const obj = data as Record<string, unknown>;
const props = schema.properties as Record<string, JsonSchemaType & { default?: unknown }>;
for (const key of Object.keys(props)) {
const propSchema = props[key];
// If missing or explicitly undefined, apply default if present
if (obj[key] === undefined && Object.prototype.hasOwnProperty.call(propSchema, 'default')) {
obj[key] = propSchema.default;
}
// Recurse into existing nested objects/arrays
if (obj[key] !== undefined) {
applyElicitationDefaults(propSchema, obj[key]);
}
}
}

// Handle arrays
if (schema.type === 'array' && Array.isArray(data) && schema.items) {
const itemsSchema = schema.items as JsonSchemaType | JsonSchemaType[];
if (Array.isArray(itemsSchema)) {
for (let i = 0; i < data.length && i < itemsSchema.length; i++) {
applyElicitationDefaults(itemsSchema[i], data[i]);
}
} else {
for (const item of data) {
applyElicitationDefaults(itemsSchema, item);
}
}
}

// Combine schemas
if (Array.isArray(schema.allOf)) {
for (const sub of schema.allOf) {
applyElicitationDefaults(sub, data);
}
}
if (Array.isArray(schema.anyOf)) {
for (const sub of schema.anyOf) {
applyElicitationDefaults(sub, data);
}
}
if (Array.isArray(schema.oneOf)) {
for (const sub of schema.oneOf) {
applyElicitationDefaults(sub, data);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

}

export type ClientOptions = ProtocolOptions & {
/**
Expand Down Expand Up @@ -141,6 +197,64 @@ export class Client<
this._capabilities = mergeCapabilities(this._capabilities, capabilities);
}

/**
* Override request handler registration to enforce client-side validation for elicitation.
*/
public override setRequestHandler<
T extends ZodObject<{
method: ZodLiteral<string>;
}>
>(
requestSchema: T,
handler: (
request: z.infer<T>,
extra: RequestHandlerExtra<ClientRequest | RequestT, ClientNotification | NotificationT>
) => ClientResult | ResultT | Promise<ClientResult | ResultT>
): void {
const method = requestSchema.shape.method.value;
if (method === 'elicitation/create') {
const wrappedHandler = async (
request: z.infer<T>,
extra: RequestHandlerExtra<ClientRequest | RequestT, ClientNotification | NotificationT>
): Promise<ClientResult | ResultT> => {
const validatedRequest = ElicitRequestSchema.safeParse(request);
if (!validatedRequest.success) {
throw new McpError(ErrorCode.InvalidParams, `Invalid elicitation request: ${validatedRequest.error.message}`);
}

const result = await Promise.resolve(handler(request, extra));

const validationResult = ElicitResultSchema.safeParse(result);
if (!validationResult.success) {
throw new McpError(ErrorCode.InvalidParams, `Invalid elicitation result: ${validationResult.error.message}`);
}

const validatedResult = validationResult.data;

if (
this._capabilities.elicitation?.applyDefaults &&
validatedResult.action === 'accept' &&
validatedResult.content &&
validatedRequest.data.params.requestedSchema
) {
try {
applyElicitationDefaults(validatedRequest.data.params.requestedSchema, validatedResult.content);
} catch {
// gracefully ignore errors in default application
}
}

return validatedResult;
};

// Install the wrapped handler
return super.setRequestHandler(requestSchema, wrappedHandler as unknown as typeof handler);
}

// Non-elicitation handlers use default behavior
return super.setRequestHandler(requestSchema, handler);
}

protected assertCapability(capability: keyof ServerCapabilities, method: string): void {
if (!this._serverCapabilities?.[capability]) {
throw new Error(`Server does not support ${capability} (required for ${method})`);
Expand Down
121 changes: 110 additions & 11 deletions src/server/elicitation.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ function testElicitationFlow(validatorProvider: typeof ajvProvider | typeof cfWo
}
);

const client = new Client({ name: 'test-client', version: '1.0.0' }, { capabilities: { elicitation: {} } });
const client = new Client(
{ name: 'test-client', version: '1.0.0' },
{ capabilities: { elicitation: {} }, jsonSchemaValidator: validatorProvider }
);

client.setRequestHandler(ElicitRequestSchema, _request => ({
action: 'accept',
Expand Down Expand Up @@ -73,7 +76,10 @@ function testElicitationFlow(validatorProvider: typeof ajvProvider | typeof cfWo
}
);

const client = new Client({ name: 'test-client', version: '1.0.0' }, { capabilities: { elicitation: {} } });
const client = new Client(
{ name: 'test-client', version: '1.0.0' },
{ capabilities: { elicitation: {} }, jsonSchemaValidator: validatorProvider }
);

client.setRequestHandler(ElicitRequestSchema, _request => ({
action: 'accept',
Expand Down Expand Up @@ -109,7 +115,10 @@ function testElicitationFlow(validatorProvider: typeof ajvProvider | typeof cfWo
}
);

const client = new Client({ name: 'test-client', version: '1.0.0' }, { capabilities: { elicitation: {} } });
const client = new Client(
{ name: 'test-client', version: '1.0.0' },
{ capabilities: { elicitation: {} }, jsonSchemaValidator: validatorProvider }
);

client.setRequestHandler(ElicitRequestSchema, _request => ({
action: 'accept',
Expand Down Expand Up @@ -145,7 +154,10 @@ function testElicitationFlow(validatorProvider: typeof ajvProvider | typeof cfWo
}
);

const client = new Client({ name: 'test-client', version: '1.0.0' }, { capabilities: { elicitation: {} } });
const client = new Client(
{ name: 'test-client', version: '1.0.0' },
{ capabilities: { elicitation: {} }, jsonSchemaValidator: validatorProvider }
);

const userData = {
name: 'Jane Smith',
Expand Down Expand Up @@ -200,7 +212,10 @@ function testElicitationFlow(validatorProvider: typeof ajvProvider | typeof cfWo
}
);

const client = new Client({ name: 'test-client', version: '1.0.0' }, { capabilities: { elicitation: {} } });
const client = new Client(
{ name: 'test-client', version: '1.0.0' },
{ capabilities: { elicitation: {} }, jsonSchemaValidator: validatorProvider }
);

client.setRequestHandler(ElicitRequestSchema, _request => ({
action: 'accept',
Expand Down Expand Up @@ -237,7 +252,10 @@ function testElicitationFlow(validatorProvider: typeof ajvProvider | typeof cfWo
}
);

const client = new Client({ name: 'test-client', version: '1.0.0' }, { capabilities: { elicitation: {} } });
const client = new Client(
{ name: 'test-client', version: '1.0.0' },
{ capabilities: { elicitation: {} }, jsonSchemaValidator: validatorProvider }
);

client.setRequestHandler(ElicitRequestSchema, _request => ({
action: 'accept',
Expand Down Expand Up @@ -274,7 +292,10 @@ function testElicitationFlow(validatorProvider: typeof ajvProvider | typeof cfWo
}
);

const client = new Client({ name: 'test-client', version: '1.0.0' }, { capabilities: { elicitation: {} } });
const client = new Client(
{ name: 'test-client', version: '1.0.0' },
{ capabilities: { elicitation: {} }, jsonSchemaValidator: validatorProvider }
);

client.setRequestHandler(ElicitRequestSchema, _request => ({
action: 'accept',
Expand Down Expand Up @@ -307,7 +328,10 @@ function testElicitationFlow(validatorProvider: typeof ajvProvider | typeof cfWo
}
);

const client = new Client({ name: 'test-client', version: '1.0.0' }, { capabilities: { elicitation: {} } });
const client = new Client(
{ name: 'test-client', version: '1.0.0' },
{ capabilities: { elicitation: {} }, jsonSchemaValidator: validatorProvider }
);

client.setRequestHandler(ElicitRequestSchema, _request => ({
action: 'accept',
Expand Down Expand Up @@ -340,7 +364,10 @@ function testElicitationFlow(validatorProvider: typeof ajvProvider | typeof cfWo
}
);

const client = new Client({ name: 'test-client', version: '1.0.0' }, { capabilities: { elicitation: {} } });
const client = new Client(
{ name: 'test-client', version: '1.0.0' },
{ capabilities: { elicitation: {} }, jsonSchemaValidator: validatorProvider }
);

client.setRequestHandler(ElicitRequestSchema, _request => ({
action: 'accept',
Expand Down Expand Up @@ -374,7 +401,10 @@ function testElicitationFlow(validatorProvider: typeof ajvProvider | typeof cfWo
}
);

const client = new Client({ name: 'test-client', version: '1.0.0' }, { capabilities: { elicitation: {} } });
const client = new Client(
{ name: 'test-client', version: '1.0.0' },
{ capabilities: { elicitation: {} }, jsonSchemaValidator: validatorProvider }
);

client.setRequestHandler(ElicitRequestSchema, _request => ({
action: 'decline'
Expand Down Expand Up @@ -408,7 +438,10 @@ function testElicitationFlow(validatorProvider: typeof ajvProvider | typeof cfWo
}
);

const client = new Client({ name: 'test-client', version: '1.0.0' }, { capabilities: { elicitation: {} } });
const client = new Client(
{ name: 'test-client', version: '1.0.0' },
{ capabilities: { elicitation: {} }, jsonSchemaValidator: validatorProvider }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need this anymore

);

client.setRequestHandler(ElicitRequestSchema, _request => ({
action: 'cancel'
Expand Down Expand Up @@ -609,6 +642,72 @@ function testElicitationFlow(validatorProvider: typeof ajvProvider | typeof cfWo
});
});

test(`${validatorName}: should default missing fields from schema defaults`, async () => {
const server = new Server(
{ name: 'test-server', version: '1.0.0' },
{
capabilities: {},
jsonSchemaValidator: validatorProvider
}
);

const client = new Client(
{ name: 'test-client', version: '1.0.0' },
{
capabilities: {
elicitation: {
applyDefaults: true
}
}
}
);

// Client returns no values; SDK should apply defaults automatically (and validate)
client.setRequestHandler(ElicitRequestSchema, request => {
expect(request.params.requestedSchema).toEqual({
type: 'object',
properties: {
subscribe: { type: 'boolean', default: true },
nickname: { type: 'string', default: 'Guest' },
age: { type: 'integer', minimum: 0, maximum: 150, default: 18 },
color: { type: 'string', enum: ['red', 'green'], default: 'green' }
},
required: ['subscribe', 'nickname', 'age', 'color']
});
return {
action: 'accept',
content: {}
};
});

const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);

const result = await server.elicitInput({
message: 'Provide your preferences',
requestedSchema: {
type: 'object',
properties: {
subscribe: { type: 'boolean', default: true },
nickname: { type: 'string', default: 'Guest' },
age: { type: 'integer', minimum: 0, maximum: 150, default: 18 },
color: { type: 'string', enum: ['red', 'green'], default: 'green' }
},
required: ['subscribe', 'nickname', 'age', 'color']
}
});

expect(result).toEqual({
action: 'accept',
content: {
subscribe: true,
nickname: 'Guest',
age: 18,
color: 'green'
}
});
});

test(`${validatorName}: should reject invalid email format`, async () => {
const server = new Server(
{ name: 'test-server', version: '1.0.0' },
Expand Down
Loading
Loading