Skip to content

Commit 8e11fa8

Browse files
pantanurag555Anurag PanttzolovKehrlannchemicL
committed
feat(client): add client tool output schema validation and caching (#302)
- Add JSON schema validation for tool call results against output schemas - Implement automatic tool output schema caching during initialization - Add `enableCallToolSchemaCaching` configuration option to enable/disable schema caching - Add `JsonSchemaValidator` integration to McpClient builder APIs - Introduce post-initialization hook mechanism for performing operations after successful client initialization - Cache tool output schemas during `listTools` operations when caching is enabled - Validate tool results against cached schemas in `callTool` operations - Return error CallToolResult when validation fails - Add test coverage - Convert validateToolResult from Mono to synchronous method - Throw IllegalArgumentException on validation errors Signed-off-by: Christian Tzolov <christian.tzolov@broadcom.com> Co-authored-by: Anurag Pant <pantanurag555@github> Co-authored-by: Christian Tzolov <christian.tzolov@broadcom.com> Co-authored-by: Daniel Garnier-Moiroux <git@garnier.wf> Co-authored-by: Dariusz Jędrzejczyk <dariusz.jedrzejczyk@broadcom.com>
1 parent 3f7578b commit 8e11fa8

File tree

10 files changed

+708
-131
lines changed

10 files changed

+708
-131
lines changed

mcp-core/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,13 @@
1111
import java.util.concurrent.atomic.AtomicReference;
1212
import java.util.function.Function;
1313

14-
import org.slf4j.Logger;
15-
import org.slf4j.LoggerFactory;
16-
1714
import io.modelcontextprotocol.spec.McpClientSession;
1815
import io.modelcontextprotocol.spec.McpError;
1916
import io.modelcontextprotocol.spec.McpSchema;
2017
import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException;
2118
import io.modelcontextprotocol.util.Assert;
19+
import org.slf4j.Logger;
20+
import org.slf4j.LoggerFactory;
2221
import reactor.core.publisher.Mono;
2322
import reactor.core.publisher.Sinks;
2423
import reactor.util.context.ContextView;
@@ -99,21 +98,30 @@ class LifecycleInitializer {
9998
*/
10099
private final Duration initializationTimeout;
101100

101+
/**
102+
* Post-initialization hook to perform additional operations after every successful
103+
* initialization.
104+
*/
105+
private final Function<Initialization, Mono<Void>> postInitializationHook;
106+
102107
public LifecycleInitializer(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo,
103108
List<String> protocolVersions, Duration initializationTimeout,
104-
Function<ContextView, McpClientSession> sessionSupplier) {
109+
Function<ContextView, McpClientSession> sessionSupplier,
110+
Function<Initialization, Mono<Void>> postInitializationHook) {
105111

106112
Assert.notNull(sessionSupplier, "Session supplier must not be null");
107113
Assert.notNull(clientCapabilities, "Client capabilities must not be null");
108114
Assert.notNull(clientInfo, "Client info must not be null");
109115
Assert.notEmpty(protocolVersions, "Protocol versions must not be empty");
110116
Assert.notNull(initializationTimeout, "Initialization timeout must not be null");
117+
Assert.notNull(postInitializationHook, "Post-initialization hook must not be null");
111118

112119
this.sessionSupplier = sessionSupplier;
113120
this.clientCapabilities = clientCapabilities;
114121
this.clientInfo = clientInfo;
115122
this.protocolVersions = Collections.unmodifiableList(new ArrayList<>(protocolVersions));
116123
this.initializationTimeout = initializationTimeout;
124+
this.postInitializationHook = postInitializationHook;
117125
}
118126

119127
/**
@@ -148,10 +156,6 @@ interface Initialization {
148156

149157
}
150158

151-
/**
152-
* Default implementation of the {@link Initialization} interface that manages the MCP
153-
* client initialization process.
154-
*/
155159
private static class DefaultInitialization implements Initialization {
156160

157161
/**
@@ -199,29 +203,20 @@ private void setMcpClientSession(McpClientSession mcpClientSession) {
199203
this.mcpClientSession.set(mcpClientSession);
200204
}
201205

202-
/**
203-
* Returns a Mono that completes when the MCP client initialization is complete.
204-
* This allows subscribers to wait for the initialization to finish before
205-
* proceeding with further operations.
206-
* @return A Mono that emits the result of the MCP initialization process
207-
*/
208206
private Mono<McpSchema.InitializeResult> await() {
209207
return this.initSink.asMono();
210208
}
211209

212-
/**
213-
* Completes the initialization process with the given result. It caches the
214-
* result and emits it to all subscribers waiting for the initialization to
215-
* complete.
216-
* @param initializeResult The result of the MCP initialization process
217-
*/
218210
private void complete(McpSchema.InitializeResult initializeResult) {
219-
// first ensure the result is cached
220-
this.result.set(initializeResult);
221211
// inform all the subscribers waiting for the initialization
222212
this.initSink.emitValue(initializeResult, Sinks.EmitFailureHandler.FAIL_FAST);
223213
}
224214

215+
private void cacheResult(McpSchema.InitializeResult initializeResult) {
216+
// first ensure the result is cached
217+
this.result.set(initializeResult);
218+
}
219+
225220
private void error(Throwable t) {
226221
this.initSink.emitError(t, Sinks.EmitFailureHandler.FAIL_FAST);
227222
}
@@ -263,7 +258,7 @@ public void handleException(Throwable t) {
263258
}
264259
// Providing an empty operation since we are only interested in triggering
265260
// the implicit initialization step.
266-
withIntitialization("re-initializing", result -> Mono.empty()).subscribe();
261+
this.withInitialization("re-initializing", result -> Mono.empty()).subscribe();
267262
}
268263
}
269264

@@ -275,16 +270,16 @@ public void handleException(Throwable t) {
275270
* @param operation The operation to execute when the client is initialized
276271
* @return A Mono that completes with the result of the operation
277272
*/
278-
public <T> Mono<T> withIntitialization(String actionName, Function<Initialization, Mono<T>> operation) {
273+
public <T> Mono<T> withInitialization(String actionName, Function<Initialization, Mono<T>> operation) {
279274
return Mono.deferContextual(ctx -> {
280275
DefaultInitialization newInit = new DefaultInitialization();
281276
DefaultInitialization previous = this.initializationRef.compareAndExchange(null, newInit);
282277

283278
boolean needsToInitialize = previous == null;
284279
logger.debug(needsToInitialize ? "Initialization process started" : "Joining previous initialization");
285280

286-
Mono<McpSchema.InitializeResult> initializationJob = needsToInitialize ? doInitialize(newInit, ctx)
287-
: previous.await();
281+
Mono<McpSchema.InitializeResult> initializationJob = needsToInitialize
282+
? this.doInitialize(newInit, this.postInitializationHook, ctx) : previous.await();
288283

289284
return initializationJob.map(initializeResult -> this.initializationRef.get())
290285
.timeout(this.initializationTimeout)
@@ -296,7 +291,9 @@ public <T> Mono<T> withIntitialization(String actionName, Function<Initializatio
296291
});
297292
}
298293

299-
private Mono<McpSchema.InitializeResult> doInitialize(DefaultInitialization initialization, ContextView ctx) {
294+
private Mono<McpSchema.InitializeResult> doInitialize(DefaultInitialization initialization,
295+
Function<Initialization, Mono<Void>> postInitOperation, ContextView ctx) {
296+
300297
initialization.setMcpClientSession(this.sessionSupplier.apply(ctx));
301298

302299
McpClientSession mcpClientSession = initialization.mcpSession();
@@ -323,6 +320,9 @@ private Mono<McpSchema.InitializeResult> doInitialize(DefaultInitialization init
323320

324321
return mcpClientSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null)
325322
.thenReturn(initializeResult);
323+
}).flatMap(initializeResult -> {
324+
initialization.cacheResult(initializeResult);
325+
return postInitOperation.apply(initialization).thenReturn(initializeResult);
326326
}).doOnNext(initialization::complete).onErrorResume(ex -> {
327327
initialization.error(ex);
328328
return Mono.error(ex);

0 commit comments

Comments
 (0)