Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 90 additions & 1 deletion src/client/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,41 @@ import {
SUPPORTED_PROTOCOL_VERSIONS,
type SubscribeRequest,
type Tool,
type UnsubscribeRequest
type UnsubscribeRequest,
ElicitResultSchema,
ElicitRequestSchema
} from '../types.js';
import { AjvJsonSchemaValidator } from '../validation/ajv-provider.js';
import type { JsonSchemaType, JsonSchemaValidator, jsonSchemaValidator } from '../validation/types.js';
import { ZodLiteral, ZodObject, z } from 'zod';
import type { RequestHandlerExtra } from '../shared/protocol.js';

/**
* Elicitation default application helper. Applies defaults to the data based on the schema.
*
* @param schema - The schema to apply defaults to.
* @param data - The data to apply defaults to.
*/
function applyElicitationDefaults(schema: JsonSchemaType | undefined, data: unknown): void {
if (!schema || data === null || typeof data !== 'object') return;

// Handle object properties
if (schema.type === 'object' && schema.properties && typeof schema.properties === 'object') {
const obj = data as Record<string, unknown>;
const props = schema.properties as Record<string, JsonSchemaType & { default?: unknown }>;
for (const key of Object.keys(props)) {
const propSchema = props[key];
// If missing or explicitly undefined, apply default if present
if (obj[key] === undefined && Object.prototype.hasOwnProperty.call(propSchema, 'default')) {
obj[key] = propSchema.default;
}
// Recurse into existing nested objects/arrays
if (obj[key] !== undefined) {
applyElicitationDefaults(propSchema, obj[key]);
}
}
}
}

export type ClientOptions = ProtocolOptions & {
/**
Expand Down Expand Up @@ -141,6 +172,64 @@ export class Client<
this._capabilities = mergeCapabilities(this._capabilities, capabilities);
}

/**
* Override request handler registration to enforce client-side validation for elicitation.
*/
public override setRequestHandler<
T extends ZodObject<{
method: ZodLiteral<string>;
}>
>(
requestSchema: T,
handler: (
request: z.infer<T>,
extra: RequestHandlerExtra<ClientRequest | RequestT, ClientNotification | NotificationT>
) => ClientResult | ResultT | Promise<ClientResult | ResultT>
): void {
const method = requestSchema.shape.method.value;
if (method === 'elicitation/create') {
const wrappedHandler = async (
request: z.infer<T>,
extra: RequestHandlerExtra<ClientRequest | RequestT, ClientNotification | NotificationT>
): Promise<ClientResult | ResultT> => {
const validatedRequest = ElicitRequestSchema.safeParse(request);
if (!validatedRequest.success) {
throw new McpError(ErrorCode.InvalidParams, `Invalid elicitation request: ${validatedRequest.error.message}`);
}

const result = await Promise.resolve(handler(request, extra));

const validationResult = ElicitResultSchema.safeParse(result);
if (!validationResult.success) {
throw new McpError(ErrorCode.InvalidParams, `Invalid elicitation result: ${validationResult.error.message}`);
}

const validatedResult = validationResult.data;

if (
this._capabilities.elicitation?.applyDefaults &&
validatedResult.action === 'accept' &&
validatedResult.content &&
validatedRequest.data.params.requestedSchema
) {
try {
applyElicitationDefaults(validatedRequest.data.params.requestedSchema, validatedResult.content);
} catch {
// gracefully ignore errors in default application
}
}

return validatedResult;
};

// Install the wrapped handler
return super.setRequestHandler(requestSchema, wrappedHandler as unknown as typeof handler);
}

// Non-elicitation handlers use default behavior
return super.setRequestHandler(requestSchema, handler);
}

protected assertCapability(capability: keyof ServerCapabilities, method: string): void {
if (!this._serverCapabilities?.[capability]) {
throw new Error(`Server does not support ${capability} (required for ${method})`);
Expand Down
66 changes: 66 additions & 0 deletions src/server/elicitation.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,72 @@ function testElicitationFlow(validatorProvider: typeof ajvProvider | typeof cfWo
});
});

