Skip to content

Commit fb8238c

Browse files
committed
SEP-1330: compatibility with SEP-1034, add test coverage for enum types
2 parents 86e367f + ce420f8 commit fb8238c

File tree

5 files changed

+447
-34
lines changed

5 files changed

+447
-34
lines changed

src/client/index.ts

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,54 @@ import {
3636
SUPPORTED_PROTOCOL_VERSIONS,
3737
type SubscribeRequest,
3838
type Tool,
39-
type UnsubscribeRequest
39+
type UnsubscribeRequest,
40+
ElicitResultSchema,
41+
ElicitRequestSchema
4042
} from '../types.js';
4143
import { AjvJsonSchemaValidator } from '../validation/ajv-provider.js';
4244
import type { JsonSchemaType, JsonSchemaValidator, jsonSchemaValidator } from '../validation/types.js';
45+
import { ZodLiteral, ZodObject, z } from 'zod';
46+
import type { RequestHandlerExtra } from '../shared/protocol.js';
47+
48+
/**
49+
* Elicitation default application helper. Applies defaults to the data based on the schema.
50+
*
51+
* @param schema - The schema to apply defaults to.
52+
* @param data - The data to apply defaults to.
53+
*/
54+
function applyElicitationDefaults(schema: JsonSchemaType | undefined, data: unknown): void {
55+
if (!schema || data === null || typeof data !== 'object') return;
56+
57+
// Handle object properties
58+
if (schema.type === 'object' && schema.properties && typeof schema.properties === 'object') {
59+
const obj = data as Record<string, unknown>;
60+
const props = schema.properties as Record<string, JsonSchemaType & { default?: unknown }>;
61+
for (const key of Object.keys(props)) {
62+
const propSchema = props[key];
63+
// If missing or explicitly undefined, apply default if present
64+
if (obj[key] === undefined && Object.prototype.hasOwnProperty.call(propSchema, 'default')) {
65+
obj[key] = propSchema.default;
66+
}
67+
// Recurse into existing nested objects/arrays
68+
if (obj[key] !== undefined) {
69+
applyElicitationDefaults(propSchema, obj[key]);
70+
}
71+
}
72+
}
73+
74+
if (Array.isArray(schema.anyOf)) {
75+
for (const sub of schema.anyOf) {
76+
applyElicitationDefaults(sub, data);
77+
}
78+
}
79+
80+
// Combine schemas
81+
if (Array.isArray(schema.oneOf)) {
82+
for (const sub of schema.oneOf) {
83+
applyElicitationDefaults(sub, data);
84+
}
85+
}
86+
}
4387

4488
export type ClientOptions = ProtocolOptions & {
4589
/**
@@ -141,6 +185,64 @@ export class Client<
141185
this._capabilities = mergeCapabilities(this._capabilities, capabilities);
142186
}
143187

188+
/**
189+
* Override request handler registration to enforce client-side validation for elicitation.
190+
*/
191+
public override setRequestHandler<
192+
T extends ZodObject<{
193+
method: ZodLiteral<string>;
194+
}>
195+
>(
196+
requestSchema: T,
197+
handler: (
198+
request: z.infer<T>,
199+
extra: RequestHandlerExtra<ClientRequest | RequestT, ClientNotification | NotificationT>
200+
) => ClientResult | ResultT | Promise<ClientResult | ResultT>
201+
): void {
202+
const method = requestSchema.shape.method.value;
203+
if (method === 'elicitation/create') {
204+
const wrappedHandler = async (
205+
request: z.infer<T>,
206+
extra: RequestHandlerExtra<ClientRequest | RequestT, ClientNotification | NotificationT>
207+
): Promise<ClientResult | ResultT> => {
208+
const validatedRequest = ElicitRequestSchema.safeParse(request);
209+
if (!validatedRequest.success) {
210+
throw new McpError(ErrorCode.InvalidParams, `Invalid elicitation request: ${validatedRequest.error.message}`);
211+
}
212+
213+
const result = await Promise.resolve(handler(request, extra));
214+
215+
const validationResult = ElicitResultSchema.safeParse(result);
216+
if (!validationResult.success) {
217+
throw new McpError(ErrorCode.InvalidParams, `Invalid elicitation result: ${validationResult.error.message}`);
218+
}
219+
220+
const validatedResult = validationResult.data;
221+
222+
if (
223+
this._capabilities.elicitation?.applyDefaults &&
224+
validatedResult.action === 'accept' &&
225+
validatedResult.content &&
226+
validatedRequest.data.params.requestedSchema
227+
) {
228+
try {
229+
applyElicitationDefaults(validatedRequest.data.params.requestedSchema, validatedResult.content);
230+
} catch {
231+
// gracefully ignore errors in default application
232+
}
233+
}
234+
235+
return validatedResult;
236+
};
237+
238+
// Install the wrapped handler
239+
return super.setRequestHandler(requestSchema, wrappedHandler as unknown as typeof handler);
240+
}
241+
242+
// Non-elicitation handlers use default behavior
243+
return super.setRequestHandler(requestSchema, handler);
244+
}
245+
144246
protected assertCapability(capability: keyof ServerCapabilities, method: string): void {
145247
if (!this._serverCapabilities?.[capability]) {
146248
throw new Error(`Server does not support ${capability} (required for ${method})`);

src/server/elicitation.test.ts

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import { Client } from '../client/index.js';
1111
import { InMemoryTransport } from '../inMemory.js';
12-
import { ElicitRequestSchema } from '../types.js';
12+
import { ElicitRequestParams, ElicitRequestSchema, PrimitiveSchemaDefinition } from '../types.js';
1313
import { AjvJsonSchemaValidator } from '../validation/ajv-provider.js';
1414
import { CfWorkerJsonSchemaValidator } from '../validation/cfworker-provider.js';
1515
import { Server } from './index.js';
@@ -449,6 +449,122 @@ function testElicitationFlow(validatorProvider: typeof ajvProvider | typeof cfWo
449449
});
450450
});
451451

