diff --git a/Directory.Packages.props b/Directory.Packages.props
index 554361cbe..3bb00dbef 100644
--- a/Directory.Packages.props
+++ b/Directory.Packages.props
@@ -1,38 +1,63 @@
true
- 9.0.3
10.0.0-preview.2.25163.2
- 9.0.3
9.3.0-preview.1.25161.3
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
-
- runtime; build; native; contentfiles; analyzers; buildtransitive
- all
-
-
-
-
-
+
-
-
+
+
+
+
+ runtime; build; native; contentfiles; analyzers; buildtransitive
+ all
+
-
-
-
+
+
+
+
+
+
+
diff --git a/README.md b/README.md
index 76cae4ee1..364354bcf 100644
--- a/README.md
+++ b/README.md
@@ -83,13 +83,24 @@ It includes a simple echo tool as an example (this is included in the same file
the employed overload of `WithTools` examines the current assembly for classes with the `McpServerToolType` attribute, and registers all methods with the
`McpTool` attribute as tools.)
+```
+dotnet add package ModelContextProtocol --prerelease
+dotnet add package Microsoft.Extensions.Hosting
+```
+
```csharp
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
+using Microsoft.Extensions.Logging;
using ModelContextProtocol.Server;
using System.ComponentModel;
-var builder = Host.CreateEmptyApplicationBuilder(settings: null);
+var builder = Host.CreateApplicationBuilder(args);
+builder.Logging.AddConsole(consoleLogOptions =>
+{
+ // Configure all logs to go to stderr
+ consoleLogOptions.LogToStandardErrorThreshold = LogLevel.Trace;
+});
builder.Services
.AddMcpServer()
.WithStdioServerTransport()
diff --git a/samples/AspNetCoreSseServer/AspNetCoreSseServer.csproj b/samples/AspNetCoreSseServer/AspNetCoreSseServer.csproj
index c17cf9c45..94a5ccdb9 100644
--- a/samples/AspNetCoreSseServer/AspNetCoreSseServer.csproj
+++ b/samples/AspNetCoreSseServer/AspNetCoreSseServer.csproj
@@ -4,6 +4,7 @@
net9.0
enable
enable
+ true
diff --git a/samples/AspNetCoreSseServer/Program.cs b/samples/AspNetCoreSseServer/Program.cs
index 774957e88..306a6e8f7 100644
--- a/samples/AspNetCoreSseServer/Program.cs
+++ b/samples/AspNetCoreSseServer/Program.cs
@@ -1,7 +1,10 @@
-using ModelContextProtocol.AspNetCore;
+using TestServerWithHosting.Tools;
var builder = WebApplication.CreateBuilder(args);
-builder.Services.AddMcpServer().WithToolsFromAssembly();
+builder.Services.AddMcpServer()
+ .WithTools()
+ .WithTools();
+
var app = builder.Build();
app.MapMcp();
diff --git a/samples/AspNetCoreSseServer/Tools/EchoTool.cs b/samples/AspNetCoreSseServer/Tools/EchoTool.cs
index 636b4063a..7913b73e4 100644
--- a/samples/AspNetCoreSseServer/Tools/EchoTool.cs
+++ b/samples/AspNetCoreSseServer/Tools/EchoTool.cs
@@ -4,7 +4,7 @@
namespace TestServerWithHosting.Tools;
[McpServerToolType]
-public static class EchoTool
+public sealed class EchoTool
{
[McpServerTool, Description("Echoes the input back to the client.")]
public static string Echo(string message)
diff --git a/samples/AspNetCoreSseServer/Tools/SampleLlmTool.cs b/samples/AspNetCoreSseServer/Tools/SampleLlmTool.cs
index 4f175a453..4fbca594a 100644
--- a/samples/AspNetCoreSseServer/Tools/SampleLlmTool.cs
+++ b/samples/AspNetCoreSseServer/Tools/SampleLlmTool.cs
@@ -8,7 +8,7 @@ namespace TestServerWithHosting.Tools;
/// This tool uses dependency injection and async method
///
[McpServerToolType]
-public static class SampleLlmTool
+public sealed class SampleLlmTool
{
[McpServerTool(Name = "sampleLLM"), Description("Samples from an LLM using MCP's sampling feature")]
public static async Task SampleLLM(
diff --git a/samples/ChatWithTools/ChatWithTools.csproj b/samples/ChatWithTools/ChatWithTools.csproj
index af8fac198..8e08a455d 100644
--- a/samples/ChatWithTools/ChatWithTools.csproj
+++ b/samples/ChatWithTools/ChatWithTools.csproj
@@ -5,6 +5,10 @@
net8.0
enable
enable
+
diff --git a/samples/QuickstartClient/Program.cs b/samples/QuickstartClient/Program.cs
index 364c2b870..1ecd40c25 100644
--- a/samples/QuickstartClient/Program.cs
+++ b/samples/QuickstartClient/Program.cs
@@ -5,7 +5,7 @@
using ModelContextProtocol.Client;
using ModelContextProtocol.Protocol.Transport;
-var builder = Host.CreateEmptyApplicationBuilder(settings: null);
+var builder = Host.CreateApplicationBuilder(args);
builder.Configuration
.AddEnvironmentVariables()
diff --git a/samples/QuickstartClient/QuickstartClient.csproj b/samples/QuickstartClient/QuickstartClient.csproj
index b820bedc1..b68f15e5f 100644
--- a/samples/QuickstartClient/QuickstartClient.csproj
+++ b/samples/QuickstartClient/QuickstartClient.csproj
@@ -6,6 +6,10 @@
enable
enable
a4e20a70-5009-4b81-b5b6-780b6d43e78e
+
@@ -15,6 +19,7 @@
+
diff --git a/samples/QuickstartWeatherServer/Program.cs b/samples/QuickstartWeatherServer/Program.cs
index a191cb163..301eeed5e 100644
--- a/samples/QuickstartWeatherServer/Program.cs
+++ b/samples/QuickstartWeatherServer/Program.cs
@@ -1,12 +1,19 @@
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
+using Microsoft.Extensions.Logging;
+using QuickstartWeatherServer.Tools;
using System.Net.Http.Headers;
-var builder = Host.CreateEmptyApplicationBuilder(settings: null);
+var builder = Host.CreateApplicationBuilder(args);
builder.Services.AddMcpServer()
.WithStdioServerTransport()
- .WithToolsFromAssembly();
+ .WithTools();
+
+builder.Logging.AddConsole(options =>
+{
+ options.LogToStandardErrorThreshold = LogLevel.Trace;
+});
builder.Services.AddSingleton(_ =>
{
diff --git a/samples/QuickstartWeatherServer/QuickstartWeatherServer.csproj b/samples/QuickstartWeatherServer/QuickstartWeatherServer.csproj
index 2e9154fd2..dc1108a8f 100644
--- a/samples/QuickstartWeatherServer/QuickstartWeatherServer.csproj
+++ b/samples/QuickstartWeatherServer/QuickstartWeatherServer.csproj
@@ -5,6 +5,7 @@
net8.0
enable
enable
+ true
diff --git a/samples/QuickstartWeatherServer/Tools/HttpClientExt.cs b/samples/QuickstartWeatherServer/Tools/HttpClientExt.cs
new file mode 100644
index 000000000..f7b2b5499
--- /dev/null
+++ b/samples/QuickstartWeatherServer/Tools/HttpClientExt.cs
@@ -0,0 +1,13 @@
+using System.Text.Json;
+
+namespace ModelContextProtocol;
+
+internal static class HttpClientExt
+{
+ public static async Task ReadJsonDocumentAsync(this HttpClient client, string requestUri)
+ {
+ using var response = await client.GetAsync(requestUri);
+ response.EnsureSuccessStatusCode();
+ return await JsonDocument.ParseAsync(await response.Content.ReadAsStreamAsync());
+ }
+}
\ No newline at end of file
diff --git a/samples/QuickstartWeatherServer/Tools/WeatherTools.cs b/samples/QuickstartWeatherServer/Tools/WeatherTools.cs
index 697b80952..8463e3501 100644
--- a/samples/QuickstartWeatherServer/Tools/WeatherTools.cs
+++ b/samples/QuickstartWeatherServer/Tools/WeatherTools.cs
@@ -1,3 +1,4 @@
+using ModelContextProtocol;
using ModelContextProtocol.Server;
using System.ComponentModel;
using System.Net.Http.Json;
@@ -6,14 +7,15 @@
namespace QuickstartWeatherServer.Tools;
[McpServerToolType]
-public static class WeatherTools
+public sealed class WeatherTools
{
[McpServerTool, Description("Get weather alerts for a US state.")]
public static async Task GetAlerts(
HttpClient client,
[Description("The US state to get alerts for.")] string state)
{
- var jsonElement = await client.GetFromJsonAsync($"/alerts/active/area/{state}");
+ using var jsonDocument = await client.ReadJsonDocumentAsync($"/alerts/active/area/{state}");
+ var jsonElement = jsonDocument.RootElement;
var alerts = jsonElement.GetProperty("features").EnumerateArray();
if (!alerts.Any())
@@ -40,7 +42,8 @@ public static async Task GetForecast(
[Description("Latitude of the location.")] double latitude,
[Description("Longitude of the location.")] double longitude)
{
- var jsonElement = await client.GetFromJsonAsync($"/points/{latitude},{longitude}");
+ using var jsonDocument = await client.ReadJsonDocumentAsync($"/points/{latitude},{longitude}");
+ var jsonElement = jsonDocument.RootElement;
var periods = jsonElement.GetProperty("properties").GetProperty("periods").EnumerateArray();
return string.Join("\n---\n", periods.Select(period => $"""
diff --git a/samples/TestServerWithHosting/Program.cs b/samples/TestServerWithHosting/Program.cs
index ee009084b..1ab6fc7a2 100644
--- a/samples/TestServerWithHosting/Program.cs
+++ b/samples/TestServerWithHosting/Program.cs
@@ -1,6 +1,7 @@
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Serilog;
+using TestServerWithHosting.Tools;
Log.Logger = new LoggerConfiguration()
.MinimumLevel.Verbose() // Capture all log levels
@@ -19,7 +20,8 @@
builder.Services.AddSerilog();
builder.Services.AddMcpServer()
.WithStdioServerTransport()
- .WithToolsFromAssembly();
+ .WithTools()
+ .WithTools();
var app = builder.Build();
diff --git a/samples/TestServerWithHosting/TestServerWithHosting.csproj b/samples/TestServerWithHosting/TestServerWithHosting.csproj
index 137d753bb..0f3918d95 100644
--- a/samples/TestServerWithHosting/TestServerWithHosting.csproj
+++ b/samples/TestServerWithHosting/TestServerWithHosting.csproj
@@ -5,6 +5,10 @@
net9.0
enable
enable
+
@@ -13,6 +17,9 @@
+
diff --git a/samples/TestServerWithHosting/Tools/EchoTool.cs b/samples/TestServerWithHosting/Tools/EchoTool.cs
index 636b4063a..7913b73e4 100644
--- a/samples/TestServerWithHosting/Tools/EchoTool.cs
+++ b/samples/TestServerWithHosting/Tools/EchoTool.cs
@@ -4,7 +4,7 @@
namespace TestServerWithHosting.Tools;
[McpServerToolType]
-public static class EchoTool
+public sealed class EchoTool
{
[McpServerTool, Description("Echoes the input back to the client.")]
public static string Echo(string message)
diff --git a/samples/TestServerWithHosting/Tools/SampleLlmTool.cs b/samples/TestServerWithHosting/Tools/SampleLlmTool.cs
index b1a0353d4..3539b4cdc 100644
--- a/samples/TestServerWithHosting/Tools/SampleLlmTool.cs
+++ b/samples/TestServerWithHosting/Tools/SampleLlmTool.cs
@@ -8,7 +8,7 @@ namespace TestServerWithHosting.Tools;
/// This tool uses depenency injection and async method
///
[McpServerToolType]
-public static class SampleLlmTool
+public sealed class SampleLlmTool
{
[McpServerTool(Name = "sampleLLM"), Description("Samples from an LLM using MCP's sampling feature")]
public static async Task SampleLLM(
diff --git a/src/Common/Polyfills/System/Threading/Channels/ChannelExtensions.cs b/src/Common/Polyfills/System/Threading/Channels/ChannelExtensions.cs
new file mode 100644
index 000000000..89822eff1
--- /dev/null
+++ b/src/Common/Polyfills/System/Threading/Channels/ChannelExtensions.cs
@@ -0,0 +1,17 @@
+using System.Runtime.CompilerServices;
+
+namespace System.Threading.Channels;
+
+internal static class ChannelExtensions
+{
+ public static async IAsyncEnumerable ReadAllAsync(this ChannelReader reader, [EnumeratorCancellation] CancellationToken cancellationToken)
+ {
+ while (await reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false))
+ {
+ while (reader.TryRead(out var item))
+ {
+ yield return item;
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/Directory.Build.props b/src/Directory.Build.props
index 1d81cc273..f6f35ebd5 100644
--- a/src/Directory.Build.props
+++ b/src/Directory.Build.props
@@ -6,7 +6,7 @@
https://github.com/modelcontextprotocol/csharp-sdk
git
0.1.0
- preview.4
+ preview.5
ModelContextProtocolOfficial
© Anthropic and Contributors.
ModelContextProtocol;mcp;ai;llm
diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs
index 9ad1848c4..bb96356b3 100644
--- a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs
+++ b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs
@@ -1,5 +1,4 @@
-using Microsoft.AspNetCore.Builder;
-using Microsoft.AspNetCore.Http;
+using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Routing;
using Microsoft.AspNetCore.WebUtilities;
using Microsoft.Extensions.DependencyInjection;
@@ -12,7 +11,7 @@
using System.Collections.Concurrent;
using System.Security.Cryptography;
-namespace ModelContextProtocol.AspNetCore;
+namespace Microsoft.AspNetCore.Builder;
///
/// Extension methods for to add MCP endpoints.
@@ -40,7 +39,7 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo
var requestAborted = context.RequestAborted;
response.Headers.ContentType = "text/event-stream";
- response.Headers.CacheControl = "no-cache";
+ response.Headers.CacheControl = "no-store";
var sessionId = MakeNewSessionId();
await using var transport = new SseResponseStreamTransport(response.Body, $"/message?sessionId={sessionId}");
@@ -48,15 +47,15 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo
{
throw new Exception($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created.");
}
- await using var server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider);
try
{
var transportTask = transport.RunAsync(cancellationToken: requestAborted);
- runSession ??= RunSession;
+ await using var server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider);
try
{
+ runSession ??= RunSession;
await runSession(context, server, requestAborted);
}
finally
@@ -86,11 +85,11 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo
if (!_sessions.TryGetValue(sessionId.ToString(), out var transport))
{
- await Results.BadRequest($"Session {sessionId} not found.").ExecuteAsync(context);
+ await Results.BadRequest($"Session ID not found.").ExecuteAsync(context);
return;
}
- var message = await context.Request.ReadFromJsonAsync(McpJsonUtilities.DefaultOptions, context.RequestAborted);
+ var message = (IJsonRpcMessage?)await context.Request.ReadFromJsonAsync(McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage)), context.RequestAborted);
if (message is null)
{
await Results.BadRequest("No message in request body.").ExecuteAsync(context);
diff --git a/src/ModelContextProtocol.AspNetCore/ModelContextProtocol.AspNetCore.csproj b/src/ModelContextProtocol.AspNetCore/ModelContextProtocol.AspNetCore.csproj
index 5dd10dbf1..1bc4feb01 100644
--- a/src/ModelContextProtocol.AspNetCore/ModelContextProtocol.AspNetCore.csproj
+++ b/src/ModelContextProtocol.AspNetCore/ModelContextProtocol.AspNetCore.csproj
@@ -1,7 +1,7 @@
- net8.0
+ net9.0;net8.0
enable
enable
true
@@ -9,6 +9,7 @@
ModelContextProtocol.AspNetCore
ASP.NET Core extensions for the C# Model Context Protocol (MCP) SDK.
README.md
+ true
diff --git a/src/ModelContextProtocol.AspNetCore/README.md b/src/ModelContextProtocol.AspNetCore/README.md
index dd7a59094..457321d09 100644
--- a/src/ModelContextProtocol.AspNetCore/README.md
+++ b/src/ModelContextProtocol.AspNetCore/README.md
@@ -23,15 +23,15 @@ To get started, install the package from NuGet
```
dotnet new web
-dotnet add package ModelContextProtocol.AspNetcore --prerelease
+dotnet add package ModelContextProtocol.AspNetCore --prerelease
```
## Getting Started
```csharp
// Program.cs
-using ModelContextProtocol;
-using ModelContextProtocol.AspNetCore;
+using ModelContextProtocol.Server;
+using System.ComponentModel;
var builder = WebApplication.CreateBuilder(args);
builder.WebHost.ConfigureKestrel(options =>
diff --git a/src/ModelContextProtocol/Client/IMcpClient.cs b/src/ModelContextProtocol/Client/IMcpClient.cs
index 6761c4ef9..357ce3843 100644
--- a/src/ModelContextProtocol/Client/IMcpClient.cs
+++ b/src/ModelContextProtocol/Client/IMcpClient.cs
@@ -1,12 +1,11 @@
-using ModelContextProtocol.Protocol.Messages;
-using ModelContextProtocol.Protocol.Types;
+using ModelContextProtocol.Protocol.Types;
namespace ModelContextProtocol.Client;
///
/// Represents an instance of an MCP client connecting to a specific server.
///
-public interface IMcpClient : IAsyncDisposable
+public interface IMcpClient : IMcpEndpoint
{
///
/// Gets the capabilities supported by the server.
@@ -24,40 +23,4 @@ public interface IMcpClient : IAsyncDisposable
/// It can be thought of like a "hint" to the model. For example, this information MAY be added to the system prompt.
///
string? ServerInstructions { get; }
-
- ///
- /// Adds a handler for server notifications of a specific method.
- ///
- /// The notification method to handle.
- /// The async handler function to process notifications.
- ///
- ///
- /// Each method may have multiple handlers. Adding a handler for a method that already has one
- /// will not replace the existing handler.
- ///
- ///
- /// provides constants for common notification methods.
- ///
- ///
- void AddNotificationHandler(string method, Func handler);
-
- ///
- /// Sends a generic JSON-RPC request to the server.
- ///
- /// The expected response type.
- /// The JSON-RPC request to send.
- /// A token to cancel the operation.
- /// A task containing the server's response.
- ///
- /// It is recommended to use the capability-specific methods that use this one in their implementation.
- /// Use this method for custom requests or those not yet covered explicitly.
- ///
- Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) where TResult : class;
-
- ///
- /// Sends a message to the server.
- ///
- /// The message.
- /// A token to cancel the operation.
- Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default);
}
\ No newline at end of file
diff --git a/src/ModelContextProtocol/Client/McpClient.cs b/src/ModelContextProtocol/Client/McpClient.cs
index b326f3c58..cf27a6b52 100644
--- a/src/ModelContextProtocol/Client/McpClient.cs
+++ b/src/ModelContextProtocol/Client/McpClient.cs
@@ -10,7 +10,7 @@
namespace ModelContextProtocol.Client;
///
-internal sealed class McpClient : McpJsonRpcEndpoint, IMcpClient
+internal sealed class McpClient : McpEndpoint, IMcpClient
{
private readonly IClientTransport _clientTransport;
private readonly McpClientOptions _options;
@@ -40,9 +40,14 @@ public McpClient(IClientTransport clientTransport, McpClientOptions options, Mcp
throw new InvalidOperationException($"Sampling capability was set but it did not provide a handler.");
}
- SetRequestHandler(
+ SetRequestHandler(
RequestMethods.SamplingCreateMessage,
- (request, ct) => samplingHandler(request, ct));
+ (request, cancellationToken) => samplingHandler(
+ request,
+ request?.Meta?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance,
+ cancellationToken),
+ McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams,
+ McpJsonUtilities.JsonContext.Default.CreateMessageResult);
}
if (options.Capabilities?.Roots is { } rootsCapability)
@@ -52,9 +57,11 @@ public McpClient(IClientTransport clientTransport, McpClientOptions options, Mcp
throw new InvalidOperationException($"Roots capability was set but it did not provide a handler.");
}
- SetRequestHandler(
+ SetRequestHandler(
RequestMethods.RootsList,
- (request, ct) => rootsHandler(request, ct));
+ rootsHandler,
+ McpJsonUtilities.JsonContext.Default.ListRootsRequestParams,
+ McpJsonUtilities.JsonContext.Default.ListRootsResult);
}
}
@@ -79,9 +86,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default)
{
// Connect transport
_sessionTransport = await _clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false);
- // We don't want the ConnectAsync token to cancel the session after we've successfully connected.
- // The base class handles cleaning up the session in DisposeAsync without our help.
- StartSession(_sessionTransport, fullSessionCancellationToken: CancellationToken.None);
+ StartSession(_sessionTransport);
// Perform initialization sequence
using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
@@ -90,18 +95,17 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default)
try
{
// Send initialize request
- var initializeResponse = await SendRequestAsync(
- new JsonRpcRequest
+ var initializeResponse = await this.SendRequestAsync(
+ RequestMethods.Initialize,
+ new InitializeRequestParams
{
- Method = RequestMethods.Initialize,
- Params = new InitializeRequestParams()
- {
- ProtocolVersion = _options.ProtocolVersion,
- Capabilities = _options.Capabilities ?? new ClientCapabilities(),
- ClientInfo = _options.ClientInfo
- }
+ ProtocolVersion = _options.ProtocolVersion,
+ Capabilities = _options.Capabilities ?? new ClientCapabilities(),
+ ClientInfo = _options.ClientInfo
},
- initializationCts.Token).ConfigureAwait(false);
+ McpJsonUtilities.JsonContext.Default.InitializeRequestParams,
+ McpJsonUtilities.JsonContext.Default.InitializeResult,
+ cancellationToken: initializationCts.Token).ConfigureAwait(false);
// Store server information
_logger.ServerCapabilitiesReceived(EndpointName,
diff --git a/src/ModelContextProtocol/Client/McpClientExtensions.cs b/src/ModelContextProtocol/Client/McpClientExtensions.cs
index 36742b14a..c98f76191 100644
--- a/src/ModelContextProtocol/Client/McpClientExtensions.cs
+++ b/src/ModelContextProtocol/Client/McpClientExtensions.cs
@@ -1,35 +1,16 @@
-using ModelContextProtocol.Protocol.Messages;
+using Microsoft.Extensions.AI;
+using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Utils;
using ModelContextProtocol.Utils.Json;
-using Microsoft.Extensions.AI;
-using System.Text.Json;
using System.Runtime.CompilerServices;
+using System.Text.Json;
namespace ModelContextProtocol.Client;
-///
-/// Provides extensions for operating on MCP clients.
-///
+/// Provides extension methods for interacting with an .
public static class McpClientExtensions
{
- ///
- /// Sends a notification to the server with parameters.
- ///
- /// The client.
- /// The notification method name.
- /// The parameters to send with the notification.
- /// A token to cancel the operation.
- public static Task SendNotificationAsync(this IMcpClient client, string method, object? parameters = null, CancellationToken cancellationToken = default)
- {
- Throw.IfNull(client);
- Throw.IfNullOrWhiteSpace(method);
-
- return client.SendMessageAsync(
- new JsonRpcNotification { Method = method, Params = parameters },
- cancellationToken);
- }
-
///
/// Sends a ping request to verify server connectivity.
///
@@ -40,34 +21,46 @@ public static Task PingAsync(this IMcpClient client, CancellationToken cancellat
{
Throw.IfNull(client);
- return client.SendRequestAsync(
- CreateRequest(RequestMethods.Ping, null),
- cancellationToken);
+ return client.SendRequestAsync(
+ RequestMethods.Ping,
+ parameters: null,
+ McpJsonUtilities.JsonContext.Default.Object!,
+ McpJsonUtilities.JsonContext.Default.Object,
+ cancellationToken: cancellationToken);
}
///
/// Retrieves a list of available tools from the server.
///
/// The client.
+ /// The serializer options governing tool parameter serialization.
/// A token to cancel the operation.
/// A list of all available tools.
public static async Task> ListToolsAsync(
- this IMcpClient client, CancellationToken cancellationToken = default)
+ this IMcpClient client,
+ JsonSerializerOptions? serializerOptions = null,
+ CancellationToken cancellationToken = default)
{
Throw.IfNull(client);
+ serializerOptions ??= McpJsonUtilities.DefaultOptions;
+ serializerOptions.MakeReadOnly();
+
List? tools = null;
string? cursor = null;
do
{
- var toolResults = await client.SendRequestAsync(
- CreateRequest(RequestMethods.ToolsList, CreateCursorDictionary(cursor)),
- cancellationToken).ConfigureAwait(false);
+ var toolResults = await client.SendRequestAsync(
+ RequestMethods.ToolsList,
+ CreateCursorDictionary(cursor)!,
+ McpJsonUtilities.JsonContext.Default.DictionaryStringObject,
+ McpJsonUtilities.JsonContext.Default.ListToolsResult,
+ cancellationToken: cancellationToken).ConfigureAwait(false);
tools ??= new List(toolResults.Tools.Count);
foreach (var tool in toolResults.Tools)
{
- tools.Add(new McpClientTool(client, tool));
+ tools.Add(new McpClientTool(client, tool, serializerOptions));
}
cursor = toolResults.NextCursor;
@@ -81,6 +74,7 @@ public static async Task> ListToolsAsync(
/// Creates an enumerable for asynchronously enumerating all available tools from the server.
///
/// The client.
+ /// The serializer options governing tool parameter serialization.
/// A token to cancel the operation.
/// An asynchronous sequence of all available tools.
///
@@ -88,20 +82,28 @@ public static async Task> ListToolsAsync(
/// will result in requerying the server and yielding the sequence of available tools.
///
public static async IAsyncEnumerable EnumerateToolsAsync(
- this IMcpClient client, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ this IMcpClient client,
+ JsonSerializerOptions? serializerOptions = null,
+ [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
Throw.IfNull(client);
+ serializerOptions ??= McpJsonUtilities.DefaultOptions;
+ serializerOptions.MakeReadOnly();
+
string? cursor = null;
do
{
- var toolResults = await client.SendRequestAsync(
- CreateRequest(RequestMethods.ToolsList, CreateCursorDictionary(cursor)),
- cancellationToken).ConfigureAwait(false);
+ var toolResults = await client.SendRequestAsync(
+ RequestMethods.ToolsList,
+ CreateCursorDictionary(cursor)!,
+ McpJsonUtilities.JsonContext.Default.DictionaryStringObject,
+ McpJsonUtilities.JsonContext.Default.ListToolsResult,
+ cancellationToken: cancellationToken).ConfigureAwait(false);
foreach (var tool in toolResults.Tools)
{
- yield return new McpClientTool(client, tool);
+ yield return new McpClientTool(client, tool, serializerOptions);
}
cursor = toolResults.NextCursor;
@@ -124,9 +126,12 @@ public static async Task> ListPromptsAsync(
string? cursor = null;
do
{
- var promptResults = await client.SendRequestAsync(
- CreateRequest(RequestMethods.PromptsList, CreateCursorDictionary(cursor)),
- cancellationToken).ConfigureAwait(false);
+ var promptResults = await client.SendRequestAsync(
+ RequestMethods.PromptsList,
+ CreateCursorDictionary(cursor)!,
+ McpJsonUtilities.JsonContext.Default.DictionaryStringObject,
+ McpJsonUtilities.JsonContext.Default.ListPromptsResult,
+ cancellationToken: cancellationToken).ConfigureAwait(false);
prompts ??= new List(promptResults.Prompts.Count);
foreach (var prompt in promptResults.Prompts)
@@ -159,9 +164,12 @@ public static async IAsyncEnumerable EnumeratePromptsAsync(
string? cursor = null;
do
{
- var promptResults = await client.SendRequestAsync(
- CreateRequest(RequestMethods.PromptsList, CreateCursorDictionary(cursor)),
- cancellationToken).ConfigureAwait(false);
+ var promptResults = await client.SendRequestAsync(
+ RequestMethods.PromptsList,
+ CreateCursorDictionary(cursor)!,
+ McpJsonUtilities.JsonContext.Default.DictionaryStringObject,
+ McpJsonUtilities.JsonContext.Default.ListPromptsResult,
+ cancellationToken: cancellationToken).ConfigureAwait(false);
foreach (var prompt in promptResults.Prompts)
{
@@ -179,17 +187,29 @@ public static async IAsyncEnumerable EnumeratePromptsAsync(
/// The client.
/// The name of the prompt to retrieve
/// Optional arguments for the prompt
+ /// The serialization options governing argument serialization.
/// A token to cancel the operation.
/// A task containing the prompt's content and messages.
public static Task GetPromptAsync(
- this IMcpClient client, string name, IReadOnlyDictionary? arguments = null, CancellationToken cancellationToken = default)
+ this IMcpClient client,
+ string name,
+ IReadOnlyDictionary? arguments = null,
+ JsonSerializerOptions? serializerOptions = null,
+ CancellationToken cancellationToken = default)
{
Throw.IfNull(client);
Throw.IfNullOrWhiteSpace(name);
+ serializerOptions ??= McpJsonUtilities.DefaultOptions;
+ serializerOptions.MakeReadOnly();
- return client.SendRequestAsync(
- CreateRequest(RequestMethods.PromptsGet, CreateParametersDictionary(name, arguments)),
- cancellationToken);
+ var parametersTypeInfo = serializerOptions.GetTypeInfo>();
+
+ return client.SendRequestAsync(
+ RequestMethods.PromptsGet,
+ CreateParametersDictionary(name, arguments),
+ parametersTypeInfo,
+ McpJsonUtilities.JsonContext.Default.GetPromptResult,
+ cancellationToken: cancellationToken);
}
///
@@ -208,9 +228,12 @@ public static async Task> ListResourceTemplatesAsync(
string? cursor = null;
do
{
- var templateResults = await client.SendRequestAsync(
- CreateRequest(RequestMethods.ResourcesTemplatesList, CreateCursorDictionary(cursor)),
- cancellationToken).ConfigureAwait(false);
+ var templateResults = await client.SendRequestAsync(
+ RequestMethods.ResourcesTemplatesList,
+ CreateCursorDictionary(cursor)!,
+ McpJsonUtilities.JsonContext.Default.DictionaryStringObject,
+ McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult,
+ cancellationToken: cancellationToken).ConfigureAwait(false);
if (templates is null)
{
@@ -246,9 +269,12 @@ public static async IAsyncEnumerable EnumerateResourceTemplate
string? cursor = null;
do
{
- var templateResults = await client.SendRequestAsync(
- CreateRequest(RequestMethods.ResourcesTemplatesList, CreateCursorDictionary(cursor)),
- cancellationToken).ConfigureAwait(false);
+ var templateResults = await client.SendRequestAsync(
+ RequestMethods.ResourcesTemplatesList,
+ CreateCursorDictionary(cursor)!,
+ McpJsonUtilities.JsonContext.Default.DictionaryStringObject,
+ McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult,
+ cancellationToken: cancellationToken).ConfigureAwait(false);
foreach (var template in templateResults.ResourceTemplates)
{
@@ -276,9 +302,12 @@ public static async Task> ListResourcesAsync(
string? cursor = null;
do
{
- var resourceResults = await client.SendRequestAsync(
- CreateRequest(RequestMethods.ResourcesList, CreateCursorDictionary(cursor)),
- cancellationToken).ConfigureAwait(false);
+ var resourceResults = await client.SendRequestAsync(
+ RequestMethods.ResourcesList,
+ CreateCursorDictionary(cursor)!,
+ McpJsonUtilities.JsonContext.Default.DictionaryStringObject,
+ McpJsonUtilities.JsonContext.Default.ListResourcesResult,
+ cancellationToken: cancellationToken).ConfigureAwait(false);
if (resources is null)
{
@@ -314,9 +343,12 @@ public static async IAsyncEnumerable EnumerateResourcesAsync(
string? cursor = null;
do
{
- var resourceResults = await client.SendRequestAsync(
- CreateRequest(RequestMethods.ResourcesList, CreateCursorDictionary(cursor)),
- cancellationToken).ConfigureAwait(false);
+ var resourceResults = await client.SendRequestAsync(
+ RequestMethods.ResourcesList,
+ CreateCursorDictionary(cursor)!,
+ McpJsonUtilities.JsonContext.Default.DictionaryStringObject,
+ McpJsonUtilities.JsonContext.Default.ListResourcesResult,
+ cancellationToken: cancellationToken).ConfigureAwait(false);
foreach (var resource in resourceResults.Resources)
{
@@ -340,9 +372,12 @@ public static Task ReadResourceAsync(
Throw.IfNull(client);
Throw.IfNullOrWhiteSpace(uri);
- return client.SendRequestAsync(
- CreateRequest(RequestMethods.ResourcesRead, new Dictionary() { ["uri"] = uri }),
- cancellationToken);
+ return client.SendRequestAsync(
+ RequestMethods.ResourcesRead,
+ new Dictionary { ["uri"] = uri },
+ McpJsonUtilities.JsonContext.Default.DictionaryStringObject,
+ McpJsonUtilities.JsonContext.Default.ReadResourceResult,
+ cancellationToken: cancellationToken);
}
///
@@ -364,13 +399,16 @@ public static Task GetCompletionAsync(this IMcpClient client, Re
throw new ArgumentException($"Invalid reference: {validationMessage}", nameof(reference));
}
- return client.SendRequestAsync(
- CreateRequest(RequestMethods.CompletionComplete, new Dictionary()
+ return client.SendRequestAsync(
+ RequestMethods.CompletionComplete,
+ new Dictionary
{
["ref"] = reference,
["argument"] = new Argument { Name = argumentName, Value = argumentValue }
- }),
- cancellationToken);
+ },
+ McpJsonUtilities.JsonContext.Default.DictionaryStringObject,
+ McpJsonUtilities.JsonContext.Default.CompleteResult,
+ cancellationToken: cancellationToken);
}
///
@@ -384,9 +422,12 @@ public static Task SubscribeToResourceAsync(this IMcpClient client, string uri,
Throw.IfNull(client);
Throw.IfNullOrWhiteSpace(uri);
- return client.SendRequestAsync(
- CreateRequest(RequestMethods.ResourcesSubscribe, new Dictionary() { ["uri"] = uri }),
- cancellationToken);
+ return client.SendRequestAsync(
+ RequestMethods.ResourcesSubscribe,
+ new Dictionary { ["uri"] = uri },
+ McpJsonUtilities.JsonContext.Default.DictionaryStringObject,
+ McpJsonUtilities.JsonContext.Default.EmptyResult,
+ cancellationToken: cancellationToken);
}
///
@@ -400,9 +441,12 @@ public static Task UnsubscribeFromResourceAsync(this IMcpClient client, string u
Throw.IfNull(client);
Throw.IfNullOrWhiteSpace(uri);
- return client.SendRequestAsync(
- CreateRequest(RequestMethods.ResourcesUnsubscribe, new Dictionary() { ["uri"] = uri }),
- cancellationToken);
+ return client.SendRequestAsync(
+ RequestMethods.ResourcesUnsubscribe,
+ new Dictionary { ["uri"] = uri },
+ McpJsonUtilities.JsonContext.Default.DictionaryStringObject,
+ McpJsonUtilities.JsonContext.Default.EmptyResult,
+ cancellationToken: cancellationToken);
}
///
@@ -411,17 +455,29 @@ public static Task UnsubscribeFromResourceAsync(this IMcpClient client, string u
/// The client.
/// The name of the tool to call.
/// Optional arguments for the tool.
+ /// The serialization options governing argument serialization.
/// A token to cancel the operation.
/// A task containing the tool's response.
public static Task CallToolAsync(
- this IMcpClient client, string toolName, IReadOnlyDictionary? arguments = null, CancellationToken cancellationToken = default)
+ this IMcpClient client,
+ string toolName,
+ IReadOnlyDictionary? arguments = null,
+ JsonSerializerOptions? serializerOptions = null,
+ CancellationToken cancellationToken = default)
{
Throw.IfNull(client);
Throw.IfNull(toolName);
+ serializerOptions ??= McpJsonUtilities.DefaultOptions;
+ serializerOptions.MakeReadOnly();
+
+ var parametersTypeInfo = serializerOptions.GetTypeInfo>();
- return client.SendRequestAsync(
- CreateRequest(RequestMethods.ToolsCall, CreateParametersDictionary(toolName, arguments)),
- cancellationToken);
+ return client.SendRequestAsync(
+ RequestMethods.ToolsCall,
+ CreateParametersDictionary(toolName, arguments),
+ parametersTypeInfo,
+ McpJsonUtilities.JsonContext.Default.CallToolResponse,
+ cancellationToken: cancellationToken);
}
///
@@ -531,17 +587,33 @@ internal static CreateMessageResult ToCreateMessageResult(this ChatResponse chat
///
/// The with which to satisfy sampling requests.
/// The created handler delegate.
- public static Func> CreateSamplingHandler(this IChatClient chatClient)
+ public static Func, CancellationToken, Task> CreateSamplingHandler(
+ this IChatClient chatClient)
{
Throw.IfNull(chatClient);
- return async (requestParams, cancellationToken) =>
+ return async (requestParams, progress, cancellationToken) =>
{
Throw.IfNull(requestParams);
var (messages, options) = requestParams.ToChatClientArguments();
- var response = await chatClient.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false);
- return response.ToCreateMessageResult();
+ var progressToken = requestParams.Meta?.ProgressToken;
+
+ List updates = [];
+ await foreach (var update in chatClient.GetStreamingResponseAsync(messages, options, cancellationToken))
+ {
+ updates.Add(update);
+
+ if (progressToken is not null)
+ {
+ progress.Report(new()
+ {
+ Progress = updates.Count,
+ });
+ }
+ }
+
+ return updates.ToChatResponse().ToCreateMessageResult();
};
}
@@ -555,18 +627,14 @@ public static Task SetLoggingLevel(this IMcpClient client, LoggingLevel level, C
{
Throw.IfNull(client);
- return client.SendRequestAsync(
- CreateRequest(RequestMethods.LoggingSetLevel, new Dictionary() { ["level"] = level }),
- cancellationToken);
+ return client.SendRequestAsync(
+ RequestMethods.LoggingSetLevel,
+ new Dictionary { ["level"] = level },
+ McpJsonUtilities.JsonContext.Default.DictionaryStringObject,
+ McpJsonUtilities.JsonContext.Default.EmptyResult,
+ cancellationToken: cancellationToken);
}
- private static JsonRpcRequest CreateRequest(string method, IReadOnlyDictionary? parameters) =>
- new()
- {
- Method = method,
- Params = parameters
- };
-
private static Dictionary? CreateCursorDictionary(string? cursor) =>
cursor != null ? new() { ["cursor"] = cursor } : null;
@@ -585,36 +653,4 @@ private static JsonRpcRequest CreateRequest(string method, IReadOnlyDictionaryProvides an AI function that calls a tool through .
- private sealed class McpAIFunction(IMcpClient client, Tool tool) : AIFunction
- {
- ///
- public override string Name => tool.Name;
-
- ///
- public override string Description => tool.Description ?? string.Empty;
-
- ///
- public override JsonElement JsonSchema => tool.InputSchema;
-
- ///
- public override JsonSerializerOptions JsonSerializerOptions => McpJsonUtilities.DefaultOptions;
-
- ///
- protected async override Task
/// Optional arguments for the prompt
+ /// The serialization options governing argument serialization.
/// A token to cancel the operation.
/// A task containing the prompt's content and messages.
public async ValueTask GetAsync(
IEnumerable>? arguments = null,
+ JsonSerializerOptions? serializerOptions = null,
CancellationToken cancellationToken = default)
{
IReadOnlyDictionary? argDict =
arguments as IReadOnlyDictionary ??
arguments?.ToDictionary();
- return await _client.GetPromptAsync(ProtocolPrompt.Name, argDict, cancellationToken).ConfigureAwait(false);
+ return await _client.GetPromptAsync(ProtocolPrompt.Name, argDict, serializerOptions, cancellationToken: cancellationToken).ConfigureAwait(false);
}
/// Gets the name of the prompt.
diff --git a/src/ModelContextProtocol/Client/McpClientTool.cs b/src/ModelContextProtocol/Client/McpClientTool.cs
index 10ccd81a6..58cb02d53 100644
--- a/src/ModelContextProtocol/Client/McpClientTool.cs
+++ b/src/ModelContextProtocol/Client/McpClientTool.cs
@@ -1,5 +1,6 @@
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Utils.Json;
+using ModelContextProtocol.Utils;
using Microsoft.Extensions.AI;
using System.Text.Json;
@@ -9,27 +10,57 @@ namespace ModelContextProtocol.Client;
public sealed class McpClientTool : AIFunction
{
private readonly IMcpClient _client;
+ private readonly string _name;
+ private readonly string _description;
- internal McpClientTool(IMcpClient client, Tool tool)
+ internal McpClientTool(IMcpClient client, Tool tool, JsonSerializerOptions serializerOptions, string? name = null, string? description = null)
{
_client = client;
ProtocolTool = tool;
+ JsonSerializerOptions = serializerOptions;
+ _name = name ?? tool.Name;
+ _description = description ?? tool.Description ?? string.Empty;
+ }
+
+ ///
+ /// Creates a new instance of the tool with the specified name.
+ /// This is useful for optimizing the tool name for specific models or for prefixing the tool name with a (usually server-derived) namespace to avoid conflicts.
+ /// The server will still be called with the original tool name, so no mapping is required.
+ ///
+ /// The model-facing name to give the tool.
+ /// Copy of this McpClientTool with the provided name
+ public McpClientTool WithName(string name)
+ {
+ return new McpClientTool(_client, ProtocolTool, JsonSerializerOptions, name, _description);
+ }
+
+ ///
+ /// Creates a new instance of the tool with the specified description.
+ /// This can be used to provide modified or additional (e.g. examples) context to the model about the tool.
+ /// This will in general require a hard-coded mapping in the client.
+ /// It is not recommended to use this without running evaluations to ensure the model actually benefits from the custom description.
+ ///
+ /// The description to give the tool.
+ /// Copy of this McpClientTool with the provided description
+ public McpClientTool WithDescription(string description)
+ {
+ return new McpClientTool(_client, ProtocolTool, JsonSerializerOptions, _name, description);
}
/// Gets the protocol type for this instance.
public Tool ProtocolTool { get; }
///
- public override string Name => ProtocolTool.Name;
+ public override string Name => _name;
///
- public override string Description => ProtocolTool.Description ?? string.Empty;
+ public override string Description => _description;
///
public override JsonElement JsonSchema => ProtocolTool.InputSchema;
///
- public override JsonSerializerOptions JsonSerializerOptions => McpJsonUtilities.DefaultOptions;
+ public override JsonSerializerOptions JsonSerializerOptions { get; }
///
protected async override Task
public string? Location { get; set; }
- ///
- /// Arguments (if any) to pass to the executable.
- ///
- public string[]? Arguments { get; init; }
-
///
/// Additional transport-specific configuration.
///
diff --git a/src/ModelContextProtocol/Diagnostics.cs b/src/ModelContextProtocol/Diagnostics.cs
new file mode 100644
index 000000000..5b4e31f4d
--- /dev/null
+++ b/src/ModelContextProtocol/Diagnostics.cs
@@ -0,0 +1,37 @@
+using System.Diagnostics;
+using System.Diagnostics.Metrics;
+
+namespace ModelContextProtocol;
+
+internal static class Diagnostics
+{
+ internal static ActivitySource ActivitySource { get; } = new("Experimental.ModelContextProtocol");
+
+ internal static Meter Meter { get; } = new("Experimental.ModelContextProtocol");
+
+ internal static Histogram CreateDurationHistogram(string name, string description, bool longBuckets) =>
+ Diagnostics.Meter.CreateHistogram(name, "s", description
+#if NET9_0_OR_GREATER
+ , advice: longBuckets ? LongSecondsBucketBoundaries : ShortSecondsBucketBoundaries
+#endif
+ );
+
+#if NET9_0_OR_GREATER
+ ///
+ /// Follows boundaries from http.server.request.duration/http.client.request.duration
+ ///
+ private static InstrumentAdvice ShortSecondsBucketBoundaries { get; } = new()
+ {
+ HistogramBucketBoundaries = [0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.25, 0.5, 0.75, 1, 2.5, 5, 7.5, 10],
+ };
+
+ ///
+ /// Not based on a standard. Larger bucket sizes for longer lasting operations, e.g. HTTP connection duration.
+ /// See https://github.com/open-telemetry/semantic-conventions/issues/336
+ ///
+ private static InstrumentAdvice LongSecondsBucketBoundaries { get; } = new()
+ {
+ HistogramBucketBoundaries = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1, 2, 5, 10, 30, 60, 120, 300],
+ };
+#endif
+}
diff --git a/src/ModelContextProtocol/IMcpEndpoint.cs b/src/ModelContextProtocol/IMcpEndpoint.cs
new file mode 100644
index 000000000..6643e02a7
--- /dev/null
+++ b/src/ModelContextProtocol/IMcpEndpoint.cs
@@ -0,0 +1,34 @@
+using ModelContextProtocol.Protocol.Messages;
+
+namespace ModelContextProtocol;
+
+/// Represents a client or server MCP endpoint.
+public interface IMcpEndpoint : IAsyncDisposable
+{
+ /// Sends a JSON-RPC request to the connected endpoint.
+ /// The JSON-RPC request to send.
+ /// A token to cancel the operation.
+ /// A task containing the client's response.
+ Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default);
+
+ /// Sends a message to the connected endpoint.
+ /// The message.
+ /// A token to cancel the operation.
+ Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default);
+
+ ///
+ /// Adds a handler for server notifications of a specific method.
+ ///
+ /// The notification method to handle.
+ /// The async handler function to process notifications.
+ ///
+ ///
+ /// Each method may have multiple handlers. Adding a handler for a method that already has one
+ /// will not replace the existing handler.
+ ///
+ ///
+ /// provides constants for common notification methods.
+ ///
+ ///
+ void AddNotificationHandler(string method, Func handler);
+}
diff --git a/src/ModelContextProtocol/Logging/Log.cs b/src/ModelContextProtocol/Logging/Log.cs
index d22c5d664..9b5d44bdf 100644
--- a/src/ModelContextProtocol/Logging/Log.cs
+++ b/src/ModelContextProtocol/Logging/Log.cs
@@ -77,9 +77,6 @@ internal static partial class Log
[LoggerMessage(Level = LogLevel.Information, Message = "Request response received for {endpointName} with method {method}")]
internal static partial void RequestResponseReceived(this ILogger logger, string endpointName, string method);
- [LoggerMessage(Level = LogLevel.Error, Message = "Request response type conversion error for {endpointName} with method {method}: expected {expectedType}")]
- internal static partial void RequestResponseTypeConversionError(this ILogger logger, string endpointName, string method, Type expectedType);
-
[LoggerMessage(Level = LogLevel.Error, Message = "Request invalid response type for {endpointName} with method {method}")]
internal static partial void RequestInvalidResponseType(this ILogger logger, string endpointName, string method);
@@ -98,8 +95,8 @@ internal static partial class Log
[LoggerMessage(Level = LogLevel.Information, Message = "Creating process for transport for {endpointName} with command {command}, arguments {arguments}, environment {environment}, working directory {workingDirectory}, shutdown timeout {shutdownTimeout}")]
internal static partial void CreateProcessForTransport(this ILogger logger, string endpointName, string command, string? arguments, string environment, string workingDirectory, string shutdownTimeout);
- [LoggerMessage(Level = LogLevel.Error, Message = "Transport for {endpointName} error: {data}")]
- internal static partial void TransportError(this ILogger logger, string endpointName, string data);
+ [LoggerMessage(Level = LogLevel.Information, Message = "Transport for {endpointName} received stderr log: {data}")]
+ internal static partial void ReadStderr(this ILogger logger, string endpointName, string data);
[LoggerMessage(Level = LogLevel.Information, Message = "Transport process start failed for {endpointName}")]
internal static partial void TransportProcessStartFailed(this ILogger logger, string endpointName);
diff --git a/src/ModelContextProtocol/McpEndpointExtensions.cs b/src/ModelContextProtocol/McpEndpointExtensions.cs
new file mode 100644
index 000000000..d2a6a952e
--- /dev/null
+++ b/src/ModelContextProtocol/McpEndpointExtensions.cs
@@ -0,0 +1,168 @@
+using ModelContextProtocol.Protocol.Messages;
+using ModelContextProtocol.Utils;
+using ModelContextProtocol.Utils.Json;
+using System.Text.Json;
+using System.Text.Json.Nodes;
+using System.Text.Json.Serialization.Metadata;
+
+namespace ModelContextProtocol;
+
+/// Provides extension methods for interacting with an .
+public static class McpEndpointExtensions
+{
+ ///
+ /// Sends a JSON-RPC request and attempts to deserialize the result to .
+ ///
+ /// The type of the request parameters to serialize from.
+ /// The type of the result to deserialize to.
+ /// The MCP client or server instance.
+ /// The JSON-RPC method name to invoke.
+ /// Object representing the request parameters.
+ /// The request id for the request.
+ /// The options governing request serialization.
+ /// A token to cancel the operation.
+ /// A task that represents the asynchronous operation. The task result contains the deserialized result.
+ public static Task SendRequestAsync(
+ this IMcpEndpoint endpoint,
+ string method,
+ TParameters parameters,
+ JsonSerializerOptions? serializerOptions = null,
+ RequestId? requestId = null,
+ CancellationToken cancellationToken = default)
+ where TResult : notnull
+ {
+ serializerOptions ??= McpJsonUtilities.DefaultOptions;
+ serializerOptions.MakeReadOnly();
+
+ JsonTypeInfo paramsTypeInfo = serializerOptions.GetTypeInfo();
+ JsonTypeInfo resultTypeInfo = serializerOptions.GetTypeInfo();
+ return SendRequestAsync(endpoint, method, parameters, paramsTypeInfo, resultTypeInfo, requestId, cancellationToken);
+ }
+
+ ///
+ /// Sends a JSON-RPC request and attempts to deserialize the result to .
+ ///
+ /// The type of the request parameters to serialize from.
+ /// The type of the result to deserialize to.
+ /// The MCP client or server instance.
+ /// The JSON-RPC method name to invoke.
+ /// Object representing the request parameters.
+ /// The type information for request parameter serialization.
+ /// The type information for request parameter deserialization.
+ /// The request id for the request.
+ /// A token to cancel the operation.
+ /// A task that represents the asynchronous operation. The task result contains the deserialized result.
+ internal static async Task SendRequestAsync(
+ this IMcpEndpoint endpoint,
+ string method,
+ TParameters parameters,
+ JsonTypeInfo parametersTypeInfo,
+ JsonTypeInfo resultTypeInfo,
+ RequestId? requestId = null,
+ CancellationToken cancellationToken = default)
+ where TResult : notnull
+ {
+ Throw.IfNull(endpoint);
+ Throw.IfNullOrWhiteSpace(method);
+ Throw.IfNull(parametersTypeInfo);
+ Throw.IfNull(resultTypeInfo);
+
+ JsonRpcRequest jsonRpcRequest = new()
+ {
+ Method = method,
+ Params = JsonSerializer.SerializeToNode(parameters, parametersTypeInfo),
+ };
+
+ if (requestId is { } id)
+ {
+ jsonRpcRequest.Id = id;
+ }
+
+ JsonRpcResponse response = await endpoint.SendRequestAsync(jsonRpcRequest, cancellationToken).ConfigureAwait(false);
+ return JsonSerializer.Deserialize(response.Result, resultTypeInfo) ?? throw new JsonException("Unexpected JSON result in response.");
+ }
+
+ ///
+ /// Sends a notification to the server with parameters.
+ ///
+ /// The client.
+ /// The notification method name.
+ /// A token to cancel the operation.
+ public static Task SendNotificationAsync(this IMcpEndpoint client, string method, CancellationToken cancellationToken = default)
+ {
+ Throw.IfNull(client);
+ Throw.IfNullOrWhiteSpace(method);
+ return client.SendMessageAsync(new JsonRpcNotification { Method = method }, cancellationToken);
+ }
+
+ ///
+ /// Sends a notification to the server with parameters.
+ ///
+ /// The MCP client or server instance.
+ /// The JSON-RPC method name to invoke.
+ /// Object representing the request parameters.
+ /// The options governing request serialization.
+ /// A token to cancel the operation.
+ public static Task SendNotificationAsync(
+ this IMcpEndpoint endpoint,
+ string method,
+ TParameters parameters,
+ JsonSerializerOptions? serializerOptions = null,
+ CancellationToken cancellationToken = default)
+ {
+ serializerOptions ??= McpJsonUtilities.DefaultOptions;
+ serializerOptions.MakeReadOnly();
+
+ JsonTypeInfo parametersTypeInfo = serializerOptions.GetTypeInfo();
+ return SendNotificationAsync(endpoint, method, parameters, parametersTypeInfo, cancellationToken);
+ }
+
+ ///
+ /// Sends a notification to the server with parameters.
+ ///
+ /// The MCP client or server instance.
+ /// The JSON-RPC method name to invoke.
+ /// Object representing the request parameters.
+ /// The type information for request parameter serialization.
+ /// A token to cancel the operation.
+ internal static Task SendNotificationAsync(
+ this IMcpEndpoint endpoint,
+ string method,
+ TParameters parameters,
+ JsonTypeInfo parametersTypeInfo,
+ CancellationToken cancellationToken = default)
+ {
+ Throw.IfNull(endpoint);
+ Throw.IfNullOrWhiteSpace(method);
+ Throw.IfNull(parametersTypeInfo);
+
+ JsonNode? parametersJson = JsonSerializer.SerializeToNode(parameters, parametersTypeInfo);
+ return endpoint.SendMessageAsync(new JsonRpcNotification { Method = method, Params = parametersJson }, cancellationToken);
+ }
+
+ /// Notifies the connected endpoint of progress.
+ /// The endpoint issuing the notification.
+ /// The identifying the operation.
+ /// The progress update to send.
+ /// A token to cancel the operation.
+ /// A task representing the completion of the operation.
+ /// is .
+ public static Task NotifyProgressAsync(
+ this IMcpEndpoint endpoint,
+ ProgressToken progressToken,
+ ProgressNotificationValue progress,
+ CancellationToken cancellationToken = default)
+ {
+ Throw.IfNull(endpoint);
+
+ return endpoint.SendMessageAsync(new JsonRpcNotification()
+ {
+ Method = NotificationMethods.ProgressNotification,
+ Params = JsonSerializer.SerializeToNode(new ProgressNotification
+ {
+ ProgressToken = progressToken,
+ Progress = progress,
+ }, McpJsonUtilities.JsonContext.Default.ProgressNotification),
+ }, cancellationToken);
+ }
+}
diff --git a/src/ModelContextProtocol/ModelContextProtocol.csproj b/src/ModelContextProtocol/ModelContextProtocol.csproj
index 6860381de..c120269bd 100644
--- a/src/ModelContextProtocol/ModelContextProtocol.csproj
+++ b/src/ModelContextProtocol/ModelContextProtocol.csproj
@@ -1,7 +1,7 @@
- net8.0;netstandard2.0
+ net9.0;net8.0;netstandard2.0
true
true
ModelContextProtocol
@@ -12,15 +12,8 @@
true
-
-
-
-
-
-
-
-
+
@@ -28,6 +21,24 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/ModelContextProtocol/Protocol/Messages/JsonRpcNotification.cs b/src/ModelContextProtocol/Protocol/Messages/JsonRpcNotification.cs
index 359773712..302a46379 100644
--- a/src/ModelContextProtocol/Protocol/Messages/JsonRpcNotification.cs
+++ b/src/ModelContextProtocol/Protocol/Messages/JsonRpcNotification.cs
@@ -1,4 +1,5 @@
-using System.Text.Json.Serialization;
+using System.Text.Json.Nodes;
+using System.Text.Json.Serialization;
namespace ModelContextProtocol.Protocol.Messages;
@@ -23,5 +24,5 @@ public record JsonRpcNotification : IJsonRpcMessage
/// Optional parameters for the notification.
///
[JsonPropertyName("params")]
- public object? Params { get; init; }
+ public JsonNode? Params { get; init; }
}
diff --git a/src/ModelContextProtocol/Protocol/Messages/JsonRpcRequest.cs b/src/ModelContextProtocol/Protocol/Messages/JsonRpcRequest.cs
index 9de16b8f2..8ae4a56de 100644
--- a/src/ModelContextProtocol/Protocol/Messages/JsonRpcRequest.cs
+++ b/src/ModelContextProtocol/Protocol/Messages/JsonRpcRequest.cs
@@ -1,4 +1,5 @@
-using System.Text.Json.Serialization;
+using System.Text.Json.Nodes;
+using System.Text.Json.Serialization;
namespace ModelContextProtocol.Protocol.Messages;
@@ -29,5 +30,5 @@ public record JsonRpcRequest : IJsonRpcMessageWithId
/// Optional parameters for the method.
///
[JsonPropertyName("params")]
- public object? Params { get; init; }
+ public JsonNode? Params { get; init; }
}
diff --git a/src/ModelContextProtocol/Protocol/Messages/JsonRpcResponse.cs b/src/ModelContextProtocol/Protocol/Messages/JsonRpcResponse.cs
index a5a1238a3..9ced163f9 100644
--- a/src/ModelContextProtocol/Protocol/Messages/JsonRpcResponse.cs
+++ b/src/ModelContextProtocol/Protocol/Messages/JsonRpcResponse.cs
@@ -1,4 +1,4 @@
-
+using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
namespace ModelContextProtocol.Protocol.Messages;
@@ -23,5 +23,5 @@ public record JsonRpcResponse : IJsonRpcMessageWithId
/// The result of the method invocation.
///
[JsonPropertyName("result")]
- public required object? Result { get; init; }
+ public required JsonNode? Result { get; init; }
}
diff --git a/src/ModelContextProtocol/Protocol/Messages/ProgressToken.cs b/src/ModelContextProtocol/Protocol/Messages/ProgressToken.cs
index 6183eb92e..b8a481086 100644
--- a/src/ModelContextProtocol/Protocol/Messages/ProgressToken.cs
+++ b/src/ModelContextProtocol/Protocol/Messages/ProgressToken.cs
@@ -12,15 +12,15 @@ namespace ModelContextProtocol.Protocol.Messages;
[JsonConverter(typeof(Converter))]
public readonly struct ProgressToken : IEquatable
{
- /// The id, either a string or a boxed long or null.
- private readonly object? _id;
+ /// The token, either a string or a boxed long or null.
+ private readonly object? _token;
/// Initializes a new instance of the with a specified value.
/// The required ID value.
public ProgressToken(string value)
{
Throw.IfNull(value);
- _id = value;
+ _token = value;
}
/// Initializes a new instance of the with a specified value.
@@ -28,28 +28,29 @@ public ProgressToken(string value)
public ProgressToken(long value)
{
// Box the long. Progress tokens are almost always strings in practice, so this should be rare.
- _id = value;
+ _token = value;
}
- /// Gets whether the identifier is uninitialized.
- public bool IsDefault => _id is null;
+ /// Gets the underlying object for this token.
+ /// This will either be a , a boxed , or .
+ public object? Token => _token;
///
public override string? ToString() =>
- _id is string stringValue ? $"\"{stringValue}\"" :
- _id is long longValue ? longValue.ToString(CultureInfo.InvariantCulture) :
+ _token is string stringValue ? $"{stringValue}" :
+ _token is long longValue ? longValue.ToString(CultureInfo.InvariantCulture) :
null;
///
/// Compares this ProgressToken to another ProgressToken.
///
- public bool Equals(ProgressToken other) => Equals(_id, other._id);
+ public bool Equals(ProgressToken other) => Equals(_token, other._token);
///
public override bool Equals(object? obj) => obj is ProgressToken other && Equals(other);
///
- public override int GetHashCode() => _id?.GetHashCode() ?? 0;
+ public override int GetHashCode() => _token?.GetHashCode() ?? 0;
///
/// Compares two ProgressTokens for equality.
@@ -83,7 +84,7 @@ public override void Write(Utf8JsonWriter writer, ProgressToken value, JsonSeria
{
Throw.IfNull(writer);
- switch (value._id)
+ switch (value._token)
{
case string str:
writer.WriteStringValue(str);
diff --git a/src/ModelContextProtocol/Protocol/Messages/RequestId.cs b/src/ModelContextProtocol/Protocol/Messages/RequestId.cs
index e6fc74418..550428ff9 100644
--- a/src/ModelContextProtocol/Protocol/Messages/RequestId.cs
+++ b/src/ModelContextProtocol/Protocol/Messages/RequestId.cs
@@ -31,12 +31,13 @@ public RequestId(long value)
_id = value;
}
- /// Gets whether the identifier is uninitialized.
- public bool IsDefault => _id is null;
+ /// Gets the underlying object for this id.
+ /// This will either be a , a boxed , or .
+ public object? Id => _id;
///
public override string ToString() =>
- _id is string stringValue ? $"\"{stringValue}\"" :
+ _id is string stringValue ? stringValue :
_id is long longValue ? longValue.ToString(CultureInfo.InvariantCulture) :
string.Empty;
diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs
index 2b4acac99..49b2fe401 100644
--- a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs
+++ b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs
@@ -24,7 +24,6 @@ internal sealed class SseClientSessionTransport : TransportBase
private Task? _receiveTask;
private readonly ILogger _logger;
private readonly McpServerConfig _serverConfig;
- private readonly JsonSerializerOptions _jsonOptions;
private readonly TaskCompletionSource _connectionEstablished;
private string EndpointName => $"Client (SSE) for ({_serverConfig.Id}: {_serverConfig.Name})";
@@ -50,7 +49,6 @@ public SseClientSessionTransport(SseClientTransportOptions transportOptions, Mcp
_httpClient = httpClient;
_connectionCts = new CancellationTokenSource();
_logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance;
- _jsonOptions = McpJsonUtilities.DefaultOptions;
_connectionEstablished = new TaskCompletionSource();
}
@@ -94,7 +92,7 @@ public override async Task SendMessageAsync(
throw new InvalidOperationException("Transport not connected");
using var content = new StringContent(
- JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo()),
+ JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.IJsonRpcMessage),
Encoding.UTF8,
"application/json"
);
@@ -127,7 +125,7 @@ public override async Task SendMessageAsync(
}
else
{
- JsonRpcResponse initializeResponse = JsonSerializer.Deserialize(responseContent, _jsonOptions.GetTypeInfo()) ??
+ JsonRpcResponse initializeResponse = JsonSerializer.Deserialize(responseContent, McpJsonUtilities.JsonContext.Default.JsonRpcResponse) ??
throw new McpTransportException("Failed to initialize client");
_logger.TransportReceivedMessageParsed(EndpointName, messageId);
@@ -259,7 +257,7 @@ private async Task ProcessSseMessage(string data, CancellationToken cancellation
try
{
- var message = JsonSerializer.Deserialize(data, _jsonOptions.GetTypeInfo());
+ var message = JsonSerializer.Deserialize(data, McpJsonUtilities.JsonContext.Default.IJsonRpcMessage);
if (message == null)
{
_logger.TransportMessageParseUnexpectedType(EndpointName, data);
diff --git a/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs
index eafd1f614..d4e39c8a4 100644
--- a/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs
+++ b/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs
@@ -32,19 +32,6 @@ public sealed class SseResponseStreamTransport(Stream sseResponseStream, string
/// A task representing the send loop that writes JSON-RPC messages to the SSE response stream.
public Task RunAsync(CancellationToken cancellationToken)
{
- void WriteJsonRpcMessageToBuffer(SseItem item, IBufferWriter writer)
- {
- if (item.EventType == "endpoint")
- {
- writer.Write(Encoding.UTF8.GetBytes(messageEndpoint));
- return;
- }
-
- JsonSerializer.Serialize(GetUtf8JsonWriter(writer), item.Data, McpJsonUtilities.DefaultOptions.GetTypeInfo());
- }
-
- IsConnected = true;
-
// The very first SSE event isn't really an IJsonRpcMessage, but there's no API to write a single item of a different type,
// so we fib and special-case the "endpoint" event type in the formatter.
if (!_outgoingSseChannel.Writer.TryWrite(new SseItem(null, "endpoint")))
@@ -52,10 +39,23 @@ void WriteJsonRpcMessageToBuffer(SseItem item, IBufferWriter item, IBufferWriter writer)
+ {
+ if (item.EventType == "endpoint")
+ {
+ writer.Write(Encoding.UTF8.GetBytes(messageEndpoint));
+ return;
+ }
+
+ JsonSerializer.Serialize(GetUtf8JsonWriter(writer), item.Data, McpJsonUtilities.JsonContext.Default.IJsonRpcMessage!);
+ }
+
///
public ChannelReader MessageReader => _incomingChannel.Reader;
@@ -76,7 +76,8 @@ public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken ca
throw new InvalidOperationException($"Transport is not connected. Make sure to call {nameof(RunAsync)} first.");
}
- await _outgoingSseChannel.Writer.WriteAsync(new SseItem(message), cancellationToken);
+ // Emit redundant "event: message" lines for better compatibility with other SDKs.
+ await _outgoingSseChannel.Writer.WriteAsync(new SseItem(message, SseParser.EventTypeDefault), cancellationToken);
}
///
diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs
new file mode 100644
index 000000000..af304c897
--- /dev/null
+++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs
@@ -0,0 +1,40 @@
+using Microsoft.Extensions.Logging;
+using ModelContextProtocol.Logging;
+using ModelContextProtocol.Protocol.Messages;
+using System.Diagnostics;
+
+namespace ModelContextProtocol.Protocol.Transport;
+
+/// Provides the client side of a stdio-based session transport.
+internal sealed class StdioClientSessionTransport : StreamClientSessionTransport
+{
+ private readonly StdioClientTransportOptions _options;
+ private readonly Process _process;
+
+ public StdioClientSessionTransport(StdioClientTransportOptions options, Process process, string endpointName, ILoggerFactory? loggerFactory)
+ : base(process.StandardInput, process.StandardOutput, endpointName, loggerFactory)
+ {
+ _process = process;
+ _options = options;
+ }
+
+ ///
+ public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
+ {
+ if (_process.HasExited)
+ {
+ Logger.TransportNotConnected(EndpointName);
+ throw new McpTransportException("Transport is not connected");
+ }
+
+ await base.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);
+ }
+
+ ///
+ protected override ValueTask CleanupAsync(CancellationToken cancellationToken)
+ {
+ StdioClientTransport.DisposeProcess(_process, processStarted: true, Logger, _options.ShutdownTimeout, EndpointName);
+
+ return base.CleanupAsync(cancellationToken);
+ }
+}
diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientStreamTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientStreamTransport.cs
deleted file mode 100644
index 35c957e52..000000000
--- a/src/ModelContextProtocol/Protocol/Transport/StdioClientStreamTransport.cs
+++ /dev/null
@@ -1,326 +0,0 @@
-using Microsoft.Extensions.Logging;
-using Microsoft.Extensions.Logging.Abstractions;
-using ModelContextProtocol.Logging;
-using ModelContextProtocol.Protocol.Messages;
-using ModelContextProtocol.Utils;
-using ModelContextProtocol.Utils.Json;
-using System.Diagnostics;
-using System.Text;
-using System.Text.Json;
-
-namespace ModelContextProtocol.Protocol.Transport;
-
-///
-/// Implements the MCP transport protocol over standard input/output streams.
-///
-internal sealed class StdioClientStreamTransport : TransportBase
-{
- private readonly StdioClientTransportOptions _options;
- private readonly McpServerConfig _serverConfig;
- private readonly ILogger _logger;
- private readonly JsonSerializerOptions _jsonOptions;
- private readonly DataReceivedEventHandler _logProcessErrors;
- private readonly SemaphoreSlim _sendLock = new(1, 1);
- private Process? _process;
- private Task? _readTask;
- private CancellationTokenSource? _shutdownCts;
- private bool _processStarted;
-
- private string EndpointName => $"Client (stdio) for ({_serverConfig.Id}: {_serverConfig.Name})";
-
- ///
- /// Initializes a new instance of the StdioTransport class.
- ///
- /// Configuration options for the transport.
- /// The server configuration for the transport.
- /// A logger factory for creating loggers.
- public StdioClientStreamTransport(StdioClientTransportOptions options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory = null)
- : base(loggerFactory)
- {
- Throw.IfNull(options);
- Throw.IfNull(serverConfig);
-
- _options = options;
- _serverConfig = serverConfig;
- _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance;
- _logProcessErrors = (sender, args) => _logger.TransportError(EndpointName, args.Data ?? "(no data)");
- _jsonOptions = McpJsonUtilities.DefaultOptions;
- }
-
- ///
- public async Task ConnectAsync(CancellationToken cancellationToken = default)
- {
- if (IsConnected)
- {
- _logger.TransportAlreadyConnected(EndpointName);
- throw new McpTransportException("Transport is already connected");
- }
-
- try
- {
- _logger.TransportConnecting(EndpointName);
-
- _shutdownCts = new CancellationTokenSource();
-
- UTF8Encoding noBomUTF8 = new(encoderShouldEmitUTF8Identifier: false);
-
- var startInfo = new ProcessStartInfo
- {
- FileName = _options.Command,
- RedirectStandardInput = true,
- RedirectStandardOutput = true,
- RedirectStandardError = true,
- UseShellExecute = false,
- CreateNoWindow = true,
- WorkingDirectory = _options.WorkingDirectory ?? Environment.CurrentDirectory,
- StandardOutputEncoding = noBomUTF8,
- StandardErrorEncoding = noBomUTF8,
-#if NET
- StandardInputEncoding = noBomUTF8,
-#endif
- };
-
- if (!string.IsNullOrWhiteSpace(_options.Arguments))
- {
- startInfo.Arguments = _options.Arguments;
- }
-
- if (_options.EnvironmentVariables != null)
- {
- foreach (var entry in _options.EnvironmentVariables)
- {
- startInfo.Environment[entry.Key] = entry.Value;
- }
- }
-
- _logger.CreateProcessForTransport(EndpointName, _options.Command,
- startInfo.Arguments, string.Join(", ", startInfo.Environment.Select(kvp => kvp.Key + "=" + kvp.Value)),
- startInfo.WorkingDirectory, _options.ShutdownTimeout.ToString());
-
- _process = new Process { StartInfo = startInfo };
-
- // Set up error logging
- _process.ErrorDataReceived += _logProcessErrors;
-
- // We need both stdin and stdout to use a no-BOM UTF-8 encoding. On .NET Core,
- // we can use ProcessStartInfo.StandardOutputEncoding/StandardInputEncoding, but
- // StandardInputEncoding doesn't exist on .NET Framework; instead, it always picks
- // up the encoding from Console.InputEncoding. As such, when not targeting .NET Core,
- // we temporarily change Console.InputEncoding to no-BOM UTF-8 around the Process.Start
- // call, to ensure it picks up the correct encoding.
-#if NET
- _processStarted = _process.Start();
-#else
- Encoding originalInputEncoding = Console.InputEncoding;
- try
- {
- Console.InputEncoding = noBomUTF8;
- _processStarted = _process.Start();
- }
- finally
- {
- Console.InputEncoding = originalInputEncoding;
- }
-#endif
-
- if (!_processStarted)
- {
- _logger.TransportProcessStartFailed(EndpointName);
- throw new McpTransportException("Failed to start MCP server process");
- }
-
- _logger.TransportProcessStarted(EndpointName, _process.Id);
-
- _process.BeginErrorReadLine();
-
- // Start reading messages in the background
- _readTask = Task.Run(() => ReadMessagesAsync(_shutdownCts.Token), CancellationToken.None);
- _logger.TransportReadingMessages(EndpointName);
-
- SetConnected(true);
- }
- catch (Exception ex)
- {
- _logger.TransportConnectFailed(EndpointName, ex);
- await CleanupAsync(cancellationToken).ConfigureAwait(false);
- throw new McpTransportException("Failed to connect transport", ex);
- }
- }
-
- ///
- public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
- {
- using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false);
-
- if (!IsConnected || _process?.HasExited == true)
- {
- _logger.TransportNotConnected(EndpointName);
- throw new McpTransportException("Transport is not connected");
- }
-
- string id = "(no id)";
- if (message is IJsonRpcMessageWithId messageWithId)
- {
- id = messageWithId.Id.ToString();
- }
-
- try
- {
- var json = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo());
- _logger.TransportSendingMessage(EndpointName, id, json);
- _logger.TransportMessageBytesUtf8(EndpointName, json);
-
- // Write the message followed by a newline using our UTF-8 writer
- await _process!.StandardInput.WriteLineAsync(json).ConfigureAwait(false);
- await _process.StandardInput.FlushAsync(cancellationToken).ConfigureAwait(false);
-
- _logger.TransportSentMessage(EndpointName, id);
- }
- catch (Exception ex)
- {
- _logger.TransportSendFailed(EndpointName, id, ex);
- throw new McpTransportException("Failed to send message", ex);
- }
- }
-
- ///
- public override async ValueTask DisposeAsync()
- {
- await CleanupAsync(CancellationToken.None).ConfigureAwait(false);
- }
-
- private async Task ReadMessagesAsync(CancellationToken cancellationToken)
- {
- try
- {
- _logger.TransportEnteringReadMessagesLoop(EndpointName);
-
- while (!cancellationToken.IsCancellationRequested && !_process!.HasExited)
- {
- _logger.TransportWaitingForMessage(EndpointName);
- var line = await _process.StandardOutput.ReadLineAsync(cancellationToken).ConfigureAwait(false);
- if (line == null)
- {
- _logger.TransportEndOfStream(EndpointName);
- break;
- }
-
- if (string.IsNullOrWhiteSpace(line))
- {
- continue;
- }
-
- _logger.TransportReceivedMessage(EndpointName, line);
- _logger.TransportMessageBytesUtf8(EndpointName, line);
-
- await ProcessMessageAsync(line, cancellationToken).ConfigureAwait(false);
- }
- _logger.TransportExitingReadMessagesLoop(EndpointName);
- }
- catch (OperationCanceledException)
- {
- _logger.TransportReadMessagesCancelled(EndpointName);
- // Normal shutdown
- }
- catch (Exception ex)
- {
- _logger.TransportReadMessagesFailed(EndpointName, ex);
- }
- finally
- {
- await CleanupAsync(cancellationToken).ConfigureAwait(false);
- }
- }
-
- private async Task ProcessMessageAsync(string line, CancellationToken cancellationToken)
- {
- try
- {
- line=line.Trim();//Fixes an error when the service prefixes nonprintable characters
- var message = JsonSerializer.Deserialize(line, _jsonOptions.GetTypeInfo());
- if (message != null)
- {
- string messageId = "(no id)";
- if (message is IJsonRpcMessageWithId messageWithId)
- {
- messageId = messageWithId.Id.ToString();
- }
- _logger.TransportReceivedMessageParsed(EndpointName, messageId);
- await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false);
- _logger.TransportMessageWritten(EndpointName, messageId);
- }
- else
- {
- _logger.TransportMessageParseUnexpectedType(EndpointName, line);
- }
- }
- catch (JsonException ex)
- {
- _logger.TransportMessageParseFailed(EndpointName, line, ex);
- }
- }
-
- private async Task CleanupAsync(CancellationToken cancellationToken)
- {
- _logger.TransportCleaningUp(EndpointName);
-
- if (_process is Process process && _processStarted && !process.HasExited)
- {
- try
- {
- // Wait for the process to exit
- _logger.TransportWaitingForShutdown(EndpointName);
-
- // Kill the while process tree because the process may spawn child processes
- // and Node.js does not kill its children when it exits properly
- process.KillTree(_options.ShutdownTimeout);
- }
- catch (Exception ex)
- {
- _logger.TransportShutdownFailed(EndpointName, ex);
- }
- finally
- {
- process.ErrorDataReceived -= _logProcessErrors;
- process.Dispose();
- _process = null;
- }
- }
-
- if (_shutdownCts is { } shutdownCts)
- {
- await shutdownCts.CancelAsync().ConfigureAwait(false);
- shutdownCts.Dispose();
- _shutdownCts = null;
- }
-
- if (_readTask is Task readTask)
- {
- try
- {
- _logger.TransportWaitingForReadTask(EndpointName);
- await readTask.WaitAsync(TimeSpan.FromSeconds(5), cancellationToken).ConfigureAwait(false);
- }
- catch (TimeoutException)
- {
- _logger.TransportCleanupReadTaskTimeout(EndpointName);
- }
- catch (OperationCanceledException)
- {
- _logger.TransportCleanupReadTaskCancelled(EndpointName);
- }
- catch (Exception ex)
- {
- _logger.TransportCleanupReadTaskFailed(EndpointName, ex);
- }
- finally
- {
- _logger.TransportReadTaskCleanedUp(EndpointName);
- _readTask = null;
- }
- }
-
- SetConnected(false);
- _logger.TransportCleanedUp(EndpointName);
- }
-}
diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs
index d2b51b950..774677f13 100644
--- a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs
+++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs
@@ -1,10 +1,16 @@
using Microsoft.Extensions.Logging;
+using Microsoft.Extensions.Logging.Abstractions;
+using ModelContextProtocol.Logging;
using ModelContextProtocol.Utils;
+using System.Diagnostics;
+using System.Text;
+
+#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
namespace ModelContextProtocol.Protocol.Transport;
///
-/// Implements the MCP transport protocol over standard input/output streams.
+/// Provides a client MCP transport implemented via "stdio" (standard input/output).
///
public sealed class StdioClientTransport : IClientTransport
{
@@ -13,7 +19,7 @@ public sealed class StdioClientTransport : IClientTransport
private readonly ILoggerFactory? _loggerFactory;
///
- /// Initializes a new instance of the StdioTransport class.
+ /// Initializes a new instance of the class.
///
/// Configuration options for the transport.
/// The server configuration for the transport.
@@ -31,17 +37,121 @@ public StdioClientTransport(StdioClientTransportOptions options, McpServerConfig
///
public async Task ConnectAsync(CancellationToken cancellationToken = default)
{
- var streamTransport = new StdioClientStreamTransport(_options, _serverConfig, _loggerFactory);
+ string endpointName = $"Client (stdio) for ({_serverConfig.Id}: {_serverConfig.Name})";
+
+ Process? process = null;
+ bool processStarted = false;
+ ILogger logger = (ILogger?)_loggerFactory?.CreateLogger() ?? NullLogger.Instance;
try
{
- await streamTransport.ConnectAsync(cancellationToken).ConfigureAwait(false);
- return streamTransport;
+ logger.TransportConnecting(endpointName);
+
+ UTF8Encoding noBomUTF8 = new(encoderShouldEmitUTF8Identifier: false);
+
+ ProcessStartInfo startInfo = new()
+ {
+ FileName = _options.Command,
+ RedirectStandardInput = true,
+ RedirectStandardOutput = true,
+ RedirectStandardError = true,
+ UseShellExecute = false,
+ CreateNoWindow = true,
+ WorkingDirectory = _options.WorkingDirectory ?? Environment.CurrentDirectory,
+ StandardOutputEncoding = noBomUTF8,
+ StandardErrorEncoding = noBomUTF8,
+#if NET
+ StandardInputEncoding = noBomUTF8,
+#endif
+ };
+
+ if (!string.IsNullOrWhiteSpace(_options.Arguments))
+ {
+ startInfo.Arguments = _options.Arguments;
+ }
+
+ if (_options.EnvironmentVariables != null)
+ {
+ foreach (var entry in _options.EnvironmentVariables)
+ {
+ startInfo.Environment[entry.Key] = entry.Value;
+ }
+ }
+
+ logger.CreateProcessForTransport(endpointName, _options.Command,
+ startInfo.Arguments, string.Join(", ", startInfo.Environment.Select(kvp => kvp.Key + "=" + kvp.Value)),
+ startInfo.WorkingDirectory, _options.ShutdownTimeout.ToString());
+
+ process = new() { StartInfo = startInfo };
+
+ // Set up error logging
+ process.ErrorDataReceived += (sender, args) => logger.ReadStderr(endpointName, args.Data ?? "(no data)");
+
+ // We need both stdin and stdout to use a no-BOM UTF-8 encoding. On .NET Core,
+ // we can use ProcessStartInfo.StandardOutputEncoding/StandardInputEncoding, but
+ // StandardInputEncoding doesn't exist on .NET Framework; instead, it always picks
+ // up the encoding from Console.InputEncoding. As such, when not targeting .NET Core,
+ // we temporarily change Console.InputEncoding to no-BOM UTF-8 around the Process.Start
+ // call, to ensure it picks up the correct encoding.
+#if NET
+ processStarted = process.Start();
+#else
+ Encoding originalInputEncoding = Console.InputEncoding;
+ try
+ {
+ Console.InputEncoding = noBomUTF8;
+ processStarted = process.Start();
+ }
+ finally
+ {
+ Console.InputEncoding = originalInputEncoding;
+ }
+#endif
+
+ if (!processStarted)
+ {
+ logger.TransportProcessStartFailed(endpointName);
+ throw new McpTransportException("Failed to start MCP server process");
+ }
+
+ logger.TransportProcessStarted(endpointName, process.Id);
+
+ process.BeginErrorReadLine();
+
+ return new StdioClientSessionTransport(_options, process, endpointName, _loggerFactory);
}
- catch
+ catch (Exception ex)
+ {
+ logger.TransportConnectFailed(endpointName, ex);
+ DisposeProcess(process, processStarted, logger, _options.ShutdownTimeout, endpointName);
+ throw new McpTransportException("Failed to connect transport", ex);
+ }
+ }
+
+ internal static void DisposeProcess(
+ Process? process, bool processStarted, ILogger logger, TimeSpan shutdownTimeout, string endpointName)
+ {
+ if (process is not null)
{
- await streamTransport.DisposeAsync().ConfigureAwait(false);
- throw;
+ try
+ {
+ if (processStarted && !process.HasExited)
+ {
+ // Wait for the process to exit.
+ // Kill the while process tree because the process may spawn child processes
+ // and Node.js does not kill its children when it exits properly.
+ logger.TransportWaitingForShutdown(endpointName);
+ process.KillTree(shutdownTimeout);
+ }
+ }
+ catch (Exception ex)
+ {
+ logger.TransportShutdownFailed(endpointName, ex);
+ }
+ finally
+ {
+ process.Dispose();
+ }
}
}
}
diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs
index 7779edc99..58077dbb2 100644
--- a/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs
+++ b/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs
@@ -1,38 +1,15 @@
-using Microsoft.Extensions.Logging;
-using Microsoft.Extensions.Logging.Abstractions;
+using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
-using ModelContextProtocol.Logging;
-using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Server;
using ModelContextProtocol.Utils;
-using ModelContextProtocol.Utils.Json;
-using System.Text;
-using System.Text.Json;
namespace ModelContextProtocol.Protocol.Transport;
///
-/// Provides an implementation of the MCP transport protocol over standard input/output streams.
+/// Provides a server MCP transport implemented via "stdio" (standard input/output).
///
-public sealed class StdioServerTransport : TransportBase, ITransport
+public sealed class StdioServerTransport : StreamServerTransport
{
- private static readonly byte[] s_newlineBytes = "\n"u8.ToArray();
-
- private readonly string _serverName;
- private readonly ILogger _logger;
-
- private readonly JsonSerializerOptions _jsonOptions = McpJsonUtilities.DefaultOptions;
- private readonly TextReader _stdInReader;
- private readonly Stream _stdOutStream;
-
- private readonly SemaphoreSlim _sendLock = new(1, 1);
- private readonly CancellationTokenSource _shutdownCts = new();
-
- private readonly Task _readLoopCompleted;
- private int _disposed = 0;
-
- private string EndpointName => $"Server (stdio) ({_serverName})";
-
///
/// Initializes a new instance of the class, using
/// and for input and output streams.
@@ -69,14 +46,6 @@ public StdioServerTransport(McpServerOptions serverOptions, ILoggerFactory? logg
{
}
- private static string GetServerName(McpServerOptions serverOptions)
- {
- Throw.IfNull(serverOptions);
- Throw.IfNull(serverOptions.ServerInfo);
- Throw.IfNull(serverOptions.ServerInfo.Name);
- return serverOptions.ServerInfo.Name;
- }
-
///
/// Initializes a new instance of the class, using
/// and for input and output streams.
@@ -90,186 +59,20 @@ private static string GetServerName(McpServerOptions serverOptions)
/// to , as that will interfere with the transport's output.
///
///
- public StdioServerTransport(string serverName, ILoggerFactory? loggerFactory)
- : this(serverName, stdinStream: null, stdoutStream: null, loggerFactory)
- {
- }
-
- ///
- /// Initializes a new instance of the class with explicit input/output streams.
- ///
- /// The name of the server.
- /// The input to use as standard input. If , will be used.
- /// The output to use as standard output. If , will be used.
- /// Optional logger factory used for logging employed by the transport.
- /// is .
- ///
- ///
- /// This constructor is useful for testing scenarios where you want to redirect input/output.
- ///
- ///
- public StdioServerTransport(string serverName, Stream? stdinStream = null, Stream? stdoutStream = null, ILoggerFactory? loggerFactory = null)
- : base(loggerFactory)
- {
- Throw.IfNull(serverName);
-
- _serverName = serverName;
- _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance;
-
- _stdInReader = new StreamReader(stdinStream ?? Console.OpenStandardInput(), Encoding.UTF8);
- _stdOutStream = stdoutStream ?? new BufferedStream(Console.OpenStandardOutput());
-
- SetConnected(true);
- _readLoopCompleted = Task.Run(ReadMessagesAsync, _shutdownCts.Token);
- }
-
- ///
- public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
+ public StdioServerTransport(string serverName, ILoggerFactory? loggerFactory = null)
+ : base(Console.OpenStandardInput(),
+ new BufferedStream(Console.OpenStandardOutput()),
+ serverName ?? throw new ArgumentNullException(nameof(serverName)),
+ loggerFactory)
{
- using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false);
-
- if (!IsConnected)
- {
- _logger.TransportNotConnected(EndpointName);
- throw new McpTransportException("Transport is not connected");
- }
-
- string id = "(no id)";
- if (message is IJsonRpcMessageWithId messageWithId)
- {
- id = messageWithId.Id.ToString();
- }
-
- try
- {
- _logger.TransportSendingMessage(EndpointName, id);
-
- await JsonSerializer.SerializeAsync(_stdOutStream, message, _jsonOptions.GetTypeInfo(), cancellationToken).ConfigureAwait(false);
- await _stdOutStream.WriteAsync(s_newlineBytes, cancellationToken).ConfigureAwait(false);
- await _stdOutStream.FlushAsync(cancellationToken).ConfigureAwait(false);;
-
- _logger.TransportSentMessage(EndpointName, id);
- }
- catch (Exception ex)
- {
- _logger.TransportSendFailed(EndpointName, id, ex);
- throw new McpTransportException("Failed to send message", ex);
- }
}
- private async Task ReadMessagesAsync()
- {
- CancellationToken shutdownToken = _shutdownCts.Token;
- try
- {
- _logger.TransportEnteringReadMessagesLoop(EndpointName);
-
- while (!shutdownToken.IsCancellationRequested)
- {
- _logger.TransportWaitingForMessage(EndpointName);
-
- var line = await _stdInReader.ReadLineAsync(shutdownToken).ConfigureAwait(false);
- if (string.IsNullOrWhiteSpace(line))
- {
- if (line is null)
- {
- _logger.TransportEndOfStream(EndpointName);
- break;
- }
-
- continue;
- }
-
- _logger.TransportReceivedMessage(EndpointName, line);
- _logger.TransportMessageBytesUtf8(EndpointName, line);
-
- try
- {
- if (JsonSerializer.Deserialize(line, _jsonOptions.GetTypeInfo()) is { } message)
- {
- string messageId = "(no id)";
- if (message is IJsonRpcMessageWithId messageWithId)
- {
- messageId = messageWithId.Id.ToString();
- }
- _logger.TransportReceivedMessageParsed(EndpointName, messageId);
-
- await WriteMessageAsync(message, shutdownToken).ConfigureAwait(false);
- _logger.TransportMessageWritten(EndpointName, messageId);
- }
- else
- {
- _logger.TransportMessageParseUnexpectedType(EndpointName, line);
- }
- }
- catch (JsonException ex)
- {
- _logger.TransportMessageParseFailed(EndpointName, line, ex);
- // Continue reading even if we fail to parse a message
- }
- }
-
- _logger.TransportExitingReadMessagesLoop(EndpointName);
- }
- catch (OperationCanceledException)
- {
- _logger.TransportReadMessagesCancelled(EndpointName);
- }
- catch (Exception ex)
- {
- _logger.TransportReadMessagesFailed(EndpointName, ex);
- }
- finally
- {
- SetConnected(false);
- }
- }
-
- ///
- public override async ValueTask DisposeAsync()
+ private static string GetServerName(McpServerOptions serverOptions)
{
- if (Interlocked.Exchange(ref _disposed, 1) != 0)
- {
- return;
- }
-
- try
- {
- _logger.TransportCleaningUp(EndpointName);
-
- // Signal to the stdin reading loop to stop.
- await _shutdownCts.CancelAsync().ConfigureAwait(false);
- _shutdownCts.Dispose();
-
- // Dispose of stdin/out. Cancellation may not be able to wake up operations
- // synchronously blocked in a syscall; we need to forcefully close the handle / file descriptor.
- _stdInReader?.Dispose();
- _stdOutStream?.Dispose();
+ Throw.IfNull(serverOptions);
+ Throw.IfNull(serverOptions.ServerInfo);
+ Throw.IfNull(serverOptions.ServerInfo.Name);
- // Make sure the work has quiesced.
- try
- {
- _logger.TransportWaitingForReadTask(EndpointName);
- await _readLoopCompleted.ConfigureAwait(false);
- _logger.TransportReadTaskCleanedUp(EndpointName);
- }
- catch (TimeoutException)
- {
- _logger.TransportCleanupReadTaskTimeout(EndpointName);
- }
- catch (OperationCanceledException)
- {
- _logger.TransportCleanupReadTaskCancelled(EndpointName);
- }
- catch (Exception ex)
- {
- _logger.TransportCleanupReadTaskFailed(EndpointName, ex);
- }
- }
- finally
- {
- SetConnected(false);
- _logger.TransportCleanedUp(EndpointName);
- }
+ return serverOptions.ServerInfo.Name;
}
}
diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs
new file mode 100644
index 000000000..589e9078b
--- /dev/null
+++ b/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs
@@ -0,0 +1,195 @@
+using Microsoft.Extensions.Logging;
+using Microsoft.Extensions.Logging.Abstractions;
+using ModelContextProtocol.Logging;
+using ModelContextProtocol.Protocol.Messages;
+using ModelContextProtocol.Utils;
+using ModelContextProtocol.Utils.Json;
+using System.Text.Json;
+
+namespace ModelContextProtocol.Protocol.Transport;
+
+/// Provides the client side of a stream-based session transport.
+internal class StreamClientSessionTransport : TransportBase
+{
+ private readonly TextReader _serverOutput;
+ private readonly TextWriter _serverInput;
+ private readonly SemaphoreSlim _sendLock = new(1, 1);
+ private CancellationTokenSource? _shutdownCts = new();
+ private Task? _readTask;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ public StreamClientSessionTransport(
+ TextWriter serverInput, TextReader serverOutput, string endpointName, ILoggerFactory? loggerFactory)
+ : base(loggerFactory)
+ {
+ Logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance;
+ _serverOutput = serverOutput;
+ _serverInput = serverInput;
+ EndpointName = endpointName;
+
+ // Start reading messages in the background. We use the rarer pattern of new Task + Start
+ // in order to ensure that the body of the task will always see _readTask initialized.
+ // It is then able to reliably null it out on completion.
+ Logger.TransportReadingMessages(endpointName);
+ var readTask = new Task(
+ thisRef => ((StreamClientSessionTransport)thisRef!).ReadMessagesAsync(_shutdownCts.Token),
+ this,
+ TaskCreationOptions.DenyChildAttach);
+ _readTask = readTask.Unwrap();
+ readTask.Start();
+
+ SetConnected(true);
+ }
+
+ protected ILogger Logger { get; private set; }
+
+ protected string EndpointName { get; }
+
+ ///
+ public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
+ {
+ if (!IsConnected)
+ {
+ Logger.TransportNotConnected(EndpointName);
+ throw new McpTransportException("Transport is not connected");
+ }
+
+ string id = "(no id)";
+ if (message is IJsonRpcMessageWithId messageWithId)
+ {
+ id = messageWithId.Id.ToString();
+ }
+
+ var json = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage)));
+
+ using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false);
+ try
+ {
+ Logger.TransportSendingMessage(EndpointName, id, json);
+ Logger.TransportMessageBytesUtf8(EndpointName, json);
+
+ // Write the message followed by a newline using our UTF-8 writer
+ await _serverInput.WriteLineAsync(json).ConfigureAwait(false);
+ await _serverInput.FlushAsync(cancellationToken).ConfigureAwait(false);
+
+ Logger.TransportSentMessage(EndpointName, id);
+ }
+ catch (Exception ex)
+ {
+ Logger.TransportSendFailed(EndpointName, id, ex);
+ throw new McpTransportException("Failed to send message", ex);
+ }
+ }
+
+ ///
+ public override ValueTask DisposeAsync() =>
+ CleanupAsync(CancellationToken.None);
+
+ private async Task ReadMessagesAsync(CancellationToken cancellationToken)
+ {
+ try
+ {
+ Logger.TransportEnteringReadMessagesLoop(EndpointName);
+
+ while (!cancellationToken.IsCancellationRequested)
+ {
+ Logger.TransportWaitingForMessage(EndpointName);
+ if (await _serverOutput.ReadLineAsync(cancellationToken).ConfigureAwait(false) is not string line)
+ {
+ Logger.TransportEndOfStream(EndpointName);
+ break;
+ }
+
+ if (string.IsNullOrWhiteSpace(line))
+ {
+ continue;
+ }
+
+ Logger.TransportReceivedMessage(EndpointName, line);
+ Logger.TransportMessageBytesUtf8(EndpointName, line);
+
+ await ProcessMessageAsync(line, cancellationToken).ConfigureAwait(false);
+ }
+ Logger.TransportExitingReadMessagesLoop(EndpointName);
+ }
+ catch (OperationCanceledException)
+ {
+ Logger.TransportReadMessagesCancelled(EndpointName);
+ }
+ catch (Exception ex)
+ {
+ Logger.TransportReadMessagesFailed(EndpointName, ex);
+ }
+ finally
+ {
+ _readTask = null;
+ await CleanupAsync(cancellationToken).ConfigureAwait(false);
+ }
+ }
+
+ private async Task ProcessMessageAsync(string line, CancellationToken cancellationToken)
+ {
+ try
+ {
+ var message = (IJsonRpcMessage?)JsonSerializer.Deserialize(line.AsSpan().Trim(), McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage)));
+ if (message != null)
+ {
+ string messageId = "(no id)";
+ if (message is IJsonRpcMessageWithId messageWithId)
+ {
+ messageId = messageWithId.Id.ToString();
+ }
+
+ Logger.TransportReceivedMessageParsed(EndpointName, messageId);
+ await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false);
+ Logger.TransportMessageWritten(EndpointName, messageId);
+ }
+ else
+ {
+ Logger.TransportMessageParseUnexpectedType(EndpointName, line);
+ }
+ }
+ catch (JsonException ex)
+ {
+ Logger.TransportMessageParseFailed(EndpointName, line, ex);
+ }
+ }
+
+ protected virtual async ValueTask CleanupAsync(CancellationToken cancellationToken)
+ {
+ Logger.TransportCleaningUp(EndpointName);
+
+ if (Interlocked.Exchange(ref _shutdownCts, null) is { } shutdownCts)
+ {
+ await shutdownCts.CancelAsync().ConfigureAwait(false);
+ shutdownCts.Dispose();
+ }
+
+ if (Interlocked.Exchange(ref _readTask, null) is Task readTask)
+ {
+ try
+ {
+ Logger.TransportWaitingForReadTask(EndpointName);
+ await readTask.WaitAsync(TimeSpan.FromSeconds(5), cancellationToken).ConfigureAwait(false);
+ Logger.TransportReadTaskCleanedUp(EndpointName);
+ }
+ catch (TimeoutException)
+ {
+ Logger.TransportCleanupReadTaskTimeout(EndpointName);
+ }
+ catch (OperationCanceledException)
+ {
+ Logger.TransportCleanupReadTaskCancelled(EndpointName);
+ }
+ catch (Exception ex)
+ {
+ Logger.TransportCleanupReadTaskFailed(EndpointName, ex);
+ }
+ }
+
+ SetConnected(false);
+ Logger.TransportCleanedUp(EndpointName);
+ }
+}
diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamClientTransport.cs
new file mode 100644
index 000000000..80bd61df5
--- /dev/null
+++ b/src/ModelContextProtocol/Protocol/Transport/StreamClientTransport.cs
@@ -0,0 +1,47 @@
+using Microsoft.Extensions.Logging;
+using ModelContextProtocol.Utils;
+
+namespace ModelContextProtocol.Protocol.Transport;
+
+///
+/// Provides a client MCP transport implemented around a pair of input/output streams.
+///
+public sealed class StreamClientTransport : IClientTransport
+{
+ private readonly Stream _serverInput;
+ private readonly Stream _serverOutput;
+ private readonly ILoggerFactory? _loggerFactory;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ ///
+ /// The stream representing the connected server's input.
+ /// Writes to this stream will be sent to the server.
+ ///
+ ///
+ /// The stream representing the connected server's output.
+ /// Reads from this stream will receive messages from the server.
+ ///
+ /// A logger factory for creating loggers.
+ public StreamClientTransport(
+ Stream serverInput, Stream serverOutput, ILoggerFactory? loggerFactory = null)
+ {
+ Throw.IfNull(serverInput);
+ Throw.IfNull(serverOutput);
+
+ _serverInput = serverInput;
+ _serverOutput = serverOutput;
+ _loggerFactory = loggerFactory;
+ }
+
+ ///
+ public Task ConnectAsync(CancellationToken cancellationToken = default)
+ {
+ return Task.FromResult(new StreamClientSessionTransport(
+ new StreamWriter(_serverInput),
+ new StreamReader(_serverOutput),
+ "Client (stream)",
+ _loggerFactory));
+ }
+}
diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs
new file mode 100644
index 000000000..ebdf36350
--- /dev/null
+++ b/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs
@@ -0,0 +1,209 @@
+using Microsoft.Extensions.Logging;
+using Microsoft.Extensions.Logging.Abstractions;
+using ModelContextProtocol.Logging;
+using ModelContextProtocol.Protocol.Messages;
+using ModelContextProtocol.Utils;
+using ModelContextProtocol.Utils.Json;
+using System.IO.Pipelines;
+using System.Text;
+using System.Text.Json;
+
+namespace ModelContextProtocol.Protocol.Transport;
+
+///
+/// Provides a server MCP transport implemented around a pair of input/output streams.
+///
+public class StreamServerTransport : TransportBase, ITransport
+{
+ private static readonly byte[] s_newlineBytes = "\n"u8.ToArray();
+
+ private readonly ILogger _logger;
+
+ private readonly TextReader _inputReader;
+ private readonly Stream _outputStream;
+ private readonly string _endpointName;
+
+ private readonly SemaphoreSlim _sendLock = new(1, 1);
+ private readonly CancellationTokenSource _shutdownCts = new();
+
+ private readonly Task _readLoopCompleted;
+ private int _disposed = 0;
+
+ ///
+ /// Initializes a new instance of the class with explicit input/output streams.
+ ///
+ /// The input to use as standard input.
+ /// The output to use as standard output.
+ /// Optional name of the server, used for diagnostic purposes, like logging.
+ /// Optional logger factory used for logging employed by the transport.
+ /// is .
+ /// is .
+ public StreamServerTransport(Stream inputStream, Stream outputStream, string? serverName = null, ILoggerFactory? loggerFactory = null)
+ : base(loggerFactory)
+ {
+ Throw.IfNull(inputStream);
+ Throw.IfNull(outputStream);
+
+ _logger = loggerFactory?.CreateLogger(GetType()) ?? NullLogger.Instance;
+
+ _inputReader = new StreamReader(inputStream, Encoding.UTF8);
+ _outputStream = outputStream;
+
+ SetConnected(true);
+ _readLoopCompleted = Task.Run(ReadMessagesAsync, _shutdownCts.Token);
+
+ _endpointName = serverName is not null ? $"Server (stream) ({serverName})" : "Server (stream)";
+ }
+
+ ///
+ public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
+ {
+ if (!IsConnected)
+ {
+ _logger.TransportNotConnected(_endpointName);
+ throw new McpTransportException("Transport is not connected");
+ }
+
+ using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false);
+
+ string id = "(no id)";
+ if (message is IJsonRpcMessageWithId messageWithId)
+ {
+ id = messageWithId.Id.ToString();
+ }
+
+ try
+ {
+ _logger.TransportSendingMessage(_endpointName, id);
+
+ await JsonSerializer.SerializeAsync(_outputStream, message, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage)), cancellationToken).ConfigureAwait(false);
+ await _outputStream.WriteAsync(s_newlineBytes, cancellationToken).ConfigureAwait(false);
+ await _outputStream.FlushAsync(cancellationToken).ConfigureAwait(false);;
+
+ _logger.TransportSentMessage(_endpointName, id);
+ }
+ catch (Exception ex)
+ {
+ _logger.TransportSendFailed(_endpointName, id, ex);
+ throw new McpTransportException("Failed to send message", ex);
+ }
+ }
+
+ private async Task ReadMessagesAsync()
+ {
+ CancellationToken shutdownToken = _shutdownCts.Token;
+ try
+ {
+ _logger.TransportEnteringReadMessagesLoop(_endpointName);
+
+ while (!shutdownToken.IsCancellationRequested)
+ {
+ _logger.TransportWaitingForMessage(_endpointName);
+
+ var line = await _inputReader.ReadLineAsync(shutdownToken).ConfigureAwait(false);
+ if (string.IsNullOrWhiteSpace(line))
+ {
+ if (line is null)
+ {
+ _logger.TransportEndOfStream(_endpointName);
+ break;
+ }
+
+ continue;
+ }
+
+ _logger.TransportReceivedMessage(_endpointName, line);
+ _logger.TransportMessageBytesUtf8(_endpointName, line);
+
+ try
+ {
+ if (JsonSerializer.Deserialize(line, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage))) is IJsonRpcMessage message)
+ {
+ string messageId = "(no id)";
+ if (message is IJsonRpcMessageWithId messageWithId)
+ {
+ messageId = messageWithId.Id.ToString();
+ }
+ _logger.TransportReceivedMessageParsed(_endpointName, messageId);
+
+ await WriteMessageAsync(message, shutdownToken).ConfigureAwait(false);
+ _logger.TransportMessageWritten(_endpointName, messageId);
+ }
+ else
+ {
+ _logger.TransportMessageParseUnexpectedType(_endpointName, line);
+ }
+ }
+ catch (JsonException ex)
+ {
+ _logger.TransportMessageParseFailed(_endpointName, line, ex);
+ // Continue reading even if we fail to parse a message
+ }
+ }
+
+ _logger.TransportExitingReadMessagesLoop(_endpointName);
+ }
+ catch (OperationCanceledException)
+ {
+ _logger.TransportReadMessagesCancelled(_endpointName);
+ }
+ catch (Exception ex)
+ {
+ _logger.TransportReadMessagesFailed(_endpointName, ex);
+ }
+ finally
+ {
+ SetConnected(false);
+ }
+ }
+
+ ///
+ public override async ValueTask DisposeAsync()
+ {
+ if (Interlocked.Exchange(ref _disposed, 1) != 0)
+ {
+ return;
+ }
+
+ try
+ {
+ _logger.TransportCleaningUp(_endpointName);
+
+ // Signal to the stdin reading loop to stop.
+ await _shutdownCts.CancelAsync().ConfigureAwait(false);
+ _shutdownCts.Dispose();
+
+ // Dispose of stdin/out. Cancellation may not be able to wake up operations
+ // synchronously blocked in a syscall; we need to forcefully close the handle / file descriptor.
+ _inputReader?.Dispose();
+ _outputStream?.Dispose();
+
+ // Make sure the work has quiesced.
+ try
+ {
+ _logger.TransportWaitingForReadTask(_endpointName);
+ await _readLoopCompleted.ConfigureAwait(false);
+ _logger.TransportReadTaskCleanedUp(_endpointName);
+ }
+ catch (TimeoutException)
+ {
+ _logger.TransportCleanupReadTaskTimeout(_endpointName);
+ }
+ catch (OperationCanceledException)
+ {
+ _logger.TransportCleanupReadTaskCancelled(_endpointName);
+ }
+ catch (Exception ex)
+ {
+ _logger.TransportCleanupReadTaskFailed(_endpointName, ex);
+ }
+ }
+ finally
+ {
+ SetConnected(false);
+ _logger.TransportCleanedUp(_endpointName);
+ }
+
+ GC.SuppressFinalize(this);
+ }
+}
diff --git a/src/ModelContextProtocol/Protocol/Types/CallToolRequestParams.cs b/src/ModelContextProtocol/Protocol/Types/CallToolRequestParams.cs
index b7c8d3f63..08c6e5770 100644
--- a/src/ModelContextProtocol/Protocol/Types/CallToolRequestParams.cs
+++ b/src/ModelContextProtocol/Protocol/Types/CallToolRequestParams.cs
@@ -1,4 +1,6 @@
-namespace ModelContextProtocol.Protocol.Types;
+using System.Text.Json;
+
+namespace ModelContextProtocol.Protocol.Types;
///
/// Used by the client to invoke a tool provided by the server.
@@ -16,5 +18,5 @@ public class CallToolRequestParams : RequestParams
/// Optional arguments to pass to the tool.
///
[System.Text.Json.Serialization.JsonPropertyName("arguments")]
- public Dictionary? Arguments { get; init; }
+ public Dictionary? Arguments { get; init; }
}
diff --git a/src/ModelContextProtocol/Protocol/Types/Capabilities.cs b/src/ModelContextProtocol/Protocol/Types/Capabilities.cs
index f33357783..c0cf41977 100644
--- a/src/ModelContextProtocol/Protocol/Types/Capabilities.cs
+++ b/src/ModelContextProtocol/Protocol/Types/Capabilities.cs
@@ -55,7 +55,7 @@ public class SamplingCapability
/// Gets or sets the handler for sampling requests.
[JsonIgnore]
- public Func>? SamplingHandler { get; set; }
+ public Func, CancellationToken, Task>? SamplingHandler { get; set; }
}
///
diff --git a/src/ModelContextProtocol/Protocol/Types/ContextInclusion.cs b/src/ModelContextProtocol/Protocol/Types/ContextInclusion.cs
index db45dd48f..bb3bae905 100644
--- a/src/ModelContextProtocol/Protocol/Types/ContextInclusion.cs
+++ b/src/ModelContextProtocol/Protocol/Types/ContextInclusion.cs
@@ -6,7 +6,7 @@ namespace ModelContextProtocol.Protocol.Types;
/// A request to include context from one or more MCP servers (including the caller), to be attached to the prompt.
/// See the schema for details
///
-[JsonConverter(typeof(JsonStringEnumConverter))]
+[JsonConverter(typeof(CustomizableJsonStringEnumConverter))]
public enum ContextInclusion
{
///
diff --git a/src/ModelContextProtocol/Protocol/Types/ListPromptsRequestParams.cs b/src/ModelContextProtocol/Protocol/Types/ListPromptsRequestParams.cs
index 419b6fceb..a5500d410 100644
--- a/src/ModelContextProtocol/Protocol/Types/ListPromptsRequestParams.cs
+++ b/src/ModelContextProtocol/Protocol/Types/ListPromptsRequestParams.cs
@@ -4,12 +4,4 @@
/// Sent from the client to request a list of prompts and prompt templates the server has.
/// See the schema for details
///
-public class ListPromptsRequestParams
-{
- ///
- /// An opaque token representing the current pagination position.
- /// If provided, the server should return results starting after this cursor.
- ///
- [System.Text.Json.Serialization.JsonPropertyName("cursor")]
- public string? Cursor { get; init; }
-}
+public class ListPromptsRequestParams : PaginatedRequestParams;
diff --git a/src/ModelContextProtocol/Protocol/Types/ListResourceTemplatesRequestParams.cs b/src/ModelContextProtocol/Protocol/Types/ListResourceTemplatesRequestParams.cs
index f4060dbd0..8a54f6e8e 100644
--- a/src/ModelContextProtocol/Protocol/Types/ListResourceTemplatesRequestParams.cs
+++ b/src/ModelContextProtocol/Protocol/Types/ListResourceTemplatesRequestParams.cs
@@ -4,12 +4,4 @@
/// Sent from the client to request a list of resource templates the server has.
/// See the schema for details
///
-public class ListResourceTemplatesRequestParams
-{
- ///
- /// An opaque token representing the current pagination position.
- /// If provided, the server should return results starting after this cursor.
- ///
- [System.Text.Json.Serialization.JsonPropertyName("cursor")]
- public string? Cursor { get; init; }
-}
\ No newline at end of file
+public class ListResourceTemplatesRequestParams : PaginatedRequestParams;
\ No newline at end of file
diff --git a/src/ModelContextProtocol/Protocol/Types/ListResourcesRequestParams.cs b/src/ModelContextProtocol/Protocol/Types/ListResourcesRequestParams.cs
index ad7f19b31..30bea5b87 100644
--- a/src/ModelContextProtocol/Protocol/Types/ListResourcesRequestParams.cs
+++ b/src/ModelContextProtocol/Protocol/Types/ListResourcesRequestParams.cs
@@ -4,12 +4,4 @@
/// Sent from the client to request a list of resources the server has.
/// See the schema for details
///
-public class ListResourcesRequestParams
-{
- ///
- /// An opaque token representing the current pagination position.
- /// If provided, the server should return results starting after this cursor.
- ///
- [System.Text.Json.Serialization.JsonPropertyName("cursor")]
- public string? Cursor { get; init; }
-}
+public class ListResourcesRequestParams : PaginatedRequestParams;
diff --git a/src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs b/src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs
index dae1b75c1..a5eec7a15 100644
--- a/src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs
+++ b/src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs
@@ -6,11 +6,4 @@ namespace ModelContextProtocol.Protocol.Types;
/// A request from the server to get a list of root URIs from the client.
/// See the schema for details
///
-public class ListRootsRequestParams
-{
- ///
- /// Optional progress token for out-of-band progress notifications.
- ///
- [System.Text.Json.Serialization.JsonPropertyName("progressToken")]
- public ProgressToken? ProgressToken { get; init; }
-}
+public class ListRootsRequestParams : RequestParams;
diff --git a/src/ModelContextProtocol/Protocol/Types/ListToolsRequestParams.cs b/src/ModelContextProtocol/Protocol/Types/ListToolsRequestParams.cs
index 4f18fbb73..64ac18599 100644
--- a/src/ModelContextProtocol/Protocol/Types/ListToolsRequestParams.cs
+++ b/src/ModelContextProtocol/Protocol/Types/ListToolsRequestParams.cs
@@ -4,12 +4,4 @@
/// Sent from the client to request a list of tools the server has.
/// See the schema for details
///
-public class ListToolsRequestParams
-{
- ///
- /// An opaque token representing the current pagination position.
- /// If provided, the server should return results starting after this cursor.
- ///
- [System.Text.Json.Serialization.JsonPropertyName("cursor")]
- public string? Cursor { get; init; }
-}
+public class ListToolsRequestParams : PaginatedRequestParams;
diff --git a/src/ModelContextProtocol/Protocol/Types/LoggingLevel.cs b/src/ModelContextProtocol/Protocol/Types/LoggingLevel.cs
index df8c4c75a..8098dbbd3 100644
--- a/src/ModelContextProtocol/Protocol/Types/LoggingLevel.cs
+++ b/src/ModelContextProtocol/Protocol/Types/LoggingLevel.cs
@@ -7,7 +7,7 @@ namespace ModelContextProtocol.Protocol.Types;
/// These map to syslog message severities, as specified in RFC-5424:
/// https://datatracker.ietf.org/doc/html/rfc5424#section-6.2.1
///
-[JsonConverter(typeof(JsonStringEnumConverter))]
+[JsonConverter(typeof(CustomizableJsonStringEnumConverter))]
public enum LoggingLevel
{
/// Detailed debug information, typically only valuable to developers.
diff --git a/src/ModelContextProtocol/Protocol/Types/LoggingMessageNotificationParams.cs b/src/ModelContextProtocol/Protocol/Types/LoggingMessageNotificationParams.cs
index 8c153f254..072992cad 100644
--- a/src/ModelContextProtocol/Protocol/Types/LoggingMessageNotificationParams.cs
+++ b/src/ModelContextProtocol/Protocol/Types/LoggingMessageNotificationParams.cs
@@ -24,7 +24,7 @@ public class LoggingMessageNotificationParams
public string? Logger { get; init; }
///
- /// The data to be logged, such as a string message or an object. Any JSON serializable type is allowed here.
+ /// The data to be logged, such as a string message or an object.
///
[JsonPropertyName("data")]
public JsonElement? Data { get; init; }
diff --git a/src/ModelContextProtocol/Protocol/Types/PaginatedRequest.cs b/src/ModelContextProtocol/Protocol/Types/PaginatedRequest.cs
new file mode 100644
index 000000000..abf47dd3c
--- /dev/null
+++ b/src/ModelContextProtocol/Protocol/Types/PaginatedRequest.cs
@@ -0,0 +1,15 @@
+namespace ModelContextProtocol.Protocol.Types;
+
+///
+/// Used as a base class for paginated requests.
+/// See the schema for details
+///
+public class PaginatedRequestParams : RequestParams
+{
+ ///
+ /// An opaque token representing the current pagination position.
+ /// If provided, the server should return results starting after this cursor.
+ ///
+ [System.Text.Json.Serialization.JsonPropertyName("cursor")]
+ public string? Cursor { get; init; }
+}
\ No newline at end of file
diff --git a/src/ModelContextProtocol/Protocol/Types/ResourceContents.cs b/src/ModelContextProtocol/Protocol/Types/ResourceContents.cs
index 1daece16b..0d70f8b43 100644
--- a/src/ModelContextProtocol/Protocol/Types/ResourceContents.cs
+++ b/src/ModelContextProtocol/Protocol/Types/ResourceContents.cs
@@ -61,6 +61,8 @@ public class Converter : JsonConverter
}
string? propertyName = reader.GetString();
+ bool success = reader.Read();
+ Debug.Assert(success, "STJ must have buffered the entire object for us.");
switch (propertyName)
{
diff --git a/src/ModelContextProtocol/Protocol/Types/Role.cs b/src/ModelContextProtocol/Protocol/Types/Role.cs
index 1cb35ea5b..c025f61ad 100644
--- a/src/ModelContextProtocol/Protocol/Types/Role.cs
+++ b/src/ModelContextProtocol/Protocol/Types/Role.cs
@@ -6,7 +6,7 @@ namespace ModelContextProtocol.Protocol.Types;
/// Represents the type of role in the conversation.
/// See the schema for details
///
-[JsonConverter(typeof(JsonStringEnumConverter))]
+[JsonConverter(typeof(CustomizableJsonStringEnumConverter))]
public enum Role
{
///
diff --git a/src/ModelContextProtocol/Protocol/Types/Tool.cs b/src/ModelContextProtocol/Protocol/Types/Tool.cs
index ed0c71290..dc0b774c0 100644
--- a/src/ModelContextProtocol/Protocol/Types/Tool.cs
+++ b/src/ModelContextProtocol/Protocol/Types/Tool.cs
@@ -38,7 +38,7 @@ public JsonElement InputSchema
{
if (!McpJsonUtilities.IsValidMcpToolSchema(value))
{
- throw new ArgumentException("The specified document is not a valid MPC tool JSON schema.", nameof(InputSchema));
+ throw new ArgumentException("The specified document is not a valid MCP tool JSON schema.", nameof(InputSchema));
}
_inputSchema = value;
diff --git a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs
index 47b8514de..0f33c8e41 100644
--- a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs
+++ b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs
@@ -1,7 +1,6 @@
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
using ModelContextProtocol.Protocol.Types;
-using ModelContextProtocol.Shared;
using ModelContextProtocol.Utils;
using ModelContextProtocol.Utils.Json;
using System.Diagnostics.CodeAnalysis;
@@ -256,8 +255,8 @@ public override async Task InvokeAsync(
cancellationToken.ThrowIfCancellationRequested();
// TODO: Once we shift to the real AIFunctionFactory, the request should be passed via AIFunctionArguments.Context.
- Dictionary arguments = request.Params?.Arguments is IDictionary existingArgs ?
- new(existingArgs) :
+ Dictionary arguments = request.Params?.Arguments is { } paramArgs ?
+ paramArgs.ToDictionary(entry => entry.Key, entry => entry.Value.AsObject()) :
[];
arguments[RequestContextKey] = request;
diff --git a/src/ModelContextProtocol/Server/IMcpServer.cs b/src/ModelContextProtocol/Server/IMcpServer.cs
index e8dffaf19..19b3967ad 100644
--- a/src/ModelContextProtocol/Server/IMcpServer.cs
+++ b/src/ModelContextProtocol/Server/IMcpServer.cs
@@ -1,12 +1,11 @@
-using ModelContextProtocol.Protocol.Messages;
-using ModelContextProtocol.Protocol.Types;
+using ModelContextProtocol.Protocol.Types;
namespace ModelContextProtocol.Server;
///
/// Represents a server that can communicate with a client using the MCP protocol.
///
-public interface IMcpServer : IAsyncDisposable
+public interface IMcpServer : IMcpEndpoint
{
///
/// Gets the capabilities supported by the client.
@@ -26,42 +25,8 @@ public interface IMcpServer : IAsyncDisposable
///
IServiceProvider? Services { get; }
- ///
- /// Adds a handler for client notifications of a specific method.
- ///
- /// The notification method to handle.
- /// The async handler function to process notifications.
- ///
- ///
- /// Each method may have multiple handlers. Adding a handler for a method that already has one
- /// will not replace the existing handler.
- ///
- ///
- /// provides constants for common notification methods.
- ///
- ///
- void AddNotificationHandler(string method, Func handler);
-
///
/// Runs the server, listening for and handling client requests.
///
Task RunAsync(CancellationToken cancellationToken = default);
-
- ///
- /// Sends a generic JSON-RPC request to the client.
- /// NB! This is a temporary method that is available to send not yet implemented feature messages.
- /// Once all MCP features are implemented this will be made private, as it is purely a convenience for those who wish to implement features ahead of the library.
- ///
- /// The expected response type.
- /// The JSON-RPC request to send.
- /// A token to cancel the operation.
- /// A task containing the client's response.
- Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) where TResult : class;
-
- ///
- /// Sends a message to the client.
- ///
- /// The message.
- /// A token to cancel the operation.
- Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default);
}
diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs
index b95483196..499f2efa8 100644
--- a/src/ModelContextProtocol/Server/McpServer.cs
+++ b/src/ModelContextProtocol/Server/McpServer.cs
@@ -1,22 +1,21 @@
using Microsoft.Extensions.Logging;
-using ModelContextProtocol.Logging;
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Shared;
using ModelContextProtocol.Utils;
-using System.Text.Json.Nodes;
+using ModelContextProtocol.Utils.Json;
namespace ModelContextProtocol.Server;
///
-internal sealed class McpServer : McpJsonRpcEndpoint, IMcpServer
+internal sealed class McpServer : McpEndpoint, IMcpServer
{
private readonly EventHandler? _toolsChangedDelegate;
private readonly EventHandler? _promptsChangedDelegate;
- private ITransport _sessionTransport;
private string _endpointName;
+ private int _started;
///
/// Creates a new instance of .
@@ -33,7 +32,6 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory?
Throw.IfNull(transport);
Throw.IfNull(options);
- _sessionTransport = transport;
ServerOptions = options;
Services = serviceProvider;
_endpointName = $"Server ({options.ServerInfo.Name} {options.ServerInfo.Version})";
@@ -69,13 +67,14 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory?
});
SetToolsHandler(options);
-
SetInitializeHandler(options);
SetCompletionHandler(options);
SetPingHandler();
SetPromptsHandler(options);
SetResourcesHandler(options);
SetSetLoggingLevelHandler(options);
+
+ StartSession(transport);
}
public ServerCapabilities? ServerCapabilities { get; set; }
@@ -98,11 +97,16 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory?
///
public async Task RunAsync(CancellationToken cancellationToken = default)
{
+ if (Interlocked.Exchange(ref _started, 1) != 0)
+ {
+ throw new InvalidOperationException($"{nameof(RunAsync)} must only be called once.");
+ }
+
try
{
- // Start processing messages
- StartSession(_sessionTransport, fullSessionCancellationToken: cancellationToken);
- await MessageProcessingTask.ConfigureAwait(false);
+ using var _ = cancellationToken.Register(static s => ((McpServer)s!).CancelSession(), this);
+ // The McpServer ctor always calls StartSession, so MessageProcessingTask is always set.
+ await MessageProcessingTask!.ConfigureAwait(false);
}
finally
{
@@ -127,13 +131,15 @@ public override async ValueTask DisposeUnsynchronizedAsync()
private void SetPingHandler()
{
- SetRequestHandler(RequestMethods.Ping,
- (request, _) => Task.FromResult(new PingResult()));
+ SetRequestHandler(RequestMethods.Ping,
+ (request, _) => Task.FromResult(new PingResult()),
+ McpJsonUtilities.JsonContext.Default.JsonNode,
+ McpJsonUtilities.JsonContext.Default.PingResult);
}
private void SetInitializeHandler(McpServerOptions options)
{
- SetRequestHandler(RequestMethods.Initialize,
+ SetRequestHandler(RequestMethods.Initialize,
(request, _) =>
{
ClientCapabilities = request?.Capabilities ?? new();
@@ -143,23 +149,27 @@ private void SetInitializeHandler(McpServerOptions options)
_endpointName = $"{_endpointName}, Client ({ClientInfo?.Name} {ClientInfo?.Version})";
GetSessionOrThrow().EndpointName = _endpointName;
- return Task.FromResult(new InitializeResult()
+ return Task.FromResult(new InitializeResult
{
ProtocolVersion = options.ProtocolVersion,
Instructions = options.ServerInstructions,
ServerInfo = options.ServerInfo,
Capabilities = ServerCapabilities ?? new(),
});
- });
+ },
+ McpJsonUtilities.JsonContext.Default.InitializeRequestParams,
+ McpJsonUtilities.JsonContext.Default.InitializeResult);
}
private void SetCompletionHandler(McpServerOptions options)
{
// This capability is not optional, so return an empty result if there is no handler.
- SetRequestHandler(RequestMethods.CompletionComplete,
+ SetRequestHandler(RequestMethods.CompletionComplete,
options.GetCompletionHandler is { } handler ?
(request, ct) => handler(new(this, request), ct) :
- (request, ct) => Task.FromResult(new CompleteResult() { Completion = new() { Values = [], Total = 0, HasMore = false } }));
+ (request, ct) => Task.FromResult(new CompleteResult() { Completion = new() { Values = [], Total = 0, HasMore = false } }),
+ McpJsonUtilities.JsonContext.Default.CompleteRequestParams,
+ McpJsonUtilities.JsonContext.Default.CompleteResult);
}
private void SetResourcesHandler(McpServerOptions options)
@@ -180,11 +190,24 @@ private void SetResourcesHandler(McpServerOptions options)
listResourcesHandler ??= (static (_, _) => Task.FromResult(new ListResourcesResult()));
- SetRequestHandler(RequestMethods.ResourcesList, (request, ct) => listResourcesHandler(new(this, request), ct));
- SetRequestHandler(RequestMethods.ResourcesRead, (request, ct) => readResourceHandler(new(this, request), ct));
+ SetRequestHandler(
+ RequestMethods.ResourcesList,
+ (request, ct) => listResourcesHandler(new(this, request), ct),
+ McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams,
+ McpJsonUtilities.JsonContext.Default.ListResourcesResult);
+
+ SetRequestHandler(
+ RequestMethods.ResourcesRead,
+ (request, ct) => readResourceHandler(new(this, request), ct),
+ McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams,
+ McpJsonUtilities.JsonContext.Default.ReadResourceResult);
listResourceTemplatesHandler ??= (static (_, _) => Task.FromResult(new ListResourceTemplatesResult()));
- SetRequestHandler(RequestMethods.ResourcesTemplatesList, (request, ct) => listResourceTemplatesHandler(new(this, request), ct));
+ SetRequestHandler(
+ RequestMethods.ResourcesTemplatesList,
+ (request, ct) => listResourceTemplatesHandler(new(this, request), ct),
+ McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams,
+ McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult);
if (resourcesCapability.Subscribe is not true)
{
@@ -198,8 +221,17 @@ private void SetResourcesHandler(McpServerOptions options)
throw new McpServerException("Resources capability was enabled with subscribe support, but SubscribeToResources and/or UnsubscribeFromResources handlers were not specified.");
}
- SetRequestHandler(RequestMethods.ResourcesSubscribe, (request, ct) => subscribeHandler(new(this, request), ct));
- SetRequestHandler(RequestMethods.ResourcesUnsubscribe, (request, ct) => unsubscribeHandler(new(this, request), ct));
+ SetRequestHandler(
+ RequestMethods.ResourcesSubscribe,
+ (request, ct) => subscribeHandler(new(this, request), ct),
+ McpJsonUtilities.JsonContext.Default.SubscribeRequestParams,
+ McpJsonUtilities.JsonContext.Default.EmptyResult);
+
+ SetRequestHandler(
+ RequestMethods.ResourcesUnsubscribe,
+ (request, ct) => unsubscribeHandler(new(this, request), ct),
+ McpJsonUtilities.JsonContext.Default.UnsubscribeRequestParams,
+ McpJsonUtilities.JsonContext.Default.EmptyResult);
}
private void SetPromptsHandler(McpServerOptions options)
@@ -214,41 +246,26 @@ private void SetPromptsHandler(McpServerOptions options)
throw new McpServerException("ListPrompts and GetPrompt handlers should be specified together.");
}
- // Handle tools provided via DI.
+ // Handle prompts provided via DI.
if (prompts is { IsEmpty: false })
{
+ // Synthesize the handlers, making sure a PromptsCapability is specified.
var originalListPromptsHandler = listPromptsHandler;
- var originalGetPromptHandler = getPromptHandler;
-
- // Synthesize the handlers, making sure a ToolsCapability is specified.
listPromptsHandler = async (request, cancellationToken) =>
{
- ListPromptsResult result = new();
- foreach (McpServerPrompt prompt in prompts)
- {
- result.Prompts.Add(prompt.ProtocolPrompt);
- }
+ ListPromptsResult result = originalListPromptsHandler is not null ?
+ await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(false) :
+ new();
- if (originalListPromptsHandler is not null)
+ if (request.Params?.Cursor is null)
{
- string? nextCursor = null;
- do
- {
- ListPromptsResult extraResults = await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(false);
- result.Prompts.AddRange(extraResults.Prompts);
-
- nextCursor = extraResults.NextCursor;
- if (nextCursor is not null)
- {
- request = request with { Params = new() { Cursor = nextCursor } };
- }
- }
- while (nextCursor is not null);
+ result.Prompts.AddRange(prompts.Select(t => t.ProtocolPrompt));
}
return result;
};
+ var originalGetPromptHandler = getPromptHandler;
getPromptHandler = (request, cancellationToken) =>
{
if (request.Params is null ||
@@ -297,8 +314,17 @@ private void SetPromptsHandler(McpServerOptions options)
}
}
- SetRequestHandler(RequestMethods.PromptsList, (request, ct) => listPromptsHandler(new(this, request), ct));
- SetRequestHandler(RequestMethods.PromptsGet, (request, ct) => getPromptHandler(new(this, request), ct));
+ SetRequestHandler(
+ RequestMethods.PromptsList,
+ (request, ct) => listPromptsHandler(new(this, request), ct),
+ McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams,
+ McpJsonUtilities.JsonContext.Default.ListPromptsResult);
+
+ SetRequestHandler(
+ RequestMethods.PromptsGet,
+ (request, ct) => getPromptHandler(new(this, request), ct),
+ McpJsonUtilities.JsonContext.Default.GetPromptRequestParams,
+ McpJsonUtilities.JsonContext.Default.GetPromptResult);
}
private void SetToolsHandler(McpServerOptions options)
@@ -316,38 +342,23 @@ private void SetToolsHandler(McpServerOptions options)
// Handle tools provided via DI.
if (tools is { IsEmpty: false })
{
- var originalListToolsHandler = listToolsHandler;
- var originalCallToolHandler = callToolHandler;
-
// Synthesize the handlers, making sure a ToolsCapability is specified.
+ var originalListToolsHandler = listToolsHandler;
listToolsHandler = async (request, cancellationToken) =>
{
- ListToolsResult result = new();
- foreach (McpServerTool tool in tools)
- {
- result.Tools.Add(tool.ProtocolTool);
- }
+ ListToolsResult result = originalListToolsHandler is not null ?
+ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) :
+ new();
- if (originalListToolsHandler is not null)
+ if (request.Params?.Cursor is null)
{
- string? nextCursor = null;
- do
- {
- ListToolsResult extraResults = await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false);
- result.Tools.AddRange(extraResults.Tools);
-
- nextCursor = extraResults.NextCursor;
- if (nextCursor is not null)
- {
- request = request with { Params = new() { Cursor = nextCursor } };
- }
- }
- while (nextCursor is not null);
+ result.Tools.AddRange(tools.Select(t => t.ProtocolTool));
}
return result;
};
+ var originalCallToolHandler = callToolHandler;
callToolHandler = (request, cancellationToken) =>
{
if (request.Params is null ||
@@ -396,8 +407,17 @@ private void SetToolsHandler(McpServerOptions options)
}
}
- SetRequestHandler(RequestMethods.ToolsList, (request, ct) => listToolsHandler(new(this, request), ct));
- SetRequestHandler(RequestMethods.ToolsCall, (request, ct) => callToolHandler(new(this, request), ct));
+ SetRequestHandler(
+ RequestMethods.ToolsList,
+ (request, ct) => listToolsHandler(new(this, request), ct),
+ McpJsonUtilities.JsonContext.Default.ListToolsRequestParams,
+ McpJsonUtilities.JsonContext.Default.ListToolsResult);
+
+ SetRequestHandler(
+ RequestMethods.ToolsCall,
+ (request, ct) => callToolHandler(new(this, request), ct),
+ McpJsonUtilities.JsonContext.Default.CallToolRequestParams,
+ McpJsonUtilities.JsonContext.Default.CallToolResponse);
}
private void SetSetLoggingLevelHandler(McpServerOptions options)
@@ -412,6 +432,10 @@ private void SetSetLoggingLevelHandler(McpServerOptions options)
throw new McpServerException("Logging capability was enabled, but SetLoggingLevelHandler was not specified.");
}
- SetRequestHandler(RequestMethods.LoggingSetLevel, (request, ct) => setLoggingLevelHandler(new(this, request), ct));
+ SetRequestHandler(
+ RequestMethods.LoggingSetLevel,
+ (request, ct) => setLoggingLevelHandler(new(this, request), ct),
+ McpJsonUtilities.JsonContext.Default.SetLevelRequestParams,
+ McpJsonUtilities.JsonContext.Default.EmptyResult);
}
}
\ No newline at end of file
diff --git a/src/ModelContextProtocol/Server/McpServerExtensions.cs b/src/ModelContextProtocol/Server/McpServerExtensions.cs
index 3b541ec80..9b160d4c7 100644
--- a/src/ModelContextProtocol/Server/McpServerExtensions.cs
+++ b/src/ModelContextProtocol/Server/McpServerExtensions.cs
@@ -1,13 +1,15 @@
-using ModelContextProtocol.Protocol.Messages;
+using Microsoft.Extensions.AI;
+using ModelContextProtocol.Client;
+using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Utils;
-using Microsoft.Extensions.AI;
+using ModelContextProtocol.Utils.Json;
using System.Runtime.CompilerServices;
using System.Text;
namespace ModelContextProtocol.Server;
-///
+/// Provides extension methods for interacting with an .
public static class McpServerExtensions
{
///
@@ -25,9 +27,12 @@ public static Task RequestSamplingAsync(
throw new ArgumentException("Client connected to the server does not support sampling.", nameof(server));
}
- return server.SendRequestAsync(
- new JsonRpcRequest { Method = RequestMethods.SamplingCreateMessage, Params = request },
- cancellationToken);
+ return server.SendRequestAsync(
+ RequestMethods.SamplingCreateMessage,
+ request,
+ McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams,
+ McpJsonUtilities.JsonContext.Default.CreateMessageResult,
+ cancellationToken: cancellationToken);
}
///
@@ -164,9 +169,12 @@ public static Task RequestRootsAsync(
throw new ArgumentException("Client connected to the server does not support roots.", nameof(server));
}
- return server.SendRequestAsync(
- new JsonRpcRequest { Method = RequestMethods.RootsList, Params = request },
- cancellationToken);
+ return server.SendRequestAsync(
+ RequestMethods.RootsList,
+ request,
+ McpJsonUtilities.JsonContext.Default.ListRootsRequestParams,
+ McpJsonUtilities.JsonContext.Default.ListRootsResult,
+ cancellationToken: cancellationToken);
}
/// Provides an implementation that's implemented via client sampling.
diff --git a/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.cs b/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.cs
index c79a09ed5..25bffe5ed 100644
--- a/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.cs
+++ b/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.cs
@@ -5,6 +5,7 @@
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Server;
using ModelContextProtocol.Utils;
+using ModelContextProtocol.Utils.Json;
using System.Collections.Concurrent;
using System.ComponentModel;
using System.Diagnostics;
@@ -132,9 +133,9 @@ public static AIFunction Create(MethodInfo method, object? target, TemporaryAIFu
///
///
/// Return values are serialized to using 's
- /// . Arguments that are not already of the expected type are
+ /// . Arguments that are not already of the expected type are
/// marshaled to the expected type via JSON and using 's
- /// . If the argument is a ,
+ /// . If the argument is a ,
/// , or , it is deserialized directly. If the argument is anything else unknown,
/// it is round-tripped through JSON, serializing the object as JSON and then deserializing it to the expected type.
///
@@ -476,7 +477,7 @@ private static JsonElement CreateFunctionJsonSchema(
description: parameter.GetCustomAttribute(inherit: true)?.Description,
hasDefaultValue: parameter.HasDefaultValue,
defaultValue: parameter.HasDefaultValue ? parameter.DefaultValue : null,
- serializerOptions), AIJsonUtilities.DefaultOptions.GetTypeInfo());
+ serializerOptions), McpJsonUtilities.JsonContext.Default.JsonElement);
parameterSchemas.Add(parameter.Name, parameterSchema);
if (!parameter.IsOptional)
diff --git a/src/ModelContextProtocol/Server/TemporaryAIFunctionFactoryOptions.cs b/src/ModelContextProtocol/Server/TemporaryAIFunctionFactoryOptions.cs
index e1f712d1f..91403b3b9 100644
--- a/src/ModelContextProtocol/Server/TemporaryAIFunctionFactoryOptions.cs
+++ b/src/ModelContextProtocol/Server/TemporaryAIFunctionFactoryOptions.cs
@@ -72,7 +72,7 @@ public TemporaryAIFunctionFactoryOptions()
/// Gets or sets a delegate used to determine the returned by .
///
///
- /// By default, the return value of invoking the method wrapped into an by
+ /// By default, the return value of invoking the method wrapped into an by
/// is then JSON serialized, with the resulting returned from the method.
/// This default behavior is ideal for the common case where the result will be passed back to an AI service. However, if the caller
/// requires more control over the result's marshaling, the property may be set to a delegate that is
@@ -82,7 +82,7 @@ public TemporaryAIFunctionFactoryOptions()
///
/// When set, the delegate is invoked even for -returning methods, in which case it is invoked with
/// a argument. By default, is returned from the
- /// method for instances produced by to wrap
+ /// method for instances produced by to wrap
/// -returning methods).
///
///
diff --git a/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs b/src/ModelContextProtocol/Shared/McpEndpoint.cs
similarity index 64%
rename from src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs
rename to src/ModelContextProtocol/Shared/McpEndpoint.cs
index ac0337f9e..8b50a8052 100644
--- a/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs
+++ b/src/ModelContextProtocol/Shared/McpEndpoint.cs
@@ -3,8 +3,10 @@
using ModelContextProtocol.Logging;
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Transport;
+using ModelContextProtocol.Server;
using ModelContextProtocol.Utils;
using System.Diagnostics.CodeAnalysis;
+using System.Text.Json.Serialization.Metadata;
namespace ModelContextProtocol.Shared;
@@ -15,14 +17,13 @@ namespace ModelContextProtocol.Shared;
/// This is especially true as a client represents a connection to one and only one server, and vice versa.
/// Any multi-client or multi-server functionality should be implemented at a higher level of abstraction.
///
-internal abstract class McpJsonRpcEndpoint : IAsyncDisposable
+internal abstract class McpEndpoint : IAsyncDisposable
{
private readonly RequestHandlers _requestHandlers = [];
private readonly NotificationHandlers _notificationHandlers = [];
private McpSession? _session;
private CancellationTokenSource? _sessionCts;
- private int _started;
private readonly SemaphoreSlim _disposeLock = new(1, 1);
private bool _disposed;
@@ -30,22 +31,27 @@ internal abstract class McpJsonRpcEndpoint : IAsyncDisposable
protected readonly ILogger _logger;
///
- /// Initializes a new instance of the class.
+ /// Initializes a new instance of the class.
///
/// The logger factory.
- protected McpJsonRpcEndpoint(ILoggerFactory? loggerFactory = null)
+ protected McpEndpoint(ILoggerFactory? loggerFactory = null)
{
_logger = loggerFactory?.CreateLogger(GetType()) ?? NullLogger.Instance;
}
- protected void SetRequestHandler(string method, Func> handler)
- => _requestHandlers.Set(method, handler);
+ protected void SetRequestHandler(
+ string method,
+ Func> handler,
+ JsonTypeInfo requestTypeInfo,
+ JsonTypeInfo responseTypeInfo)
+
+ => _requestHandlers.Set(method, handler, requestTypeInfo, responseTypeInfo);
public void AddNotificationHandler(string method, Func handler)
=> _notificationHandlers.Add(method, handler);
- public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) where TResult : class
- => GetSessionOrThrow().SendRequestAsync(request, cancellationToken);
+ public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default)
+ => GetSessionOrThrow().SendRequestAsync(request, cancellationToken);
public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
=> GetSessionOrThrow().SendMessageAsync(message, cancellationToken);
@@ -58,21 +64,18 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella
///
/// Task that processes incoming messages from the transport.
///
- protected Task? MessageProcessingTask { get; set; }
+ protected Task? MessageProcessingTask { get; private set; }
[MemberNotNull(nameof(MessageProcessingTask))]
- protected void StartSession(ITransport sessionTransport, CancellationToken fullSessionCancellationToken = default)
+ protected void StartSession(ITransport sessionTransport)
{
- if (Interlocked.Exchange(ref _started, 1) != 0)
- {
- throw new InvalidOperationException("The MCP session has already stared.");
- }
-
- _sessionCts = CancellationTokenSource.CreateLinkedTokenSource(fullSessionCancellationToken);
- _session = new McpSession(sessionTransport, EndpointName, _requestHandlers, _notificationHandlers, _logger);
+ _sessionCts = new CancellationTokenSource();
+ _session = new McpSession(this is IMcpServer, sessionTransport, EndpointName, _requestHandlers, _notificationHandlers, _logger);
MessageProcessingTask = _session.ProcessMessagesAsync(_sessionCts.Token);
}
+ protected void CancelSession() => _sessionCts?.Cancel();
+
public async ValueTask DisposeAsync()
{
using var _ = await _disposeLock.LockAsync().ConfigureAwait(false);
@@ -94,25 +97,30 @@ public virtual async ValueTask DisposeUnsynchronizedAsync()
{
_logger.CleaningUpEndpoint(EndpointName);
- if (_sessionCts is not null)
- {
- await _sessionCts.CancelAsync().ConfigureAwait(false);
- }
-
- if (MessageProcessingTask is not null)
+ try
{
- try
+ if (_sessionCts is not null)
{
- await MessageProcessingTask.ConfigureAwait(false);
+ await _sessionCts.CancelAsync().ConfigureAwait(false);
}
- catch (OperationCanceledException)
+
+ if (MessageProcessingTask is not null)
{
- // Ignore cancellation
+ try
+ {
+ await MessageProcessingTask.ConfigureAwait(false);
+ }
+ catch (OperationCanceledException)
+ {
+ // Ignore cancellation
+ }
}
}
-
- _session?.Dispose();
- _sessionCts?.Dispose();
+ finally
+ {
+ _session?.Dispose();
+ _sessionCts?.Dispose();
+ }
_logger.EndpointCleanedUp(EndpointName);
}
diff --git a/src/ModelContextProtocol/Shared/McpSession.cs b/src/ModelContextProtocol/Shared/McpSession.cs
index 97cbcb592..d5e4f930d 100644
--- a/src/ModelContextProtocol/Shared/McpSession.cs
+++ b/src/ModelContextProtocol/Shared/McpSession.cs
@@ -4,10 +4,15 @@
using ModelContextProtocol.Logging;
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Transport;
+using ModelContextProtocol.Server;
using ModelContextProtocol.Utils;
using ModelContextProtocol.Utils.Json;
using System.Collections.Concurrent;
+using System.Diagnostics;
+using System.Diagnostics.Metrics;
using System.Text.Json;
+using System.Text.Json.Nodes;
+using System.Threading.Channels;
namespace ModelContextProtocol.Shared;
@@ -16,9 +21,21 @@ namespace ModelContextProtocol.Shared;
///
internal sealed class McpSession : IDisposable
{
+ private static readonly Histogram s_clientSessionDuration = Diagnostics.CreateDurationHistogram(
+ "mcp.client.session.duration", "Measures the duration of a client session.", longBuckets: true);
+ private static readonly Histogram s_serverSessionDuration = Diagnostics.CreateDurationHistogram(
+ "mcp.server.session.duration", "Measures the duration of a server session.", longBuckets: true);
+ private static readonly Histogram s_clientRequestDuration = Diagnostics.CreateDurationHistogram(
+ "rpc.client.duration", "Measures the duration of outbound RPC.", longBuckets: false);
+ private static readonly Histogram s_serverRequestDuration = Diagnostics.CreateDurationHistogram(
+ "rpc.server.duration", "Measures the duration of inbound RPC.", longBuckets: false);
+
+ private readonly bool _isServer;
+ private readonly string _transportKind;
private readonly ITransport _transport;
private readonly RequestHandlers _requestHandlers;
private readonly NotificationHandlers _notificationHandlers;
+ private readonly long _sessionStartingTimestamp = Stopwatch.GetTimestamp();
/// Collection of requests sent on this session and waiting for responses.
private readonly ConcurrentDictionary> _pendingRequests = [];
@@ -27,21 +44,22 @@ internal sealed class McpSession : IDisposable
/// that can be used to request cancellation of the in-flight handler.
///
private readonly ConcurrentDictionary _handlingRequests = new();
- private readonly JsonSerializerOptions _jsonOptions;
private readonly ILogger _logger;
-
+
private readonly string _id = Guid.NewGuid().ToString("N");
private long _nextRequestId;
///
/// Initializes a new instance of the class.
///
+ /// true if this is a server; false if it's a client.
/// An MCP transport implementation.
/// The name of the endpoint for logging and debug purposes.
/// A collection of request handlers.
/// A collection of notification handlers.
/// The logger.
public McpSession(
+ bool isServer,
ITransport transport,
string endpointName,
RequestHandlers requestHandlers,
@@ -50,11 +68,19 @@ public McpSession(
{
Throw.IfNull(transport);
+ _transportKind = transport switch
+ {
+ StdioClientSessionTransport or StdioServerTransport => "stdio",
+ StreamClientSessionTransport or StreamServerTransport => "stream",
+ SseClientSessionTransport or SseResponseStreamTransport => "sse",
+ _ => "unknownTransport"
+ };
+
+ _isServer = isServer;
_transport = transport;
EndpointName = endpointName;
_requestHandlers = requestHandlers;
_notificationHandlers = notificationHandlers;
- _jsonOptions = McpJsonUtilities.DefaultOptions;
_logger = logger ?? NullLogger.Instance;
}
@@ -121,14 +147,14 @@ await _transport.SendMessageAsync(new JsonRpcError
JsonRpc = "2.0",
Error = new JsonRpcErrorDetail
{
- Code = ErrorCodes.InternalError,
+ Code = (ex as McpServerException)?.ErrorCode ?? ErrorCodes.InternalError,
Message = ex.Message
}
}, cancellationToken).ConfigureAwait(false);
}
else if (ex is not OperationCanceledException)
{
- var payload = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo());
+ var payload = JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.IJsonRpcMessage);
_logger.MessageHandlerError(EndpointName, message.GetType().Name, payload, ex);
}
}
@@ -148,27 +174,67 @@ await _transport.SendMessageAsync(new JsonRpcError
// Normal shutdown
_logger.EndpointMessageProcessingCancelled(EndpointName);
}
+ finally
+ {
+ // Fail any pending requests, as they'll never be satisfied.
+ foreach (var entry in _pendingRequests)
+ {
+ entry.Value.TrySetException(new InvalidOperationException("The server shut down unexpectedly."));
+ }
+ }
}
private async Task HandleMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken)
{
- switch (message)
+ Histogram durationMetric = _isServer ? s_serverRequestDuration : s_clientRequestDuration;
+ string method = GetMethodName(message);
+
+ long? startingTimestamp = durationMetric.Enabled ? Stopwatch.GetTimestamp() : null;
+ Activity? activity = Diagnostics.ActivitySource.HasListeners() ?
+ Diagnostics.ActivitySource.StartActivity(CreateActivityName(method)) :
+ null;
+
+ TagList tags = default;
+ bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null;
+ try
{
- case JsonRpcRequest request:
- await HandleRequest(request, cancellationToken).ConfigureAwait(false);
- break;
+ if (addTags)
+ {
+ AddStandardTags(ref tags, method);
+ }
- case IJsonRpcMessageWithId messageWithId:
- HandleMessageWithId(message, messageWithId);
- break;
+ switch (message)
+ {
+ case JsonRpcRequest request:
+ if (addTags)
+ {
+ AddRpcRequestTags(ref tags, activity, request);
+ }
+
+ await HandleRequest(request, cancellationToken).ConfigureAwait(false);
+ break;
- case JsonRpcNotification notification:
- await HandleNotification(notification).ConfigureAwait(false);
- break;
+ case JsonRpcNotification notification:
+ await HandleNotification(notification).ConfigureAwait(false);
+ break;
- default:
- _logger.EndpointHandlerUnexpectedMessageType(EndpointName, message.GetType().Name);
- break;
+ case IJsonRpcMessageWithId messageWithId:
+ HandleMessageWithId(message, messageWithId);
+ break;
+
+ default:
+ _logger.EndpointHandlerUnexpectedMessageType(EndpointName, message.GetType().Name);
+ break;
+ }
+ }
+ catch (Exception e) when (addTags)
+ {
+ AddExceptionTags(ref tags, e);
+ throw;
+ }
+ finally
+ {
+ FinalizeDiagnostics(activity, startingTimestamp, durationMetric, ref tags);
}
}
@@ -212,7 +278,7 @@ private async Task HandleNotification(JsonRpcNotification notification)
private void HandleMessageWithId(IJsonRpcMessage message, IJsonRpcMessageWithId messageWithId)
{
- if (messageWithId.Id.IsDefault)
+ if (messageWithId.Id.Id is null)
{
_logger.RequestHasInvalidId(EndpointName);
}
@@ -229,34 +295,32 @@ private void HandleMessageWithId(IJsonRpcMessage message, IJsonRpcMessageWithId
private async Task HandleRequest(JsonRpcRequest request, CancellationToken cancellationToken)
{
- if (_requestHandlers.TryGetValue(request.Method, out var handler))
- {
- _logger.RequestHandlerCalled(EndpointName, request.Method);
- var result = await handler(request, cancellationToken).ConfigureAwait(false);
- _logger.RequestHandlerCompleted(EndpointName, request.Method);
- await _transport.SendMessageAsync(new JsonRpcResponse
- {
- Id = request.Id,
- JsonRpc = "2.0",
- Result = result
- }, cancellationToken).ConfigureAwait(false);
- }
- else
+ if (!_requestHandlers.TryGetValue(request.Method, out var handler))
{
_logger.NoHandlerFoundForRequest(EndpointName, request.Method);
+ throw new McpServerException("The method does not exist or is not available.", ErrorCodes.MethodNotFound);
}
+
+ _logger.RequestHandlerCalled(EndpointName, request.Method);
+ JsonNode? result = await handler(request, cancellationToken).ConfigureAwait(false);
+ _logger.RequestHandlerCompleted(EndpointName, request.Method);
+ await _transport.SendMessageAsync(new JsonRpcResponse
+ {
+ Id = request.Id,
+ JsonRpc = "2.0",
+ Result = result
+ }, cancellationToken).ConfigureAwait(false);
}
///
- /// Sends a generic JSON-RPC request to the server.
+ /// Sends a JSON-RPC request to the server.
/// It is strongly recommended use the capability-specific methods instead of this one.
/// Use this method for custom requests or those not yet covered explicitly by the endpoint implementation.
///
- /// The expected response type.
/// The JSON-RPC request to send.
/// A token to cancel the operation.
/// A task containing the server's response.
- public async Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken) where TResult : class
+ public async Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken)
{
if (!_transport.IsConnected)
{
@@ -264,21 +328,37 @@ public async Task SendRequestAsync(JsonRpcRequest request, Can
throw new McpClientException("Transport is not connected");
}
+ Histogram durationMetric = _isServer ? s_serverRequestDuration : s_clientRequestDuration;
+ string method = request.Method;
+
+ long? startingTimestamp = durationMetric.Enabled ? Stopwatch.GetTimestamp() : null;
+ using Activity? activity = Diagnostics.ActivitySource.HasListeners() ?
+ Diagnostics.ActivitySource.StartActivity(CreateActivityName(method)) :
+ null;
+
// Set request ID
- if (request.Id.IsDefault)
+ if (request.Id.Id is null)
{
request.Id = new RequestId($"{_id}-{Interlocked.Increment(ref _nextRequestId)}");
}
+ TagList tags = default;
+ bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null;
+
var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
_pendingRequests[request.Id] = tcs;
-
try
{
+ if (addTags)
+ {
+ AddStandardTags(ref tags, method);
+ AddRpcRequestTags(ref tags, activity, request);
+ }
+
// Expensive logging, use the logging framework to check if the logger is enabled
if (_logger.IsEnabled(LogLevel.Debug))
{
- _logger.SendingRequestPayload(EndpointName, JsonSerializer.Serialize(request, _jsonOptions.GetTypeInfo()));
+ _logger.SendingRequestPayload(EndpointName, JsonSerializer.Serialize(request, McpJsonUtilities.JsonContext.Default.JsonRpcRequest));
}
// Less expensive information logging
@@ -297,31 +377,24 @@ public async Task SendRequestAsync(JsonRpcRequest request, Can
if (response is JsonRpcResponse success)
{
- // Convert the Result object to JSON and back to get our strongly-typed result
- var resultJson = JsonSerializer.Serialize(success.Result, _jsonOptions.GetTypeInfo