Skip to content

feat: Add Tool.outputSchema and CallToolResult.structuredContent #302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

package io.modelcontextprotocol.server;

import java.util.HashMap;
import java.util.List;

import io.modelcontextprotocol.spec.McpError;
Expand All @@ -17,6 +18,7 @@
import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities;
import io.modelcontextprotocol.spec.McpSchema.Tool;
import io.modelcontextprotocol.spec.McpServerTransportProvider;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -125,7 +127,7 @@ void testAddTool() {

@Test
void testAddDuplicateTool() {
Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema);
Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema, emptyJsonSchema);

var mcpSyncServer = McpServer.sync(createMcpTransportProvider())
.serverInfo("test-server", "1.0.0")
Expand All @@ -134,7 +136,7 @@ void testAddDuplicateTool() {
.build();

assertThatThrownBy(() -> mcpSyncServer.addTool(new McpServerFeatures.SyncToolSpecification(duplicateTool,
(exchange, args) -> new CallToolResult(List.of(), false))))
(exchange, args) -> new CallToolResult(List.of(), false, new HashMap<String, Object>()))))
.isInstanceOf(McpError.class)
.hasMessage("Tool with name '" + TEST_TOOL_NAME + "' already exists");

Expand Down
7 changes: 7 additions & 0 deletions mcp/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,13 @@
<scope>test</scope>
</dependency>

<!-- Json validator dependency -->
<dependency>
<groupId>com.networknt</groupId>
<artifactId>json-schema-validator</artifactId>
<version>1.5.7</version>
</dependency>
Comment on lines +205 to +210
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth getting extra clarification from maintainers on if we want schema validation as part of this or a separate effort, along the lines of the discussion in #271.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed some of the issues offline with @LucaButBoring. I wanted to get clarity around the scope of the work on the issue that I had created. There are a few questions that we might want to answer regarding the output validation:

  1. Do we want validation to be part of this PR or do we want to separate it out and think more about how we can cache/validate the results?
  2. Should cache be in the client class or should we construct a separate cache class?
  3. Should there be refresh intervals for cache invalidation? How else can we deal with it and should it be configurable by the client?
  4. Is a Map the best way to represent cache or should we opt for an in-memory cache implementation like Guava Cache?

Would like inputs from the maintainers to decide on the best way to proceed.


</dependencies>


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Supplier;

import com.fasterxml.jackson.core.type.TypeReference;

import io.modelcontextprotocol.spec.McpClientSession;
import io.modelcontextprotocol.spec.McpClientSession.NotificationHandler;
import io.modelcontextprotocol.spec.McpClientSession.RequestHandler;
Expand All @@ -27,6 +29,7 @@
import io.modelcontextprotocol.spec.McpSchema.ElicitResult;
import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest;
import io.modelcontextprotocol.spec.McpSchema.GetPromptResult;
import io.modelcontextprotocol.spec.McpSchema.JsonSchema;
import io.modelcontextprotocol.spec.McpSchema.ListPromptsResult;
import io.modelcontextprotocol.spec.McpSchema.LoggingLevel;
import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification;
Expand All @@ -35,8 +38,11 @@
import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException;
import io.modelcontextprotocol.util.Assert;
import io.modelcontextprotocol.util.Utils;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import io.modelcontextprotocol.server.McpServerFeatures;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;
Expand Down Expand Up @@ -118,6 +124,11 @@ public class McpAsyncClient {
*/
private final McpSchema.Implementation clientInfo;

/**
* Cached tool output schemas.
*/
private final ConcurrentHashMap<String, Optional<JsonSchema>> toolsOutputSchemaCache;

/**
* Roots define the boundaries of where servers can operate within the filesystem,
* allowing them to understand which directories and files they have access to.
Expand Down Expand Up @@ -181,6 +192,7 @@ public class McpAsyncClient {
this.transport = transport;
this.roots = new ConcurrentHashMap<>(features.roots());
this.initializationTimeout = initializationTimeout;
this.toolsOutputSchemaCache = new ConcurrentHashMap<>();

// Request Handlers
Map<String, RequestHandler<?>> requestHandlers = new HashMap<>();
Expand Down Expand Up @@ -331,6 +343,14 @@ public McpSchema.Implementation getClientInfo() {
return this.clientInfo;
}

/**
* Get the cached tool output schemas.
* @return The cached tool output schemas
*/
public ConcurrentHashMap<String, Optional<JsonSchema>> getToolsOutputSchemaCache() {
return this.toolsOutputSchemaCache;
}

/**
* Closes the client connection immediately.
*/
Expand Down Expand Up @@ -650,8 +670,13 @@ public Mono<McpSchema.CallToolResult> callTool(McpSchema.CallToolRequest callToo
if (init.get().capabilities().tools() == null) {
return Mono.error(new McpError("Server does not provide tools capability"));
}
return init.mcpSession()
.sendRequest(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF);
// Refresh tool output schema cache, if necessary, prior to making tool call
Mono<Void> refreshCacheMono = Mono.empty();
if (!this.toolsOutputSchemaCache.containsKey(callToolRequest.name())) {
refreshCacheMono = refreshToolOutputSchemaCache();
}
return refreshCacheMono.then(init.mcpSession()
.sendRequest(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF));
});
}

