Skip to content

Commit 8dd7b8b

Browse files
committed
Implement MCP client opentelemetry tracing
finish client spans Inject tracer_provider into servers
1 parent 47acddd commit 8dd7b8b

File tree

8 files changed

+602
-18
lines changed

8 files changed

+602
-18
lines changed

src/mcp/client/client.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from dataclasses import KW_ONLY, dataclass, field
77
from typing import Any
88

9+
from opentelemetry.trace import TracerProvider
10+
911
from mcp.client._memory import InMemoryTransport
1012
from mcp.client._transport import Transport
1113
from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
@@ -95,6 +97,8 @@ async def main():
9597
elicitation_callback: ElicitationFnT | None = None
9698
"""Callback for handling elicitation requests."""
9799

100+
tracer_provider: TracerProvider | None = None
101+
98102
_session: ClientSession | None = field(init=False, default=None)
99103
_exit_stack: AsyncExitStack | None = field(init=False, default=None)
100104
_transport: Transport = field(init=False)
@@ -126,6 +130,7 @@ async def __aenter__(self) -> Client:
126130
message_handler=self.message_handler,
127131
client_info=self.client_info,
128132
elicitation_callback=self.elicitation_callback,
133+
tracer_provider=self.tracer_provider,
129134
)
130135
)
131136

src/mcp/client/session.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import anyio.lowlevel
77
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
8+
from opentelemetry.trace import TracerProvider
89
from pydantic import TypeAdapter
910

1011
from mcp import types
@@ -121,8 +122,11 @@ def __init__(
121122
*,
122123
sampling_capabilities: types.SamplingCapability | None = None,
123124
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
125+
tracer_provider: TracerProvider | None = None,
124126
) -> None:
125-
super().__init__(read_stream, write_stream, read_timeout_seconds=read_timeout_seconds)
127+
super().__init__(
128+
read_stream, write_stream, read_timeout_seconds=read_timeout_seconds, tracer_provider=tracer_provider
129+
)
126130
self._client_info = client_info or DEFAULT_CLIENT_INFO
127131
self._sampling_callback = sampling_callback or _default_sampling_callback
128132
self._sampling_capabilities = sampling_capabilities

src/mcp/server/lowlevel/server.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ async def main():
4646

4747
import anyio
4848
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
49+
from opentelemetry.trace import TracerProvider
4950
from starlette.applications import Starlette
5051
from starlette.middleware import Middleware
5152
from starlette.middleware.authentication import AuthenticationMiddleware
@@ -184,6 +185,7 @@ def __init__(
184185
Awaitable[None],
185186
]
186187
| None = None,
188+
tracer_provider: TracerProvider | None = None,
187189
):
188190
self.name = name
189191
self.version = version
@@ -199,6 +201,7 @@ def __init__(
199201
] = {}
200202
self._experimental_handlers: ExperimentalHandlers[LifespanResultT] | None = None
201203
self._session_manager: StreamableHTTPSessionManager | None = None
204+
self._tracer_provider = tracer_provider
202205
logger.debug("Initializing server %r", name)
203206

204207
# Populate internal handler dicts from on_* kwargs
@@ -380,6 +383,7 @@ async def run(
380383
write_stream,
381384
initialization_options,
382385
stateless=stateless,
386+
tracer_provider=self._tracer_provider,
383387
)
384388
)
385389

src/mcp/server/mcpserver/server.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import anyio
1414
import pydantic_core
15+
from opentelemetry.trace import TracerProvider
1516
from pydantic import BaseModel
1617
from pydantic.networks import AnyUrl
1718
from pydantic_settings import BaseSettings, SettingsConfigDict
@@ -146,6 +147,7 @@ def __init__(
146147
warn_on_duplicate_prompts: bool = True,
147148
lifespan: Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None = None,
148149
auth: AuthSettings | None = None,
150+
tracer_provider: TracerProvider | None = None,
149151
):
150152
self.settings = Settings(
151153
debug=debug,
@@ -178,6 +180,7 @@ def __init__(
178180
# TODO(Marcelo): It seems there's a type mismatch between the lifespan type from an MCPServer and Server.
179181
# We need to create a Lifespan type that is a generic on the server type, like Starlette does.
180182
lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan), # type: ignore
183+
tracer_provider=tracer_provider,
181184
)
182185
# Validate auth configuration
183186
if self.settings.auth is not None:

src/mcp/server/session.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult:
3434
import anyio
3535
import anyio.lowlevel
3636
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
37+
from opentelemetry.trace import TracerProvider
3738
from pydantic import AnyUrl, TypeAdapter
3839

