11from __future__ import annotations
22
3+ import contextvars
34import logging
45from collections .abc import Callable
56from contextlib import AsyncExitStack
89
910import anyio
1011from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
11- from opentelemetry .propagate import inject
12+ from opentelemetry import context as otel_context
13+ from opentelemetry .propagate import extract , inject
1214from pydantic import BaseModel , TypeAdapter
1315from typing_extensions import Self
1416
@@ -80,11 +82,13 @@ def __init__(
8082 session : BaseSession [SendRequestT , SendNotificationT , SendResultT , ReceiveRequestT , ReceiveNotificationT ],
8183 on_complete : Callable [[RequestResponder [ReceiveRequestT , SendResultT ]], Any ],
8284 message_metadata : MessageMetadata = None ,
85+ context : contextvars .Context | None = None ,
8386 ) -> None :
8487 self .request_id = request_id
8588 self .request_meta = request_meta
8689 self .request = request
8790 self .message_metadata = message_metadata
91+ self .context = context
8892 self ._session = session
8993 self ._completed = False
9094 self ._cancel_scope = anyio .CancelScope ()
@@ -363,10 +367,9 @@ def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]:
363367 async def _receive_loop (self ) -> None :
364368 async with self ._read_stream , self ._write_stream :
365369 try :
366- async for message in self ._read_stream :
367- if isinstance (message , Exception ): # pragma: no cover
368- await self ._handle_incoming (message )
369- elif isinstance (message .message , JSONRPCRequest ):
370+
371+ async def handle_message (message : SessionMessage ) -> None :
372+ if isinstance (message .message , JSONRPCRequest ):
370373 try :
371374 validated_request = self ._receive_request_adapter .validate_python (
372375 message .message .model_dump (by_alias = True , mode = "json" , exclude_none = True ),
@@ -379,6 +382,7 @@ async def _receive_loop(self) -> None:
379382 session = self ,
380383 on_complete = lambda r : self ._in_flight .pop (r .request_id , None ),
381384 message_metadata = message .metadata ,
385+ context = message .context ,
382386 )
383387 self ._in_flight [responder .request_id ] = responder
384388 await self ._received_request (responder )
@@ -427,15 +431,35 @@ async def _receive_loop(self) -> None:
427431 logging .exception ("Progress callback raised an exception" )
428432 await self ._received_notification (notification )
429433 await self ._handle_incoming (notification )
430- except Exception :
434+ except Exception : # pragma: lax no cover
431435 # For other validation errors, log and continue
432- logging .warning ( # pragma: no cover
436+ logging .warning (
433437 f"Failed to validate notification:. Message was: { message .message } " ,
434438 exc_info = True ,
435439 )
436440 else : # Response or error
437441 await self ._handle_response (message )
438442
443+ async def _handle_message_with_otel (message : SessionMessage ) -> None :
444+ meta = None
445+ if isinstance (message .message , (JSONRPCRequest | JSONRPCNotification )) and message .message .params :
446+ meta = message .message .params .get ("_meta" )
447+
448+ extracted_ctx = extract (meta ) if meta else None
449+ otel_token = otel_context .attach (extracted_ctx ) if extracted_ctx else None
450+ try :
451+ await handle_message (message )
452+ finally :
453+ if otel_token :
454+ otel_context .detach (otel_token )
455+
456+ async for message in self ._read_stream :
457+ if isinstance (message , Exception ): # pragma: no cover
458+ await self ._handle_incoming (message )
459+ else :
460+ async with anyio .create_task_group () as tg :
461+ message .context .run (tg .start_soon , _handle_message_with_otel , message )
462+
439463 except anyio .ClosedResourceError :
440464 # This is expected when the client disconnects abruptly.
441465 # Without this handler, the exception would propagate up and
0 commit comments