Expand All @@ -675,7 +700,33 @@ public Mono<McpSchema.ListToolsResult> listTools(String cursor) {
}
return init.mcpSession()
.sendRequest(McpSchema.METHOD_TOOLS_LIST, new McpSchema.PaginatedRequest(cursor),
LIST_TOOLS_RESULT_TYPE_REF);
LIST_TOOLS_RESULT_TYPE_REF)
.doOnNext(result -> {
// Cache tools output schema
if (result.tools() != null) {
// Cache tools output schema
result.tools()
.forEach(tool -> this.toolsOutputSchemaCache.put(tool.name(),
Optional.ofNullable(tool.outputSchema())));
}
});
});
}

/**
* Refreshes the tool output schema cache by fetching all tools from the server.
* @return A Mono that completes when all tool output schemas have been cached
*/
private Mono<Void> refreshToolOutputSchemaCache() {
return this.withSession("refreshing tool output schema cache", init -> {

// Use expand operator to handle pagination in a reactive way
return this.listTools(null).expand(result -> {
if (result.nextCursor() != null) {
return this.listTools(result.nextCursor());
}
return Mono.empty();
}).then();
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,29 @@
package io.modelcontextprotocol.client;

import java.time.Duration;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.networknt.schema.JsonSchema;
import com.networknt.schema.JsonSchemaFactory;
import com.networknt.schema.SpecVersion;
import com.networknt.schema.ValidationMessage;

import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities;
import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest;
import io.modelcontextprotocol.spec.McpSchema.GetPromptResult;
import io.modelcontextprotocol.spec.McpSchema.ListPromptsResult;
import io.modelcontextprotocol.util.Assert;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* A synchronous client implementation for the Model Context Protocol (MCP) that wraps an
Expand Down Expand Up @@ -62,14 +76,28 @@ public class McpSyncClient implements AutoCloseable {

private final McpAsyncClient delegate;

/** JSON object mapper for message serialization/deserialization */
protected ObjectMapper objectMapper;

/**
* Create a new McpSyncClient with the given delegate.
* @param delegate the asynchronous kernel on top of which this synchronous client
* provides a blocking API.
*/
McpSyncClient(McpAsyncClient delegate) {
this(delegate, new ObjectMapper());
}

/**
* Create a new McpSyncClient with the given delegate.
* @param delegate the asynchronous kernel on top of which this synchronous client
* provides a blocking API.
* @param objectMapper the object mapper for JSON serialization/deserialization
*/
McpSyncClient(McpAsyncClient delegate, ObjectMapper objectMapper) {
Assert.notNull(delegate, "The delegate can not be null");
this.delegate = delegate;
this.objectMapper = objectMapper;
}

/**
Expand Down Expand Up @@ -206,7 +234,8 @@ public Object ping() {
/**
* Calls a tool provided by the server. Tools enable servers to expose executable
* functionality that can interact with external systems, perform computations, and
* take actions in the real world.
* take actions in the real world. If tool contains an output schema, validates the
* tool result structured content against the output schema.
* @param callToolRequest The request containing: - name: The name of the tool to call
* (must match a tool name from tools/list) - arguments: Arguments that conform to the
* tool's input schema
Expand All @@ -215,7 +244,54 @@ public Object ping() {
* Boolean indicating if the execution failed (true) or succeeded (false/absent)
*/
public McpSchema.CallToolResult callTool(McpSchema.CallToolRequest callToolRequest) {
return this.delegate.callTool(callToolRequest).block();
McpSchema.CallToolResult result = this.delegate.callTool(callToolRequest).block();
ConcurrentHashMap<String, Optional<McpSchema.JsonSchema>> toolsOutputSchemaCache = this.delegate
.getToolsOutputSchemaCache();
// Should not be triggered but added for completeness
if (!toolsOutputSchemaCache.containsKey(callToolRequest.name())) {
throw new McpError("Tool with name '" + callToolRequest.name() + "' not found");
}
Optional<McpSchema.JsonSchema> optOutputSchema = toolsOutputSchemaCache.get(callToolRequest.name());
if (result != null && optOutputSchema != null && optOutputSchema.isPresent()) {
if (result.structuredContent() == null) {
throw new McpError("CallToolResult validation failed: structuredContent is null and "
+ "does not match tool outputSchema.");
}
McpSchema.JsonSchema outputSchema = optOutputSchema.get();

try {
// Convert outputSchema to string
String outputSchemaString = this.objectMapper.writeValueAsString(outputSchema);

// Create JsonSchema validator
ObjectNode schemaNode = (ObjectNode) this.objectMapper.readTree(outputSchemaString);
// Set additional properties to false if not specified in output schema
if (!schemaNode.has("additionalProperties")) {
schemaNode.put("additionalProperties", false);
}
JsonSchema schema = JsonSchemaFactory.getInstance(SpecVersion.VersionFlag.V202012)
.getSchema(schemaNode);

// Convert structured content in reult to JsonNode
JsonNode jsonNode = this.objectMapper.valueToTree(result.structuredContent());

// Validate outputSchema against structuredContent
Set<ValidationMessage> validationResult = schema.validate(jsonNode);

// Check if validation passed
if (!validationResult.isEmpty()) {
// Handle validation errors
throw new McpError(
"CallToolResult validation failed: structuredContent does not match tool outputSchema.");
}
}
catch (JsonProcessingException e) {
// Log warning if output schema can't be parsed to prevent erroring out
// for successful call tool request
logger.warn("Failed to validate CallToolResult: Error parsing tool outputSchema: {}", e);
}
}
return result;
}

/**
Expand Down
Loading