3940
from mcp import types
@@ -83,8 +84,9 @@ def __init__(
8384
write_stream: MemoryObjectSendStream[SessionMessage],
8485
init_options: InitializationOptions,
8586
stateless: bool = False,
87+
tracer_provider: TracerProvider | None = None,
8688
) -> None:
87-
super().__init__(read_stream, write_stream)
89+
super().__init__(read_stream, write_stream, tracer_provider=tracer_provider)
8890
self._stateless = stateless
8991
self._initialization_state = (
9092
InitializationState.Initialized if stateless else InitializationState.NotInitialized

src/mcp/shared/_otel_utils.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import contextlib
2+
from collections.abc import Iterator
3+
4+
from opentelemetry.trace import Span, SpanKind, Tracer
5+
6+
from mcp import types
7+
from mcp.shared.exceptions import MCPError
8+
9+
10+
@contextlib.contextmanager
11+
def mcp_client_span(
12+
tracer: Tracer,
13+
request: types.ClientRequest | types.ServerRequest | types.ClientNotification | types.ServerNotification,
14+
*,
15+
json_rpc_request_id: int | None = None,
16+
) -> Iterator[Span]:
17+
"""Starts an MCP client span as current span
18+
19+
https://github.com/open-telemetry/semantic-conventions/blob/v1.40.0/docs/gen-ai/mcp.md#client
20+
"""
21+
attributes = {"mcp.method.name": request.method}
22+
23+
# When omitted, the request is treated as a notification. Instrumentations SHOULD NOT
24+
# capture this attribute when the id is null or omitted.
25+
if json_rpc_request_id is not None:
26+
attributes["jsonrpc.request.id"] = str(json_rpc_request_id)
27+
28+
target = None
29+
30+
match request:
31+
case types.CallToolRequest():
32+
target = request.params.name
33+
attributes["gen_ai.tool.name"] = target
34+
attributes["gen_ai.operation.name"] = "execute_tool"
35+
case types.GetPromptRequest():
36+
target = request.params.name
37+
attributes["gen_ai.prompt.name"] = target
38+
case (
39+
types.ReadResourceRequest()
40+
| types.SubscribeRequest()
41+
| types.UnsubscribeRequest()
42+
| types.ResourceUpdatedNotification()
43+
):
44+
attributes["mcp.resource.uri"] = request.params.uri
45+
case _:
46+
pass
47+
48+
if target:
49+
span_name = f"{request.method} {target}"
50+
else:
51+
span_name = request.method
52+
53+
with tracer.start_as_current_span(span_name, kind=SpanKind.CLIENT, attributes=attributes) as span:
54+
try:
55+
yield span
56+
except MCPError as e:
57+
if span.is_recording():
58+
if e.code == types.REQUEST_TIMEOUT:
59+
span.set_attribute("error.type", "timeout")
60+
else:
61+
span.set_attribute("error.type", str(e.code))
62+
span.set_attribute("rpc.response.status_code", str(e.code))
63+
raise

src/mcp/shared/session.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1212
from opentelemetry import context as otel_context
1313
from opentelemetry.propagate import extract, inject
14+
from opentelemetry.trace import TracerProvider, get_tracer
1415
from pydantic import BaseModel, TypeAdapter
1516
from typing_extensions import Self
1617

18+
from mcp.shared._otel_utils import mcp_client_span
1719
from mcp.shared.exceptions import MCPError
1820
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage
1921
from mcp.shared.response_router import ResponseRouter
@@ -190,6 +192,8 @@ def __init__(
190192
write_stream: MemoryObjectSendStream[SessionMessage],
191193
# If none, reading will never time out
192194
read_timeout_seconds: float | None = None,
195+
*,
196+
tracer_provider: TracerProvider | None = None,
193197
) -> None:
194198
self._read_stream = read_stream
195199
self._write_stream = write_stream
@@ -200,6 +204,7 @@ def __init__(
200204
self._progress_callbacks = {}
201205
self._response_routers = []
202206
self._exit_stack = AsyncExitStack()
207+
self._tracer = get_tracer("mcp", tracer_provider=tracer_provider)
203208

204209
def add_response_router(self, router: ResponseRouter) -> None:
205210
"""Register a response router to handle responses for non-standard requests.
@@ -256,22 +261,22 @@ async def send_request(
256261
response_stream, response_stream_reader = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](1)
257262
self._response_streams[request_id] = response_stream
258263

259-
# Propagate opentelemetry trace context
260-
self._inject_otel_context(request)
261-
262-
# Set up progress token if progress callback is provided
263-
request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True)
264-
if progress_callback is not None:
265-
# Use request_id as progress token
266-
if "params" not in request_data: # pragma: lax no cover
267-
request_data["params"] = {}
268-
if "_meta" not in request_data["params"]: # pragma: lax no cover
269-
request_data["params"]["_meta"] = {}
270-
request_data["params"]["_meta"]["progressToken"] = request_id
271-
# Store the callback for this request
272-
self._progress_callbacks[request_id] = progress_callback
264+
async def make_request() -> ReceiveResultT:
265+
# Propagate opentelemetry trace context
266+
self._inject_otel_context(request)
267+
268+
# Set up progress token if progress callback is provided
269+
request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True)
270+
if progress_callback is not None:
271+
# Use request_id as progress token
272+
if "params" not in request_data: # pragma: lax no cover
273+
request_data["params"] = {}
274+
if "_meta" not in request_data["params"]: # pragma: lax no cover
275+
request_data["params"]["_meta"] = {}
276+
request_data["params"]["_meta"]["progressToken"] = request_id
277+
# Store the callback for this request
278+
self._progress_callbacks[request_id] = progress_callback
273279

274-
try:
275280
jsonrpc_request = JSONRPCRequest(jsonrpc="2.0", id=request_id, **request_data)
276281
await self._write_stream.send(SessionMessage(message=jsonrpc_request, metadata=metadata))
277282

@@ -291,6 +296,9 @@ async def send_request(
291296
else:
292297
return result_type.model_validate(response_or_error.result, by_name=False)
293298

299+
try:
300+
with mcp_client_span(self._tracer, request, json_rpc_request_id=request_id):
301+
return await make_request()
294302
finally:
295303
self._response_streams.pop(request_id, None)
296304
self._progress_callbacks.pop(request_id, None)
@@ -317,7 +325,9 @@ async def send_notification(
317325
message=jsonrpc_notification,
318326
metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None,
319327
)
320-
await self._write_stream.send(session_message)
328+
329+
with mcp_client_span(self._tracer, notification):
330+
await self._write_stream.send(session_message)
321331

322332
def _inject_otel_context(self, request: SendRequestT | SendNotificationT) -> None:
323333
"""Propagate OpenTelemetry context in `_meta`.

0 commit comments

Comments
 (0)