Skip to content

feat: Propagate Context to eager connect via McpClientSession #339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Supplier;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -42,6 +41,7 @@
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;
import reactor.util.context.ContextView;

/**
* The Model Context Protocol (MCP) client implementation that provides asynchronous
Expand Down Expand Up @@ -161,7 +161,7 @@ public class McpAsyncClient {
* The MCP session supplier that manages bidirectional JSON-RPC communication between
* clients and servers.
*/
private final Supplier<McpClientSession> sessionSupplier;
private final Function<ContextView, McpClientSession> sessionSupplier;

/**
* Create a new McpAsyncClient with the given transport and session request-response
Expand Down Expand Up @@ -268,9 +268,8 @@ public class McpAsyncClient {
asyncLoggingNotificationHandler(loggingConsumersFinal));

this.transport.setExceptionHandler(this::handleException);
this.sessionSupplier = () -> new McpClientSession(requestTimeout, transport, requestHandlers,
notificationHandlers);

this.sessionSupplier = ctx -> new McpClientSession(requestTimeout, transport, requestHandlers,
notificationHandlers, con -> con.contextWrite(ctx));
}

private void handleException(Throwable t) {
Expand Down Expand Up @@ -401,9 +400,8 @@ public Mono<McpSchema.InitializeResult> initialize() {
return withSession("by explicit API call", init -> Mono.just(init.get()));
}

private Mono<McpSchema.InitializeResult> doInitialize(Initialization initialization) {

initialization.setMcpClientSession(this.sessionSupplier.get());
private Mono<McpSchema.InitializeResult> doInitialize(Initialization initialization, ContextView ctx) {
initialization.setMcpClientSession(this.sessionSupplier.apply(ctx));

McpClientSession mcpClientSession = initialization.mcpSession();

Expand Down Expand Up @@ -493,14 +491,14 @@ Mono<Void> closeGracefully() {
* @return A Mono that completes with the result of the operation
*/
private <T> Mono<T> withSession(String actionName, Function<Initialization, Mono<T>> operation) {
return Mono.defer(() -> {
return Mono.deferContextual(ctx -> {
Initialization newInit = Initialization.create();
Initialization previous = this.initializationRef.compareAndExchange(null, newInit);

boolean needsToInitialize = previous == null;
logger.debug(needsToInitialize ? "Initialization process started" : "Joining previous initialization");

Mono<McpSchema.InitializeResult> initializationJob = needsToInitialize ? doInitialize(newInit)
Mono<McpSchema.InitializeResult> initializationJob = needsToInitialize ? doInitialize(newInit, ctx)
: previous.await();

return initializationJob.map(initializeResult -> this.initializationRef.get())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import com.fasterxml.jackson.core.type.TypeReference;
import io.modelcontextprotocol.util.Assert;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;
Expand All @@ -16,6 +17,7 @@
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;

/**
* Default implementation of the MCP (Model Context Protocol) session that manages
Expand Down Expand Up @@ -99,9 +101,27 @@ public interface NotificationHandler {
* @param transport Transport implementation for message exchange
* @param requestHandlers Map of method names to request handlers
* @param notificationHandlers Map of method names to notification handlers
* @deprecated Use
* {@link #McpClientSession(Duration, McpClientTransport, Map, Map, Function)}
*/
@Deprecated
public McpClientSession(Duration requestTimeout, McpClientTransport transport,
Map<String, RequestHandler<?>> requestHandlers, Map<String, NotificationHandler> notificationHandlers) {
this(requestTimeout, transport, requestHandlers, notificationHandlers, Function.identity());
}

/**
* Creates a new McpClientSession with the specified configuration and handlers.
* @param requestTimeout Duration to wait for responses
* @param transport Transport implementation for message exchange
* @param requestHandlers Map of method names to request handlers
* @param notificationHandlers Map of method names to notification handlers
* @param connectHook Hook that allows transforming the connection Publisher prior to
* subscribing
*/
public McpClientSession(Duration requestTimeout, McpClientTransport transport,
Map<String, RequestHandler<?>> requestHandlers, Map<String, NotificationHandler> notificationHandlers,
Function<? super Mono<Void>, ? extends Publisher<Void>> connectHook) {

Assert.notNull(requestTimeout, "The requestTimeout can not be null");
Assert.notNull(transport, "The transport can not be null");
Expand All @@ -113,7 +133,7 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport,
this.requestHandlers.putAll(requestHandlers);
this.notificationHandlers.putAll(notificationHandlers);

this.transport.connect(mono -> mono.doOnNext(this::handle)).subscribe();
this.transport.connect(mono -> mono.doOnNext(this::handle)).transform(connectHook).subscribe();
}

private void dismissPendingResponses() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package io.modelcontextprotocol.client;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpSchema;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;

import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;

import static org.assertj.core.api.Assertions.assertThatCode;

class McpAsyncClientTests {

public static final McpSchema.Implementation MOCK_SERVER_INFO = new McpSchema.Implementation("test-server",
"1.0.0");

public static final McpSchema.ServerCapabilities MOCK_SERVER_CAPABILITIES = McpSchema.ServerCapabilities.builder()
.build();

public static final McpSchema.InitializeResult MOCK_INIT_RESULT = new McpSchema.InitializeResult(
McpSchema.LATEST_PROTOCOL_VERSION, MOCK_SERVER_CAPABILITIES, MOCK_SERVER_INFO, "Test instructions");

private static final String CONTEXT_KEY = "context.key";

private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();

@Test
void validateContextPassedToTransportConnect() {
McpClientTransport transport = new McpClientTransport() {
Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler;

final AtomicReference<String> contextValue = new AtomicReference<>();

@Override
public Mono<Void> connect(
Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
return Mono.deferContextual(ctx -> {
this.handler = handler;
if (ctx.hasKey(CONTEXT_KEY)) {
this.contextValue.set(ctx.get(CONTEXT_KEY));
}
return Mono.empty();
});
}

@Override
public Mono<Void> closeGracefully() {
return Mono.empty();
}

@Override
public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
if (!"hello".equals(this.contextValue.get())) {
return Mono.error(new RuntimeException("Context value not propagated via #connect method"));
}
// We're only interested in handling the init request to provide an init
// response
if (!(message instanceof McpSchema.JSONRPCRequest)) {
return Mono.empty();
}
McpSchema.JSONRPCResponse initResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION,
((McpSchema.JSONRPCRequest) message).id(), MOCK_INIT_RESULT, null);
return handler.apply(Mono.just(initResponse)).then();
}

@Override
public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
return OBJECT_MAPPER.convertValue(data, typeRef);
}
};

assertThatCode(() -> {
McpAsyncClient client = McpClient.async(transport).build();
client.initialize().contextWrite(ctx -> ctx.put(CONTEXT_KEY, "hello")).block();
}).doesNotThrowAnyException();
}

}