Skip to content

Commit

Permalink
Send the initialization notification
Browse files Browse the repository at this point in the history
  • Loading branch information
jmartisk committed Feb 5, 2025
1 parent 1542582 commit 62987b4
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 16 deletions.
4 changes: 2 additions & 2 deletions docs/modules/ROOT/pages/includes/attributes.adoc
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
:project-version: 0.24.0.CR1
:langchain4j-version: 1.0.0-alpha1
:examples-dir: ./../examples/
:langchain4j-version: 1.0.0-alpha2-SNAPSHOT
:examples-dir: ./../examples/
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ public class MockHttpMcpServer {
private volatile SseEventSink sink;
private volatile Sse sse;
private final ObjectMapper objectMapper = new ObjectMapper();
private volatile boolean initializationNotificationReceived = false;

@Inject
ScheduledExecutorService scheduledExecutorService;
Expand All @@ -111,12 +112,21 @@ public Response post(JsonNode message) {
if (method.equals("notifications/cancelled")) {
return Response.ok().build();
}
if (method.equals("notifications/initialized")) {
if (initializationNotificationReceived) {
return Response.serverError().entity("Duplicate 'notifications/initialized' message").build();
}
initializationNotificationReceived = true;
return Response.ok().build();
}
String operationId = message.get("id").asText();
if (method.equals("initialize")) {
initialize(operationId);
} else if (method.equals("tools/list")) {
ensureInitialized();
listTools(operationId);
} else if (method.equals("tools/call")) {
ensureInitialized();
if (message.get("params").get("name").asText().equals("add")) {
executeAddOperation(message, operationId);
} else if (message.get("params").get("name").asText().equals("longRunningOperation")) {
Expand All @@ -128,6 +138,13 @@ public Response post(JsonNode message) {
return Response.accepted().build();
}

// throw an exception if we haven't received the 'notifications/initialized' message yet
private void ensureInitialized() {
if (!initializationNotificationReceived) {
throw new IllegalStateException("The client has not sent the 'notifications/initialized' message yet");
}
}

private void listTools(String operationId) {
String response = TOOLS_LIST_RESPONSE.formatted(operationId);
sink.send(sse.newEventBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.fasterxml.jackson.databind.JsonNode;

import dev.langchain4j.mcp.client.protocol.CancellationNotification;
import dev.langchain4j.mcp.client.protocol.InitializationNotification;
import dev.langchain4j.mcp.client.protocol.McpCallToolRequest;
import dev.langchain4j.mcp.client.protocol.McpClientMessage;
import dev.langchain4j.mcp.client.protocol.McpInitializeRequest;
Expand All @@ -24,6 +25,7 @@
import dev.langchain4j.mcp.client.transport.McpTransport;
import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory;
import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder;
import io.smallrye.mutiny.Uni;

public class QuarkusHttpMcpTransport implements McpTransport {

Expand Down Expand Up @@ -79,37 +81,49 @@ public void start(McpOperationHandler messageHandler) {

@Override
public CompletableFuture<JsonNode> initialize(McpInitializeRequest request) {
return execute(request, request.getId());
return execute(request, request.getId()).onItem()
.transformToUni(
response -> execute(new InitializationNotification(), null).onItem().transform(ignored -> response))
.subscribeAsCompletionStage();
}

@Override
public CompletableFuture<JsonNode> listTools(McpListToolsRequest operation) {
return execute(operation, operation.getId());
return execute(operation, operation.getId()).subscribeAsCompletionStage();
}

@Override
public void cancelOperation(long operationId) {
CancellationNotification cancellationNotification = new CancellationNotification(operationId, "Timeout");
execute(cancellationNotification, null);
execute(cancellationNotification, null).subscribe().with(ignored -> {
});
}

@Override
public CompletableFuture<JsonNode> executeTool(McpCallToolRequest operation) {
return execute(operation, operation.getId());
return execute(operation, operation.getId()).subscribeAsCompletionStage();
}

private CompletableFuture<JsonNode> execute(McpClientMessage request, Long id) {
private Uni<JsonNode> execute(McpClientMessage request, Long id) {
CompletableFuture<JsonNode> future = new CompletableFuture<>();
Uni<JsonNode> uni = Uni.createFrom().completionStage(future);
if (id != null) {
operationHandler.startOperation(id, future);
}
postEndpoint.post(request).onItem().invoke(response -> {
int statusCode = response.getStatus();
if (!isExpectedStatusCode(statusCode)) {
throw new RuntimeException("Unexpected status code: " + statusCode);
}
}).subscribeAsCompletionStage();
return future;
postEndpoint.post(request)
.onFailure().invoke(future::completeExceptionally)
.onItem().invoke(response -> {
int statusCode = response.getStatus();
if (!isExpectedStatusCode(statusCode)) {
future.completeExceptionally(new RuntimeException("Unexpected status code: " + statusCode));
}
// For messages with null ID, we don't wait for a response in the SSE channel,
// so if the server accepted the request, we consider the operation done
if (id == null) {
future.complete(null);
}
}).subscribeAsCompletionStage();
return uni;
}

private boolean isExpectedStatusCode(int statusCode) {
Expand Down
4 changes: 2 additions & 2 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
<properties>
<quarkus.version>3.15.2</quarkus.version>
<quarkus.extension-processor.version>3.18.0</quarkus.extension-processor.version>
<langchain4j.version>1.0.0-alpha1</langchain4j.version>
<langchain4j-embeddings.version>1.0.0-alpha1</langchain4j-embeddings.version>
<langchain4j.version>1.0.0-beta1</langchain4j.version>
<langchain4j-embeddings.version>1.0.0-beta1</langchain4j-embeddings.version>
<quarkus-antora.version>1.0.2</quarkus-antora.version>
<quarkus-poi.version>2.0.4</quarkus-poi.version> <!-- we need to use this version because langchain4j uses POI 5.2.3 instead of 5.2.5 and the substitution needed is different in the two versions -->
<assertj.version>3.27.3</assertj.version>
Expand Down

0 comments on commit 62987b4

Please sign in to comment.