452+
test(`${validatorName}: should default missing fields from schema defaults`, async () => {
453+
const server = new Server(
454+
{ name: 'test-server', version: '1.0.0' },
455+
{
456+
capabilities: {},
457+
jsonSchemaValidator: validatorProvider
458+
}
459+
);
460+
461+
const client = new Client(
462+
{ name: 'test-client', version: '1.0.0' },
463+
{
464+
capabilities: {
465+
elicitation: {
466+
applyDefaults: true
467+
}
468+
}
469+
}
470+
);
471+
472+
const testSchemaProperties: ElicitRequestParams['requestedSchema'] = {
473+
type: 'object',
474+
properties: {
475+
subscribe: { type: 'boolean', default: true },
476+
nickname: { type: 'string', default: 'Guest' },
477+
age: { type: 'integer', minimum: 0, maximum: 150, default: 18 },
478+
color: { type: 'string', enum: ['red', 'green'], default: 'green' },
479+
untitledSingleSelectEnum: {
480+
type: 'string',
481+
title: 'Untitled Single Select Enum',
482+
description: 'Choose your favorite color',
483+
enum: ['red', 'green', 'blue'],
484+
default: 'green'
485+
},
486+
untitledMultipleSelectEnum: {
487+
type: 'array',
488+
title: 'Untitled Multiple Select Enum',
489+
description: 'Choose your favorite colors',
490+
minItems: 1,
491+
maxItems: 3,
492+
items: { type: 'string', enum: ['red', 'green', 'blue'] },
493+
default: ['green', 'blue']
494+
},
495+
titledSingleSelectEnum: {
496+
type: 'string',
497+
title: 'Single Select Enum',
498+
description: 'Choose your favorite color',
499+
oneOf: [
500+
{ const: 'red', title: 'Red' },
501+
{ const: 'green', title: 'Green' },
502+
{ const: 'blue', title: 'Blue' }
503+
],
504+
default: 'green'
505+
},
506+
titledMultipleSelectEnum: {
507+
type: 'array',
508+
title: 'Multiple Select Enum',
509+
description: 'Choose your favorite colors',
510+
minItems: 1,
511+
maxItems: 3,
512+
items: {
513+
anyOf: [
514+
{ const: 'red', title: 'Red' },
515+
{ const: 'green', title: 'Green' },
516+
{ const: 'blue', title: 'Blue' }
517+
]
518+
},
519+
default: ['green', 'blue']
520+
},
521+
optionalWithADefault: { type: 'string', default: 'default value' }
522+
},
523+
required: [
524+
'subscribe',
525+
'nickname',
526+
'age',
527+
'color',
528+
'titledSingleSelectEnum',
529+
'titledMultipleSelectEnum',
530+
'untitledSingleSelectEnum',
531+
'untitledMultipleSelectEnum'
532+
]
533+
};
534+
535+
// Client returns no values; SDK should apply defaults automatically (and validate)
536+
client.setRequestHandler(ElicitRequestSchema, request => {
537+
expect(request.params.requestedSchema).toEqual(testSchemaProperties);
538+
return {
539+
action: 'accept',
540+
content: {}
541+
};
542+
});
543+
544+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
545+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
546+
547+
const result = await server.elicitInput({
548+
message: 'Provide your preferences',
549+
requestedSchema: testSchemaProperties
550+
});
551+
552+
expect(result).toEqual({
553+
action: 'accept',
554+
content: {
555+
subscribe: true,
556+
nickname: 'Guest',
557+
age: 18,
558+
color: 'green',
559+
untitledSingleSelectEnum: 'green',
560+
untitledMultipleSelectEnum: ['green', 'blue'],
561+
titledSingleSelectEnum: 'green',
562+
titledMultipleSelectEnum: ['green', 'blue'],
563+
optionalWithADefault: 'default value'
564+
}
565+
});
566+
});
567+
452568
test(`${validatorName}: should reject invalid email format`, async () => {
453569
client.setRequestHandler(ElicitRequestSchema, _request => ({
454570
action: 'accept',

src/spec.types.test.ts

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,19 @@ type MakeUnknownsNotOptional<T> =
6262
}
6363
: T;
6464

65+
// Targeted fix: in spec, treat ClientCapabilities.elicitation?: object as Record<string, unknown>
66+
type FixSpecClientCapabilities<T> = T extends { elicitation?: object }
67+
? Omit<T, 'elicitation'> & { elicitation?: Record<string, unknown> }
68+
: T;
69+
70+
type FixSpecInitializeRequestParams<T> = T extends { capabilities: infer C }
71+
? Omit<T, 'capabilities'> & { capabilities: FixSpecClientCapabilities<C> }
72+
: T;
73+
74+
type FixSpecInitializeRequest<T> = T extends { params: infer P } ? Omit<T, 'params'> & { params: FixSpecInitializeRequestParams<P> } : T;
75+
76+
type FixSpecClientRequest<T> = T extends { params: infer P } ? Omit<T, 'params'> & { params: FixSpecInitializeRequestParams<P> } : T;
77+
6578
const sdkTypeChecks = {
6679
RequestParams: (sdk: SDKTypes.RequestParams, spec: SpecTypes.RequestParams) => {
6780
sdk = spec;
@@ -75,7 +88,10 @@ const sdkTypeChecks = {
7588
sdk = spec;
7689
spec = sdk;
7790
},
78-
InitializeRequestParams: (sdk: SDKTypes.InitializeRequestParams, spec: SpecTypes.InitializeRequestParams) => {
91+
InitializeRequestParams: (
92+
sdk: SDKTypes.InitializeRequestParams,
93+
spec: FixSpecInitializeRequestParams<SpecTypes.InitializeRequestParams>
94+
) => {
7995
sdk = spec;
8096
spec = sdk;
8197
},
@@ -508,23 +524,29 @@ const sdkTypeChecks = {
508524
sdk = spec;
509525
spec = sdk;
510526
},
511-
InitializeRequest: (sdk: WithJSONRPCRequest<SDKTypes.InitializeRequest>, spec: SpecTypes.InitializeRequest) => {
527+
InitializeRequest: (
528+
sdk: WithJSONRPCRequest<SDKTypes.InitializeRequest>,
529+
spec: FixSpecInitializeRequest<SpecTypes.InitializeRequest>
530+
) => {
512531
sdk = spec;
513532
spec = sdk;
514533
},
515534
InitializeResult: (sdk: SDKTypes.InitializeResult, spec: SpecTypes.InitializeResult) => {
516535
sdk = spec;
517536
spec = sdk;
518537
},
519-
ClientCapabilities: (sdk: SDKTypes.ClientCapabilities, spec: SpecTypes.ClientCapabilities) => {
538+
ClientCapabilities: (sdk: SDKTypes.ClientCapabilities, spec: FixSpecClientCapabilities<SpecTypes.ClientCapabilities>) => {
520539
sdk = spec;
521540
spec = sdk;
522541
},
523542
ServerCapabilities: (sdk: SDKTypes.ServerCapabilities, spec: SpecTypes.ServerCapabilities) => {
524543
sdk = spec;
525544
spec = sdk;
526545
},
527-
ClientRequest: (sdk: RemovePassthrough<WithJSONRPCRequest<SDKTypes.ClientRequest>>, spec: SpecTypes.ClientRequest) => {
546+
ClientRequest: (
547+
sdk: RemovePassthrough<WithJSONRPCRequest<SDKTypes.ClientRequest>>,
548+
spec: FixSpecClientRequest<SpecTypes.ClientRequest>
549+
) => {
528550
sdk = spec;
529551
spec = sdk;
530552
},

0 commit comments

Comments
 (0)