test(`${validatorName}: should default missing fields from schema defaults`, async () => {
const server = new Server(
{ name: 'test-server', version: '1.0.0' },
{
capabilities: {},
jsonSchemaValidator: validatorProvider
}
);

const client = new Client(
{ name: 'test-client', version: '1.0.0' },
{
capabilities: {
elicitation: {
applyDefaults: true
}
}
}
);

// Client returns no values; SDK should apply defaults automatically (and validate)
client.setRequestHandler(ElicitRequestSchema, request => {
expect(request.params.requestedSchema).toEqual({
type: 'object',
properties: {
subscribe: { type: 'boolean', default: true },
nickname: { type: 'string', default: 'Guest' },
age: { type: 'integer', minimum: 0, maximum: 150, default: 18 },
color: { type: 'string', enum: ['red', 'green'], default: 'green' }
},
required: ['subscribe', 'nickname', 'age', 'color']
});
return {
action: 'accept',
content: {}
};
});

const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);

const result = await server.elicitInput({
message: 'Provide your preferences',
requestedSchema: {
type: 'object',
properties: {
subscribe: { type: 'boolean', default: true },
nickname: { type: 'string', default: 'Guest' },
age: { type: 'integer', minimum: 0, maximum: 150, default: 18 },
color: { type: 'string', enum: ['red', 'green'], default: 'green' }
},
required: ['subscribe', 'nickname', 'age', 'color']
}
});

expect(result).toEqual({
action: 'accept',
content: {
subscribe: true,
nickname: 'Guest',
age: 18,
color: 'green'
}
});
});

