Skip to content

Commit 4f93ce8

Browse files
authored
feat: Propagate Context to eager connect via McpClientSession (#339)
* feat: Propagate Context to eager connect via McpClientSession In order to allow the initial connection to have contextual information in the reactive chain, the McpClientSession should be able to transform the McpClientTransport#connect result, e.g. to attach Context items. This change introduces a new constructor for sessions that makes it possible. Signed-off-by: Dariusz Jędrzejczyk <dariusz.jedrzejczyk@broadcom.com> --------- Signed-off-by: Dariusz Jędrzejczyk <dariusz.jedrzejczyk@broadcom.com>
1 parent 9ebff0c commit 4f93ce8

File tree

3 files changed

+110
-11
lines changed

3 files changed

+110
-11
lines changed

mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import java.util.concurrent.ConcurrentHashMap;
1212
import java.util.concurrent.atomic.AtomicReference;
1313
import java.util.function.Function;
14-
import java.util.function.Supplier;
1514

1615
import org.slf4j.Logger;
1716
import org.slf4j.LoggerFactory;
@@ -42,6 +41,7 @@
4241
import reactor.core.publisher.Flux;
4342
import reactor.core.publisher.Mono;
4443
import reactor.core.publisher.Sinks;
44+
import reactor.util.context.ContextView;
4545

4646
/**
4747
* The Model Context Protocol (MCP) client implementation that provides asynchronous
@@ -161,7 +161,7 @@ public class McpAsyncClient {
161161
* The MCP session supplier that manages bidirectional JSON-RPC communication between
162162
* clients and servers.
163163
*/
164-
private final Supplier<McpClientSession> sessionSupplier;
164+
private final Function<ContextView, McpClientSession> sessionSupplier;
165165

166166
/**
167167
* Create a new McpAsyncClient with the given transport and session request-response
@@ -268,9 +268,8 @@ public class McpAsyncClient {
268268
asyncLoggingNotificationHandler(loggingConsumersFinal));
269269

270270
this.transport.setExceptionHandler(this::handleException);
271-
this.sessionSupplier = () -> new McpClientSession(requestTimeout, transport, requestHandlers,
272-
notificationHandlers);
273-
271+
this.sessionSupplier = ctx -> new McpClientSession(requestTimeout, transport, requestHandlers,
272+
notificationHandlers, con -> con.contextWrite(ctx));
274273
}
275274

276275
private void handleException(Throwable t) {
@@ -401,9 +400,8 @@ public Mono<McpSchema.InitializeResult> initialize() {
401400
return withSession("by explicit API call", init -> Mono.just(init.get()));
402401
}
403402

404-
private Mono<McpSchema.InitializeResult> doInitialize(Initialization initialization) {
405-
406-
initialization.setMcpClientSession(this.sessionSupplier.get());
403+
private Mono<McpSchema.InitializeResult> doInitialize(Initialization initialization, ContextView ctx) {
404+
initialization.setMcpClientSession(this.sessionSupplier.apply(ctx));
407405

408406
McpClientSession mcpClientSession = initialization.mcpSession();
409407

@@ -493,14 +491,14 @@ Mono<Void> closeGracefully() {
493491
* @return A Mono that completes with the result of the operation
494492
*/
495493
private <T> Mono<T> withSession(String actionName, Function<Initialization, Mono<T>> operation) {
496-
return Mono.defer(() -> {
494+
return Mono.deferContextual(ctx -> {
497495
Initialization newInit = Initialization.create();
498496
Initialization previous = this.initializationRef.compareAndExchange(null, newInit);
499497

500498
boolean needsToInitialize = previous == null;
501499
logger.debug(needsToInitialize ? "Initialization process started" : "Joining previous initialization");
502500

503-
Mono<McpSchema.InitializeResult> initializationJob = needsToInitialize ? doInitialize(newInit)
501+
Mono<McpSchema.InitializeResult> initializationJob = needsToInitialize ? doInitialize(newInit, ctx)
504502
: previous.await();
505503

506504
return initializationJob.map(initializeResult -> this.initializationRef.get())

mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import com.fasterxml.jackson.core.type.TypeReference;
88
import io.modelcontextprotocol.util.Assert;
9+
import org.reactivestreams.Publisher;
910
import org.slf4j.Logger;
1011
import org.slf4j.LoggerFactory;
1112
import reactor.core.publisher.Mono;
@@ -16,6 +17,7 @@
1617
import java.util.UUID;
1718
import java.util.concurrent.ConcurrentHashMap;
1819
import java.util.concurrent.atomic.AtomicLong;
20+
import java.util.function.Function;
1921

2022
/**
2123
* Default implementation of the MCP (Model Context Protocol) session that manages
@@ -99,9 +101,27 @@ public interface NotificationHandler {
99101
* @param transport Transport implementation for message exchange
100102
* @param requestHandlers Map of method names to request handlers
101103
* @param notificationHandlers Map of method names to notification handlers
104+
* @deprecated Use
105+
* {@link #McpClientSession(Duration, McpClientTransport, Map, Map, Function)}
102106
*/
107+
@Deprecated
103108
public McpClientSession(Duration requestTimeout, McpClientTransport transport,
104109
Map<String, RequestHandler<?>> requestHandlers, Map<String, NotificationHandler> notificationHandlers) {
110+
this(requestTimeout, transport, requestHandlers, notificationHandlers, Function.identity());
111+
}
112+
113+
/**
114+
* Creates a new McpClientSession with the specified configuration and handlers.
115+
* @param requestTimeout Duration to wait for responses
116+
* @param transport Transport implementation for message exchange
117+
* @param requestHandlers Map of method names to request handlers
118+
* @param notificationHandlers Map of method names to notification handlers
119+
* @param connectHook Hook that allows transforming the connection Publisher prior to
120+
* subscribing
121+
*/
122+
public McpClientSession(Duration requestTimeout, McpClientTransport transport,
123+
Map<String, RequestHandler<?>> requestHandlers, Map<String, NotificationHandler> notificationHandlers,
124+
Function<? super Mono<Void>, ? extends Publisher<Void>> connectHook) {
105125

106126
Assert.notNull(requestTimeout, "The requestTimeout can not be null");
107127
Assert.notNull(transport, "The transport can not be null");
@@ -113,7 +133,7 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport,
113133
this.requestHandlers.putAll(requestHandlers);
114134
this.notificationHandlers.putAll(notificationHandlers);
115135

116-
this.transport.connect(mono -> mono.doOnNext(this::handle)).subscribe();
136+
this.transport.connect(mono -> mono.doOnNext(this::handle)).transform(connectHook).subscribe();
117137
}
118138

119139
private void dismissPendingResponses() {
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
package io.modelcontextprotocol.client;
2+
3+
import com.fasterxml.jackson.core.type.TypeReference;
4+
import com.fasterxml.jackson.databind.ObjectMapper;
5+
import io.modelcontextprotocol.spec.McpClientTransport;
6+
import io.modelcontextprotocol.spec.McpSchema;
7+
import org.junit.jupiter.api.Test;
8+
import reactor.core.publisher.Mono;
9+
10+
import java.util.concurrent.atomic.AtomicReference;
11+
import java.util.function.Function;
12+
13+
import static org.assertj.core.api.Assertions.assertThatCode;
14+
15+
class McpAsyncClientTests {
16+
17+
public static final McpSchema.Implementation MOCK_SERVER_INFO = new McpSchema.Implementation("test-server",
18+
"1.0.0");
19+
20+
public static final McpSchema.ServerCapabilities MOCK_SERVER_CAPABILITIES = McpSchema.ServerCapabilities.builder()
21+
.build();
22+
23+
public static final McpSchema.InitializeResult MOCK_INIT_RESULT = new McpSchema.InitializeResult(
24+
McpSchema.LATEST_PROTOCOL_VERSION, MOCK_SERVER_CAPABILITIES, MOCK_SERVER_INFO, "Test instructions");
25+
26+
private static final String CONTEXT_KEY = "context.key";
27+
28+
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
29+
30+
@Test
31+
void validateContextPassedToTransportConnect() {
32+
McpClientTransport transport = new McpClientTransport() {
33+
Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler;
34+
35+
final AtomicReference<String> contextValue = new AtomicReference<>();
36+
37+
@Override
38+
public Mono<Void> connect(
39+
Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
40+
return Mono.deferContextual(ctx -> {
41+
this.handler = handler;
42+
if (ctx.hasKey(CONTEXT_KEY)) {
43+
this.contextValue.set(ctx.get(CONTEXT_KEY));
44+
}
45+
return Mono.empty();
46+
});
47+
}
48+
49+
@Override
50+
public Mono<Void> closeGracefully() {
51+
return Mono.empty();
52+
}
53+
54+
@Override
55+
public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
56+
if (!"hello".equals(this.contextValue.get())) {
57+
return Mono.error(new RuntimeException("Context value not propagated via #connect method"));
58+
}
59+
// We're only interested in handling the init request to provide an init
60+
// response
61+
if (!(message instanceof McpSchema.JSONRPCRequest)) {
62+
return Mono.empty();
63+
}
64+
McpSchema.JSONRPCResponse initResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION,
65+
((McpSchema.JSONRPCRequest) message).id(), MOCK_INIT_RESULT, null);
66+
return handler.apply(Mono.just(initResponse)).then();
67+
}
68+
69+
@Override
70+
public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
71+
return OBJECT_MAPPER.convertValue(data, typeRef);
72+
}
73+
};
74+
75+
assertThatCode(() -> {
76+
McpAsyncClient client = McpClient.async(transport).build();
77+
client.initialize().contextWrite(ctx -> ctx.put(CONTEXT_KEY, "hello")).block();
78+
}).doesNotThrowAnyException();
79+
}
80+
81+
}

0 commit comments

Comments
 (0)