Skip to content

Commit 98450f7

Browse files
committed
implement completion request and add integration test
Signed-off-by: jitokim <pigberger70@gmail.com>
1 parent 5d036f0 commit 98450f7

File tree

5 files changed

+92
-12
lines changed

5 files changed

+92
-12
lines changed

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import java.util.Map;
1010
import java.util.concurrent.ConcurrentHashMap;
1111
import java.util.concurrent.atomic.AtomicReference;
12+
import java.util.function.BiFunction;
1213
import java.util.function.Function;
1314
import java.util.stream.Collectors;
1415

@@ -18,21 +19,16 @@
1819
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
1920
import io.modelcontextprotocol.server.McpServer;
2021
import io.modelcontextprotocol.server.McpServerFeatures;
22+
import io.modelcontextprotocol.server.McpSyncServerExchange;
2123
import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
2224
import io.modelcontextprotocol.spec.McpError;
2325
import io.modelcontextprotocol.spec.McpSchema;
24-
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
25-
import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities;
26-
import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest;
27-
import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult;
28-
import io.modelcontextprotocol.spec.McpSchema.InitializeResult;
29-
import io.modelcontextprotocol.spec.McpSchema.ModelPreferences;
30-
import io.modelcontextprotocol.spec.McpSchema.Role;
31-
import io.modelcontextprotocol.spec.McpSchema.Root;
32-
import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities;
33-
import io.modelcontextprotocol.spec.McpSchema.Tool;
26+
import io.modelcontextprotocol.spec.McpSchema.*;
27+
import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities.CompletionCapabilities;
28+
import io.modelcontextprotocol.spec.McpServerSession;
3429
import org.junit.jupiter.api.AfterEach;
3530
import org.junit.jupiter.api.BeforeEach;
31+
import org.junit.jupiter.api.Test;
3632
import org.junit.jupiter.params.ParameterizedTest;
3733
import org.junit.jupiter.params.provider.ValueSource;
3834
import reactor.core.publisher.Mono;
@@ -620,4 +616,48 @@ void testLoggingNotification(String clientType) {
620616
mcpServer.close();
621617
}
622618

623-
}
619+
@ParameterizedTest(name = "{0} : Completion call")
620+
@ValueSource(strings = { "httpclient", "webflux" })
621+
void testCompletionShouldReturnExpectedSuggestions(String clientType) {
622+
var clientBuilder = clientBuilders.get(clientType);
623+
624+
var expectedValues = List.of("python", "pytorch", "pyside");
625+
var completionResponse = new McpSchema.CompleteResult(new CompleteResult.CompleteCompletion(expectedValues, 10, // total
626+
true // hasMore
627+
));
628+
629+
BiFunction<McpSyncServerExchange, CompleteRequest, CompleteResult> completionHandler = (mcpSyncServerExchange,
630+
request) -> {
631+
assertThat(request.argument().name()).isEqualTo("language");
632+
assertThat(request.argument().value()).isEqualTo("py");
633+
assertThat(request.ref().type()).isEqualTo("ref/prompt");
634+
return completionResponse;
635+
};
636+
637+
var mcpServer = McpServer.sync(mcpServerTransportProvider)
638+
.capabilities(ServerCapabilities.builder().completions(new CompletionCapabilities()).build())
639+
.prompts(new McpServerFeatures.SyncPromptSpecification(
640+
new Prompt("code_review", "this is code review prompt", List.of()),
641+
(mcpSyncServerExchange, getPromptRequest) -> null))
642+
.completions(new McpServerFeatures.SyncCompletionSpecification(
643+
new McpServerFeatures.CompletionRefKey("ref/prompt", "code_review"), completionHandler))
644+
.build();
645+
646+
try (var mcpClient = clientBuilder.build()) {
647+
648+
InitializeResult initResult = mcpClient.initialize();
649+
assertThat(initResult).isNotNull();
650+
651+
CompleteRequest request = new CompleteRequest(
652+
new CompleteRequest.PromptReference("ref/prompt", "code_review"),
653+
new CompleteRequest.CompleteArgument("language", "py"));
654+
655+
CompleteResult result = mcpClient.completeCompletion(request);
656+
657+
assertThat(result).isNotNull();
658+
}
659+
660+
mcpServer.close();
661+
}
662+
663+
}

mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,4 +801,15 @@ void setProtocolVersions(List<String> protocolVersions) {
801801
this.protocolVersions = protocolVersions;
802802
}
803803

804+
// --------------------------
805+
// Completions
806+
// --------------------------
807+
private static final TypeReference<McpSchema.CompleteResult> COMPLETION_COMPLETE_RESULT_TYPE_REF = new TypeReference<>() {
808+
};
809+
810+
public Mono<McpSchema.CompleteResult> completeCompletion(McpSchema.CompleteRequest completeRequest) {
811+
return this.withInitializationCheck("complete completions", initializedResult -> this.mcpSession
812+
.sendRequest(McpSchema.METHOD_COMPLETION_COMPLETE, completeRequest, COMPLETION_COMPLETE_RESULT_TYPE_REF));
813+
}
814+
804815
}

mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,4 +325,8 @@ public void setLoggingLevel(McpSchema.LoggingLevel loggingLevel) {
325325
this.delegate.setLoggingLevel(loggingLevel).block();
326326
}
327327

328+
public McpSchema.CompleteResult completeCompletion(McpSchema.CompleteRequest completeRequest) {
329+
return this.delegate.completeCompletion(completeRequest).block();
330+
}
331+
328332
}

mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,7 @@ private McpServerSession.RequestHandler<Object> setLoggerRequestHandler() {
716716

717717
private McpServerSession.RequestHandler<McpSchema.CompleteResult> completionCompleteRequestHandler() {
718718
return (exchange, params) -> {
719-
McpSchema.CompleteRequest request = objectMapper.convertValue(params, McpSchema.CompleteRequest.class);
719+
McpSchema.CompleteRequest request = parseCompletionParams(params);
720720

721721
if (request.ref() == null) {
722722
return Mono.error(new McpError("ref must not be null"));
@@ -755,6 +755,30 @@ private McpServerSession.RequestHandler<McpSchema.CompleteResult> completionComp
755755
};
756756
}
757757

758+
@SuppressWarnings("unchecked")
759+
private McpSchema.CompleteRequest parseCompletionParams(Object object) {
760+
Map<String, Object> params = (Map<String, Object>) object;
761+
Map<String, Object> refMap = (Map<String, Object>) params.get("ref");
762+
Map<String, Object> argMap = (Map<String, Object>) params.get("argument");
763+
764+
String refType = (String) refMap.get("type");
765+
766+
McpSchema.CompleteRequest.PromptOrResourceReference ref = switch (refType) {
767+
case "ref/prompt" ->
768+
new McpSchema.CompleteRequest.PromptReference(refType, (String) refMap.get("name"));
769+
case "ref/resource" ->
770+
new McpSchema.CompleteRequest.ResourceReference(refType, (String) refMap.get("uri"));
771+
default -> throw new IllegalArgumentException("Invalid ref type: " + refType);
772+
};
773+
774+
String argName = (String) argMap.get("name");
775+
String argValue = (String) argMap.get("value");
776+
McpSchema.CompleteRequest.CompleteArgument argument = new McpSchema.CompleteRequest.CompleteArgument(
777+
argName, argValue);
778+
779+
return new McpSchema.CompleteRequest(ref, argument);
780+
}
781+
758782
// ---------------------------------------
759783
// Sampling
760784
// ---------------------------------------

mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ public <T> Mono<T> sendRequest(String method, Object requestParams, TypeReferenc
238238
});
239239
}).timeout(this.requestTimeout).handle((jsonRpcResponse, sink) -> {
240240
if (jsonRpcResponse.error() != null) {
241+
logger.error("Error handling request: {}", jsonRpcResponse.error());
241242
sink.error(new McpError(jsonRpcResponse.error()));
242243
}
243244
else {

0 commit comments

Comments
 (0)