test(`${validatorName}: should reject invalid email format`, async () => {
const server = new Server(
{ name: 'test-server', version: '1.0.0' },
Expand Down
30 changes: 26 additions & 4 deletions src/spec.types.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,19 @@ type MakeUnknownsNotOptional<T> =
}
: T;

// Targeted fix: in spec, treat ClientCapabilities.elicitation?: object as Record<string, unknown>
type FixSpecClientCapabilities<T> = T extends { elicitation?: object }
? Omit<T, 'elicitation'> & { elicitation?: Record<string, unknown> }
: T;

type FixSpecInitializeRequestParams<T> = T extends { capabilities: infer C }
? Omit<T, 'capabilities'> & { capabilities: FixSpecClientCapabilities<C> }
: T;

type FixSpecInitializeRequest<T> = T extends { params: infer P } ? Omit<T, 'params'> & { params: FixSpecInitializeRequestParams<P> } : T;

type FixSpecClientRequest<T> = T extends { params: infer P } ? Omit<T, 'params'> & { params: FixSpecInitializeRequestParams<P> } : T;

const sdkTypeChecks = {
RequestParams: (sdk: SDKTypes.RequestParams, spec: SpecTypes.RequestParams) => {
sdk = spec;
Expand All @@ -75,7 +88,10 @@ const sdkTypeChecks = {
sdk = spec;
spec = sdk;
},
InitializeRequestParams: (sdk: SDKTypes.InitializeRequestParams, spec: SpecTypes.InitializeRequestParams) => {
InitializeRequestParams: (
sdk: SDKTypes.InitializeRequestParams,
spec: FixSpecInitializeRequestParams<SpecTypes.InitializeRequestParams>
) => {
sdk = spec;
spec = sdk;
},
Expand Down Expand Up @@ -480,23 +496,29 @@ const sdkTypeChecks = {
sdk = spec;
spec = sdk;
},
InitializeRequest: (sdk: WithJSONRPCRequest<SDKTypes.InitializeRequest>, spec: SpecTypes.InitializeRequest) => {
InitializeRequest: (
sdk: WithJSONRPCRequest<SDKTypes.InitializeRequest>,
spec: FixSpecInitializeRequest<SpecTypes.InitializeRequest>
) => {
sdk = spec;
spec = sdk;
},
InitializeResult: (sdk: SDKTypes.InitializeResult, spec: SpecTypes.InitializeResult) => {
sdk = spec;
spec = sdk;
},
ClientCapabilities: (sdk: SDKTypes.ClientCapabilities, spec: SpecTypes.ClientCapabilities) => {
ClientCapabilities: (sdk: SDKTypes.ClientCapabilities, spec: FixSpecClientCapabilities<SpecTypes.ClientCapabilities>) => {
sdk = spec;
spec = sdk;
},
ServerCapabilities: (sdk: SDKTypes.ServerCapabilities, spec: SpecTypes.ServerCapabilities) => {
sdk = spec;
spec = sdk;
},
ClientRequest: (sdk: RemovePassthrough<WithJSONRPCRequest<SDKTypes.ClientRequest>>, spec: SpecTypes.ClientRequest) => {
ClientRequest: (
sdk: RemovePassthrough<WithJSONRPCRequest<SDKTypes.ClientRequest>>,
spec: FixSpecClientRequest<SpecTypes.ClientRequest>
) => {
sdk = spec;
spec = sdk;
},
Expand Down
47 changes: 30 additions & 17 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,17 @@ export const ClientCapabilitiesSchema = z.object({
/**
* Present if the client supports eliciting user input.
*/
elicitation: AssertObjectSchema.optional(),
elicitation: z.intersection(
z
.object({
/**
* Whether the client should apply defaults to the user input.
*/
applyDefaults: z.boolean().optional()
})
.optional(),
z.record(z.string(), z.unknown()).optional()
),
/**
* Present if the client supports listing roots.
*/
Expand Down Expand Up @@ -1198,49 +1208,52 @@ export const CreateMessageResultSchema = ResultSchema.extend({
*/
export const BooleanSchemaSchema = z.object({
type: z.literal('boolean'),
title: z.optional(z.string()),
description: z.optional(z.string()),
default: z.optional(z.boolean())
title: z.string().optional(),
description: z.string().optional(),
default: z.boolean().optional()
});

/**
* Primitive schema definition for string fields.
*/
export const StringSchemaSchema = z.object({
type: z.literal('string'),
title: z.optional(z.string()),
description: z.optional(z.string()),
minLength: z.optional(z.number()),
maxLength: z.optional(z.number()),
format: z.optional(z.enum(['email', 'uri', 'date', 'date-time']))
title: z.string().optional(),
description: z.string().optional(),
minLength: z.number().optional(),
maxLength: z.number().optional(),
format: z.enum(['email', 'uri', 'date', 'date-time']).optional(),
default: z.string().optional()
});

/**
* Primitive schema definition for number fields.
*/
export const NumberSchemaSchema = z.object({
type: z.enum(['number', 'integer']),
title: z.optional(z.string()),
description: z.optional(z.string()),
minimum: z.optional(z.number()),
maximum: z.optional(z.number())
title: z.string().optional(),
description: z.string().optional(),
minimum: z.number().optional(),
maximum: z.number().optional(),
default: z.number().optional()
});

/**
* Primitive schema definition for enum fields.
*/
export const EnumSchemaSchema = z.object({
type: z.literal('string'),
title: z.optional(z.string()),
description: z.optional(z.string()),
title: z.string().optional(),
description: z.string().optional(),
enum: z.array(z.string()),
enumNames: z.optional(z.array(z.string()))
enumNames: z.array(z.string()).optional(),
default: z.string().optional()
});

/**
* Union of all primitive schema definitions.
*/
export const PrimitiveSchemaDefinitionSchema = z.union([BooleanSchemaSchema, StringSchemaSchema, NumberSchemaSchema, EnumSchemaSchema]);
export const PrimitiveSchemaDefinitionSchema = z.union([EnumSchemaSchema, BooleanSchemaSchema, StringSchemaSchema, NumberSchemaSchema]);

/**
* Parameters for an `elicitation/create` request.
Expand Down
Loading