11
11
import java .util .concurrent .atomic .AtomicReference ;
12
12
import java .util .function .Function ;
13
13
14
- import org .slf4j .Logger ;
15
- import org .slf4j .LoggerFactory ;
16
-
17
14
import io .modelcontextprotocol .spec .McpClientSession ;
18
15
import io .modelcontextprotocol .spec .McpError ;
19
16
import io .modelcontextprotocol .spec .McpSchema ;
20
17
import io .modelcontextprotocol .spec .McpTransportSessionNotFoundException ;
21
18
import io .modelcontextprotocol .util .Assert ;
19
+ import org .slf4j .Logger ;
20
+ import org .slf4j .LoggerFactory ;
22
21
import reactor .core .publisher .Mono ;
23
22
import reactor .core .publisher .Sinks ;
24
23
import reactor .util .context .ContextView ;
@@ -99,21 +98,30 @@ class LifecycleInitializer {
99
98
*/
100
99
private final Duration initializationTimeout ;
101
100
101
+ /**
102
+ * Post-initialization hook to perform additional operations after every successful
103
+ * initialization.
104
+ */
105
+ private final Function <Initialization , Mono <Void >> postInitializationHook ;
106
+
102
107
public LifecycleInitializer (McpSchema .ClientCapabilities clientCapabilities , McpSchema .Implementation clientInfo ,
103
108
List <String > protocolVersions , Duration initializationTimeout ,
104
- Function <ContextView , McpClientSession > sessionSupplier ) {
109
+ Function <ContextView , McpClientSession > sessionSupplier ,
110
+ Function <Initialization , Mono <Void >> postInitializationHook ) {
105
111
106
112
Assert .notNull (sessionSupplier , "Session supplier must not be null" );
107
113
Assert .notNull (clientCapabilities , "Client capabilities must not be null" );
108
114
Assert .notNull (clientInfo , "Client info must not be null" );
109
115
Assert .notEmpty (protocolVersions , "Protocol versions must not be empty" );
110
116
Assert .notNull (initializationTimeout , "Initialization timeout must not be null" );
117
+ Assert .notNull (postInitializationHook , "Post-initialization hook must not be null" );
111
118
112
119
this .sessionSupplier = sessionSupplier ;
113
120
this .clientCapabilities = clientCapabilities ;
114
121
this .clientInfo = clientInfo ;
115
122
this .protocolVersions = Collections .unmodifiableList (new ArrayList <>(protocolVersions ));
116
123
this .initializationTimeout = initializationTimeout ;
124
+ this .postInitializationHook = postInitializationHook ;
117
125
}
118
126
119
127
/**
@@ -148,10 +156,6 @@ interface Initialization {
148
156
149
157
}
150
158
151
- /**
152
- * Default implementation of the {@link Initialization} interface that manages the MCP
153
- * client initialization process.
154
- */
155
159
private static class DefaultInitialization implements Initialization {
156
160
157
161
/**
@@ -199,29 +203,20 @@ private void setMcpClientSession(McpClientSession mcpClientSession) {
199
203
this .mcpClientSession .set (mcpClientSession );
200
204
}
201
205
202
- /**
203
- * Returns a Mono that completes when the MCP client initialization is complete.
204
- * This allows subscribers to wait for the initialization to finish before
205
- * proceeding with further operations.
206
- * @return A Mono that emits the result of the MCP initialization process
207
- */
208
206
private Mono <McpSchema .InitializeResult > await () {
209
207
return this .initSink .asMono ();
210
208
}
211
209
212
- /**
213
- * Completes the initialization process with the given result. It caches the
214
- * result and emits it to all subscribers waiting for the initialization to
215
- * complete.
216
- * @param initializeResult The result of the MCP initialization process
217
- */
218
210
private void complete (McpSchema .InitializeResult initializeResult ) {
219
- // first ensure the result is cached
220
- this .result .set (initializeResult );
221
211
// inform all the subscribers waiting for the initialization
222
212
this .initSink .emitValue (initializeResult , Sinks .EmitFailureHandler .FAIL_FAST );
223
213
}
224
214
215
+ private void cacheResult (McpSchema .InitializeResult initializeResult ) {
216
+ // first ensure the result is cached
217
+ this .result .set (initializeResult );
218
+ }
219
+
225
220
private void error (Throwable t ) {
226
221
this .initSink .emitError (t , Sinks .EmitFailureHandler .FAIL_FAST );
227
222
}
@@ -263,7 +258,7 @@ public void handleException(Throwable t) {
263
258
}
264
259
// Providing an empty operation since we are only interested in triggering
265
260
// the implicit initialization step.
266
- withIntitialization ("re-initializing" , result -> Mono .empty ()).subscribe ();
261
+ this . withInitialization ("re-initializing" , result -> Mono .empty ()).subscribe ();
267
262
}
268
263
}
269
264
@@ -275,16 +270,16 @@ public void handleException(Throwable t) {
275
270
* @param operation The operation to execute when the client is initialized
276
271
* @return A Mono that completes with the result of the operation
277
272
*/
278
- public <T > Mono <T > withIntitialization (String actionName , Function <Initialization , Mono <T >> operation ) {
273
+ public <T > Mono <T > withInitialization (String actionName , Function <Initialization , Mono <T >> operation ) {
279
274
return Mono .deferContextual (ctx -> {
280
275
DefaultInitialization newInit = new DefaultInitialization ();
281
276
DefaultInitialization previous = this .initializationRef .compareAndExchange (null , newInit );
282
277
283
278
boolean needsToInitialize = previous == null ;
284
279
logger .debug (needsToInitialize ? "Initialization process started" : "Joining previous initialization" );
285
280
286
- Mono <McpSchema .InitializeResult > initializationJob = needsToInitialize ? doInitialize ( newInit , ctx )
287
- : previous .await ();
281
+ Mono <McpSchema .InitializeResult > initializationJob = needsToInitialize
282
+ ? this . doInitialize ( newInit , this . postInitializationHook , ctx ) : previous .await ();
288
283
289
284
return initializationJob .map (initializeResult -> this .initializationRef .get ())
290
285
.timeout (this .initializationTimeout )
@@ -296,7 +291,9 @@ public <T> Mono<T> withIntitialization(String actionName, Function<Initializatio
296
291
});
297
292
}
298
293
299
- private Mono <McpSchema .InitializeResult > doInitialize (DefaultInitialization initialization , ContextView ctx ) {
294
+ private Mono <McpSchema .InitializeResult > doInitialize (DefaultInitialization initialization ,
295
+ Function <Initialization , Mono <Void >> postInitOperation , ContextView ctx ) {
296
+
300
297
initialization .setMcpClientSession (this .sessionSupplier .apply (ctx ));
301
298
302
299
McpClientSession mcpClientSession = initialization .mcpSession ();
@@ -323,6 +320,9 @@ private Mono<McpSchema.InitializeResult> doInitialize(DefaultInitialization init
323
320
324
321
return mcpClientSession .sendNotification (McpSchema .METHOD_NOTIFICATION_INITIALIZED , null )
325
322
.thenReturn (initializeResult );
323
+ }).flatMap (initializeResult -> {
324
+ initialization .cacheResult (initializeResult );
325
+ return postInitOperation .apply (initialization ).thenReturn (initializeResult );
326
326
}).doOnNext (initialization ::complete ).onErrorResume (ex -> {
327
327
initialization .error (ex );
328
328
return Mono .error (ex );
0 commit comments