Skip to content

Commit 47acddd

Browse files
committed
Extract OTel context from _meta in incoming requests
This works for both incoming client->server requests and server->client requests like MCP Sampling.
2 parents b6b0647 + bd43714 commit 47acddd

File tree

7 files changed

+267
-15
lines changed

7 files changed

+267
-15
lines changed

src/mcp/client/sse.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,8 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
132132
async def post_writer(endpoint_url: str):
133133
try:
134134
async with write_stream_reader:
135-
async for session_message in write_stream_reader:
135+
136+
async def handle_message(session_message: SessionMessage) -> None:
136137
logger.debug(f"Sending client message: {session_message}")
137138
response = await client.post(
138139
endpoint_url,
@@ -144,6 +145,13 @@ async def post_writer(endpoint_url: str):
144145
)
145146
response.raise_for_status()
146147
logger.debug(f"Client message sent successfully: {response.status_code}")
148+
149+
async for session_message in write_stream_reader:
150+
async with anyio.create_task_group() as tg_local:
151+
session_message.context.run(
152+
tg_local.start_soon, handle_message, session_message
153+
)
154+
147155
except Exception: # pragma: lax no cover
148156
logger.exception("Error in post_writer")
149157
finally:

src/mcp/client/streamable_http.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,8 @@ async def post_writer(
441441
"""Handle writing requests to the server."""
442442
try:
443443
async with write_stream_reader:
444-
async for session_message in write_stream_reader:
444+
445+
async def handle_message(session_message: SessionMessage) -> None:
445446
message = session_message.message
446447
metadata = (
447448
session_message.metadata
@@ -478,8 +479,12 @@ async def handle_request_async():
478479
else:
479480
await handle_request_async()
480481

481-
except Exception: # pragma: lax no cover
482-
logger.exception("Error in post_writer")
482+
async for session_message in write_stream_reader:
483+
async with anyio.create_task_group() as tg_local:
484+
session_message.context.run(tg_local.start_soon, handle_message, session_message)
485+
486+
except Exception:
487+
logger.exception("Error in post_writer") # pragma: no cover
483488
finally:
484489
await read_stream_writer.aclose()
485490
await write_stream.aclose()

src/mcp/server/lowlevel/server.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,13 @@ async def run(
393393
async for message in session.incoming_messages:
394394
logger.debug("Received message: %s", message)
395395

396-
tg.start_soon(
396+
if isinstance(message, RequestResponder) and message.context is not None:
397+
context = message.context
398+
else:
399+
context = contextvars.copy_context()
400+
401+
context.run(
402+
tg.start_soon,
397403
self._handle_message,
398404
message,
399405
session,

src/mcp/shared/message.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
to support transport-specific features like resumability.
55
"""
66

7+
import contextvars
78
from collections.abc import Awaitable, Callable
8-
from dataclasses import dataclass
9+
from dataclasses import dataclass, field
910
from typing import Any
1011

1112
from mcp.types import JSONRPCMessage, RequestId
@@ -49,4 +50,5 @@ class SessionMessage:
4950
"""A message with specific metadata for transport-specific features."""
5051

5152
message: JSONRPCMessage
53+
context: contextvars.Context = field(default_factory=contextvars.copy_context)
5254
metadata: MessageMetadata = None

src/mcp/shared/session.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import contextvars
34
import logging
45
from collections.abc import Callable
56
from contextlib import AsyncExitStack
@@ -8,7 +9,8 @@
89

910
import anyio
1011
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
11-
from opentelemetry.propagate import inject
12+
from opentelemetry import context as otel_context
13+
from opentelemetry.propagate import extract, inject
1214
from pydantic import BaseModel, TypeAdapter
1315
from typing_extensions import Self
1416

@@ -80,11 +82,13 @@ def __init__(
8082
session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT],
8183
on_complete: Callable[[RequestResponder[ReceiveRequestT, SendResultT]], Any],
8284
message_metadata: MessageMetadata = None,
85+
context: contextvars.Context | None = None,
8386
) -> None:
8487
self.request_id = request_id
8588
self.request_meta = request_meta
8689
self.request = request
8790
self.message_metadata = message_metadata
91+
self.context = context
8892
self._session = session
8993
self._completed = False
9094
self._cancel_scope = anyio.CancelScope()
@@ -363,10 +367,9 @@ def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]:
363367
async def _receive_loop(self) -> None:
364368
async with self._read_stream, self._write_stream:
365369
try:
366-
async for message in self._read_stream:
367-
if isinstance(message, Exception): # pragma: no cover
368-
await self._handle_incoming(message)
369-
elif isinstance(message.message, JSONRPCRequest):
370+
371+
async def handle_message(message: SessionMessage) -> None:
372+
if isinstance(message.message, JSONRPCRequest):
370373
try:
371374
validated_request = self._receive_request_adapter.validate_python(
372375
message.message.model_dump(by_alias=True, mode="json", exclude_none=True),
@@ -379,6 +382,7 @@ async def _receive_loop(self) -> None:
379382
session=self,
380383
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
381384
message_metadata=message.metadata,
385+
context=message.context,
382386
)
383387
self._in_flight[responder.request_id] = responder
384388
await self._received_request(responder)
@@ -427,15 +431,35 @@ async def _receive_loop(self) -> None:
427431
logging.exception("Progress callback raised an exception")
428432
await self._received_notification(notification)
429433
await self._handle_incoming(notification)
430-
except Exception:
434+
except Exception: # pragma: lax no cover
431435
# For other validation errors, log and continue
432-
logging.warning( # pragma: no cover
436+
logging.warning(
433437
f"Failed to validate notification:. Message was: {message.message}",
434438
exc_info=True,
435439
)
436440
else: # Response or error
437441
await self._handle_response(message)
438442

443+
async def _handle_message_with_otel(message: SessionMessage) -> None:
444+
meta = None
445+
if isinstance(message.message, (JSONRPCRequest | JSONRPCNotification)) and message.message.params:
446+
meta = message.message.params.get("_meta")
447+
448+
extracted_ctx = extract(meta) if meta else None
449+
otel_token = otel_context.attach(extracted_ctx) if extracted_ctx else None
450+
try:
451+
await handle_message(message)
452+
finally:
453+
if otel_token:
454+
otel_context.detach(otel_token)
455+
456+
async for message in self._read_stream:
457+
if isinstance(message, Exception): # pragma: no cover
458+
await self._handle_incoming(message)
459+
else:
460+
async with anyio.create_task_group() as tg:
461+
message.context.run(tg.start_soon, _handle_message_with_otel, message)
462+
439463
except anyio.ClosedResourceError:
440464
# This is expected when the client disconnects abruptly.
441465
# Without this handler, the exception would propagate up and

tests/shared/test_otel_context_meta.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ async def tool_with_sampling(topic: str, ctx: Context[ServerSession, None]) -> s
7979
)
8080
return "ran sampling"
8181

82+
@mcp.tool()
83+
async def tool_that_checks_trace_context() -> str:
84+
"""Returns current span details to verify parent propagation."""
85+
return trace.format_trace_id(trace.get_current_span().get_span_context().trace_id)
86+
8287
return mcp
8388

8489

@@ -106,8 +111,9 @@ async def patched_client(server: MCPServer, monkeypatch: pytest.MonkeyPatch):
106111
async def sampling_callback(
107112
context: RequestContext[ClientSession], params: types.CreateMessageRequestParams
108113
) -> types.CreateMessageResult:
114+
current_trace_id = trace.format_trace_id(trace.get_current_span().get_span_context().trace_id)
109115
return types.CreateMessageResult(
110-
role="assistant", content=TextContent(type="text", text="hello"), model="foomodel"
116+
role="assistant", content=TextContent(type="text", text=current_trace_id), model="foomodel"
111117
)
112118

113119
async with create_client_server_memory_streams() as (client_streams, server_streams):
@@ -231,6 +237,17 @@ async def test_with_existing_meta(
231237
assert patched_client.client_to_server_messages == expect_client_to_server
232238

233239

240+
@pytest.mark.anyio
241+
async def test_trace_context_extraction(patched_client: PatchedClient):
242+
"""Test that OTEL context is successfully extracted on the receiving end."""
243+
244+
with trace.use_span(SPAN_IN_CLIENT):
245+
result = await patched_client.session.call_tool("tool_that_checks_trace_context")
246+
247+
# Verify that SPAN_IN_CLIENT was extracted and made it through to the handler
248+
assert result.content[0] == snapshot(TextContent(text="00000000000000000000000000000123"))
249+
250+
234251
@pytest.mark.anyio
235252
async def test_list_tools_with_span(patched_client: PatchedClient):
236253
"""Test that OTEL context is injected into the _meta field of a tools/list request."""
@@ -316,7 +333,11 @@ async def test_server_side_sampling_propagates_to_client(patched_client: Patched
316333
JSONRPCResponse(
317334
jsonrpc="2.0",
318335
id=0,
319-
result={"role": "assistant", "content": {"type": "text", "text": "hello"}, "model": "foomodel"},
336+
result={
337+
"role": "assistant",
338+
"content": {"type": "text", "text": "00000000000000000000000000000456"},
339+
"model": "foomodel",
340+
},
320341
),
321342
]
322343
)

0 commit comments

Comments
 (0)