Skip to content

Commit 09cbe0a

Browse files
committed
implement completion request and add integration test
Signed-off-by: jitokim <pigberger70@gmail.com>
1 parent 5d036f0 commit 09cbe0a

File tree

5 files changed

+92
-12
lines changed

5 files changed

+92
-12
lines changed

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import java.util.Map;
1010
import java.util.concurrent.ConcurrentHashMap;
1111
import java.util.concurrent.atomic.AtomicReference;
12+
import java.util.function.BiFunction;
1213
import java.util.function.Function;
1314
import java.util.stream.Collectors;
1415

@@ -18,19 +19,12 @@
1819
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
1920
import io.modelcontextprotocol.server.McpServer;
2021
import io.modelcontextprotocol.server.McpServerFeatures;
22+
import io.modelcontextprotocol.server.McpSyncServerExchange;
2123
import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
2224
import io.modelcontextprotocol.spec.McpError;
2325
import io.modelcontextprotocol.spec.McpSchema;
24-
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
25-
import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities;
26-
import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest;
27-
import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult;
28-
import io.modelcontextprotocol.spec.McpSchema.InitializeResult;
29-
import io.modelcontextprotocol.spec.McpSchema.ModelPreferences;
30-
import io.modelcontextprotocol.spec.McpSchema.Role;
31-
import io.modelcontextprotocol.spec.McpSchema.Root;
32-
import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities;
33-
import io.modelcontextprotocol.spec.McpSchema.Tool;
26+
import io.modelcontextprotocol.spec.McpSchema.*;
27+
import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities.CompletionCapabilities;
3428
import org.junit.jupiter.api.AfterEach;
3529
import org.junit.jupiter.api.BeforeEach;
3630
import org.junit.jupiter.params.ParameterizedTest;
@@ -620,4 +614,48 @@ void testLoggingNotification(String clientType) {
620614
mcpServer.close();
621615
}
622616

623-
}
617+
@ParameterizedTest(name = "{0} : Completion call")
618+
@ValueSource(strings = { "httpclient", "webflux" })
619+
void testCompletionShouldReturnExpectedSuggestions(String clientType) {
620+
var clientBuilder = clientBuilders.get(clientType);
621+
622+
var expectedValues = List.of("python", "pytorch", "pyside");
623+
var completionResponse = new McpSchema.CompleteResult(new CompleteResult.CompleteCompletion(expectedValues, 10, // total
624+
true // hasMore
625+
));
626+
627+
BiFunction<McpSyncServerExchange, CompleteRequest, CompleteResult> completionHandler = (mcpSyncServerExchange,
628+
request) -> {
629+
assertThat(request.argument().name()).isEqualTo("language");
630+
assertThat(request.argument().value()).isEqualTo("py");
631+
assertThat(request.ref().type()).isEqualTo("ref/prompt");
632+
return completionResponse;
633+
};
634+
635+
var mcpServer = McpServer.sync(mcpServerTransportProvider)
636+
.capabilities(ServerCapabilities.builder().completions(new CompletionCapabilities()).build())
637+
.prompts(new McpServerFeatures.SyncPromptSpecification(
638+
new Prompt("code_review", "this is code review prompt", List.of()),
639+
(mcpSyncServerExchange, getPromptRequest) -> null))
640+
.completions(new McpServerFeatures.SyncCompletionSpecification(
641+
new McpServerFeatures.CompletionRefKey("ref/prompt", "code_review"), completionHandler))
642+
.build();
643+
644+
try (var mcpClient = clientBuilder.build()) {
645+
646+
InitializeResult initResult = mcpClient.initialize();
647+
assertThat(initResult).isNotNull();
648+
649+
CompleteRequest request = new CompleteRequest(
650+
new CompleteRequest.PromptReference("ref/prompt", "code_review"),
651+
new CompleteRequest.CompleteArgument("language", "py"));
652+
653+
CompleteResult result = mcpClient.completeCompletion(request);
654+
655+
assertThat(result).isNotNull();
656+
}
657+
658+
mcpServer.close();
659+
}
660+
661+
}

mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
*
7272
* @author Dariusz Jędrzejczyk
7373
* @author Christian Tzolov
74+
* @author Jihoon Kim
7475
* @see McpClient
7576
* @see McpSchema
7677
* @see McpClientSession
@@ -801,4 +802,15 @@ void setProtocolVersions(List<String> protocolVersions) {
801802
this.protocolVersions = protocolVersions;
802803
}
803804

805+
// --------------------------
806+
// Completions
807+
// --------------------------
808+
private static final TypeReference<McpSchema.CompleteResult> COMPLETION_COMPLETE_RESULT_TYPE_REF = new TypeReference<>() {
809+
};
810+
811+
public Mono<McpSchema.CompleteResult> completeCompletion(McpSchema.CompleteRequest completeRequest) {
812+
return this.withInitializationCheck("complete completions", initializedResult -> this.mcpSession
813+
.sendRequest(McpSchema.METHOD_COMPLETION_COMPLETE, completeRequest, COMPLETION_COMPLETE_RESULT_TYPE_REF));
814+
}
815+
804816
}

mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
*
4747
* @author Dariusz Jędrzejczyk
4848
* @author Christian Tzolov
49+
* @author Jihoon Kim
4950
* @see McpClient
5051
* @see McpAsyncClient
5152
* @see McpSchema
@@ -325,4 +326,8 @@ public void setLoggingLevel(McpSchema.LoggingLevel loggingLevel) {
325326
this.delegate.setLoggingLevel(loggingLevel).block();
326327
}
327328

329+
public McpSchema.CompleteResult completeCompletion(McpSchema.CompleteRequest completeRequest) {
330+
return this.delegate.completeCompletion(completeRequest).block();
331+
}
332+
328333
}

mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,7 @@ private McpServerSession.RequestHandler<Object> setLoggerRequestHandler() {
716716

717717
private McpServerSession.RequestHandler<McpSchema.CompleteResult> completionCompleteRequestHandler() {
718718
return (exchange, params) -> {
719-
McpSchema.CompleteRequest request = objectMapper.convertValue(params, McpSchema.CompleteRequest.class);
719+
McpSchema.CompleteRequest request = parseCompletionParams(params);
720720

721721
if (request.ref() == null) {
722722
return Mono.error(new McpError("ref must not be null"));
@@ -755,6 +755,30 @@ private McpServerSession.RequestHandler<McpSchema.CompleteResult> completionComp
755755
};
756756
}
757757

758+
@SuppressWarnings("unchecked")
759+
private McpSchema.CompleteRequest parseCompletionParams(Object object) {
760+
Map<String, Object> params = (Map<String, Object>) object;
761+
Map<String, Object> refMap = (Map<String, Object>) params.get("ref");
762+
Map<String, Object> argMap = (Map<String, Object>) params.get("argument");
763+
764+
String refType = (String) refMap.get("type");
765+
766+
McpSchema.CompleteRequest.PromptOrResourceReference ref = switch (refType) {
767+
case "ref/prompt" ->
768+
new McpSchema.CompleteRequest.PromptReference(refType, (String) refMap.get("name"));
769+
case "ref/resource" ->
770+
new McpSchema.CompleteRequest.ResourceReference(refType, (String) refMap.get("uri"));
771+
default -> throw new IllegalArgumentException("Invalid ref type: " + refType);
772+
};
773+
774+
String argName = (String) argMap.get("name");
775+
String argValue = (String) argMap.get("value");
776+
McpSchema.CompleteRequest.CompleteArgument argument = new McpSchema.CompleteRequest.CompleteArgument(
777+
argName, argValue);
778+
779+
return new McpSchema.CompleteRequest(ref, argument);
780+
}
781+
758782
// ---------------------------------------
759783
// Sampling
760784
// ---------------------------------------

mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ public <T> Mono<T> sendRequest(String method, Object requestParams, TypeReferenc
238238
});
239239
}).timeout(this.requestTimeout).handle((jsonRpcResponse, sink) -> {
240240
if (jsonRpcResponse.error() != null) {
241+
logger.error("Error handling request: {}", jsonRpcResponse.error());
241242
sink.error(new McpError(jsonRpcResponse.error()));
242243
}
243244
else {

0 commit comments

Comments
 (0)