diff --git a/Directory.Packages.props b/Directory.Packages.props index 554361cbe..3bb00dbef 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -1,38 +1,63 @@ true - 9.0.3 10.0.0-preview.2.25163.2 - 9.0.3 9.3.0-preview.1.25161.3 + + + + + + + + + + + + + + + + + + + + + + + + + + - - - runtime; build; native; contentfiles; analyzers; buildtransitive - all - - - - - + - - + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + - - - + + + + + + + diff --git a/README.md b/README.md index 76cae4ee1..364354bcf 100644 --- a/README.md +++ b/README.md @@ -83,13 +83,24 @@ It includes a simple echo tool as an example (this is included in the same file the employed overload of `WithTools` examines the current assembly for classes with the `McpServerToolType` attribute, and registers all methods with the `McpTool` attribute as tools.) +``` +dotnet add package ModelContextProtocol --prerelease +dotnet add package Microsoft.Extensions.Hosting +``` + ```csharp using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; using ModelContextProtocol.Server; using System.ComponentModel; -var builder = Host.CreateEmptyApplicationBuilder(settings: null); +var builder = Host.CreateApplicationBuilder(args); +builder.Logging.AddConsole(consoleLogOptions => +{ + // Configure all logs to go to stderr + consoleLogOptions.LogToStandardErrorThreshold = LogLevel.Trace; +}); builder.Services .AddMcpServer() .WithStdioServerTransport() diff --git a/samples/AspNetCoreSseServer/AspNetCoreSseServer.csproj b/samples/AspNetCoreSseServer/AspNetCoreSseServer.csproj index c17cf9c45..94a5ccdb9 100644 --- a/samples/AspNetCoreSseServer/AspNetCoreSseServer.csproj +++ b/samples/AspNetCoreSseServer/AspNetCoreSseServer.csproj @@ -4,6 +4,7 @@ net9.0 enable enable + true diff --git a/samples/AspNetCoreSseServer/Program.cs b/samples/AspNetCoreSseServer/Program.cs index 774957e88..306a6e8f7 100644 --- a/samples/AspNetCoreSseServer/Program.cs +++ b/samples/AspNetCoreSseServer/Program.cs @@ -1,7 +1,10 @@ -using ModelContextProtocol.AspNetCore; +using TestServerWithHosting.Tools; var builder = WebApplication.CreateBuilder(args); -builder.Services.AddMcpServer().WithToolsFromAssembly(); +builder.Services.AddMcpServer() + .WithTools() + .WithTools(); + var app = builder.Build(); app.MapMcp(); diff --git a/samples/AspNetCoreSseServer/Tools/EchoTool.cs b/samples/AspNetCoreSseServer/Tools/EchoTool.cs index 636b4063a..7913b73e4 100644 --- a/samples/AspNetCoreSseServer/Tools/EchoTool.cs +++ b/samples/AspNetCoreSseServer/Tools/EchoTool.cs @@ -4,7 +4,7 @@ namespace TestServerWithHosting.Tools; [McpServerToolType] -public static class EchoTool +public sealed class EchoTool { [McpServerTool, Description("Echoes the input back to the client.")] public static string Echo(string message) diff --git a/samples/AspNetCoreSseServer/Tools/SampleLlmTool.cs b/samples/AspNetCoreSseServer/Tools/SampleLlmTool.cs index 4f175a453..4fbca594a 100644 --- a/samples/AspNetCoreSseServer/Tools/SampleLlmTool.cs +++ b/samples/AspNetCoreSseServer/Tools/SampleLlmTool.cs @@ -8,7 +8,7 @@ namespace TestServerWithHosting.Tools; /// This tool uses dependency injection and async method /// [McpServerToolType] -public static class SampleLlmTool +public sealed class SampleLlmTool { [McpServerTool(Name = "sampleLLM"), Description("Samples from an LLM using MCP's sampling feature")] public static async Task SampleLLM( diff --git a/samples/ChatWithTools/ChatWithTools.csproj b/samples/ChatWithTools/ChatWithTools.csproj index af8fac198..8e08a455d 100644 --- a/samples/ChatWithTools/ChatWithTools.csproj +++ b/samples/ChatWithTools/ChatWithTools.csproj @@ -5,6 +5,10 @@ net8.0 enable enable + diff --git a/samples/QuickstartClient/Program.cs b/samples/QuickstartClient/Program.cs index 364c2b870..1ecd40c25 100644 --- a/samples/QuickstartClient/Program.cs +++ b/samples/QuickstartClient/Program.cs @@ -5,7 +5,7 @@ using ModelContextProtocol.Client; using ModelContextProtocol.Protocol.Transport; -var builder = Host.CreateEmptyApplicationBuilder(settings: null); +var builder = Host.CreateApplicationBuilder(args); builder.Configuration .AddEnvironmentVariables() diff --git a/samples/QuickstartClient/QuickstartClient.csproj b/samples/QuickstartClient/QuickstartClient.csproj index b820bedc1..b68f15e5f 100644 --- a/samples/QuickstartClient/QuickstartClient.csproj +++ b/samples/QuickstartClient/QuickstartClient.csproj @@ -6,6 +6,10 @@ enable enable a4e20a70-5009-4b81-b5b6-780b6d43e78e + @@ -15,6 +19,7 @@ + diff --git a/samples/QuickstartWeatherServer/Program.cs b/samples/QuickstartWeatherServer/Program.cs index a191cb163..301eeed5e 100644 --- a/samples/QuickstartWeatherServer/Program.cs +++ b/samples/QuickstartWeatherServer/Program.cs @@ -1,12 +1,19 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using QuickstartWeatherServer.Tools; using System.Net.Http.Headers; -var builder = Host.CreateEmptyApplicationBuilder(settings: null); +var builder = Host.CreateApplicationBuilder(args); builder.Services.AddMcpServer() .WithStdioServerTransport() - .WithToolsFromAssembly(); + .WithTools(); + +builder.Logging.AddConsole(options => +{ + options.LogToStandardErrorThreshold = LogLevel.Trace; +}); builder.Services.AddSingleton(_ => { diff --git a/samples/QuickstartWeatherServer/QuickstartWeatherServer.csproj b/samples/QuickstartWeatherServer/QuickstartWeatherServer.csproj index 2e9154fd2..dc1108a8f 100644 --- a/samples/QuickstartWeatherServer/QuickstartWeatherServer.csproj +++ b/samples/QuickstartWeatherServer/QuickstartWeatherServer.csproj @@ -5,6 +5,7 @@ net8.0 enable enable + true diff --git a/samples/QuickstartWeatherServer/Tools/HttpClientExt.cs b/samples/QuickstartWeatherServer/Tools/HttpClientExt.cs new file mode 100644 index 000000000..f7b2b5499 --- /dev/null +++ b/samples/QuickstartWeatherServer/Tools/HttpClientExt.cs @@ -0,0 +1,13 @@ +using System.Text.Json; + +namespace ModelContextProtocol; + +internal static class HttpClientExt +{ + public static async Task ReadJsonDocumentAsync(this HttpClient client, string requestUri) + { + using var response = await client.GetAsync(requestUri); + response.EnsureSuccessStatusCode(); + return await JsonDocument.ParseAsync(await response.Content.ReadAsStreamAsync()); + } +} \ No newline at end of file diff --git a/samples/QuickstartWeatherServer/Tools/WeatherTools.cs b/samples/QuickstartWeatherServer/Tools/WeatherTools.cs index 697b80952..8463e3501 100644 --- a/samples/QuickstartWeatherServer/Tools/WeatherTools.cs +++ b/samples/QuickstartWeatherServer/Tools/WeatherTools.cs @@ -1,3 +1,4 @@ +using ModelContextProtocol; using ModelContextProtocol.Server; using System.ComponentModel; using System.Net.Http.Json; @@ -6,14 +7,15 @@ namespace QuickstartWeatherServer.Tools; [McpServerToolType] -public static class WeatherTools +public sealed class WeatherTools { [McpServerTool, Description("Get weather alerts for a US state.")] public static async Task GetAlerts( HttpClient client, [Description("The US state to get alerts for.")] string state) { - var jsonElement = await client.GetFromJsonAsync($"/alerts/active/area/{state}"); + using var jsonDocument = await client.ReadJsonDocumentAsync($"/alerts/active/area/{state}"); + var jsonElement = jsonDocument.RootElement; var alerts = jsonElement.GetProperty("features").EnumerateArray(); if (!alerts.Any()) @@ -40,7 +42,8 @@ public static async Task GetForecast( [Description("Latitude of the location.")] double latitude, [Description("Longitude of the location.")] double longitude) { - var jsonElement = await client.GetFromJsonAsync($"/points/{latitude},{longitude}"); + using var jsonDocument = await client.ReadJsonDocumentAsync($"/points/{latitude},{longitude}"); + var jsonElement = jsonDocument.RootElement; var periods = jsonElement.GetProperty("properties").GetProperty("periods").EnumerateArray(); return string.Join("\n---\n", periods.Select(period => $""" diff --git a/samples/TestServerWithHosting/Program.cs b/samples/TestServerWithHosting/Program.cs index ee009084b..1ab6fc7a2 100644 --- a/samples/TestServerWithHosting/Program.cs +++ b/samples/TestServerWithHosting/Program.cs @@ -1,6 +1,7 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; using Serilog; +using TestServerWithHosting.Tools; Log.Logger = new LoggerConfiguration() .MinimumLevel.Verbose() // Capture all log levels @@ -19,7 +20,8 @@ builder.Services.AddSerilog(); builder.Services.AddMcpServer() .WithStdioServerTransport() - .WithToolsFromAssembly(); + .WithTools() + .WithTools(); var app = builder.Build(); diff --git a/samples/TestServerWithHosting/TestServerWithHosting.csproj b/samples/TestServerWithHosting/TestServerWithHosting.csproj index 137d753bb..0f3918d95 100644 --- a/samples/TestServerWithHosting/TestServerWithHosting.csproj +++ b/samples/TestServerWithHosting/TestServerWithHosting.csproj @@ -5,6 +5,10 @@ net9.0 enable enable + @@ -13,6 +17,9 @@ + diff --git a/samples/TestServerWithHosting/Tools/EchoTool.cs b/samples/TestServerWithHosting/Tools/EchoTool.cs index 636b4063a..7913b73e4 100644 --- a/samples/TestServerWithHosting/Tools/EchoTool.cs +++ b/samples/TestServerWithHosting/Tools/EchoTool.cs @@ -4,7 +4,7 @@ namespace TestServerWithHosting.Tools; [McpServerToolType] -public static class EchoTool +public sealed class EchoTool { [McpServerTool, Description("Echoes the input back to the client.")] public static string Echo(string message) diff --git a/samples/TestServerWithHosting/Tools/SampleLlmTool.cs b/samples/TestServerWithHosting/Tools/SampleLlmTool.cs index b1a0353d4..3539b4cdc 100644 --- a/samples/TestServerWithHosting/Tools/SampleLlmTool.cs +++ b/samples/TestServerWithHosting/Tools/SampleLlmTool.cs @@ -8,7 +8,7 @@ namespace TestServerWithHosting.Tools; /// This tool uses depenency injection and async method /// [McpServerToolType] -public static class SampleLlmTool +public sealed class SampleLlmTool { [McpServerTool(Name = "sampleLLM"), Description("Samples from an LLM using MCP's sampling feature")] public static async Task SampleLLM( diff --git a/src/Common/Polyfills/System/Threading/Channels/ChannelExtensions.cs b/src/Common/Polyfills/System/Threading/Channels/ChannelExtensions.cs new file mode 100644 index 000000000..89822eff1 --- /dev/null +++ b/src/Common/Polyfills/System/Threading/Channels/ChannelExtensions.cs @@ -0,0 +1,17 @@ +using System.Runtime.CompilerServices; + +namespace System.Threading.Channels; + +internal static class ChannelExtensions +{ + public static async IAsyncEnumerable ReadAllAsync(this ChannelReader reader, [EnumeratorCancellation] CancellationToken cancellationToken) + { + while (await reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) + { + while (reader.TryRead(out var item)) + { + yield return item; + } + } + } +} \ No newline at end of file diff --git a/src/Directory.Build.props b/src/Directory.Build.props index 1d81cc273..f6f35ebd5 100644 --- a/src/Directory.Build.props +++ b/src/Directory.Build.props @@ -6,7 +6,7 @@ https://github.com/modelcontextprotocol/csharp-sdk git 0.1.0 - preview.4 + preview.5 ModelContextProtocolOfficial © Anthropic and Contributors. ModelContextProtocol;mcp;ai;llm diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs index 9ad1848c4..bb96356b3 100644 --- a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs @@ -1,5 +1,4 @@ -using Microsoft.AspNetCore.Builder; -using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Routing; using Microsoft.AspNetCore.WebUtilities; using Microsoft.Extensions.DependencyInjection; @@ -12,7 +11,7 @@ using System.Collections.Concurrent; using System.Security.Cryptography; -namespace ModelContextProtocol.AspNetCore; +namespace Microsoft.AspNetCore.Builder; /// /// Extension methods for to add MCP endpoints. @@ -40,7 +39,7 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo var requestAborted = context.RequestAborted; response.Headers.ContentType = "text/event-stream"; - response.Headers.CacheControl = "no-cache"; + response.Headers.CacheControl = "no-store"; var sessionId = MakeNewSessionId(); await using var transport = new SseResponseStreamTransport(response.Body, $"/message?sessionId={sessionId}"); @@ -48,15 +47,15 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo { throw new Exception($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created."); } - await using var server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider); try { var transportTask = transport.RunAsync(cancellationToken: requestAborted); - runSession ??= RunSession; + await using var server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider); try { + runSession ??= RunSession; await runSession(context, server, requestAborted); } finally @@ -86,11 +85,11 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo if (!_sessions.TryGetValue(sessionId.ToString(), out var transport)) { - await Results.BadRequest($"Session {sessionId} not found.").ExecuteAsync(context); + await Results.BadRequest($"Session ID not found.").ExecuteAsync(context); return; } - var message = await context.Request.ReadFromJsonAsync(McpJsonUtilities.DefaultOptions, context.RequestAborted); + var message = (IJsonRpcMessage?)await context.Request.ReadFromJsonAsync(McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage)), context.RequestAborted); if (message is null) { await Results.BadRequest("No message in request body.").ExecuteAsync(context); diff --git a/src/ModelContextProtocol.AspNetCore/ModelContextProtocol.AspNetCore.csproj b/src/ModelContextProtocol.AspNetCore/ModelContextProtocol.AspNetCore.csproj index 5dd10dbf1..1bc4feb01 100644 --- a/src/ModelContextProtocol.AspNetCore/ModelContextProtocol.AspNetCore.csproj +++ b/src/ModelContextProtocol.AspNetCore/ModelContextProtocol.AspNetCore.csproj @@ -1,7 +1,7 @@  - net8.0 + net9.0;net8.0 enable enable true @@ -9,6 +9,7 @@ ModelContextProtocol.AspNetCore ASP.NET Core extensions for the C# Model Context Protocol (MCP) SDK. README.md + true diff --git a/src/ModelContextProtocol.AspNetCore/README.md b/src/ModelContextProtocol.AspNetCore/README.md index dd7a59094..457321d09 100644 --- a/src/ModelContextProtocol.AspNetCore/README.md +++ b/src/ModelContextProtocol.AspNetCore/README.md @@ -23,15 +23,15 @@ To get started, install the package from NuGet ``` dotnet new web -dotnet add package ModelContextProtocol.AspNetcore --prerelease +dotnet add package ModelContextProtocol.AspNetCore --prerelease ``` ## Getting Started ```csharp // Program.cs -using ModelContextProtocol; -using ModelContextProtocol.AspNetCore; +using ModelContextProtocol.Server; +using System.ComponentModel; var builder = WebApplication.CreateBuilder(args); builder.WebHost.ConfigureKestrel(options => diff --git a/src/ModelContextProtocol/Client/IMcpClient.cs b/src/ModelContextProtocol/Client/IMcpClient.cs index 6761c4ef9..357ce3843 100644 --- a/src/ModelContextProtocol/Client/IMcpClient.cs +++ b/src/ModelContextProtocol/Client/IMcpClient.cs @@ -1,12 +1,11 @@ -using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Protocol.Types; namespace ModelContextProtocol.Client; /// /// Represents an instance of an MCP client connecting to a specific server. /// -public interface IMcpClient : IAsyncDisposable +public interface IMcpClient : IMcpEndpoint { /// /// Gets the capabilities supported by the server. @@ -24,40 +23,4 @@ public interface IMcpClient : IAsyncDisposable /// It can be thought of like a "hint" to the model. For example, this information MAY be added to the system prompt. /// string? ServerInstructions { get; } - - /// - /// Adds a handler for server notifications of a specific method. - /// - /// The notification method to handle. - /// The async handler function to process notifications. - /// - /// - /// Each method may have multiple handlers. Adding a handler for a method that already has one - /// will not replace the existing handler. - /// - /// - /// provides constants for common notification methods. - /// - /// - void AddNotificationHandler(string method, Func handler); - - /// - /// Sends a generic JSON-RPC request to the server. - /// - /// The expected response type. - /// The JSON-RPC request to send. - /// A token to cancel the operation. - /// A task containing the server's response. - /// - /// It is recommended to use the capability-specific methods that use this one in their implementation. - /// Use this method for custom requests or those not yet covered explicitly. - /// - Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) where TResult : class; - - /// - /// Sends a message to the server. - /// - /// The message. - /// A token to cancel the operation. - Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default); } \ No newline at end of file diff --git a/src/ModelContextProtocol/Client/McpClient.cs b/src/ModelContextProtocol/Client/McpClient.cs index b326f3c58..cf27a6b52 100644 --- a/src/ModelContextProtocol/Client/McpClient.cs +++ b/src/ModelContextProtocol/Client/McpClient.cs @@ -10,7 +10,7 @@ namespace ModelContextProtocol.Client; /// -internal sealed class McpClient : McpJsonRpcEndpoint, IMcpClient +internal sealed class McpClient : McpEndpoint, IMcpClient { private readonly IClientTransport _clientTransport; private readonly McpClientOptions _options; @@ -40,9 +40,14 @@ public McpClient(IClientTransport clientTransport, McpClientOptions options, Mcp throw new InvalidOperationException($"Sampling capability was set but it did not provide a handler."); } - SetRequestHandler( + SetRequestHandler( RequestMethods.SamplingCreateMessage, - (request, ct) => samplingHandler(request, ct)); + (request, cancellationToken) => samplingHandler( + request, + request?.Meta?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance, + cancellationToken), + McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, + McpJsonUtilities.JsonContext.Default.CreateMessageResult); } if (options.Capabilities?.Roots is { } rootsCapability) @@ -52,9 +57,11 @@ public McpClient(IClientTransport clientTransport, McpClientOptions options, Mcp throw new InvalidOperationException($"Roots capability was set but it did not provide a handler."); } - SetRequestHandler( + SetRequestHandler( RequestMethods.RootsList, - (request, ct) => rootsHandler(request, ct)); + rootsHandler, + McpJsonUtilities.JsonContext.Default.ListRootsRequestParams, + McpJsonUtilities.JsonContext.Default.ListRootsResult); } } @@ -79,9 +86,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) { // Connect transport _sessionTransport = await _clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false); - // We don't want the ConnectAsync token to cancel the session after we've successfully connected. - // The base class handles cleaning up the session in DisposeAsync without our help. - StartSession(_sessionTransport, fullSessionCancellationToken: CancellationToken.None); + StartSession(_sessionTransport); // Perform initialization sequence using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); @@ -90,18 +95,17 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) try { // Send initialize request - var initializeResponse = await SendRequestAsync( - new JsonRpcRequest + var initializeResponse = await this.SendRequestAsync( + RequestMethods.Initialize, + new InitializeRequestParams { - Method = RequestMethods.Initialize, - Params = new InitializeRequestParams() - { - ProtocolVersion = _options.ProtocolVersion, - Capabilities = _options.Capabilities ?? new ClientCapabilities(), - ClientInfo = _options.ClientInfo - } + ProtocolVersion = _options.ProtocolVersion, + Capabilities = _options.Capabilities ?? new ClientCapabilities(), + ClientInfo = _options.ClientInfo }, - initializationCts.Token).ConfigureAwait(false); + McpJsonUtilities.JsonContext.Default.InitializeRequestParams, + McpJsonUtilities.JsonContext.Default.InitializeResult, + cancellationToken: initializationCts.Token).ConfigureAwait(false); // Store server information _logger.ServerCapabilitiesReceived(EndpointName, diff --git a/src/ModelContextProtocol/Client/McpClientExtensions.cs b/src/ModelContextProtocol/Client/McpClientExtensions.cs index 36742b14a..c98f76191 100644 --- a/src/ModelContextProtocol/Client/McpClientExtensions.cs +++ b/src/ModelContextProtocol/Client/McpClientExtensions.cs @@ -1,35 +1,16 @@ -using ModelContextProtocol.Protocol.Messages; +using Microsoft.Extensions.AI; +using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Utils; using ModelContextProtocol.Utils.Json; -using Microsoft.Extensions.AI; -using System.Text.Json; using System.Runtime.CompilerServices; +using System.Text.Json; namespace ModelContextProtocol.Client; -/// -/// Provides extensions for operating on MCP clients. -/// +/// Provides extension methods for interacting with an . public static class McpClientExtensions { - /// - /// Sends a notification to the server with parameters. - /// - /// The client. - /// The notification method name. - /// The parameters to send with the notification. - /// A token to cancel the operation. - public static Task SendNotificationAsync(this IMcpClient client, string method, object? parameters = null, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNullOrWhiteSpace(method); - - return client.SendMessageAsync( - new JsonRpcNotification { Method = method, Params = parameters }, - cancellationToken); - } - /// /// Sends a ping request to verify server connectivity. /// @@ -40,34 +21,46 @@ public static Task PingAsync(this IMcpClient client, CancellationToken cancellat { Throw.IfNull(client); - return client.SendRequestAsync( - CreateRequest(RequestMethods.Ping, null), - cancellationToken); + return client.SendRequestAsync( + RequestMethods.Ping, + parameters: null, + McpJsonUtilities.JsonContext.Default.Object!, + McpJsonUtilities.JsonContext.Default.Object, + cancellationToken: cancellationToken); } /// /// Retrieves a list of available tools from the server. /// /// The client. + /// The serializer options governing tool parameter serialization. /// A token to cancel the operation. /// A list of all available tools. public static async Task> ListToolsAsync( - this IMcpClient client, CancellationToken cancellationToken = default) + this IMcpClient client, + JsonSerializerOptions? serializerOptions = null, + CancellationToken cancellationToken = default) { Throw.IfNull(client); + serializerOptions ??= McpJsonUtilities.DefaultOptions; + serializerOptions.MakeReadOnly(); + List? tools = null; string? cursor = null; do { - var toolResults = await client.SendRequestAsync( - CreateRequest(RequestMethods.ToolsList, CreateCursorDictionary(cursor)), - cancellationToken).ConfigureAwait(false); + var toolResults = await client.SendRequestAsync( + RequestMethods.ToolsList, + CreateCursorDictionary(cursor)!, + McpJsonUtilities.JsonContext.Default.DictionaryStringObject, + McpJsonUtilities.JsonContext.Default.ListToolsResult, + cancellationToken: cancellationToken).ConfigureAwait(false); tools ??= new List(toolResults.Tools.Count); foreach (var tool in toolResults.Tools) { - tools.Add(new McpClientTool(client, tool)); + tools.Add(new McpClientTool(client, tool, serializerOptions)); } cursor = toolResults.NextCursor; @@ -81,6 +74,7 @@ public static async Task> ListToolsAsync( /// Creates an enumerable for asynchronously enumerating all available tools from the server. /// /// The client. + /// The serializer options governing tool parameter serialization. /// A token to cancel the operation. /// An asynchronous sequence of all available tools. /// @@ -88,20 +82,28 @@ public static async Task> ListToolsAsync( /// will result in requerying the server and yielding the sequence of available tools. /// public static async IAsyncEnumerable EnumerateToolsAsync( - this IMcpClient client, [EnumeratorCancellation] CancellationToken cancellationToken = default) + this IMcpClient client, + JsonSerializerOptions? serializerOptions = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) { Throw.IfNull(client); + serializerOptions ??= McpJsonUtilities.DefaultOptions; + serializerOptions.MakeReadOnly(); + string? cursor = null; do { - var toolResults = await client.SendRequestAsync( - CreateRequest(RequestMethods.ToolsList, CreateCursorDictionary(cursor)), - cancellationToken).ConfigureAwait(false); + var toolResults = await client.SendRequestAsync( + RequestMethods.ToolsList, + CreateCursorDictionary(cursor)!, + McpJsonUtilities.JsonContext.Default.DictionaryStringObject, + McpJsonUtilities.JsonContext.Default.ListToolsResult, + cancellationToken: cancellationToken).ConfigureAwait(false); foreach (var tool in toolResults.Tools) { - yield return new McpClientTool(client, tool); + yield return new McpClientTool(client, tool, serializerOptions); } cursor = toolResults.NextCursor; @@ -124,9 +126,12 @@ public static async Task> ListPromptsAsync( string? cursor = null; do { - var promptResults = await client.SendRequestAsync( - CreateRequest(RequestMethods.PromptsList, CreateCursorDictionary(cursor)), - cancellationToken).ConfigureAwait(false); + var promptResults = await client.SendRequestAsync( + RequestMethods.PromptsList, + CreateCursorDictionary(cursor)!, + McpJsonUtilities.JsonContext.Default.DictionaryStringObject, + McpJsonUtilities.JsonContext.Default.ListPromptsResult, + cancellationToken: cancellationToken).ConfigureAwait(false); prompts ??= new List(promptResults.Prompts.Count); foreach (var prompt in promptResults.Prompts) @@ -159,9 +164,12 @@ public static async IAsyncEnumerable EnumeratePromptsAsync( string? cursor = null; do { - var promptResults = await client.SendRequestAsync( - CreateRequest(RequestMethods.PromptsList, CreateCursorDictionary(cursor)), - cancellationToken).ConfigureAwait(false); + var promptResults = await client.SendRequestAsync( + RequestMethods.PromptsList, + CreateCursorDictionary(cursor)!, + McpJsonUtilities.JsonContext.Default.DictionaryStringObject, + McpJsonUtilities.JsonContext.Default.ListPromptsResult, + cancellationToken: cancellationToken).ConfigureAwait(false); foreach (var prompt in promptResults.Prompts) { @@ -179,17 +187,29 @@ public static async IAsyncEnumerable EnumeratePromptsAsync( /// The client. /// The name of the prompt to retrieve /// Optional arguments for the prompt + /// The serialization options governing argument serialization. /// A token to cancel the operation. /// A task containing the prompt's content and messages. public static Task GetPromptAsync( - this IMcpClient client, string name, IReadOnlyDictionary? arguments = null, CancellationToken cancellationToken = default) + this IMcpClient client, + string name, + IReadOnlyDictionary? arguments = null, + JsonSerializerOptions? serializerOptions = null, + CancellationToken cancellationToken = default) { Throw.IfNull(client); Throw.IfNullOrWhiteSpace(name); + serializerOptions ??= McpJsonUtilities.DefaultOptions; + serializerOptions.MakeReadOnly(); - return client.SendRequestAsync( - CreateRequest(RequestMethods.PromptsGet, CreateParametersDictionary(name, arguments)), - cancellationToken); + var parametersTypeInfo = serializerOptions.GetTypeInfo>(); + + return client.SendRequestAsync( + RequestMethods.PromptsGet, + CreateParametersDictionary(name, arguments), + parametersTypeInfo, + McpJsonUtilities.JsonContext.Default.GetPromptResult, + cancellationToken: cancellationToken); } /// @@ -208,9 +228,12 @@ public static async Task> ListResourceTemplatesAsync( string? cursor = null; do { - var templateResults = await client.SendRequestAsync( - CreateRequest(RequestMethods.ResourcesTemplatesList, CreateCursorDictionary(cursor)), - cancellationToken).ConfigureAwait(false); + var templateResults = await client.SendRequestAsync( + RequestMethods.ResourcesTemplatesList, + CreateCursorDictionary(cursor)!, + McpJsonUtilities.JsonContext.Default.DictionaryStringObject, + McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult, + cancellationToken: cancellationToken).ConfigureAwait(false); if (templates is null) { @@ -246,9 +269,12 @@ public static async IAsyncEnumerable EnumerateResourceTemplate string? cursor = null; do { - var templateResults = await client.SendRequestAsync( - CreateRequest(RequestMethods.ResourcesTemplatesList, CreateCursorDictionary(cursor)), - cancellationToken).ConfigureAwait(false); + var templateResults = await client.SendRequestAsync( + RequestMethods.ResourcesTemplatesList, + CreateCursorDictionary(cursor)!, + McpJsonUtilities.JsonContext.Default.DictionaryStringObject, + McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult, + cancellationToken: cancellationToken).ConfigureAwait(false); foreach (var template in templateResults.ResourceTemplates) { @@ -276,9 +302,12 @@ public static async Task> ListResourcesAsync( string? cursor = null; do { - var resourceResults = await client.SendRequestAsync( - CreateRequest(RequestMethods.ResourcesList, CreateCursorDictionary(cursor)), - cancellationToken).ConfigureAwait(false); + var resourceResults = await client.SendRequestAsync( + RequestMethods.ResourcesList, + CreateCursorDictionary(cursor)!, + McpJsonUtilities.JsonContext.Default.DictionaryStringObject, + McpJsonUtilities.JsonContext.Default.ListResourcesResult, + cancellationToken: cancellationToken).ConfigureAwait(false); if (resources is null) { @@ -314,9 +343,12 @@ public static async IAsyncEnumerable EnumerateResourcesAsync( string? cursor = null; do { - var resourceResults = await client.SendRequestAsync( - CreateRequest(RequestMethods.ResourcesList, CreateCursorDictionary(cursor)), - cancellationToken).ConfigureAwait(false); + var resourceResults = await client.SendRequestAsync( + RequestMethods.ResourcesList, + CreateCursorDictionary(cursor)!, + McpJsonUtilities.JsonContext.Default.DictionaryStringObject, + McpJsonUtilities.JsonContext.Default.ListResourcesResult, + cancellationToken: cancellationToken).ConfigureAwait(false); foreach (var resource in resourceResults.Resources) { @@ -340,9 +372,12 @@ public static Task ReadResourceAsync( Throw.IfNull(client); Throw.IfNullOrWhiteSpace(uri); - return client.SendRequestAsync( - CreateRequest(RequestMethods.ResourcesRead, new Dictionary() { ["uri"] = uri }), - cancellationToken); + return client.SendRequestAsync( + RequestMethods.ResourcesRead, + new Dictionary { ["uri"] = uri }, + McpJsonUtilities.JsonContext.Default.DictionaryStringObject, + McpJsonUtilities.JsonContext.Default.ReadResourceResult, + cancellationToken: cancellationToken); } /// @@ -364,13 +399,16 @@ public static Task GetCompletionAsync(this IMcpClient client, Re throw new ArgumentException($"Invalid reference: {validationMessage}", nameof(reference)); } - return client.SendRequestAsync( - CreateRequest(RequestMethods.CompletionComplete, new Dictionary() + return client.SendRequestAsync( + RequestMethods.CompletionComplete, + new Dictionary { ["ref"] = reference, ["argument"] = new Argument { Name = argumentName, Value = argumentValue } - }), - cancellationToken); + }, + McpJsonUtilities.JsonContext.Default.DictionaryStringObject, + McpJsonUtilities.JsonContext.Default.CompleteResult, + cancellationToken: cancellationToken); } /// @@ -384,9 +422,12 @@ public static Task SubscribeToResourceAsync(this IMcpClient client, string uri, Throw.IfNull(client); Throw.IfNullOrWhiteSpace(uri); - return client.SendRequestAsync( - CreateRequest(RequestMethods.ResourcesSubscribe, new Dictionary() { ["uri"] = uri }), - cancellationToken); + return client.SendRequestAsync( + RequestMethods.ResourcesSubscribe, + new Dictionary { ["uri"] = uri }, + McpJsonUtilities.JsonContext.Default.DictionaryStringObject, + McpJsonUtilities.JsonContext.Default.EmptyResult, + cancellationToken: cancellationToken); } /// @@ -400,9 +441,12 @@ public static Task UnsubscribeFromResourceAsync(this IMcpClient client, string u Throw.IfNull(client); Throw.IfNullOrWhiteSpace(uri); - return client.SendRequestAsync( - CreateRequest(RequestMethods.ResourcesUnsubscribe, new Dictionary() { ["uri"] = uri }), - cancellationToken); + return client.SendRequestAsync( + RequestMethods.ResourcesUnsubscribe, + new Dictionary { ["uri"] = uri }, + McpJsonUtilities.JsonContext.Default.DictionaryStringObject, + McpJsonUtilities.JsonContext.Default.EmptyResult, + cancellationToken: cancellationToken); } /// @@ -411,17 +455,29 @@ public static Task UnsubscribeFromResourceAsync(this IMcpClient client, string u /// The client. /// The name of the tool to call. /// Optional arguments for the tool. + /// The serialization options governing argument serialization. /// A token to cancel the operation. /// A task containing the tool's response. public static Task CallToolAsync( - this IMcpClient client, string toolName, IReadOnlyDictionary? arguments = null, CancellationToken cancellationToken = default) + this IMcpClient client, + string toolName, + IReadOnlyDictionary? arguments = null, + JsonSerializerOptions? serializerOptions = null, + CancellationToken cancellationToken = default) { Throw.IfNull(client); Throw.IfNull(toolName); + serializerOptions ??= McpJsonUtilities.DefaultOptions; + serializerOptions.MakeReadOnly(); + + var parametersTypeInfo = serializerOptions.GetTypeInfo>(); - return client.SendRequestAsync( - CreateRequest(RequestMethods.ToolsCall, CreateParametersDictionary(toolName, arguments)), - cancellationToken); + return client.SendRequestAsync( + RequestMethods.ToolsCall, + CreateParametersDictionary(toolName, arguments), + parametersTypeInfo, + McpJsonUtilities.JsonContext.Default.CallToolResponse, + cancellationToken: cancellationToken); } /// @@ -531,17 +587,33 @@ internal static CreateMessageResult ToCreateMessageResult(this ChatResponse chat /// /// The with which to satisfy sampling requests. /// The created handler delegate. - public static Func> CreateSamplingHandler(this IChatClient chatClient) + public static Func, CancellationToken, Task> CreateSamplingHandler( + this IChatClient chatClient) { Throw.IfNull(chatClient); - return async (requestParams, cancellationToken) => + return async (requestParams, progress, cancellationToken) => { Throw.IfNull(requestParams); var (messages, options) = requestParams.ToChatClientArguments(); - var response = await chatClient.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); - return response.ToCreateMessageResult(); + var progressToken = requestParams.Meta?.ProgressToken; + + List updates = []; + await foreach (var update in chatClient.GetStreamingResponseAsync(messages, options, cancellationToken)) + { + updates.Add(update); + + if (progressToken is not null) + { + progress.Report(new() + { + Progress = updates.Count, + }); + } + } + + return updates.ToChatResponse().ToCreateMessageResult(); }; } @@ -555,18 +627,14 @@ public static Task SetLoggingLevel(this IMcpClient client, LoggingLevel level, C { Throw.IfNull(client); - return client.SendRequestAsync( - CreateRequest(RequestMethods.LoggingSetLevel, new Dictionary() { ["level"] = level }), - cancellationToken); + return client.SendRequestAsync( + RequestMethods.LoggingSetLevel, + new Dictionary { ["level"] = level }, + McpJsonUtilities.JsonContext.Default.DictionaryStringObject, + McpJsonUtilities.JsonContext.Default.EmptyResult, + cancellationToken: cancellationToken); } - private static JsonRpcRequest CreateRequest(string method, IReadOnlyDictionary? parameters) => - new() - { - Method = method, - Params = parameters - }; - private static Dictionary? CreateCursorDictionary(string? cursor) => cursor != null ? new() { ["cursor"] = cursor } : null; @@ -585,36 +653,4 @@ private static JsonRpcRequest CreateRequest(string method, IReadOnlyDictionaryProvides an AI function that calls a tool through . - private sealed class McpAIFunction(IMcpClient client, Tool tool) : AIFunction - { - /// - public override string Name => tool.Name; - - /// - public override string Description => tool.Description ?? string.Empty; - - /// - public override JsonElement JsonSchema => tool.InputSchema; - - /// - public override JsonSerializerOptions JsonSerializerOptions => McpJsonUtilities.DefaultOptions; - - /// - protected async override Task InvokeCoreAsync( - IEnumerable> arguments, CancellationToken cancellationToken) - { - IReadOnlyDictionary argDict = - arguments as IReadOnlyDictionary ?? -#if NET - arguments.ToDictionary(); -#else - arguments.ToDictionary(kv => kv.Key, kv => kv.Value); -#endif - - CallToolResponse result = await client.CallToolAsync(tool.Name, argDict, cancellationToken).ConfigureAwait(false); - return JsonSerializer.SerializeToElement(result, McpJsonUtilities.JsonContext.Default.CallToolResponse); - } - } } \ No newline at end of file diff --git a/src/ModelContextProtocol/Client/McpClientPrompt.cs b/src/ModelContextProtocol/Client/McpClientPrompt.cs index 8deed8eb0..71d6a4e67 100644 --- a/src/ModelContextProtocol/Client/McpClientPrompt.cs +++ b/src/ModelContextProtocol/Client/McpClientPrompt.cs @@ -1,4 +1,5 @@ using ModelContextProtocol.Protocol.Types; +using System.Text.Json; namespace ModelContextProtocol.Client; @@ -20,17 +21,19 @@ internal McpClientPrompt(IMcpClient client, Prompt prompt) /// Retrieves a specific prompt with optional arguments. /// /// Optional arguments for the prompt + /// The serialization options governing argument serialization. /// A token to cancel the operation. /// A task containing the prompt's content and messages. public async ValueTask GetAsync( IEnumerable>? arguments = null, + JsonSerializerOptions? serializerOptions = null, CancellationToken cancellationToken = default) { IReadOnlyDictionary? argDict = arguments as IReadOnlyDictionary ?? arguments?.ToDictionary(); - return await _client.GetPromptAsync(ProtocolPrompt.Name, argDict, cancellationToken).ConfigureAwait(false); + return await _client.GetPromptAsync(ProtocolPrompt.Name, argDict, serializerOptions, cancellationToken: cancellationToken).ConfigureAwait(false); } /// Gets the name of the prompt. diff --git a/src/ModelContextProtocol/Client/McpClientTool.cs b/src/ModelContextProtocol/Client/McpClientTool.cs index 10ccd81a6..58cb02d53 100644 --- a/src/ModelContextProtocol/Client/McpClientTool.cs +++ b/src/ModelContextProtocol/Client/McpClientTool.cs @@ -1,5 +1,6 @@ using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Utils.Json; +using ModelContextProtocol.Utils; using Microsoft.Extensions.AI; using System.Text.Json; @@ -9,27 +10,57 @@ namespace ModelContextProtocol.Client; public sealed class McpClientTool : AIFunction { private readonly IMcpClient _client; + private readonly string _name; + private readonly string _description; - internal McpClientTool(IMcpClient client, Tool tool) + internal McpClientTool(IMcpClient client, Tool tool, JsonSerializerOptions serializerOptions, string? name = null, string? description = null) { _client = client; ProtocolTool = tool; + JsonSerializerOptions = serializerOptions; + _name = name ?? tool.Name; + _description = description ?? tool.Description ?? string.Empty; + } + + /// + /// Creates a new instance of the tool with the specified name. + /// This is useful for optimizing the tool name for specific models or for prefixing the tool name with a (usually server-derived) namespace to avoid conflicts. + /// The server will still be called with the original tool name, so no mapping is required. + /// + /// The model-facing name to give the tool. + /// Copy of this McpClientTool with the provided name + public McpClientTool WithName(string name) + { + return new McpClientTool(_client, ProtocolTool, JsonSerializerOptions, name, _description); + } + + /// + /// Creates a new instance of the tool with the specified description. + /// This can be used to provide modified or additional (e.g. examples) context to the model about the tool. + /// This will in general require a hard-coded mapping in the client. + /// It is not recommended to use this without running evaluations to ensure the model actually benefits from the custom description. + /// + /// The description to give the tool. + /// Copy of this McpClientTool with the provided description + public McpClientTool WithDescription(string description) + { + return new McpClientTool(_client, ProtocolTool, JsonSerializerOptions, _name, description); } /// Gets the protocol type for this instance. public Tool ProtocolTool { get; } /// - public override string Name => ProtocolTool.Name; + public override string Name => _name; /// - public override string Description => ProtocolTool.Description ?? string.Empty; + public override string Description => _description; /// public override JsonElement JsonSchema => ProtocolTool.InputSchema; /// - public override JsonSerializerOptions JsonSerializerOptions => McpJsonUtilities.DefaultOptions; + public override JsonSerializerOptions JsonSerializerOptions { get; } /// protected async override Task InvokeCoreAsync( @@ -39,7 +70,7 @@ internal McpClientTool(IMcpClient client, Tool tool) arguments as IReadOnlyDictionary ?? arguments.ToDictionary(); - CallToolResponse result = await _client.CallToolAsync(ProtocolTool.Name, argDict, cancellationToken).ConfigureAwait(false); + CallToolResponse result = await _client.CallToolAsync(ProtocolTool.Name, argDict, JsonSerializerOptions, cancellationToken: cancellationToken).ConfigureAwait(false); return JsonSerializer.SerializeToElement(result, McpJsonUtilities.JsonContext.Default.CallToolResponse); } } \ No newline at end of file diff --git a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.cs b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.cs index 3fbc4fea7..a59de8ce8 100644 --- a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.cs @@ -176,7 +176,7 @@ public static IMcpServerBuilder WithPrompts(this IMcpServerBuilder builder, para } /// - /// Adds types marked with the attribute from the given assembly as prompts to the server. + /// Adds types marked with the attribute from the given assembly as prompts to the server. /// /// The builder instance. /// The assembly to load the types from. Null to get the current assembly @@ -190,7 +190,7 @@ public static IMcpServerBuilder WithPromptsFromAssembly(this IMcpServerBuilder b return builder.WithPrompts( from t in promptAssembly.GetTypes() - where t.GetCustomAttribute() is not null + where t.GetCustomAttribute() is not null select t); } #endregion diff --git a/src/ModelContextProtocol/Configuration/McpServerConfig.cs b/src/ModelContextProtocol/Configuration/McpServerConfig.cs index c8d0a26eb..27cd39e41 100644 --- a/src/ModelContextProtocol/Configuration/McpServerConfig.cs +++ b/src/ModelContextProtocol/Configuration/McpServerConfig.cs @@ -27,11 +27,6 @@ public record McpServerConfig /// public string? Location { get; set; } - /// - /// Arguments (if any) to pass to the executable. - /// - public string[]? Arguments { get; init; } - /// /// Additional transport-specific configuration. /// diff --git a/src/ModelContextProtocol/Diagnostics.cs b/src/ModelContextProtocol/Diagnostics.cs new file mode 100644 index 000000000..5b4e31f4d --- /dev/null +++ b/src/ModelContextProtocol/Diagnostics.cs @@ -0,0 +1,37 @@ +using System.Diagnostics; +using System.Diagnostics.Metrics; + +namespace ModelContextProtocol; + +internal static class Diagnostics +{ + internal static ActivitySource ActivitySource { get; } = new("Experimental.ModelContextProtocol"); + + internal static Meter Meter { get; } = new("Experimental.ModelContextProtocol"); + + internal static Histogram CreateDurationHistogram(string name, string description, bool longBuckets) => + Diagnostics.Meter.CreateHistogram(name, "s", description +#if NET9_0_OR_GREATER + , advice: longBuckets ? LongSecondsBucketBoundaries : ShortSecondsBucketBoundaries +#endif + ); + +#if NET9_0_OR_GREATER + /// + /// Follows boundaries from http.server.request.duration/http.client.request.duration + /// + private static InstrumentAdvice ShortSecondsBucketBoundaries { get; } = new() + { + HistogramBucketBoundaries = [0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.25, 0.5, 0.75, 1, 2.5, 5, 7.5, 10], + }; + + /// + /// Not based on a standard. Larger bucket sizes for longer lasting operations, e.g. HTTP connection duration. + /// See https://github.com/open-telemetry/semantic-conventions/issues/336 + /// + private static InstrumentAdvice LongSecondsBucketBoundaries { get; } = new() + { + HistogramBucketBoundaries = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1, 2, 5, 10, 30, 60, 120, 300], + }; +#endif +} diff --git a/src/ModelContextProtocol/IMcpEndpoint.cs b/src/ModelContextProtocol/IMcpEndpoint.cs new file mode 100644 index 000000000..6643e02a7 --- /dev/null +++ b/src/ModelContextProtocol/IMcpEndpoint.cs @@ -0,0 +1,34 @@ +using ModelContextProtocol.Protocol.Messages; + +namespace ModelContextProtocol; + +/// Represents a client or server MCP endpoint. +public interface IMcpEndpoint : IAsyncDisposable +{ + /// Sends a JSON-RPC request to the connected endpoint. + /// The JSON-RPC request to send. + /// A token to cancel the operation. + /// A task containing the client's response. + Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default); + + /// Sends a message to the connected endpoint. + /// The message. + /// A token to cancel the operation. + Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default); + + /// + /// Adds a handler for server notifications of a specific method. + /// + /// The notification method to handle. + /// The async handler function to process notifications. + /// + /// + /// Each method may have multiple handlers. Adding a handler for a method that already has one + /// will not replace the existing handler. + /// + /// + /// provides constants for common notification methods. + /// + /// + void AddNotificationHandler(string method, Func handler); +} diff --git a/src/ModelContextProtocol/Logging/Log.cs b/src/ModelContextProtocol/Logging/Log.cs index d22c5d664..9b5d44bdf 100644 --- a/src/ModelContextProtocol/Logging/Log.cs +++ b/src/ModelContextProtocol/Logging/Log.cs @@ -77,9 +77,6 @@ internal static partial class Log [LoggerMessage(Level = LogLevel.Information, Message = "Request response received for {endpointName} with method {method}")] internal static partial void RequestResponseReceived(this ILogger logger, string endpointName, string method); - [LoggerMessage(Level = LogLevel.Error, Message = "Request response type conversion error for {endpointName} with method {method}: expected {expectedType}")] - internal static partial void RequestResponseTypeConversionError(this ILogger logger, string endpointName, string method, Type expectedType); - [LoggerMessage(Level = LogLevel.Error, Message = "Request invalid response type for {endpointName} with method {method}")] internal static partial void RequestInvalidResponseType(this ILogger logger, string endpointName, string method); @@ -98,8 +95,8 @@ internal static partial class Log [LoggerMessage(Level = LogLevel.Information, Message = "Creating process for transport for {endpointName} with command {command}, arguments {arguments}, environment {environment}, working directory {workingDirectory}, shutdown timeout {shutdownTimeout}")] internal static partial void CreateProcessForTransport(this ILogger logger, string endpointName, string command, string? arguments, string environment, string workingDirectory, string shutdownTimeout); - [LoggerMessage(Level = LogLevel.Error, Message = "Transport for {endpointName} error: {data}")] - internal static partial void TransportError(this ILogger logger, string endpointName, string data); + [LoggerMessage(Level = LogLevel.Information, Message = "Transport for {endpointName} received stderr log: {data}")] + internal static partial void ReadStderr(this ILogger logger, string endpointName, string data); [LoggerMessage(Level = LogLevel.Information, Message = "Transport process start failed for {endpointName}")] internal static partial void TransportProcessStartFailed(this ILogger logger, string endpointName); diff --git a/src/ModelContextProtocol/McpEndpointExtensions.cs b/src/ModelContextProtocol/McpEndpointExtensions.cs new file mode 100644 index 000000000..d2a6a952e --- /dev/null +++ b/src/ModelContextProtocol/McpEndpointExtensions.cs @@ -0,0 +1,168 @@ +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Utils; +using ModelContextProtocol.Utils.Json; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization.Metadata; + +namespace ModelContextProtocol; + +/// Provides extension methods for interacting with an . +public static class McpEndpointExtensions +{ + /// + /// Sends a JSON-RPC request and attempts to deserialize the result to . + /// + /// The type of the request parameters to serialize from. + /// The type of the result to deserialize to. + /// The MCP client or server instance. + /// The JSON-RPC method name to invoke. + /// Object representing the request parameters. + /// The request id for the request. + /// The options governing request serialization. + /// A token to cancel the operation. + /// A task that represents the asynchronous operation. The task result contains the deserialized result. + public static Task SendRequestAsync( + this IMcpEndpoint endpoint, + string method, + TParameters parameters, + JsonSerializerOptions? serializerOptions = null, + RequestId? requestId = null, + CancellationToken cancellationToken = default) + where TResult : notnull + { + serializerOptions ??= McpJsonUtilities.DefaultOptions; + serializerOptions.MakeReadOnly(); + + JsonTypeInfo paramsTypeInfo = serializerOptions.GetTypeInfo(); + JsonTypeInfo resultTypeInfo = serializerOptions.GetTypeInfo(); + return SendRequestAsync(endpoint, method, parameters, paramsTypeInfo, resultTypeInfo, requestId, cancellationToken); + } + + /// + /// Sends a JSON-RPC request and attempts to deserialize the result to . + /// + /// The type of the request parameters to serialize from. + /// The type of the result to deserialize to. + /// The MCP client or server instance. + /// The JSON-RPC method name to invoke. + /// Object representing the request parameters. + /// The type information for request parameter serialization. + /// The type information for request parameter deserialization. + /// The request id for the request. + /// A token to cancel the operation. + /// A task that represents the asynchronous operation. The task result contains the deserialized result. + internal static async Task SendRequestAsync( + this IMcpEndpoint endpoint, + string method, + TParameters parameters, + JsonTypeInfo parametersTypeInfo, + JsonTypeInfo resultTypeInfo, + RequestId? requestId = null, + CancellationToken cancellationToken = default) + where TResult : notnull + { + Throw.IfNull(endpoint); + Throw.IfNullOrWhiteSpace(method); + Throw.IfNull(parametersTypeInfo); + Throw.IfNull(resultTypeInfo); + + JsonRpcRequest jsonRpcRequest = new() + { + Method = method, + Params = JsonSerializer.SerializeToNode(parameters, parametersTypeInfo), + }; + + if (requestId is { } id) + { + jsonRpcRequest.Id = id; + } + + JsonRpcResponse response = await endpoint.SendRequestAsync(jsonRpcRequest, cancellationToken).ConfigureAwait(false); + return JsonSerializer.Deserialize(response.Result, resultTypeInfo) ?? throw new JsonException("Unexpected JSON result in response."); + } + + /// + /// Sends a notification to the server with parameters. + /// + /// The client. + /// The notification method name. + /// A token to cancel the operation. + public static Task SendNotificationAsync(this IMcpEndpoint client, string method, CancellationToken cancellationToken = default) + { + Throw.IfNull(client); + Throw.IfNullOrWhiteSpace(method); + return client.SendMessageAsync(new JsonRpcNotification { Method = method }, cancellationToken); + } + + /// + /// Sends a notification to the server with parameters. + /// + /// The MCP client or server instance. + /// The JSON-RPC method name to invoke. + /// Object representing the request parameters. + /// The options governing request serialization. + /// A token to cancel the operation. + public static Task SendNotificationAsync( + this IMcpEndpoint endpoint, + string method, + TParameters parameters, + JsonSerializerOptions? serializerOptions = null, + CancellationToken cancellationToken = default) + { + serializerOptions ??= McpJsonUtilities.DefaultOptions; + serializerOptions.MakeReadOnly(); + + JsonTypeInfo parametersTypeInfo = serializerOptions.GetTypeInfo(); + return SendNotificationAsync(endpoint, method, parameters, parametersTypeInfo, cancellationToken); + } + + /// + /// Sends a notification to the server with parameters. + /// + /// The MCP client or server instance. + /// The JSON-RPC method name to invoke. + /// Object representing the request parameters. + /// The type information for request parameter serialization. + /// A token to cancel the operation. + internal static Task SendNotificationAsync( + this IMcpEndpoint endpoint, + string method, + TParameters parameters, + JsonTypeInfo parametersTypeInfo, + CancellationToken cancellationToken = default) + { + Throw.IfNull(endpoint); + Throw.IfNullOrWhiteSpace(method); + Throw.IfNull(parametersTypeInfo); + + JsonNode? parametersJson = JsonSerializer.SerializeToNode(parameters, parametersTypeInfo); + return endpoint.SendMessageAsync(new JsonRpcNotification { Method = method, Params = parametersJson }, cancellationToken); + } + + /// Notifies the connected endpoint of progress. + /// The endpoint issuing the notification. + /// The identifying the operation. + /// The progress update to send. + /// A token to cancel the operation. + /// A task representing the completion of the operation. + /// is . + public static Task NotifyProgressAsync( + this IMcpEndpoint endpoint, + ProgressToken progressToken, + ProgressNotificationValue progress, + CancellationToken cancellationToken = default) + { + Throw.IfNull(endpoint); + + return endpoint.SendMessageAsync(new JsonRpcNotification() + { + Method = NotificationMethods.ProgressNotification, + Params = JsonSerializer.SerializeToNode(new ProgressNotification + { + ProgressToken = progressToken, + Progress = progress, + }, McpJsonUtilities.JsonContext.Default.ProgressNotification), + }, cancellationToken); + } +} diff --git a/src/ModelContextProtocol/ModelContextProtocol.csproj b/src/ModelContextProtocol/ModelContextProtocol.csproj index 6860381de..c120269bd 100644 --- a/src/ModelContextProtocol/ModelContextProtocol.csproj +++ b/src/ModelContextProtocol/ModelContextProtocol.csproj @@ -1,7 +1,7 @@  - net8.0;netstandard2.0 + net9.0;net8.0;netstandard2.0 true true ModelContextProtocol @@ -12,15 +12,8 @@ true - - - - - - - - + @@ -28,6 +21,24 @@ + + + + + + + + + + + + + + + diff --git a/src/ModelContextProtocol/Protocol/Messages/JsonRpcNotification.cs b/src/ModelContextProtocol/Protocol/Messages/JsonRpcNotification.cs index 359773712..302a46379 100644 --- a/src/ModelContextProtocol/Protocol/Messages/JsonRpcNotification.cs +++ b/src/ModelContextProtocol/Protocol/Messages/JsonRpcNotification.cs @@ -1,4 +1,5 @@ -using System.Text.Json.Serialization; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; namespace ModelContextProtocol.Protocol.Messages; @@ -23,5 +24,5 @@ public record JsonRpcNotification : IJsonRpcMessage /// Optional parameters for the notification. /// [JsonPropertyName("params")] - public object? Params { get; init; } + public JsonNode? Params { get; init; } } diff --git a/src/ModelContextProtocol/Protocol/Messages/JsonRpcRequest.cs b/src/ModelContextProtocol/Protocol/Messages/JsonRpcRequest.cs index 9de16b8f2..8ae4a56de 100644 --- a/src/ModelContextProtocol/Protocol/Messages/JsonRpcRequest.cs +++ b/src/ModelContextProtocol/Protocol/Messages/JsonRpcRequest.cs @@ -1,4 +1,5 @@ -using System.Text.Json.Serialization; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; namespace ModelContextProtocol.Protocol.Messages; @@ -29,5 +30,5 @@ public record JsonRpcRequest : IJsonRpcMessageWithId /// Optional parameters for the method. /// [JsonPropertyName("params")] - public object? Params { get; init; } + public JsonNode? Params { get; init; } } diff --git a/src/ModelContextProtocol/Protocol/Messages/JsonRpcResponse.cs b/src/ModelContextProtocol/Protocol/Messages/JsonRpcResponse.cs index a5a1238a3..9ced163f9 100644 --- a/src/ModelContextProtocol/Protocol/Messages/JsonRpcResponse.cs +++ b/src/ModelContextProtocol/Protocol/Messages/JsonRpcResponse.cs @@ -1,4 +1,4 @@ - +using System.Text.Json.Nodes; using System.Text.Json.Serialization; namespace ModelContextProtocol.Protocol.Messages; @@ -23,5 +23,5 @@ public record JsonRpcResponse : IJsonRpcMessageWithId /// The result of the method invocation. /// [JsonPropertyName("result")] - public required object? Result { get; init; } + public required JsonNode? Result { get; init; } } diff --git a/src/ModelContextProtocol/Protocol/Messages/ProgressToken.cs b/src/ModelContextProtocol/Protocol/Messages/ProgressToken.cs index 6183eb92e..b8a481086 100644 --- a/src/ModelContextProtocol/Protocol/Messages/ProgressToken.cs +++ b/src/ModelContextProtocol/Protocol/Messages/ProgressToken.cs @@ -12,15 +12,15 @@ namespace ModelContextProtocol.Protocol.Messages; [JsonConverter(typeof(Converter))] public readonly struct ProgressToken : IEquatable { - /// The id, either a string or a boxed long or null. - private readonly object? _id; + /// The token, either a string or a boxed long or null. + private readonly object? _token; /// Initializes a new instance of the with a specified value. /// The required ID value. public ProgressToken(string value) { Throw.IfNull(value); - _id = value; + _token = value; } /// Initializes a new instance of the with a specified value. @@ -28,28 +28,29 @@ public ProgressToken(string value) public ProgressToken(long value) { // Box the long. Progress tokens are almost always strings in practice, so this should be rare. - _id = value; + _token = value; } - /// Gets whether the identifier is uninitialized. - public bool IsDefault => _id is null; + /// Gets the underlying object for this token. + /// This will either be a , a boxed , or . + public object? Token => _token; /// public override string? ToString() => - _id is string stringValue ? $"\"{stringValue}\"" : - _id is long longValue ? longValue.ToString(CultureInfo.InvariantCulture) : + _token is string stringValue ? $"{stringValue}" : + _token is long longValue ? longValue.ToString(CultureInfo.InvariantCulture) : null; /// /// Compares this ProgressToken to another ProgressToken. /// - public bool Equals(ProgressToken other) => Equals(_id, other._id); + public bool Equals(ProgressToken other) => Equals(_token, other._token); /// public override bool Equals(object? obj) => obj is ProgressToken other && Equals(other); /// - public override int GetHashCode() => _id?.GetHashCode() ?? 0; + public override int GetHashCode() => _token?.GetHashCode() ?? 0; /// /// Compares two ProgressTokens for equality. @@ -83,7 +84,7 @@ public override void Write(Utf8JsonWriter writer, ProgressToken value, JsonSeria { Throw.IfNull(writer); - switch (value._id) + switch (value._token) { case string str: writer.WriteStringValue(str); diff --git a/src/ModelContextProtocol/Protocol/Messages/RequestId.cs b/src/ModelContextProtocol/Protocol/Messages/RequestId.cs index e6fc74418..550428ff9 100644 --- a/src/ModelContextProtocol/Protocol/Messages/RequestId.cs +++ b/src/ModelContextProtocol/Protocol/Messages/RequestId.cs @@ -31,12 +31,13 @@ public RequestId(long value) _id = value; } - /// Gets whether the identifier is uninitialized. - public bool IsDefault => _id is null; + /// Gets the underlying object for this id. + /// This will either be a , a boxed , or . + public object? Id => _id; /// public override string ToString() => - _id is string stringValue ? $"\"{stringValue}\"" : + _id is string stringValue ? stringValue : _id is long longValue ? longValue.ToString(CultureInfo.InvariantCulture) : string.Empty; diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs index 2b4acac99..49b2fe401 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs @@ -24,7 +24,6 @@ internal sealed class SseClientSessionTransport : TransportBase private Task? _receiveTask; private readonly ILogger _logger; private readonly McpServerConfig _serverConfig; - private readonly JsonSerializerOptions _jsonOptions; private readonly TaskCompletionSource _connectionEstablished; private string EndpointName => $"Client (SSE) for ({_serverConfig.Id}: {_serverConfig.Name})"; @@ -50,7 +49,6 @@ public SseClientSessionTransport(SseClientTransportOptions transportOptions, Mcp _httpClient = httpClient; _connectionCts = new CancellationTokenSource(); _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; - _jsonOptions = McpJsonUtilities.DefaultOptions; _connectionEstablished = new TaskCompletionSource(); } @@ -94,7 +92,7 @@ public override async Task SendMessageAsync( throw new InvalidOperationException("Transport not connected"); using var content = new StringContent( - JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo()), + JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.IJsonRpcMessage), Encoding.UTF8, "application/json" ); @@ -127,7 +125,7 @@ public override async Task SendMessageAsync( } else { - JsonRpcResponse initializeResponse = JsonSerializer.Deserialize(responseContent, _jsonOptions.GetTypeInfo()) ?? + JsonRpcResponse initializeResponse = JsonSerializer.Deserialize(responseContent, McpJsonUtilities.JsonContext.Default.JsonRpcResponse) ?? throw new McpTransportException("Failed to initialize client"); _logger.TransportReceivedMessageParsed(EndpointName, messageId); @@ -259,7 +257,7 @@ private async Task ProcessSseMessage(string data, CancellationToken cancellation try { - var message = JsonSerializer.Deserialize(data, _jsonOptions.GetTypeInfo()); + var message = JsonSerializer.Deserialize(data, McpJsonUtilities.JsonContext.Default.IJsonRpcMessage); if (message == null) { _logger.TransportMessageParseUnexpectedType(EndpointName, data); diff --git a/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs index eafd1f614..d4e39c8a4 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs @@ -32,19 +32,6 @@ public sealed class SseResponseStreamTransport(Stream sseResponseStream, string /// A task representing the send loop that writes JSON-RPC messages to the SSE response stream. public Task RunAsync(CancellationToken cancellationToken) { - void WriteJsonRpcMessageToBuffer(SseItem item, IBufferWriter writer) - { - if (item.EventType == "endpoint") - { - writer.Write(Encoding.UTF8.GetBytes(messageEndpoint)); - return; - } - - JsonSerializer.Serialize(GetUtf8JsonWriter(writer), item.Data, McpJsonUtilities.DefaultOptions.GetTypeInfo()); - } - - IsConnected = true; - // The very first SSE event isn't really an IJsonRpcMessage, but there's no API to write a single item of a different type, // so we fib and special-case the "endpoint" event type in the formatter. if (!_outgoingSseChannel.Writer.TryWrite(new SseItem(null, "endpoint"))) @@ -52,10 +39,23 @@ void WriteJsonRpcMessageToBuffer(SseItem item, IBufferWriter item, IBufferWriter writer) + { + if (item.EventType == "endpoint") + { + writer.Write(Encoding.UTF8.GetBytes(messageEndpoint)); + return; + } + + JsonSerializer.Serialize(GetUtf8JsonWriter(writer), item.Data, McpJsonUtilities.JsonContext.Default.IJsonRpcMessage!); + } + /// public ChannelReader MessageReader => _incomingChannel.Reader; @@ -76,7 +76,8 @@ public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken ca throw new InvalidOperationException($"Transport is not connected. Make sure to call {nameof(RunAsync)} first."); } - await _outgoingSseChannel.Writer.WriteAsync(new SseItem(message), cancellationToken); + // Emit redundant "event: message" lines for better compatibility with other SDKs. + await _outgoingSseChannel.Writer.WriteAsync(new SseItem(message, SseParser.EventTypeDefault), cancellationToken); } /// diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs new file mode 100644 index 000000000..af304c897 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs @@ -0,0 +1,40 @@ +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Logging; +using ModelContextProtocol.Protocol.Messages; +using System.Diagnostics; + +namespace ModelContextProtocol.Protocol.Transport; + +/// Provides the client side of a stdio-based session transport. +internal sealed class StdioClientSessionTransport : StreamClientSessionTransport +{ + private readonly StdioClientTransportOptions _options; + private readonly Process _process; + + public StdioClientSessionTransport(StdioClientTransportOptions options, Process process, string endpointName, ILoggerFactory? loggerFactory) + : base(process.StandardInput, process.StandardOutput, endpointName, loggerFactory) + { + _process = process; + _options = options; + } + + /// + public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + { + if (_process.HasExited) + { + Logger.TransportNotConnected(EndpointName); + throw new McpTransportException("Transport is not connected"); + } + + await base.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); + } + + /// + protected override ValueTask CleanupAsync(CancellationToken cancellationToken) + { + StdioClientTransport.DisposeProcess(_process, processStarted: true, Logger, _options.ShutdownTimeout, EndpointName); + + return base.CleanupAsync(cancellationToken); + } +} diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientStreamTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientStreamTransport.cs deleted file mode 100644 index 35c957e52..000000000 --- a/src/ModelContextProtocol/Protocol/Transport/StdioClientStreamTransport.cs +++ /dev/null @@ -1,326 +0,0 @@ -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; -using ModelContextProtocol.Logging; -using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Utils; -using ModelContextProtocol.Utils.Json; -using System.Diagnostics; -using System.Text; -using System.Text.Json; - -namespace ModelContextProtocol.Protocol.Transport; - -/// -/// Implements the MCP transport protocol over standard input/output streams. -/// -internal sealed class StdioClientStreamTransport : TransportBase -{ - private readonly StdioClientTransportOptions _options; - private readonly McpServerConfig _serverConfig; - private readonly ILogger _logger; - private readonly JsonSerializerOptions _jsonOptions; - private readonly DataReceivedEventHandler _logProcessErrors; - private readonly SemaphoreSlim _sendLock = new(1, 1); - private Process? _process; - private Task? _readTask; - private CancellationTokenSource? _shutdownCts; - private bool _processStarted; - - private string EndpointName => $"Client (stdio) for ({_serverConfig.Id}: {_serverConfig.Name})"; - - /// - /// Initializes a new instance of the StdioTransport class. - /// - /// Configuration options for the transport. - /// The server configuration for the transport. - /// A logger factory for creating loggers. - public StdioClientStreamTransport(StdioClientTransportOptions options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory = null) - : base(loggerFactory) - { - Throw.IfNull(options); - Throw.IfNull(serverConfig); - - _options = options; - _serverConfig = serverConfig; - _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; - _logProcessErrors = (sender, args) => _logger.TransportError(EndpointName, args.Data ?? "(no data)"); - _jsonOptions = McpJsonUtilities.DefaultOptions; - } - - /// - public async Task ConnectAsync(CancellationToken cancellationToken = default) - { - if (IsConnected) - { - _logger.TransportAlreadyConnected(EndpointName); - throw new McpTransportException("Transport is already connected"); - } - - try - { - _logger.TransportConnecting(EndpointName); - - _shutdownCts = new CancellationTokenSource(); - - UTF8Encoding noBomUTF8 = new(encoderShouldEmitUTF8Identifier: false); - - var startInfo = new ProcessStartInfo - { - FileName = _options.Command, - RedirectStandardInput = true, - RedirectStandardOutput = true, - RedirectStandardError = true, - UseShellExecute = false, - CreateNoWindow = true, - WorkingDirectory = _options.WorkingDirectory ?? Environment.CurrentDirectory, - StandardOutputEncoding = noBomUTF8, - StandardErrorEncoding = noBomUTF8, -#if NET - StandardInputEncoding = noBomUTF8, -#endif - }; - - if (!string.IsNullOrWhiteSpace(_options.Arguments)) - { - startInfo.Arguments = _options.Arguments; - } - - if (_options.EnvironmentVariables != null) - { - foreach (var entry in _options.EnvironmentVariables) - { - startInfo.Environment[entry.Key] = entry.Value; - } - } - - _logger.CreateProcessForTransport(EndpointName, _options.Command, - startInfo.Arguments, string.Join(", ", startInfo.Environment.Select(kvp => kvp.Key + "=" + kvp.Value)), - startInfo.WorkingDirectory, _options.ShutdownTimeout.ToString()); - - _process = new Process { StartInfo = startInfo }; - - // Set up error logging - _process.ErrorDataReceived += _logProcessErrors; - - // We need both stdin and stdout to use a no-BOM UTF-8 encoding. On .NET Core, - // we can use ProcessStartInfo.StandardOutputEncoding/StandardInputEncoding, but - // StandardInputEncoding doesn't exist on .NET Framework; instead, it always picks - // up the encoding from Console.InputEncoding. As such, when not targeting .NET Core, - // we temporarily change Console.InputEncoding to no-BOM UTF-8 around the Process.Start - // call, to ensure it picks up the correct encoding. -#if NET - _processStarted = _process.Start(); -#else - Encoding originalInputEncoding = Console.InputEncoding; - try - { - Console.InputEncoding = noBomUTF8; - _processStarted = _process.Start(); - } - finally - { - Console.InputEncoding = originalInputEncoding; - } -#endif - - if (!_processStarted) - { - _logger.TransportProcessStartFailed(EndpointName); - throw new McpTransportException("Failed to start MCP server process"); - } - - _logger.TransportProcessStarted(EndpointName, _process.Id); - - _process.BeginErrorReadLine(); - - // Start reading messages in the background - _readTask = Task.Run(() => ReadMessagesAsync(_shutdownCts.Token), CancellationToken.None); - _logger.TransportReadingMessages(EndpointName); - - SetConnected(true); - } - catch (Exception ex) - { - _logger.TransportConnectFailed(EndpointName, ex); - await CleanupAsync(cancellationToken).ConfigureAwait(false); - throw new McpTransportException("Failed to connect transport", ex); - } - } - - /// - public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) - { - using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false); - - if (!IsConnected || _process?.HasExited == true) - { - _logger.TransportNotConnected(EndpointName); - throw new McpTransportException("Transport is not connected"); - } - - string id = "(no id)"; - if (message is IJsonRpcMessageWithId messageWithId) - { - id = messageWithId.Id.ToString(); - } - - try - { - var json = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo()); - _logger.TransportSendingMessage(EndpointName, id, json); - _logger.TransportMessageBytesUtf8(EndpointName, json); - - // Write the message followed by a newline using our UTF-8 writer - await _process!.StandardInput.WriteLineAsync(json).ConfigureAwait(false); - await _process.StandardInput.FlushAsync(cancellationToken).ConfigureAwait(false); - - _logger.TransportSentMessage(EndpointName, id); - } - catch (Exception ex) - { - _logger.TransportSendFailed(EndpointName, id, ex); - throw new McpTransportException("Failed to send message", ex); - } - } - - /// - public override async ValueTask DisposeAsync() - { - await CleanupAsync(CancellationToken.None).ConfigureAwait(false); - } - - private async Task ReadMessagesAsync(CancellationToken cancellationToken) - { - try - { - _logger.TransportEnteringReadMessagesLoop(EndpointName); - - while (!cancellationToken.IsCancellationRequested && !_process!.HasExited) - { - _logger.TransportWaitingForMessage(EndpointName); - var line = await _process.StandardOutput.ReadLineAsync(cancellationToken).ConfigureAwait(false); - if (line == null) - { - _logger.TransportEndOfStream(EndpointName); - break; - } - - if (string.IsNullOrWhiteSpace(line)) - { - continue; - } - - _logger.TransportReceivedMessage(EndpointName, line); - _logger.TransportMessageBytesUtf8(EndpointName, line); - - await ProcessMessageAsync(line, cancellationToken).ConfigureAwait(false); - } - _logger.TransportExitingReadMessagesLoop(EndpointName); - } - catch (OperationCanceledException) - { - _logger.TransportReadMessagesCancelled(EndpointName); - // Normal shutdown - } - catch (Exception ex) - { - _logger.TransportReadMessagesFailed(EndpointName, ex); - } - finally - { - await CleanupAsync(cancellationToken).ConfigureAwait(false); - } - } - - private async Task ProcessMessageAsync(string line, CancellationToken cancellationToken) - { - try - { - line=line.Trim();//Fixes an error when the service prefixes nonprintable characters - var message = JsonSerializer.Deserialize(line, _jsonOptions.GetTypeInfo()); - if (message != null) - { - string messageId = "(no id)"; - if (message is IJsonRpcMessageWithId messageWithId) - { - messageId = messageWithId.Id.ToString(); - } - _logger.TransportReceivedMessageParsed(EndpointName, messageId); - await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); - _logger.TransportMessageWritten(EndpointName, messageId); - } - else - { - _logger.TransportMessageParseUnexpectedType(EndpointName, line); - } - } - catch (JsonException ex) - { - _logger.TransportMessageParseFailed(EndpointName, line, ex); - } - } - - private async Task CleanupAsync(CancellationToken cancellationToken) - { - _logger.TransportCleaningUp(EndpointName); - - if (_process is Process process && _processStarted && !process.HasExited) - { - try - { - // Wait for the process to exit - _logger.TransportWaitingForShutdown(EndpointName); - - // Kill the while process tree because the process may spawn child processes - // and Node.js does not kill its children when it exits properly - process.KillTree(_options.ShutdownTimeout); - } - catch (Exception ex) - { - _logger.TransportShutdownFailed(EndpointName, ex); - } - finally - { - process.ErrorDataReceived -= _logProcessErrors; - process.Dispose(); - _process = null; - } - } - - if (_shutdownCts is { } shutdownCts) - { - await shutdownCts.CancelAsync().ConfigureAwait(false); - shutdownCts.Dispose(); - _shutdownCts = null; - } - - if (_readTask is Task readTask) - { - try - { - _logger.TransportWaitingForReadTask(EndpointName); - await readTask.WaitAsync(TimeSpan.FromSeconds(5), cancellationToken).ConfigureAwait(false); - } - catch (TimeoutException) - { - _logger.TransportCleanupReadTaskTimeout(EndpointName); - } - catch (OperationCanceledException) - { - _logger.TransportCleanupReadTaskCancelled(EndpointName); - } - catch (Exception ex) - { - _logger.TransportCleanupReadTaskFailed(EndpointName, ex); - } - finally - { - _logger.TransportReadTaskCleanedUp(EndpointName); - _readTask = null; - } - } - - SetConnected(false); - _logger.TransportCleanedUp(EndpointName); - } -} diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs index d2b51b950..774677f13 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs @@ -1,10 +1,16 @@ using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Logging; using ModelContextProtocol.Utils; +using System.Diagnostics; +using System.Text; + +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously namespace ModelContextProtocol.Protocol.Transport; /// -/// Implements the MCP transport protocol over standard input/output streams. +/// Provides a client MCP transport implemented via "stdio" (standard input/output). /// public sealed class StdioClientTransport : IClientTransport { @@ -13,7 +19,7 @@ public sealed class StdioClientTransport : IClientTransport private readonly ILoggerFactory? _loggerFactory; /// - /// Initializes a new instance of the StdioTransport class. + /// Initializes a new instance of the class. /// /// Configuration options for the transport. /// The server configuration for the transport. @@ -31,17 +37,121 @@ public StdioClientTransport(StdioClientTransportOptions options, McpServerConfig /// public async Task ConnectAsync(CancellationToken cancellationToken = default) { - var streamTransport = new StdioClientStreamTransport(_options, _serverConfig, _loggerFactory); + string endpointName = $"Client (stdio) for ({_serverConfig.Id}: {_serverConfig.Name})"; + + Process? process = null; + bool processStarted = false; + ILogger logger = (ILogger?)_loggerFactory?.CreateLogger() ?? NullLogger.Instance; try { - await streamTransport.ConnectAsync(cancellationToken).ConfigureAwait(false); - return streamTransport; + logger.TransportConnecting(endpointName); + + UTF8Encoding noBomUTF8 = new(encoderShouldEmitUTF8Identifier: false); + + ProcessStartInfo startInfo = new() + { + FileName = _options.Command, + RedirectStandardInput = true, + RedirectStandardOutput = true, + RedirectStandardError = true, + UseShellExecute = false, + CreateNoWindow = true, + WorkingDirectory = _options.WorkingDirectory ?? Environment.CurrentDirectory, + StandardOutputEncoding = noBomUTF8, + StandardErrorEncoding = noBomUTF8, +#if NET + StandardInputEncoding = noBomUTF8, +#endif + }; + + if (!string.IsNullOrWhiteSpace(_options.Arguments)) + { + startInfo.Arguments = _options.Arguments; + } + + if (_options.EnvironmentVariables != null) + { + foreach (var entry in _options.EnvironmentVariables) + { + startInfo.Environment[entry.Key] = entry.Value; + } + } + + logger.CreateProcessForTransport(endpointName, _options.Command, + startInfo.Arguments, string.Join(", ", startInfo.Environment.Select(kvp => kvp.Key + "=" + kvp.Value)), + startInfo.WorkingDirectory, _options.ShutdownTimeout.ToString()); + + process = new() { StartInfo = startInfo }; + + // Set up error logging + process.ErrorDataReceived += (sender, args) => logger.ReadStderr(endpointName, args.Data ?? "(no data)"); + + // We need both stdin and stdout to use a no-BOM UTF-8 encoding. On .NET Core, + // we can use ProcessStartInfo.StandardOutputEncoding/StandardInputEncoding, but + // StandardInputEncoding doesn't exist on .NET Framework; instead, it always picks + // up the encoding from Console.InputEncoding. As such, when not targeting .NET Core, + // we temporarily change Console.InputEncoding to no-BOM UTF-8 around the Process.Start + // call, to ensure it picks up the correct encoding. +#if NET + processStarted = process.Start(); +#else + Encoding originalInputEncoding = Console.InputEncoding; + try + { + Console.InputEncoding = noBomUTF8; + processStarted = process.Start(); + } + finally + { + Console.InputEncoding = originalInputEncoding; + } +#endif + + if (!processStarted) + { + logger.TransportProcessStartFailed(endpointName); + throw new McpTransportException("Failed to start MCP server process"); + } + + logger.TransportProcessStarted(endpointName, process.Id); + + process.BeginErrorReadLine(); + + return new StdioClientSessionTransport(_options, process, endpointName, _loggerFactory); } - catch + catch (Exception ex) + { + logger.TransportConnectFailed(endpointName, ex); + DisposeProcess(process, processStarted, logger, _options.ShutdownTimeout, endpointName); + throw new McpTransportException("Failed to connect transport", ex); + } + } + + internal static void DisposeProcess( + Process? process, bool processStarted, ILogger logger, TimeSpan shutdownTimeout, string endpointName) + { + if (process is not null) { - await streamTransport.DisposeAsync().ConfigureAwait(false); - throw; + try + { + if (processStarted && !process.HasExited) + { + // Wait for the process to exit. + // Kill the while process tree because the process may spawn child processes + // and Node.js does not kill its children when it exits properly. + logger.TransportWaitingForShutdown(endpointName); + process.KillTree(shutdownTimeout); + } + } + catch (Exception ex) + { + logger.TransportShutdownFailed(endpointName, ex); + } + finally + { + process.Dispose(); + } } } } diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs index 7779edc99..58077dbb2 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs @@ -1,38 +1,15 @@ -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; -using ModelContextProtocol.Logging; -using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Server; using ModelContextProtocol.Utils; -using ModelContextProtocol.Utils.Json; -using System.Text; -using System.Text.Json; namespace ModelContextProtocol.Protocol.Transport; /// -/// Provides an implementation of the MCP transport protocol over standard input/output streams. +/// Provides a server MCP transport implemented via "stdio" (standard input/output). /// -public sealed class StdioServerTransport : TransportBase, ITransport +public sealed class StdioServerTransport : StreamServerTransport { - private static readonly byte[] s_newlineBytes = "\n"u8.ToArray(); - - private readonly string _serverName; - private readonly ILogger _logger; - - private readonly JsonSerializerOptions _jsonOptions = McpJsonUtilities.DefaultOptions; - private readonly TextReader _stdInReader; - private readonly Stream _stdOutStream; - - private readonly SemaphoreSlim _sendLock = new(1, 1); - private readonly CancellationTokenSource _shutdownCts = new(); - - private readonly Task _readLoopCompleted; - private int _disposed = 0; - - private string EndpointName => $"Server (stdio) ({_serverName})"; - /// /// Initializes a new instance of the class, using /// and for input and output streams. @@ -69,14 +46,6 @@ public StdioServerTransport(McpServerOptions serverOptions, ILoggerFactory? logg { } - private static string GetServerName(McpServerOptions serverOptions) - { - Throw.IfNull(serverOptions); - Throw.IfNull(serverOptions.ServerInfo); - Throw.IfNull(serverOptions.ServerInfo.Name); - return serverOptions.ServerInfo.Name; - } - /// /// Initializes a new instance of the class, using /// and for input and output streams. @@ -90,186 +59,20 @@ private static string GetServerName(McpServerOptions serverOptions) /// to , as that will interfere with the transport's output. /// /// - public StdioServerTransport(string serverName, ILoggerFactory? loggerFactory) - : this(serverName, stdinStream: null, stdoutStream: null, loggerFactory) - { - } - - /// - /// Initializes a new instance of the class with explicit input/output streams. - /// - /// The name of the server. - /// The input to use as standard input. If , will be used. - /// The output to use as standard output. If , will be used. - /// Optional logger factory used for logging employed by the transport. - /// is . - /// - /// - /// This constructor is useful for testing scenarios where you want to redirect input/output. - /// - /// - public StdioServerTransport(string serverName, Stream? stdinStream = null, Stream? stdoutStream = null, ILoggerFactory? loggerFactory = null) - : base(loggerFactory) - { - Throw.IfNull(serverName); - - _serverName = serverName; - _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; - - _stdInReader = new StreamReader(stdinStream ?? Console.OpenStandardInput(), Encoding.UTF8); - _stdOutStream = stdoutStream ?? new BufferedStream(Console.OpenStandardOutput()); - - SetConnected(true); - _readLoopCompleted = Task.Run(ReadMessagesAsync, _shutdownCts.Token); - } - - /// - public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + public StdioServerTransport(string serverName, ILoggerFactory? loggerFactory = null) + : base(Console.OpenStandardInput(), + new BufferedStream(Console.OpenStandardOutput()), + serverName ?? throw new ArgumentNullException(nameof(serverName)), + loggerFactory) { - using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false); - - if (!IsConnected) - { - _logger.TransportNotConnected(EndpointName); - throw new McpTransportException("Transport is not connected"); - } - - string id = "(no id)"; - if (message is IJsonRpcMessageWithId messageWithId) - { - id = messageWithId.Id.ToString(); - } - - try - { - _logger.TransportSendingMessage(EndpointName, id); - - await JsonSerializer.SerializeAsync(_stdOutStream, message, _jsonOptions.GetTypeInfo(), cancellationToken).ConfigureAwait(false); - await _stdOutStream.WriteAsync(s_newlineBytes, cancellationToken).ConfigureAwait(false); - await _stdOutStream.FlushAsync(cancellationToken).ConfigureAwait(false);; - - _logger.TransportSentMessage(EndpointName, id); - } - catch (Exception ex) - { - _logger.TransportSendFailed(EndpointName, id, ex); - throw new McpTransportException("Failed to send message", ex); - } } - private async Task ReadMessagesAsync() - { - CancellationToken shutdownToken = _shutdownCts.Token; - try - { - _logger.TransportEnteringReadMessagesLoop(EndpointName); - - while (!shutdownToken.IsCancellationRequested) - { - _logger.TransportWaitingForMessage(EndpointName); - - var line = await _stdInReader.ReadLineAsync(shutdownToken).ConfigureAwait(false); - if (string.IsNullOrWhiteSpace(line)) - { - if (line is null) - { - _logger.TransportEndOfStream(EndpointName); - break; - } - - continue; - } - - _logger.TransportReceivedMessage(EndpointName, line); - _logger.TransportMessageBytesUtf8(EndpointName, line); - - try - { - if (JsonSerializer.Deserialize(line, _jsonOptions.GetTypeInfo()) is { } message) - { - string messageId = "(no id)"; - if (message is IJsonRpcMessageWithId messageWithId) - { - messageId = messageWithId.Id.ToString(); - } - _logger.TransportReceivedMessageParsed(EndpointName, messageId); - - await WriteMessageAsync(message, shutdownToken).ConfigureAwait(false); - _logger.TransportMessageWritten(EndpointName, messageId); - } - else - { - _logger.TransportMessageParseUnexpectedType(EndpointName, line); - } - } - catch (JsonException ex) - { - _logger.TransportMessageParseFailed(EndpointName, line, ex); - // Continue reading even if we fail to parse a message - } - } - - _logger.TransportExitingReadMessagesLoop(EndpointName); - } - catch (OperationCanceledException) - { - _logger.TransportReadMessagesCancelled(EndpointName); - } - catch (Exception ex) - { - _logger.TransportReadMessagesFailed(EndpointName, ex); - } - finally - { - SetConnected(false); - } - } - - /// - public override async ValueTask DisposeAsync() + private static string GetServerName(McpServerOptions serverOptions) { - if (Interlocked.Exchange(ref _disposed, 1) != 0) - { - return; - } - - try - { - _logger.TransportCleaningUp(EndpointName); - - // Signal to the stdin reading loop to stop. - await _shutdownCts.CancelAsync().ConfigureAwait(false); - _shutdownCts.Dispose(); - - // Dispose of stdin/out. Cancellation may not be able to wake up operations - // synchronously blocked in a syscall; we need to forcefully close the handle / file descriptor. - _stdInReader?.Dispose(); - _stdOutStream?.Dispose(); + Throw.IfNull(serverOptions); + Throw.IfNull(serverOptions.ServerInfo); + Throw.IfNull(serverOptions.ServerInfo.Name); - // Make sure the work has quiesced. - try - { - _logger.TransportWaitingForReadTask(EndpointName); - await _readLoopCompleted.ConfigureAwait(false); - _logger.TransportReadTaskCleanedUp(EndpointName); - } - catch (TimeoutException) - { - _logger.TransportCleanupReadTaskTimeout(EndpointName); - } - catch (OperationCanceledException) - { - _logger.TransportCleanupReadTaskCancelled(EndpointName); - } - catch (Exception ex) - { - _logger.TransportCleanupReadTaskFailed(EndpointName, ex); - } - } - finally - { - SetConnected(false); - _logger.TransportCleanedUp(EndpointName); - } + return serverOptions.ServerInfo.Name; } } diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs new file mode 100644 index 000000000..589e9078b --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs @@ -0,0 +1,195 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Logging; +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Utils; +using ModelContextProtocol.Utils.Json; +using System.Text.Json; + +namespace ModelContextProtocol.Protocol.Transport; + +/// Provides the client side of a stream-based session transport. +internal class StreamClientSessionTransport : TransportBase +{ + private readonly TextReader _serverOutput; + private readonly TextWriter _serverInput; + private readonly SemaphoreSlim _sendLock = new(1, 1); + private CancellationTokenSource? _shutdownCts = new(); + private Task? _readTask; + + /// + /// Initializes a new instance of the class. + /// + public StreamClientSessionTransport( + TextWriter serverInput, TextReader serverOutput, string endpointName, ILoggerFactory? loggerFactory) + : base(loggerFactory) + { + Logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; + _serverOutput = serverOutput; + _serverInput = serverInput; + EndpointName = endpointName; + + // Start reading messages in the background. We use the rarer pattern of new Task + Start + // in order to ensure that the body of the task will always see _readTask initialized. + // It is then able to reliably null it out on completion. + Logger.TransportReadingMessages(endpointName); + var readTask = new Task( + thisRef => ((StreamClientSessionTransport)thisRef!).ReadMessagesAsync(_shutdownCts.Token), + this, + TaskCreationOptions.DenyChildAttach); + _readTask = readTask.Unwrap(); + readTask.Start(); + + SetConnected(true); + } + + protected ILogger Logger { get; private set; } + + protected string EndpointName { get; } + + /// + public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + { + if (!IsConnected) + { + Logger.TransportNotConnected(EndpointName); + throw new McpTransportException("Transport is not connected"); + } + + string id = "(no id)"; + if (message is IJsonRpcMessageWithId messageWithId) + { + id = messageWithId.Id.ToString(); + } + + var json = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage))); + + using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false); + try + { + Logger.TransportSendingMessage(EndpointName, id, json); + Logger.TransportMessageBytesUtf8(EndpointName, json); + + // Write the message followed by a newline using our UTF-8 writer + await _serverInput.WriteLineAsync(json).ConfigureAwait(false); + await _serverInput.FlushAsync(cancellationToken).ConfigureAwait(false); + + Logger.TransportSentMessage(EndpointName, id); + } + catch (Exception ex) + { + Logger.TransportSendFailed(EndpointName, id, ex); + throw new McpTransportException("Failed to send message", ex); + } + } + + /// + public override ValueTask DisposeAsync() => + CleanupAsync(CancellationToken.None); + + private async Task ReadMessagesAsync(CancellationToken cancellationToken) + { + try + { + Logger.TransportEnteringReadMessagesLoop(EndpointName); + + while (!cancellationToken.IsCancellationRequested) + { + Logger.TransportWaitingForMessage(EndpointName); + if (await _serverOutput.ReadLineAsync(cancellationToken).ConfigureAwait(false) is not string line) + { + Logger.TransportEndOfStream(EndpointName); + break; + } + + if (string.IsNullOrWhiteSpace(line)) + { + continue; + } + + Logger.TransportReceivedMessage(EndpointName, line); + Logger.TransportMessageBytesUtf8(EndpointName, line); + + await ProcessMessageAsync(line, cancellationToken).ConfigureAwait(false); + } + Logger.TransportExitingReadMessagesLoop(EndpointName); + } + catch (OperationCanceledException) + { + Logger.TransportReadMessagesCancelled(EndpointName); + } + catch (Exception ex) + { + Logger.TransportReadMessagesFailed(EndpointName, ex); + } + finally + { + _readTask = null; + await CleanupAsync(cancellationToken).ConfigureAwait(false); + } + } + + private async Task ProcessMessageAsync(string line, CancellationToken cancellationToken) + { + try + { + var message = (IJsonRpcMessage?)JsonSerializer.Deserialize(line.AsSpan().Trim(), McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage))); + if (message != null) + { + string messageId = "(no id)"; + if (message is IJsonRpcMessageWithId messageWithId) + { + messageId = messageWithId.Id.ToString(); + } + + Logger.TransportReceivedMessageParsed(EndpointName, messageId); + await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); + Logger.TransportMessageWritten(EndpointName, messageId); + } + else + { + Logger.TransportMessageParseUnexpectedType(EndpointName, line); + } + } + catch (JsonException ex) + { + Logger.TransportMessageParseFailed(EndpointName, line, ex); + } + } + + protected virtual async ValueTask CleanupAsync(CancellationToken cancellationToken) + { + Logger.TransportCleaningUp(EndpointName); + + if (Interlocked.Exchange(ref _shutdownCts, null) is { } shutdownCts) + { + await shutdownCts.CancelAsync().ConfigureAwait(false); + shutdownCts.Dispose(); + } + + if (Interlocked.Exchange(ref _readTask, null) is Task readTask) + { + try + { + Logger.TransportWaitingForReadTask(EndpointName); + await readTask.WaitAsync(TimeSpan.FromSeconds(5), cancellationToken).ConfigureAwait(false); + Logger.TransportReadTaskCleanedUp(EndpointName); + } + catch (TimeoutException) + { + Logger.TransportCleanupReadTaskTimeout(EndpointName); + } + catch (OperationCanceledException) + { + Logger.TransportCleanupReadTaskCancelled(EndpointName); + } + catch (Exception ex) + { + Logger.TransportCleanupReadTaskFailed(EndpointName, ex); + } + } + + SetConnected(false); + Logger.TransportCleanedUp(EndpointName); + } +} diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamClientTransport.cs new file mode 100644 index 000000000..80bd61df5 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Transport/StreamClientTransport.cs @@ -0,0 +1,47 @@ +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Utils; + +namespace ModelContextProtocol.Protocol.Transport; + +/// +/// Provides a client MCP transport implemented around a pair of input/output streams. +/// +public sealed class StreamClientTransport : IClientTransport +{ + private readonly Stream _serverInput; + private readonly Stream _serverOutput; + private readonly ILoggerFactory? _loggerFactory; + + /// + /// Initializes a new instance of the class. + /// + /// + /// The stream representing the connected server's input. + /// Writes to this stream will be sent to the server. + /// + /// + /// The stream representing the connected server's output. + /// Reads from this stream will receive messages from the server. + /// + /// A logger factory for creating loggers. + public StreamClientTransport( + Stream serverInput, Stream serverOutput, ILoggerFactory? loggerFactory = null) + { + Throw.IfNull(serverInput); + Throw.IfNull(serverOutput); + + _serverInput = serverInput; + _serverOutput = serverOutput; + _loggerFactory = loggerFactory; + } + + /// + public Task ConnectAsync(CancellationToken cancellationToken = default) + { + return Task.FromResult(new StreamClientSessionTransport( + new StreamWriter(_serverInput), + new StreamReader(_serverOutput), + "Client (stream)", + _loggerFactory)); + } +} diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs new file mode 100644 index 000000000..ebdf36350 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs @@ -0,0 +1,209 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Logging; +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Utils; +using ModelContextProtocol.Utils.Json; +using System.IO.Pipelines; +using System.Text; +using System.Text.Json; + +namespace ModelContextProtocol.Protocol.Transport; + +/// +/// Provides a server MCP transport implemented around a pair of input/output streams. +/// +public class StreamServerTransport : TransportBase, ITransport +{ + private static readonly byte[] s_newlineBytes = "\n"u8.ToArray(); + + private readonly ILogger _logger; + + private readonly TextReader _inputReader; + private readonly Stream _outputStream; + private readonly string _endpointName; + + private readonly SemaphoreSlim _sendLock = new(1, 1); + private readonly CancellationTokenSource _shutdownCts = new(); + + private readonly Task _readLoopCompleted; + private int _disposed = 0; + + /// + /// Initializes a new instance of the class with explicit input/output streams. + /// + /// The input to use as standard input. + /// The output to use as standard output. + /// Optional name of the server, used for diagnostic purposes, like logging. + /// Optional logger factory used for logging employed by the transport. + /// is . + /// is . + public StreamServerTransport(Stream inputStream, Stream outputStream, string? serverName = null, ILoggerFactory? loggerFactory = null) + : base(loggerFactory) + { + Throw.IfNull(inputStream); + Throw.IfNull(outputStream); + + _logger = loggerFactory?.CreateLogger(GetType()) ?? NullLogger.Instance; + + _inputReader = new StreamReader(inputStream, Encoding.UTF8); + _outputStream = outputStream; + + SetConnected(true); + _readLoopCompleted = Task.Run(ReadMessagesAsync, _shutdownCts.Token); + + _endpointName = serverName is not null ? $"Server (stream) ({serverName})" : "Server (stream)"; + } + + /// + public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + { + if (!IsConnected) + { + _logger.TransportNotConnected(_endpointName); + throw new McpTransportException("Transport is not connected"); + } + + using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false); + + string id = "(no id)"; + if (message is IJsonRpcMessageWithId messageWithId) + { + id = messageWithId.Id.ToString(); + } + + try + { + _logger.TransportSendingMessage(_endpointName, id); + + await JsonSerializer.SerializeAsync(_outputStream, message, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage)), cancellationToken).ConfigureAwait(false); + await _outputStream.WriteAsync(s_newlineBytes, cancellationToken).ConfigureAwait(false); + await _outputStream.FlushAsync(cancellationToken).ConfigureAwait(false);; + + _logger.TransportSentMessage(_endpointName, id); + } + catch (Exception ex) + { + _logger.TransportSendFailed(_endpointName, id, ex); + throw new McpTransportException("Failed to send message", ex); + } + } + + private async Task ReadMessagesAsync() + { + CancellationToken shutdownToken = _shutdownCts.Token; + try + { + _logger.TransportEnteringReadMessagesLoop(_endpointName); + + while (!shutdownToken.IsCancellationRequested) + { + _logger.TransportWaitingForMessage(_endpointName); + + var line = await _inputReader.ReadLineAsync(shutdownToken).ConfigureAwait(false); + if (string.IsNullOrWhiteSpace(line)) + { + if (line is null) + { + _logger.TransportEndOfStream(_endpointName); + break; + } + + continue; + } + + _logger.TransportReceivedMessage(_endpointName, line); + _logger.TransportMessageBytesUtf8(_endpointName, line); + + try + { + if (JsonSerializer.Deserialize(line, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage))) is IJsonRpcMessage message) + { + string messageId = "(no id)"; + if (message is IJsonRpcMessageWithId messageWithId) + { + messageId = messageWithId.Id.ToString(); + } + _logger.TransportReceivedMessageParsed(_endpointName, messageId); + + await WriteMessageAsync(message, shutdownToken).ConfigureAwait(false); + _logger.TransportMessageWritten(_endpointName, messageId); + } + else + { + _logger.TransportMessageParseUnexpectedType(_endpointName, line); + } + } + catch (JsonException ex) + { + _logger.TransportMessageParseFailed(_endpointName, line, ex); + // Continue reading even if we fail to parse a message + } + } + + _logger.TransportExitingReadMessagesLoop(_endpointName); + } + catch (OperationCanceledException) + { + _logger.TransportReadMessagesCancelled(_endpointName); + } + catch (Exception ex) + { + _logger.TransportReadMessagesFailed(_endpointName, ex); + } + finally + { + SetConnected(false); + } + } + + /// + public override async ValueTask DisposeAsync() + { + if (Interlocked.Exchange(ref _disposed, 1) != 0) + { + return; + } + + try + { + _logger.TransportCleaningUp(_endpointName); + + // Signal to the stdin reading loop to stop. + await _shutdownCts.CancelAsync().ConfigureAwait(false); + _shutdownCts.Dispose(); + + // Dispose of stdin/out. Cancellation may not be able to wake up operations + // synchronously blocked in a syscall; we need to forcefully close the handle / file descriptor. + _inputReader?.Dispose(); + _outputStream?.Dispose(); + + // Make sure the work has quiesced. + try + { + _logger.TransportWaitingForReadTask(_endpointName); + await _readLoopCompleted.ConfigureAwait(false); + _logger.TransportReadTaskCleanedUp(_endpointName); + } + catch (TimeoutException) + { + _logger.TransportCleanupReadTaskTimeout(_endpointName); + } + catch (OperationCanceledException) + { + _logger.TransportCleanupReadTaskCancelled(_endpointName); + } + catch (Exception ex) + { + _logger.TransportCleanupReadTaskFailed(_endpointName, ex); + } + } + finally + { + SetConnected(false); + _logger.TransportCleanedUp(_endpointName); + } + + GC.SuppressFinalize(this); + } +} diff --git a/src/ModelContextProtocol/Protocol/Types/CallToolRequestParams.cs b/src/ModelContextProtocol/Protocol/Types/CallToolRequestParams.cs index b7c8d3f63..08c6e5770 100644 --- a/src/ModelContextProtocol/Protocol/Types/CallToolRequestParams.cs +++ b/src/ModelContextProtocol/Protocol/Types/CallToolRequestParams.cs @@ -1,4 +1,6 @@ -namespace ModelContextProtocol.Protocol.Types; +using System.Text.Json; + +namespace ModelContextProtocol.Protocol.Types; /// /// Used by the client to invoke a tool provided by the server. @@ -16,5 +18,5 @@ public class CallToolRequestParams : RequestParams /// Optional arguments to pass to the tool. /// [System.Text.Json.Serialization.JsonPropertyName("arguments")] - public Dictionary? Arguments { get; init; } + public Dictionary? Arguments { get; init; } } diff --git a/src/ModelContextProtocol/Protocol/Types/Capabilities.cs b/src/ModelContextProtocol/Protocol/Types/Capabilities.cs index f33357783..c0cf41977 100644 --- a/src/ModelContextProtocol/Protocol/Types/Capabilities.cs +++ b/src/ModelContextProtocol/Protocol/Types/Capabilities.cs @@ -55,7 +55,7 @@ public class SamplingCapability /// Gets or sets the handler for sampling requests. [JsonIgnore] - public Func>? SamplingHandler { get; set; } + public Func, CancellationToken, Task>? SamplingHandler { get; set; } } /// diff --git a/src/ModelContextProtocol/Protocol/Types/ContextInclusion.cs b/src/ModelContextProtocol/Protocol/Types/ContextInclusion.cs index db45dd48f..bb3bae905 100644 --- a/src/ModelContextProtocol/Protocol/Types/ContextInclusion.cs +++ b/src/ModelContextProtocol/Protocol/Types/ContextInclusion.cs @@ -6,7 +6,7 @@ namespace ModelContextProtocol.Protocol.Types; /// A request to include context from one or more MCP servers (including the caller), to be attached to the prompt. /// See the schema for details /// -[JsonConverter(typeof(JsonStringEnumConverter))] +[JsonConverter(typeof(CustomizableJsonStringEnumConverter))] public enum ContextInclusion { /// diff --git a/src/ModelContextProtocol/Protocol/Types/ListPromptsRequestParams.cs b/src/ModelContextProtocol/Protocol/Types/ListPromptsRequestParams.cs index 419b6fceb..a5500d410 100644 --- a/src/ModelContextProtocol/Protocol/Types/ListPromptsRequestParams.cs +++ b/src/ModelContextProtocol/Protocol/Types/ListPromptsRequestParams.cs @@ -4,12 +4,4 @@ /// Sent from the client to request a list of prompts and prompt templates the server has. /// See the schema for details /// -public class ListPromptsRequestParams -{ - /// - /// An opaque token representing the current pagination position. - /// If provided, the server should return results starting after this cursor. - /// - [System.Text.Json.Serialization.JsonPropertyName("cursor")] - public string? Cursor { get; init; } -} +public class ListPromptsRequestParams : PaginatedRequestParams; diff --git a/src/ModelContextProtocol/Protocol/Types/ListResourceTemplatesRequestParams.cs b/src/ModelContextProtocol/Protocol/Types/ListResourceTemplatesRequestParams.cs index f4060dbd0..8a54f6e8e 100644 --- a/src/ModelContextProtocol/Protocol/Types/ListResourceTemplatesRequestParams.cs +++ b/src/ModelContextProtocol/Protocol/Types/ListResourceTemplatesRequestParams.cs @@ -4,12 +4,4 @@ /// Sent from the client to request a list of resource templates the server has. /// See the schema for details /// -public class ListResourceTemplatesRequestParams -{ - /// - /// An opaque token representing the current pagination position. - /// If provided, the server should return results starting after this cursor. - /// - [System.Text.Json.Serialization.JsonPropertyName("cursor")] - public string? Cursor { get; init; } -} \ No newline at end of file +public class ListResourceTemplatesRequestParams : PaginatedRequestParams; \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Types/ListResourcesRequestParams.cs b/src/ModelContextProtocol/Protocol/Types/ListResourcesRequestParams.cs index ad7f19b31..30bea5b87 100644 --- a/src/ModelContextProtocol/Protocol/Types/ListResourcesRequestParams.cs +++ b/src/ModelContextProtocol/Protocol/Types/ListResourcesRequestParams.cs @@ -4,12 +4,4 @@ /// Sent from the client to request a list of resources the server has. /// See the schema for details /// -public class ListResourcesRequestParams -{ - /// - /// An opaque token representing the current pagination position. - /// If provided, the server should return results starting after this cursor. - /// - [System.Text.Json.Serialization.JsonPropertyName("cursor")] - public string? Cursor { get; init; } -} +public class ListResourcesRequestParams : PaginatedRequestParams; diff --git a/src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs b/src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs index dae1b75c1..a5eec7a15 100644 --- a/src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs +++ b/src/ModelContextProtocol/Protocol/Types/ListRootsRequestParams.cs @@ -6,11 +6,4 @@ namespace ModelContextProtocol.Protocol.Types; /// A request from the server to get a list of root URIs from the client. /// See the schema for details /// -public class ListRootsRequestParams -{ - /// - /// Optional progress token for out-of-band progress notifications. - /// - [System.Text.Json.Serialization.JsonPropertyName("progressToken")] - public ProgressToken? ProgressToken { get; init; } -} +public class ListRootsRequestParams : RequestParams; diff --git a/src/ModelContextProtocol/Protocol/Types/ListToolsRequestParams.cs b/src/ModelContextProtocol/Protocol/Types/ListToolsRequestParams.cs index 4f18fbb73..64ac18599 100644 --- a/src/ModelContextProtocol/Protocol/Types/ListToolsRequestParams.cs +++ b/src/ModelContextProtocol/Protocol/Types/ListToolsRequestParams.cs @@ -4,12 +4,4 @@ /// Sent from the client to request a list of tools the server has. /// See the schema for details /// -public class ListToolsRequestParams -{ - /// - /// An opaque token representing the current pagination position. - /// If provided, the server should return results starting after this cursor. - /// - [System.Text.Json.Serialization.JsonPropertyName("cursor")] - public string? Cursor { get; init; } -} +public class ListToolsRequestParams : PaginatedRequestParams; diff --git a/src/ModelContextProtocol/Protocol/Types/LoggingLevel.cs b/src/ModelContextProtocol/Protocol/Types/LoggingLevel.cs index df8c4c75a..8098dbbd3 100644 --- a/src/ModelContextProtocol/Protocol/Types/LoggingLevel.cs +++ b/src/ModelContextProtocol/Protocol/Types/LoggingLevel.cs @@ -7,7 +7,7 @@ namespace ModelContextProtocol.Protocol.Types; /// These map to syslog message severities, as specified in RFC-5424: /// https://datatracker.ietf.org/doc/html/rfc5424#section-6.2.1 /// -[JsonConverter(typeof(JsonStringEnumConverter))] +[JsonConverter(typeof(CustomizableJsonStringEnumConverter))] public enum LoggingLevel { /// Detailed debug information, typically only valuable to developers. diff --git a/src/ModelContextProtocol/Protocol/Types/LoggingMessageNotificationParams.cs b/src/ModelContextProtocol/Protocol/Types/LoggingMessageNotificationParams.cs index 8c153f254..072992cad 100644 --- a/src/ModelContextProtocol/Protocol/Types/LoggingMessageNotificationParams.cs +++ b/src/ModelContextProtocol/Protocol/Types/LoggingMessageNotificationParams.cs @@ -24,7 +24,7 @@ public class LoggingMessageNotificationParams public string? Logger { get; init; } /// - /// The data to be logged, such as a string message or an object. Any JSON serializable type is allowed here. + /// The data to be logged, such as a string message or an object. /// [JsonPropertyName("data")] public JsonElement? Data { get; init; } diff --git a/src/ModelContextProtocol/Protocol/Types/PaginatedRequest.cs b/src/ModelContextProtocol/Protocol/Types/PaginatedRequest.cs new file mode 100644 index 000000000..abf47dd3c --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Types/PaginatedRequest.cs @@ -0,0 +1,15 @@ +namespace ModelContextProtocol.Protocol.Types; + +/// +/// Used as a base class for paginated requests. +/// See the schema for details +/// +public class PaginatedRequestParams : RequestParams +{ + /// + /// An opaque token representing the current pagination position. + /// If provided, the server should return results starting after this cursor. + /// + [System.Text.Json.Serialization.JsonPropertyName("cursor")] + public string? Cursor { get; init; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Types/ResourceContents.cs b/src/ModelContextProtocol/Protocol/Types/ResourceContents.cs index 1daece16b..0d70f8b43 100644 --- a/src/ModelContextProtocol/Protocol/Types/ResourceContents.cs +++ b/src/ModelContextProtocol/Protocol/Types/ResourceContents.cs @@ -61,6 +61,8 @@ public class Converter : JsonConverter } string? propertyName = reader.GetString(); + bool success = reader.Read(); + Debug.Assert(success, "STJ must have buffered the entire object for us."); switch (propertyName) { diff --git a/src/ModelContextProtocol/Protocol/Types/Role.cs b/src/ModelContextProtocol/Protocol/Types/Role.cs index 1cb35ea5b..c025f61ad 100644 --- a/src/ModelContextProtocol/Protocol/Types/Role.cs +++ b/src/ModelContextProtocol/Protocol/Types/Role.cs @@ -6,7 +6,7 @@ namespace ModelContextProtocol.Protocol.Types; /// Represents the type of role in the conversation. /// See the schema for details /// -[JsonConverter(typeof(JsonStringEnumConverter))] +[JsonConverter(typeof(CustomizableJsonStringEnumConverter))] public enum Role { /// diff --git a/src/ModelContextProtocol/Protocol/Types/Tool.cs b/src/ModelContextProtocol/Protocol/Types/Tool.cs index ed0c71290..dc0b774c0 100644 --- a/src/ModelContextProtocol/Protocol/Types/Tool.cs +++ b/src/ModelContextProtocol/Protocol/Types/Tool.cs @@ -38,7 +38,7 @@ public JsonElement InputSchema { if (!McpJsonUtilities.IsValidMcpToolSchema(value)) { - throw new ArgumentException("The specified document is not a valid MPC tool JSON schema.", nameof(InputSchema)); + throw new ArgumentException("The specified document is not a valid MCP tool JSON schema.", nameof(InputSchema)); } _inputSchema = value; diff --git a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs index 47b8514de..0f33c8e41 100644 --- a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs +++ b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs @@ -1,7 +1,6 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.Protocol.Types; -using ModelContextProtocol.Shared; using ModelContextProtocol.Utils; using ModelContextProtocol.Utils.Json; using System.Diagnostics.CodeAnalysis; @@ -256,8 +255,8 @@ public override async Task InvokeAsync( cancellationToken.ThrowIfCancellationRequested(); // TODO: Once we shift to the real AIFunctionFactory, the request should be passed via AIFunctionArguments.Context. - Dictionary arguments = request.Params?.Arguments is IDictionary existingArgs ? - new(existingArgs) : + Dictionary arguments = request.Params?.Arguments is { } paramArgs ? + paramArgs.ToDictionary(entry => entry.Key, entry => entry.Value.AsObject()) : []; arguments[RequestContextKey] = request; diff --git a/src/ModelContextProtocol/Server/IMcpServer.cs b/src/ModelContextProtocol/Server/IMcpServer.cs index e8dffaf19..19b3967ad 100644 --- a/src/ModelContextProtocol/Server/IMcpServer.cs +++ b/src/ModelContextProtocol/Server/IMcpServer.cs @@ -1,12 +1,11 @@ -using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Protocol.Types; namespace ModelContextProtocol.Server; /// /// Represents a server that can communicate with a client using the MCP protocol. /// -public interface IMcpServer : IAsyncDisposable +public interface IMcpServer : IMcpEndpoint { /// /// Gets the capabilities supported by the client. @@ -26,42 +25,8 @@ public interface IMcpServer : IAsyncDisposable /// IServiceProvider? Services { get; } - /// - /// Adds a handler for client notifications of a specific method. - /// - /// The notification method to handle. - /// The async handler function to process notifications. - /// - /// - /// Each method may have multiple handlers. Adding a handler for a method that already has one - /// will not replace the existing handler. - /// - /// - /// provides constants for common notification methods. - /// - /// - void AddNotificationHandler(string method, Func handler); - /// /// Runs the server, listening for and handling client requests. /// Task RunAsync(CancellationToken cancellationToken = default); - - /// - /// Sends a generic JSON-RPC request to the client. - /// NB! This is a temporary method that is available to send not yet implemented feature messages. - /// Once all MCP features are implemented this will be made private, as it is purely a convenience for those who wish to implement features ahead of the library. - /// - /// The expected response type. - /// The JSON-RPC request to send. - /// A token to cancel the operation. - /// A task containing the client's response. - Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) where TResult : class; - - /// - /// Sends a message to the client. - /// - /// The message. - /// A token to cancel the operation. - Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default); } diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index b95483196..499f2efa8 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -1,22 +1,21 @@ using Microsoft.Extensions.Logging; -using ModelContextProtocol.Logging; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Shared; using ModelContextProtocol.Utils; -using System.Text.Json.Nodes; +using ModelContextProtocol.Utils.Json; namespace ModelContextProtocol.Server; /// -internal sealed class McpServer : McpJsonRpcEndpoint, IMcpServer +internal sealed class McpServer : McpEndpoint, IMcpServer { private readonly EventHandler? _toolsChangedDelegate; private readonly EventHandler? _promptsChangedDelegate; - private ITransport _sessionTransport; private string _endpointName; + private int _started; /// /// Creates a new instance of . @@ -33,7 +32,6 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? Throw.IfNull(transport); Throw.IfNull(options); - _sessionTransport = transport; ServerOptions = options; Services = serviceProvider; _endpointName = $"Server ({options.ServerInfo.Name} {options.ServerInfo.Version})"; @@ -69,13 +67,14 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? }); SetToolsHandler(options); - SetInitializeHandler(options); SetCompletionHandler(options); SetPingHandler(); SetPromptsHandler(options); SetResourcesHandler(options); SetSetLoggingLevelHandler(options); + + StartSession(transport); } public ServerCapabilities? ServerCapabilities { get; set; } @@ -98,11 +97,16 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? /// public async Task RunAsync(CancellationToken cancellationToken = default) { + if (Interlocked.Exchange(ref _started, 1) != 0) + { + throw new InvalidOperationException($"{nameof(RunAsync)} must only be called once."); + } + try { - // Start processing messages - StartSession(_sessionTransport, fullSessionCancellationToken: cancellationToken); - await MessageProcessingTask.ConfigureAwait(false); + using var _ = cancellationToken.Register(static s => ((McpServer)s!).CancelSession(), this); + // The McpServer ctor always calls StartSession, so MessageProcessingTask is always set. + await MessageProcessingTask!.ConfigureAwait(false); } finally { @@ -127,13 +131,15 @@ public override async ValueTask DisposeUnsynchronizedAsync() private void SetPingHandler() { - SetRequestHandler(RequestMethods.Ping, - (request, _) => Task.FromResult(new PingResult())); + SetRequestHandler(RequestMethods.Ping, + (request, _) => Task.FromResult(new PingResult()), + McpJsonUtilities.JsonContext.Default.JsonNode, + McpJsonUtilities.JsonContext.Default.PingResult); } private void SetInitializeHandler(McpServerOptions options) { - SetRequestHandler(RequestMethods.Initialize, + SetRequestHandler(RequestMethods.Initialize, (request, _) => { ClientCapabilities = request?.Capabilities ?? new(); @@ -143,23 +149,27 @@ private void SetInitializeHandler(McpServerOptions options) _endpointName = $"{_endpointName}, Client ({ClientInfo?.Name} {ClientInfo?.Version})"; GetSessionOrThrow().EndpointName = _endpointName; - return Task.FromResult(new InitializeResult() + return Task.FromResult(new InitializeResult { ProtocolVersion = options.ProtocolVersion, Instructions = options.ServerInstructions, ServerInfo = options.ServerInfo, Capabilities = ServerCapabilities ?? new(), }); - }); + }, + McpJsonUtilities.JsonContext.Default.InitializeRequestParams, + McpJsonUtilities.JsonContext.Default.InitializeResult); } private void SetCompletionHandler(McpServerOptions options) { // This capability is not optional, so return an empty result if there is no handler. - SetRequestHandler(RequestMethods.CompletionComplete, + SetRequestHandler(RequestMethods.CompletionComplete, options.GetCompletionHandler is { } handler ? (request, ct) => handler(new(this, request), ct) : - (request, ct) => Task.FromResult(new CompleteResult() { Completion = new() { Values = [], Total = 0, HasMore = false } })); + (request, ct) => Task.FromResult(new CompleteResult() { Completion = new() { Values = [], Total = 0, HasMore = false } }), + McpJsonUtilities.JsonContext.Default.CompleteRequestParams, + McpJsonUtilities.JsonContext.Default.CompleteResult); } private void SetResourcesHandler(McpServerOptions options) @@ -180,11 +190,24 @@ private void SetResourcesHandler(McpServerOptions options) listResourcesHandler ??= (static (_, _) => Task.FromResult(new ListResourcesResult())); - SetRequestHandler(RequestMethods.ResourcesList, (request, ct) => listResourcesHandler(new(this, request), ct)); - SetRequestHandler(RequestMethods.ResourcesRead, (request, ct) => readResourceHandler(new(this, request), ct)); + SetRequestHandler( + RequestMethods.ResourcesList, + (request, ct) => listResourcesHandler(new(this, request), ct), + McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams, + McpJsonUtilities.JsonContext.Default.ListResourcesResult); + + SetRequestHandler( + RequestMethods.ResourcesRead, + (request, ct) => readResourceHandler(new(this, request), ct), + McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams, + McpJsonUtilities.JsonContext.Default.ReadResourceResult); listResourceTemplatesHandler ??= (static (_, _) => Task.FromResult(new ListResourceTemplatesResult())); - SetRequestHandler(RequestMethods.ResourcesTemplatesList, (request, ct) => listResourceTemplatesHandler(new(this, request), ct)); + SetRequestHandler( + RequestMethods.ResourcesTemplatesList, + (request, ct) => listResourceTemplatesHandler(new(this, request), ct), + McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams, + McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult); if (resourcesCapability.Subscribe is not true) { @@ -198,8 +221,17 @@ private void SetResourcesHandler(McpServerOptions options) throw new McpServerException("Resources capability was enabled with subscribe support, but SubscribeToResources and/or UnsubscribeFromResources handlers were not specified."); } - SetRequestHandler(RequestMethods.ResourcesSubscribe, (request, ct) => subscribeHandler(new(this, request), ct)); - SetRequestHandler(RequestMethods.ResourcesUnsubscribe, (request, ct) => unsubscribeHandler(new(this, request), ct)); + SetRequestHandler( + RequestMethods.ResourcesSubscribe, + (request, ct) => subscribeHandler(new(this, request), ct), + McpJsonUtilities.JsonContext.Default.SubscribeRequestParams, + McpJsonUtilities.JsonContext.Default.EmptyResult); + + SetRequestHandler( + RequestMethods.ResourcesUnsubscribe, + (request, ct) => unsubscribeHandler(new(this, request), ct), + McpJsonUtilities.JsonContext.Default.UnsubscribeRequestParams, + McpJsonUtilities.JsonContext.Default.EmptyResult); } private void SetPromptsHandler(McpServerOptions options) @@ -214,41 +246,26 @@ private void SetPromptsHandler(McpServerOptions options) throw new McpServerException("ListPrompts and GetPrompt handlers should be specified together."); } - // Handle tools provided via DI. + // Handle prompts provided via DI. if (prompts is { IsEmpty: false }) { + // Synthesize the handlers, making sure a PromptsCapability is specified. var originalListPromptsHandler = listPromptsHandler; - var originalGetPromptHandler = getPromptHandler; - - // Synthesize the handlers, making sure a ToolsCapability is specified. listPromptsHandler = async (request, cancellationToken) => { - ListPromptsResult result = new(); - foreach (McpServerPrompt prompt in prompts) - { - result.Prompts.Add(prompt.ProtocolPrompt); - } + ListPromptsResult result = originalListPromptsHandler is not null ? + await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(false) : + new(); - if (originalListPromptsHandler is not null) + if (request.Params?.Cursor is null) { - string? nextCursor = null; - do - { - ListPromptsResult extraResults = await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(false); - result.Prompts.AddRange(extraResults.Prompts); - - nextCursor = extraResults.NextCursor; - if (nextCursor is not null) - { - request = request with { Params = new() { Cursor = nextCursor } }; - } - } - while (nextCursor is not null); + result.Prompts.AddRange(prompts.Select(t => t.ProtocolPrompt)); } return result; }; + var originalGetPromptHandler = getPromptHandler; getPromptHandler = (request, cancellationToken) => { if (request.Params is null || @@ -297,8 +314,17 @@ private void SetPromptsHandler(McpServerOptions options) } } - SetRequestHandler(RequestMethods.PromptsList, (request, ct) => listPromptsHandler(new(this, request), ct)); - SetRequestHandler(RequestMethods.PromptsGet, (request, ct) => getPromptHandler(new(this, request), ct)); + SetRequestHandler( + RequestMethods.PromptsList, + (request, ct) => listPromptsHandler(new(this, request), ct), + McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams, + McpJsonUtilities.JsonContext.Default.ListPromptsResult); + + SetRequestHandler( + RequestMethods.PromptsGet, + (request, ct) => getPromptHandler(new(this, request), ct), + McpJsonUtilities.JsonContext.Default.GetPromptRequestParams, + McpJsonUtilities.JsonContext.Default.GetPromptResult); } private void SetToolsHandler(McpServerOptions options) @@ -316,38 +342,23 @@ private void SetToolsHandler(McpServerOptions options) // Handle tools provided via DI. if (tools is { IsEmpty: false }) { - var originalListToolsHandler = listToolsHandler; - var originalCallToolHandler = callToolHandler; - // Synthesize the handlers, making sure a ToolsCapability is specified. + var originalListToolsHandler = listToolsHandler; listToolsHandler = async (request, cancellationToken) => { - ListToolsResult result = new(); - foreach (McpServerTool tool in tools) - { - result.Tools.Add(tool.ProtocolTool); - } + ListToolsResult result = originalListToolsHandler is not null ? + await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) : + new(); - if (originalListToolsHandler is not null) + if (request.Params?.Cursor is null) { - string? nextCursor = null; - do - { - ListToolsResult extraResults = await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false); - result.Tools.AddRange(extraResults.Tools); - - nextCursor = extraResults.NextCursor; - if (nextCursor is not null) - { - request = request with { Params = new() { Cursor = nextCursor } }; - } - } - while (nextCursor is not null); + result.Tools.AddRange(tools.Select(t => t.ProtocolTool)); } return result; }; + var originalCallToolHandler = callToolHandler; callToolHandler = (request, cancellationToken) => { if (request.Params is null || @@ -396,8 +407,17 @@ private void SetToolsHandler(McpServerOptions options) } } - SetRequestHandler(RequestMethods.ToolsList, (request, ct) => listToolsHandler(new(this, request), ct)); - SetRequestHandler(RequestMethods.ToolsCall, (request, ct) => callToolHandler(new(this, request), ct)); + SetRequestHandler( + RequestMethods.ToolsList, + (request, ct) => listToolsHandler(new(this, request), ct), + McpJsonUtilities.JsonContext.Default.ListToolsRequestParams, + McpJsonUtilities.JsonContext.Default.ListToolsResult); + + SetRequestHandler( + RequestMethods.ToolsCall, + (request, ct) => callToolHandler(new(this, request), ct), + McpJsonUtilities.JsonContext.Default.CallToolRequestParams, + McpJsonUtilities.JsonContext.Default.CallToolResponse); } private void SetSetLoggingLevelHandler(McpServerOptions options) @@ -412,6 +432,10 @@ private void SetSetLoggingLevelHandler(McpServerOptions options) throw new McpServerException("Logging capability was enabled, but SetLoggingLevelHandler was not specified."); } - SetRequestHandler(RequestMethods.LoggingSetLevel, (request, ct) => setLoggingLevelHandler(new(this, request), ct)); + SetRequestHandler( + RequestMethods.LoggingSetLevel, + (request, ct) => setLoggingLevelHandler(new(this, request), ct), + McpJsonUtilities.JsonContext.Default.SetLevelRequestParams, + McpJsonUtilities.JsonContext.Default.EmptyResult); } } \ No newline at end of file diff --git a/src/ModelContextProtocol/Server/McpServerExtensions.cs b/src/ModelContextProtocol/Server/McpServerExtensions.cs index 3b541ec80..9b160d4c7 100644 --- a/src/ModelContextProtocol/Server/McpServerExtensions.cs +++ b/src/ModelContextProtocol/Server/McpServerExtensions.cs @@ -1,13 +1,15 @@ -using ModelContextProtocol.Protocol.Messages; +using Microsoft.Extensions.AI; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Utils; -using Microsoft.Extensions.AI; +using ModelContextProtocol.Utils.Json; using System.Runtime.CompilerServices; using System.Text; namespace ModelContextProtocol.Server; -/// +/// Provides extension methods for interacting with an . public static class McpServerExtensions { /// @@ -25,9 +27,12 @@ public static Task RequestSamplingAsync( throw new ArgumentException("Client connected to the server does not support sampling.", nameof(server)); } - return server.SendRequestAsync( - new JsonRpcRequest { Method = RequestMethods.SamplingCreateMessage, Params = request }, - cancellationToken); + return server.SendRequestAsync( + RequestMethods.SamplingCreateMessage, + request, + McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, + McpJsonUtilities.JsonContext.Default.CreateMessageResult, + cancellationToken: cancellationToken); } /// @@ -164,9 +169,12 @@ public static Task RequestRootsAsync( throw new ArgumentException("Client connected to the server does not support roots.", nameof(server)); } - return server.SendRequestAsync( - new JsonRpcRequest { Method = RequestMethods.RootsList, Params = request }, - cancellationToken); + return server.SendRequestAsync( + RequestMethods.RootsList, + request, + McpJsonUtilities.JsonContext.Default.ListRootsRequestParams, + McpJsonUtilities.JsonContext.Default.ListRootsResult, + cancellationToken: cancellationToken); } /// Provides an implementation that's implemented via client sampling. diff --git a/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.cs b/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.cs index c79a09ed5..25bffe5ed 100644 --- a/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.cs +++ b/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.cs @@ -5,6 +5,7 @@ using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; using ModelContextProtocol.Utils; +using ModelContextProtocol.Utils.Json; using System.Collections.Concurrent; using System.ComponentModel; using System.Diagnostics; @@ -132,9 +133,9 @@ public static AIFunction Create(MethodInfo method, object? target, TemporaryAIFu /// /// /// Return values are serialized to using 's - /// . Arguments that are not already of the expected type are + /// . Arguments that are not already of the expected type are /// marshaled to the expected type via JSON and using 's - /// . If the argument is a , + /// . If the argument is a , /// , or , it is deserialized directly. If the argument is anything else unknown, /// it is round-tripped through JSON, serializing the object as JSON and then deserializing it to the expected type. /// @@ -476,7 +477,7 @@ private static JsonElement CreateFunctionJsonSchema( description: parameter.GetCustomAttribute(inherit: true)?.Description, hasDefaultValue: parameter.HasDefaultValue, defaultValue: parameter.HasDefaultValue ? parameter.DefaultValue : null, - serializerOptions), AIJsonUtilities.DefaultOptions.GetTypeInfo()); + serializerOptions), McpJsonUtilities.JsonContext.Default.JsonElement); parameterSchemas.Add(parameter.Name, parameterSchema); if (!parameter.IsOptional) diff --git a/src/ModelContextProtocol/Server/TemporaryAIFunctionFactoryOptions.cs b/src/ModelContextProtocol/Server/TemporaryAIFunctionFactoryOptions.cs index e1f712d1f..91403b3b9 100644 --- a/src/ModelContextProtocol/Server/TemporaryAIFunctionFactoryOptions.cs +++ b/src/ModelContextProtocol/Server/TemporaryAIFunctionFactoryOptions.cs @@ -72,7 +72,7 @@ public TemporaryAIFunctionFactoryOptions() /// Gets or sets a delegate used to determine the returned by . /// /// - /// By default, the return value of invoking the method wrapped into an by + /// By default, the return value of invoking the method wrapped into an by /// is then JSON serialized, with the resulting returned from the method. /// This default behavior is ideal for the common case where the result will be passed back to an AI service. However, if the caller /// requires more control over the result's marshaling, the property may be set to a delegate that is @@ -82,7 +82,7 @@ public TemporaryAIFunctionFactoryOptions() /// /// When set, the delegate is invoked even for -returning methods, in which case it is invoked with /// a argument. By default, is returned from the - /// method for instances produced by to wrap + /// method for instances produced by to wrap /// -returning methods). /// /// diff --git a/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs b/src/ModelContextProtocol/Shared/McpEndpoint.cs similarity index 64% rename from src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs rename to src/ModelContextProtocol/Shared/McpEndpoint.cs index ac0337f9e..8b50a8052 100644 --- a/src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs +++ b/src/ModelContextProtocol/Shared/McpEndpoint.cs @@ -3,8 +3,10 @@ using ModelContextProtocol.Logging; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Server; using ModelContextProtocol.Utils; using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Serialization.Metadata; namespace ModelContextProtocol.Shared; @@ -15,14 +17,13 @@ namespace ModelContextProtocol.Shared; /// This is especially true as a client represents a connection to one and only one server, and vice versa. /// Any multi-client or multi-server functionality should be implemented at a higher level of abstraction. /// -internal abstract class McpJsonRpcEndpoint : IAsyncDisposable +internal abstract class McpEndpoint : IAsyncDisposable { private readonly RequestHandlers _requestHandlers = []; private readonly NotificationHandlers _notificationHandlers = []; private McpSession? _session; private CancellationTokenSource? _sessionCts; - private int _started; private readonly SemaphoreSlim _disposeLock = new(1, 1); private bool _disposed; @@ -30,22 +31,27 @@ internal abstract class McpJsonRpcEndpoint : IAsyncDisposable protected readonly ILogger _logger; /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// The logger factory. - protected McpJsonRpcEndpoint(ILoggerFactory? loggerFactory = null) + protected McpEndpoint(ILoggerFactory? loggerFactory = null) { _logger = loggerFactory?.CreateLogger(GetType()) ?? NullLogger.Instance; } - protected void SetRequestHandler(string method, Func> handler) - => _requestHandlers.Set(method, handler); + protected void SetRequestHandler( + string method, + Func> handler, + JsonTypeInfo requestTypeInfo, + JsonTypeInfo responseTypeInfo) + + => _requestHandlers.Set(method, handler, requestTypeInfo, responseTypeInfo); public void AddNotificationHandler(string method, Func handler) => _notificationHandlers.Add(method, handler); - public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) where TResult : class - => GetSessionOrThrow().SendRequestAsync(request, cancellationToken); + public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) + => GetSessionOrThrow().SendRequestAsync(request, cancellationToken); public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) => GetSessionOrThrow().SendMessageAsync(message, cancellationToken); @@ -58,21 +64,18 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella /// /// Task that processes incoming messages from the transport. /// - protected Task? MessageProcessingTask { get; set; } + protected Task? MessageProcessingTask { get; private set; } [MemberNotNull(nameof(MessageProcessingTask))] - protected void StartSession(ITransport sessionTransport, CancellationToken fullSessionCancellationToken = default) + protected void StartSession(ITransport sessionTransport) { - if (Interlocked.Exchange(ref _started, 1) != 0) - { - throw new InvalidOperationException("The MCP session has already stared."); - } - - _sessionCts = CancellationTokenSource.CreateLinkedTokenSource(fullSessionCancellationToken); - _session = new McpSession(sessionTransport, EndpointName, _requestHandlers, _notificationHandlers, _logger); + _sessionCts = new CancellationTokenSource(); + _session = new McpSession(this is IMcpServer, sessionTransport, EndpointName, _requestHandlers, _notificationHandlers, _logger); MessageProcessingTask = _session.ProcessMessagesAsync(_sessionCts.Token); } + protected void CancelSession() => _sessionCts?.Cancel(); + public async ValueTask DisposeAsync() { using var _ = await _disposeLock.LockAsync().ConfigureAwait(false); @@ -94,25 +97,30 @@ public virtual async ValueTask DisposeUnsynchronizedAsync() { _logger.CleaningUpEndpoint(EndpointName); - if (_sessionCts is not null) - { - await _sessionCts.CancelAsync().ConfigureAwait(false); - } - - if (MessageProcessingTask is not null) + try { - try + if (_sessionCts is not null) { - await MessageProcessingTask.ConfigureAwait(false); + await _sessionCts.CancelAsync().ConfigureAwait(false); } - catch (OperationCanceledException) + + if (MessageProcessingTask is not null) { - // Ignore cancellation + try + { + await MessageProcessingTask.ConfigureAwait(false); + } + catch (OperationCanceledException) + { + // Ignore cancellation + } } } - - _session?.Dispose(); - _sessionCts?.Dispose(); + finally + { + _session?.Dispose(); + _sessionCts?.Dispose(); + } _logger.EndpointCleanedUp(EndpointName); } diff --git a/src/ModelContextProtocol/Shared/McpSession.cs b/src/ModelContextProtocol/Shared/McpSession.cs index 97cbcb592..d5e4f930d 100644 --- a/src/ModelContextProtocol/Shared/McpSession.cs +++ b/src/ModelContextProtocol/Shared/McpSession.cs @@ -4,10 +4,15 @@ using ModelContextProtocol.Logging; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Server; using ModelContextProtocol.Utils; using ModelContextProtocol.Utils.Json; using System.Collections.Concurrent; +using System.Diagnostics; +using System.Diagnostics.Metrics; using System.Text.Json; +using System.Text.Json.Nodes; +using System.Threading.Channels; namespace ModelContextProtocol.Shared; @@ -16,9 +21,21 @@ namespace ModelContextProtocol.Shared; /// internal sealed class McpSession : IDisposable { + private static readonly Histogram s_clientSessionDuration = Diagnostics.CreateDurationHistogram( + "mcp.client.session.duration", "Measures the duration of a client session.", longBuckets: true); + private static readonly Histogram s_serverSessionDuration = Diagnostics.CreateDurationHistogram( + "mcp.server.session.duration", "Measures the duration of a server session.", longBuckets: true); + private static readonly Histogram s_clientRequestDuration = Diagnostics.CreateDurationHistogram( + "rpc.client.duration", "Measures the duration of outbound RPC.", longBuckets: false); + private static readonly Histogram s_serverRequestDuration = Diagnostics.CreateDurationHistogram( + "rpc.server.duration", "Measures the duration of inbound RPC.", longBuckets: false); + + private readonly bool _isServer; + private readonly string _transportKind; private readonly ITransport _transport; private readonly RequestHandlers _requestHandlers; private readonly NotificationHandlers _notificationHandlers; + private readonly long _sessionStartingTimestamp = Stopwatch.GetTimestamp(); /// Collection of requests sent on this session and waiting for responses. private readonly ConcurrentDictionary> _pendingRequests = []; @@ -27,21 +44,22 @@ internal sealed class McpSession : IDisposable /// that can be used to request cancellation of the in-flight handler. /// private readonly ConcurrentDictionary _handlingRequests = new(); - private readonly JsonSerializerOptions _jsonOptions; private readonly ILogger _logger; - + private readonly string _id = Guid.NewGuid().ToString("N"); private long _nextRequestId; /// /// Initializes a new instance of the class. /// + /// true if this is a server; false if it's a client. /// An MCP transport implementation. /// The name of the endpoint for logging and debug purposes. /// A collection of request handlers. /// A collection of notification handlers. /// The logger. public McpSession( + bool isServer, ITransport transport, string endpointName, RequestHandlers requestHandlers, @@ -50,11 +68,19 @@ public McpSession( { Throw.IfNull(transport); + _transportKind = transport switch + { + StdioClientSessionTransport or StdioServerTransport => "stdio", + StreamClientSessionTransport or StreamServerTransport => "stream", + SseClientSessionTransport or SseResponseStreamTransport => "sse", + _ => "unknownTransport" + }; + + _isServer = isServer; _transport = transport; EndpointName = endpointName; _requestHandlers = requestHandlers; _notificationHandlers = notificationHandlers; - _jsonOptions = McpJsonUtilities.DefaultOptions; _logger = logger ?? NullLogger.Instance; } @@ -121,14 +147,14 @@ await _transport.SendMessageAsync(new JsonRpcError JsonRpc = "2.0", Error = new JsonRpcErrorDetail { - Code = ErrorCodes.InternalError, + Code = (ex as McpServerException)?.ErrorCode ?? ErrorCodes.InternalError, Message = ex.Message } }, cancellationToken).ConfigureAwait(false); } else if (ex is not OperationCanceledException) { - var payload = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo()); + var payload = JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.IJsonRpcMessage); _logger.MessageHandlerError(EndpointName, message.GetType().Name, payload, ex); } } @@ -148,27 +174,67 @@ await _transport.SendMessageAsync(new JsonRpcError // Normal shutdown _logger.EndpointMessageProcessingCancelled(EndpointName); } + finally + { + // Fail any pending requests, as they'll never be satisfied. + foreach (var entry in _pendingRequests) + { + entry.Value.TrySetException(new InvalidOperationException("The server shut down unexpectedly.")); + } + } } private async Task HandleMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken) { - switch (message) + Histogram durationMetric = _isServer ? s_serverRequestDuration : s_clientRequestDuration; + string method = GetMethodName(message); + + long? startingTimestamp = durationMetric.Enabled ? Stopwatch.GetTimestamp() : null; + Activity? activity = Diagnostics.ActivitySource.HasListeners() ? + Diagnostics.ActivitySource.StartActivity(CreateActivityName(method)) : + null; + + TagList tags = default; + bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null; + try { - case JsonRpcRequest request: - await HandleRequest(request, cancellationToken).ConfigureAwait(false); - break; + if (addTags) + { + AddStandardTags(ref tags, method); + } - case IJsonRpcMessageWithId messageWithId: - HandleMessageWithId(message, messageWithId); - break; + switch (message) + { + case JsonRpcRequest request: + if (addTags) + { + AddRpcRequestTags(ref tags, activity, request); + } + + await HandleRequest(request, cancellationToken).ConfigureAwait(false); + break; - case JsonRpcNotification notification: - await HandleNotification(notification).ConfigureAwait(false); - break; + case JsonRpcNotification notification: + await HandleNotification(notification).ConfigureAwait(false); + break; - default: - _logger.EndpointHandlerUnexpectedMessageType(EndpointName, message.GetType().Name); - break; + case IJsonRpcMessageWithId messageWithId: + HandleMessageWithId(message, messageWithId); + break; + + default: + _logger.EndpointHandlerUnexpectedMessageType(EndpointName, message.GetType().Name); + break; + } + } + catch (Exception e) when (addTags) + { + AddExceptionTags(ref tags, e); + throw; + } + finally + { + FinalizeDiagnostics(activity, startingTimestamp, durationMetric, ref tags); } } @@ -212,7 +278,7 @@ private async Task HandleNotification(JsonRpcNotification notification) private void HandleMessageWithId(IJsonRpcMessage message, IJsonRpcMessageWithId messageWithId) { - if (messageWithId.Id.IsDefault) + if (messageWithId.Id.Id is null) { _logger.RequestHasInvalidId(EndpointName); } @@ -229,34 +295,32 @@ private void HandleMessageWithId(IJsonRpcMessage message, IJsonRpcMessageWithId private async Task HandleRequest(JsonRpcRequest request, CancellationToken cancellationToken) { - if (_requestHandlers.TryGetValue(request.Method, out var handler)) - { - _logger.RequestHandlerCalled(EndpointName, request.Method); - var result = await handler(request, cancellationToken).ConfigureAwait(false); - _logger.RequestHandlerCompleted(EndpointName, request.Method); - await _transport.SendMessageAsync(new JsonRpcResponse - { - Id = request.Id, - JsonRpc = "2.0", - Result = result - }, cancellationToken).ConfigureAwait(false); - } - else + if (!_requestHandlers.TryGetValue(request.Method, out var handler)) { _logger.NoHandlerFoundForRequest(EndpointName, request.Method); + throw new McpServerException("The method does not exist or is not available.", ErrorCodes.MethodNotFound); } + + _logger.RequestHandlerCalled(EndpointName, request.Method); + JsonNode? result = await handler(request, cancellationToken).ConfigureAwait(false); + _logger.RequestHandlerCompleted(EndpointName, request.Method); + await _transport.SendMessageAsync(new JsonRpcResponse + { + Id = request.Id, + JsonRpc = "2.0", + Result = result + }, cancellationToken).ConfigureAwait(false); } /// - /// Sends a generic JSON-RPC request to the server. + /// Sends a JSON-RPC request to the server. /// It is strongly recommended use the capability-specific methods instead of this one. /// Use this method for custom requests or those not yet covered explicitly by the endpoint implementation. /// - /// The expected response type. /// The JSON-RPC request to send. /// A token to cancel the operation. /// A task containing the server's response. - public async Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken) where TResult : class + public async Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken) { if (!_transport.IsConnected) { @@ -264,21 +328,37 @@ public async Task SendRequestAsync(JsonRpcRequest request, Can throw new McpClientException("Transport is not connected"); } + Histogram durationMetric = _isServer ? s_serverRequestDuration : s_clientRequestDuration; + string method = request.Method; + + long? startingTimestamp = durationMetric.Enabled ? Stopwatch.GetTimestamp() : null; + using Activity? activity = Diagnostics.ActivitySource.HasListeners() ? + Diagnostics.ActivitySource.StartActivity(CreateActivityName(method)) : + null; + // Set request ID - if (request.Id.IsDefault) + if (request.Id.Id is null) { request.Id = new RequestId($"{_id}-{Interlocked.Increment(ref _nextRequestId)}"); } + TagList tags = default; + bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null; + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); _pendingRequests[request.Id] = tcs; - try { + if (addTags) + { + AddStandardTags(ref tags, method); + AddRpcRequestTags(ref tags, activity, request); + } + // Expensive logging, use the logging framework to check if the logger is enabled if (_logger.IsEnabled(LogLevel.Debug)) { - _logger.SendingRequestPayload(EndpointName, JsonSerializer.Serialize(request, _jsonOptions.GetTypeInfo())); + _logger.SendingRequestPayload(EndpointName, JsonSerializer.Serialize(request, McpJsonUtilities.JsonContext.Default.JsonRpcRequest)); } // Less expensive information logging @@ -297,31 +377,24 @@ public async Task SendRequestAsync(JsonRpcRequest request, Can if (response is JsonRpcResponse success) { - // Convert the Result object to JSON and back to get our strongly-typed result - var resultJson = JsonSerializer.Serialize(success.Result, _jsonOptions.GetTypeInfo()); - var resultObject = JsonSerializer.Deserialize(resultJson, _jsonOptions.GetTypeInfo()); - - // Not expensive logging because we're already converting to JSON in order to get the result object - _logger.RequestResponseReceivedPayload(EndpointName, resultJson); + _logger.RequestResponseReceivedPayload(EndpointName, success.Result?.ToJsonString() ?? "null"); _logger.RequestResponseReceived(EndpointName, request.Method); - - if (resultObject != null) - { - return resultObject; - } - - // Result object was null, this is unexpected - _logger.RequestResponseTypeConversionError(EndpointName, request.Method, typeof(TResult)); - throw new McpClientException($"Unexpected response type {JsonSerializer.Serialize(success.Result, _jsonOptions.GetTypeInfo())}, expected {typeof(TResult)}"); + return success; } // Unexpected response type _logger.RequestInvalidResponseType(EndpointName, request.Method); throw new McpClientException("Invalid response type"); } + catch (Exception ex) when (addTags) + { + AddExceptionTags(ref tags, ex); + throw; + } finally { _pendingRequests.TryRemove(request.Id, out _); + FinalizeDiagnostics(activity, startingTimestamp, durationMetric, ref tags); } } @@ -335,58 +408,185 @@ public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken ca throw new McpClientException("Transport is not connected"); } - if (_logger.IsEnabled(LogLevel.Debug)) + Histogram durationMetric = _isServer ? s_serverRequestDuration : s_clientRequestDuration; + string method = GetMethodName(message); + + long? startingTimestamp = durationMetric.Enabled ? Stopwatch.GetTimestamp() : null; + using Activity? activity = Diagnostics.ActivitySource.HasListeners() ? + Diagnostics.ActivitySource.StartActivity(CreateActivityName(method)) : + null; + + TagList tags = default; + bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null; + + try { - _logger.SendingMessage(EndpointName, JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo())); - } + if (addTags) + { + AddStandardTags(ref tags, method); + } - await _transport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); + if (_logger.IsEnabled(LogLevel.Debug)) + { + _logger.SendingMessage(EndpointName, JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.IJsonRpcMessage)); + } + + await _transport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); - // If the sent notification was a cancellation notification, cancel the pending request's await, as either the - // server won't be sending a response, or per the specification, the response should be ignored. There are inherent - // race conditions here, so it's possible and allowed for the operation to complete before we get to this point. - if (message is JsonRpcNotification { Method: NotificationMethods.CancelledNotification } notification && - GetCancelledNotificationParams(notification.Params) is CancelledNotification cn && - _pendingRequests.TryRemove(cn.RequestId, out var tcs)) + // If the sent notification was a cancellation notification, cancel the pending request's await, as either the + // server won't be sending a response, or per the specification, the response should be ignored. There are inherent + // race conditions here, so it's possible and allowed for the operation to complete before we get to this point. + if (message is JsonRpcNotification { Method: NotificationMethods.CancelledNotification } notification && + GetCancelledNotificationParams(notification.Params) is CancelledNotification cn && + _pendingRequests.TryRemove(cn.RequestId, out var tcs)) + { + tcs.TrySetCanceled(default); + } + } + catch (Exception ex) when (addTags) + { + AddExceptionTags(ref tags, ex); + throw; + } + finally { - tcs.TrySetCanceled(default); + FinalizeDiagnostics(activity, startingTimestamp, durationMetric, ref tags); } } - private static CancelledNotification? GetCancelledNotificationParams(object? notificationParams) + private static CancelledNotification? GetCancelledNotificationParams(JsonNode? notificationParams) { try { - switch (notificationParams) + return JsonSerializer.Deserialize(notificationParams, McpJsonUtilities.JsonContext.Default.CancelledNotification); + } + catch + { + return null; + } + } + + private string CreateActivityName(string method) => + $"mcp.{(_isServer ? "server" : "client")}.{_transportKind}/{method}"; + + private static string GetMethodName(IJsonRpcMessage message) => + message switch + { + JsonRpcRequest request => request.Method, + JsonRpcNotification notification => notification.Method, + _ => "unknownMethod", + }; + + private void AddStandardTags(ref TagList tags, string method) + { + tags.Add("session.id", _id); + tags.Add("rpc.system", "jsonrpc"); + tags.Add("rpc.jsonrpc.version", "2.0"); + tags.Add("rpc.method", method); + tags.Add("network.transport", _transportKind); + + // RPC spans convention also includes: + // server.address, server.port, client.address, client.port, network.peer.address, network.peer.port, network.type + } + + private static void AddRpcRequestTags(ref TagList tags, Activity? activity, JsonRpcRequest request) + { + tags.Add("rpc.jsonrpc.request_id", request.Id.ToString()); + + if (request.Params is JsonObject paramsObj) + { + switch (request.Method) { - case null: - return null; + case RequestMethods.ToolsCall: + case RequestMethods.PromptsGet: + if (paramsObj.TryGetPropertyValue("name", out var prop) && prop?.GetValueKind() is JsonValueKind.String) + { + string name = prop.GetValue(); + tags.Add("mcp.request.params.name", name); + if (activity is not null) + { + activity.DisplayName = $"{request.Method}({name})"; + } + } + break; - case CancelledNotification cn: - return cn; + case RequestMethods.ResourcesRead: + if (paramsObj.TryGetPropertyValue("uri", out prop) && prop?.GetValueKind() is JsonValueKind.String) + { + string uri = prop.GetValue(); + tags.Add("mcp.request.params.uri", uri); + if (activity is not null) + { + activity.DisplayName = $"{request.Method}({uri})"; + } + } + break; + } + } + } - case JsonElement je: - return JsonSerializer.Deserialize(je, McpJsonUtilities.DefaultOptions.GetTypeInfo()); + private static void AddExceptionTags(ref TagList tags, Exception e) + { + tags.Add("error.type", e.GetType().FullName); + tags.Add("rpc.jsonrpc.error_code", + (e as McpClientException)?.ErrorCode is int clientError ? clientError : + (e as McpServerException)?.ErrorCode is int serverError ? serverError : + e is JsonException ? ErrorCodes.ParseError : + ErrorCodes.InternalError); + } - default: - return JsonSerializer.Deserialize( - JsonSerializer.Serialize(notificationParams, McpJsonUtilities.DefaultOptions.GetTypeInfo()), - McpJsonUtilities.DefaultOptions.GetTypeInfo()); + private static void FinalizeDiagnostics( + Activity? activity, long? startingTimestamp, Histogram durationMetric, ref TagList tags) + { + try + { + if (startingTimestamp is not null) + { + durationMetric.Record(GetElapsed(startingTimestamp.Value).TotalSeconds, tags); + } + + if (activity is { IsAllDataRequested: true }) + { + foreach (var tag in tags) + { + activity.AddTag(tag.Key, tag.Value); + } } } - catch + finally { - return null; + activity?.Dispose(); } } public void Dispose() { + Histogram durationMetric = _isServer ? s_serverSessionDuration : s_clientSessionDuration; + if (durationMetric.Enabled) + { + TagList tags = default; + tags.Add("session.id", _id); + tags.Add("network.transport", _transportKind); + durationMetric.Record(GetElapsed(_sessionStartingTimestamp).TotalSeconds, tags); + } + // Complete all pending requests with cancellation foreach (var entry in _pendingRequests) { entry.Value.TrySetCanceled(); } + _pendingRequests.Clear(); } + +#if !NET + private static readonly double s_timestampToTicks = TimeSpan.TicksPerSecond / (double)Stopwatch.Frequency; +#endif + + private static TimeSpan GetElapsed(long startingTimestamp) => +#if NET + Stopwatch.GetElapsedTime(startingTimestamp); +#else + new((long)(s_timestampToTicks * (Stopwatch.GetTimestamp() - startingTimestamp))); +#endif } diff --git a/src/ModelContextProtocol/Shared/RequestHandlers.cs b/src/ModelContextProtocol/Shared/RequestHandlers.cs index be1f80c99..83c911cd7 100644 --- a/src/ModelContextProtocol/Shared/RequestHandlers.cs +++ b/src/ModelContextProtocol/Shared/RequestHandlers.cs @@ -1,11 +1,12 @@ using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Utils; -using ModelContextProtocol.Utils.Json; using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization.Metadata; namespace ModelContextProtocol.Shared; -internal sealed class RequestHandlers : Dictionary>> +internal sealed class RequestHandlers : Dictionary>> { /// /// Registers a handler for incoming requests of a specific method. @@ -14,18 +15,24 @@ internal sealed class RequestHandlers : DictionaryType of response payload (not full RPC response /// Method identifier to register for /// Handler to be called when a request with specified method identifier is received - public void Set(string method, Func> handler) + /// The JSON contract governing request serialization. + /// The JSON contract governing response serialization. + public void Set( + string method, + Func> handler, + JsonTypeInfo requestTypeInfo, + JsonTypeInfo responseTypeInfo) { Throw.IfNull(method); Throw.IfNull(handler); + Throw.IfNull(requestTypeInfo); + Throw.IfNull(responseTypeInfo); this[method] = async (request, cancellationToken) => { - // Convert the params JsonElement to our type using the same options - var jsonString = JsonSerializer.Serialize(request.Params, McpJsonUtilities.DefaultOptions.GetTypeInfo()); - var typedRequest = JsonSerializer.Deserialize(jsonString, McpJsonUtilities.DefaultOptions.GetTypeInfo()); - - return await handler(typedRequest, cancellationToken).ConfigureAwait(false); + TRequest? typedRequest = JsonSerializer.Deserialize(request.Params, requestTypeInfo); + object? result = await handler(typedRequest, cancellationToken).ConfigureAwait(false); + return JsonSerializer.SerializeToNode(result, responseTypeInfo); }; } } diff --git a/src/ModelContextProtocol/TokenProgress.cs b/src/ModelContextProtocol/TokenProgress.cs index 7cc97236a..62834e75a 100644 --- a/src/ModelContextProtocol/TokenProgress.cs +++ b/src/ModelContextProtocol/TokenProgress.cs @@ -1,31 +1,16 @@ using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Server; -using ModelContextProtocol.Shared; namespace ModelContextProtocol; /// /// Provides an tied to a specific progress token and that will issue -/// progress notifications to the supplied endpoint. +/// progress notifications on the supplied endpoint. /// -internal sealed class TokenProgress(IMcpServer server, ProgressToken progressToken) : IProgress +internal sealed class TokenProgress(IMcpEndpoint endpoint, ProgressToken progressToken) : IProgress { /// public void Report(ProgressNotificationValue value) { - _ = server.SendMessageAsync(new JsonRpcNotification() - { - Method = NotificationMethods.ProgressNotification, - Params = new ProgressNotification() - { - ProgressToken = progressToken, - Progress = new() - { - Progress = value.Progress, - Total = value.Total, - Message = value.Message, - }, - }, - }, CancellationToken.None); + _ = endpoint.NotifyProgressAsync(progressToken, value, CancellationToken.None); } } diff --git a/src/ModelContextProtocol/Utils/Json/CustomizableJsonStringEnumConverter.cs b/src/ModelContextProtocol/Utils/Json/CustomizableJsonStringEnumConverter.cs new file mode 100644 index 000000000..e9c26f18c --- /dev/null +++ b/src/ModelContextProtocol/Utils/Json/CustomizableJsonStringEnumConverter.cs @@ -0,0 +1,64 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; +using System.Reflection; + +// NOTE: +// This is a temporary workaround for lack of System.Text.Json's JsonStringEnumConverter +// 9.x support for JsonStringEnumMemberNameAttribute. Once all builds use the System.Text.Json 9.x +// version, this whole file can be removed. + +namespace System.Text.Json.Serialization; + +internal sealed class CustomizableJsonStringEnumConverter<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields)] TEnum> : + JsonStringEnumConverter where TEnum : struct, Enum +{ +#if !NET9_0_OR_GREATER + public CustomizableJsonStringEnumConverter() : + base(namingPolicy: ResolveNamingPolicy()) + { + } + + private static JsonNamingPolicy? ResolveNamingPolicy() + { + var map = typeof(TEnum).GetFields(BindingFlags.Public | BindingFlags.Static) + .Select(f => (f.Name, AttributeName: f.GetCustomAttribute()?.Name)) + .Where(pair => pair.AttributeName != null) + .ToDictionary(pair => pair.Name, pair => pair.AttributeName); + + return map.Count > 0 ? new EnumMemberNamingPolicy(map!) : null; + } + + private sealed class EnumMemberNamingPolicy(Dictionary map) : JsonNamingPolicy + { + public override string ConvertName(string name) => + map.TryGetValue(name, out string? newName) ? + newName : + name; + } +#endif +} + +#if !NET9_0_OR_GREATER +/// +/// Determines the string value that should be used when serializing an enum member. +/// +[AttributeUsage(AttributeTargets.Field, AllowMultiple = false)] +internal sealed class JsonStringEnumMemberNameAttribute : Attribute +{ + /// + /// Creates new attribute instance with a specified enum member name. + /// + /// The name to apply to the current enum member. + public JsonStringEnumMemberNameAttribute(string name) + { + Name = name; + } + + /// + /// Gets the name of the enum member. + /// + public string Name { get; } +} +#endif \ No newline at end of file diff --git a/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs b/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs index 337cdbda9..e6245c7fa 100644 --- a/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs +++ b/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs @@ -3,7 +3,6 @@ using ModelContextProtocol.Protocol.Types; using System.Diagnostics.CodeAnalysis; using System.Text.Json; -using System.Text.Json.Nodes; using System.Text.Json.Serialization; using System.Text.Json.Serialization.Metadata; @@ -35,36 +34,12 @@ public static partial class McpJsonUtilities /// Creates default options to use for MCP-related serialization. /// /// The configured options. - [UnconditionalSuppressMessage("AotAnalysis", "IL3050", Justification = "DefaultJsonTypeInfoResolver is only used when reflection-based serialization is enabled")] - [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026", Justification = "DefaultJsonTypeInfoResolver is only used when reflection-based serialization is enabled")] private static JsonSerializerOptions CreateDefaultOptions() { - // If reflection-based serialization is enabled by default, use it, as it's the most permissive in terms of what it can serialize, - // and we want to be flexible in terms of what can be put into the various collections in the object model. - // Otherwise, use the source-generated options to enable trimming and Native AOT. - JsonSerializerOptions options; + // Copy the configuration from the source generated context. + JsonSerializerOptions options = new(JsonContext.Default.Options); - if (JsonSerializer.IsReflectionEnabledByDefault) - { - // Keep in sync with the JsonSourceGenerationOptions attribute on JsonContext below. - options = new(JsonSerializerDefaults.Web) - { - TypeInfoResolver = new DefaultJsonTypeInfoResolver(), - Converters = { new JsonStringEnumConverter() }, - DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, - NumberHandling = JsonNumberHandling.AllowReadingFromString, - }; - } - else - { - // Keep in sync with any additional settings above beyond what's in JsonContext below. - options = new(JsonContext.Default.Options) - { - }; - } - - // Include all types from AIJsonUtilities, so that anything default usable as part of an AIFunction - // is also usable as part of an McpServerTool. + // Chain with all supported types from MEAI options.TypeInfoResolverChain.Add(AIJsonUtilities.DefaultOptions.TypeInfoResolver!); options.MakeReadOnly(); @@ -75,6 +50,7 @@ internal static JsonTypeInfo GetTypeInfo(this JsonSerializerOptions option (JsonTypeInfo)options.GetTypeInfo(typeof(T)); internal static JsonElement DefaultMcpToolSchema { get; } = ParseJsonElement("""{"type":"object"}"""u8); + internal static object? AsObject(this JsonElement element) => element.ValueKind is JsonValueKind.Null ? null : element; internal static bool IsValidMcpToolSchema(JsonElement element) { @@ -102,14 +78,8 @@ internal static bool IsValidMcpToolSchema(JsonElement element) // Keep in sync with CreateDefaultOptions above. [JsonSourceGenerationOptions(JsonSerializerDefaults.Web, - UseStringEnumConverter = true, DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, NumberHandling = JsonNumberHandling.AllowReadingFromString)] - - // JSON - [JsonSerializable(typeof(JsonDocument))] - [JsonSerializable(typeof(JsonElement))] - [JsonSerializable(typeof(JsonNode))] // JSON-RPC [JsonSerializable(typeof(IJsonRpcMessage))] @@ -151,6 +121,7 @@ internal static bool IsValidMcpToolSchema(JsonElement element) [JsonSerializable(typeof(SubscribeRequestParams))] [JsonSerializable(typeof(UnsubscribeFromResourceRequestParams))] [JsonSerializable(typeof(UnsubscribeRequestParams))] + [JsonSerializable(typeof(IReadOnlyDictionary))] [ExcludeFromCodeCoverage] internal sealed partial class JsonContext : JsonSerializerContext; diff --git a/tests/ModelContextProtocol.TestServer/ModelContextProtocol.TestServer.csproj b/tests/ModelContextProtocol.TestServer/ModelContextProtocol.TestServer.csproj index 67dc6a197..fb6320a07 100644 --- a/tests/ModelContextProtocol.TestServer/ModelContextProtocol.TestServer.csproj +++ b/tests/ModelContextProtocol.TestServer/ModelContextProtocol.TestServer.csproj @@ -2,7 +2,7 @@ Exe - net8.0 + net9.0;net8.0 enable enable TestServer @@ -10,8 +10,11 @@ + + + diff --git a/tests/ModelContextProtocol.TestServer/Program.cs b/tests/ModelContextProtocol.TestServer/Program.cs index 312013fa5..6e5655a3d 100644 --- a/tests/ModelContextProtocol.TestServer/Program.cs +++ b/tests/ModelContextProtocol.TestServer/Program.cs @@ -24,6 +24,10 @@ private static ILoggerFactory CreateLoggerFactory() return LoggerFactory.Create(builder => { + builder.AddConsole(options => + { + options.LogToStandardErrorThreshold = LogLevel.Trace; + }); builder.AddSerilog(); }); } @@ -75,11 +79,11 @@ private static async Task RunBackgroundLoop(IMcpServer server, CancellationToken await server.SendMessageAsync(new JsonRpcNotification() { Method = NotificationMethods.LoggingMessageNotification, - Params = new LoggingMessageNotificationParams + Params = JsonSerializer.SerializeToNode(new LoggingMessageNotificationParams { Level = logLevel, Data = JsonSerializer.Deserialize("\"Random log message\"") - } + }) }, cancellationToken); } @@ -87,10 +91,10 @@ await server.SendMessageAsync(new JsonRpcNotification() foreach (var resource in _subscribedResources) { ResourceUpdatedNotificationParams notificationParams = new() { Uri = resource.Key }; - await server.SendMessageAsync(new JsonRpcNotification() + await server.SendMessageAsync(new JsonRpcNotification { Method = NotificationMethods.ResourceUpdatedNotification, - Params = notificationParams + Params = JsonSerializer.SerializeToNode(notificationParams), }, cancellationToken); } } @@ -164,7 +168,7 @@ private static ToolsCapability ConfigureTools() } return new CallToolResponse() { - Content = [new Content() { Text = "Echo: " + message?.ToString(), Type = "text" }] + Content = [new Content() { Text = "Echo: " + message.ToString(), Type = "text" }] }; } else if (request.Params?.Name == "sampleLLM") @@ -175,7 +179,7 @@ private static ToolsCapability ConfigureTools() { throw new McpServerException("Missing required arguments 'prompt' and 'maxTokens'"); } - var sampleResult = await request.Server.RequestSamplingAsync(CreateRequestSamplingParams(prompt?.ToString() ?? "", "sampleLLM", Convert.ToInt32(maxTokens?.ToString())), + var sampleResult = await request.Server.RequestSamplingAsync(CreateRequestSamplingParams(prompt.ToString(), "sampleLLM", Convert.ToInt32(maxTokens.GetRawText())), cancellationToken); return new CallToolResponse() diff --git a/tests/ModelContextProtocol.TestSseServer/ModelContextProtocol.TestSseServer.csproj b/tests/ModelContextProtocol.TestSseServer/ModelContextProtocol.TestSseServer.csproj index 6633ad4ad..3015fd554 100644 --- a/tests/ModelContextProtocol.TestSseServer/ModelContextProtocol.TestSseServer.csproj +++ b/tests/ModelContextProtocol.TestSseServer/ModelContextProtocol.TestSseServer.csproj @@ -2,7 +2,7 @@ Exe - net8.0 + net9.0;net8.0 enable enable TestSseServer @@ -10,6 +10,7 @@ + diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index 4bbc5bfc8..d5a24c997 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -1,5 +1,4 @@ -using ModelContextProtocol.AspNetCore; -using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; using Serilog; using System.Text; @@ -176,7 +175,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st { throw new McpServerException("Missing required arguments 'prompt' and 'maxTokens'"); } - var sampleResult = await request.Server.RequestSamplingAsync(CreateRequestSamplingParams(prompt?.ToString() ?? "", "sampleLLM", Convert.ToInt32(maxTokens?.ToString())), + var sampleResult = await request.Server.RequestSamplingAsync(CreateRequestSamplingParams(prompt.ToString(), "sampleLLM", Convert.ToInt32(maxTokens.ToString())), cancellationToken); return new CallToolResponse() @@ -378,10 +377,12 @@ public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvide { Console.WriteLine("Starting server..."); + int port = args.Length > 0 && uint.TryParse(args[0], out var parsedPort) ? (int)parsedPort : 3001; + var builder = WebApplication.CreateSlimBuilder(args); builder.WebHost.ConfigureKestrel(options => { - options.ListenLocalhost(3001); + options.ListenLocalhost(port); }); ConfigureSerilog(builder.Logging); diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs index 710679e9f..1a451a2d5 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs @@ -1,10 +1,15 @@ +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; -using ModelContextProtocol.Tests.Transport; using ModelContextProtocol.Tests.Utils; +using Moq; using System.IO.Pipelines; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; namespace ModelContextProtocol.Tests.Client; @@ -23,7 +28,7 @@ public McpClientExtensionsTests(ITestOutputHelper outputHelper) sc.AddSingleton(LoggerFactory); sc.AddMcpServer().WithStdioServerTransport(); // Call WithStdioServerTransport to get the IMcpServer registration, then overwrite default transport with a pipe transport. - sc.AddSingleton(new StdioServerTransport("TestServer", _clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream())); + sc.AddSingleton(new StreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream())); for (int f = 0; f < 10; f++) { string name = $"Method{f}"; @@ -38,6 +43,180 @@ public McpClientExtensionsTests(ITestOutputHelper outputHelper) _serverTask = server.RunAsync(cancellationToken: _cts.Token); } + [Theory] + [InlineData(null, null)] + [InlineData(0.7f, 50)] + [InlineData(1.0f, 100)] + public async Task CreateSamplingHandler_ShouldHandleTextMessages(float? temperature, int? maxTokens) + { + // Arrange + var mockChatClient = new Mock(); + var requestParams = new CreateMessageRequestParams + { + Messages = + [ + new SamplingMessage + { + Role = Role.User, + Content = new Content { Type = "text", Text = "Hello" } + } + ], + Temperature = temperature, + MaxTokens = maxTokens, + Meta = new RequestParamsMetadata + { + ProgressToken = new ProgressToken(), + } + }; + + var cancellationToken = CancellationToken.None; + var expectedResponse = new[] { + new ChatResponseUpdate + { + ModelId = "test-model", + FinishReason = ChatFinishReason.Stop, + Role = ChatRole.Assistant, + Contents = + [ + new TextContent("Hello, World!") { RawRepresentation = "Hello, World!" } + ] + } + }.ToAsyncEnumerable(); + + mockChatClient + .Setup(client => client.GetStreamingResponseAsync(It.IsAny>(), It.IsAny(), cancellationToken)) + .Returns(expectedResponse); + + var handler = McpClientExtensions.CreateSamplingHandler(mockChatClient.Object); + + // Act + var result = await handler(requestParams, Mock.Of>(), cancellationToken); + + // Assert + Assert.NotNull(result); + Assert.Equal("Hello, World!", result.Content.Text); + Assert.Equal("test-model", result.Model); + Assert.Equal("assistant", result.Role); + Assert.Equal("endTurn", result.StopReason); + } + + [Fact] + public async Task CreateSamplingHandler_ShouldHandleImageMessages() + { + // Arrange + var mockChatClient = new Mock(); + var requestParams = new CreateMessageRequestParams + { + Messages = + [ + new SamplingMessage + { + Role = Role.User, + Content = new Content + { + Type = "image", + MimeType = "image/png", + Data = Convert.ToBase64String(new byte[] { 1, 2, 3 }) + } + } + ], + MaxTokens = 100 + }; + + const string expectedData = "SGVsbG8sIFdvcmxkIQ=="; + var cancellationToken = CancellationToken.None; + var expectedResponse = new[] { + new ChatResponseUpdate + { + ModelId = "test-model", + FinishReason = ChatFinishReason.Stop, + Role = ChatRole.Assistant, + Contents = + [ + new DataContent($"data:image/png;base64,{expectedData}") { RawRepresentation = "Hello, World!" } + ] + } + }.ToAsyncEnumerable(); + + mockChatClient + .Setup(client => client.GetStreamingResponseAsync(It.IsAny>(), It.IsAny(), cancellationToken)) + .Returns(expectedResponse); + + var handler = McpClientExtensions.CreateSamplingHandler(mockChatClient.Object); + + // Act + var result = await handler(requestParams, Mock.Of>(), cancellationToken); + + // Assert + Assert.NotNull(result); + Assert.Equal(expectedData, result.Content.Data); + Assert.Equal("test-model", result.Model); + Assert.Equal("assistant", result.Role); + Assert.Equal("endTurn", result.StopReason); + } + + [Fact] + public async Task CreateSamplingHandler_ShouldHandleResourceMessages() + { + // Arrange + const string data = "SGVsbG8sIFdvcmxkIQ=="; + string content = $"data:application/octet-stream;base64,{data}"; + var mockChatClient = new Mock(); + var resource = new BlobResourceContents + { + Blob = data, + MimeType = "application/octet-stream", + Uri = "data:application/octet-stream" + }; + + var requestParams = new CreateMessageRequestParams + { + Messages = + [ + new SamplingMessage + { + Role = Role.User, + Content = new Content + { + Type = "resource", + Resource = resource + }, + } + ], + MaxTokens = 100 + }; + + var cancellationToken = CancellationToken.None; + var expectedResponse = new[] { + new ChatResponseUpdate + { + ModelId = "test-model", + FinishReason = ChatFinishReason.Stop, + AuthorName = "bot", + Role = ChatRole.Assistant, + Contents = + [ + resource.ToAIContent() + ] + } + }.ToAsyncEnumerable(); + + mockChatClient + .Setup(client => client.GetStreamingResponseAsync(It.IsAny>(), It.IsAny(), cancellationToken)) + .Returns(expectedResponse); + + var handler = McpClientExtensions.CreateSamplingHandler(mockChatClient.Object); + + // Act + var result = await handler(requestParams, Mock.Of>(), cancellationToken); + + // Assert + Assert.NotNull(result); + Assert.Equal("test-model", result.Model); + Assert.Equal(ChatRole.Assistant.ToString(), result.Role); + Assert.Equal("endTurn", result.StopReason); + } + public async ValueTask DisposeAsync() { await _cts.CancelAsync(); @@ -53,19 +232,17 @@ public async ValueTask DisposeAsync() private async Task CreateMcpClientForServer() { - var serverStdinWriter = new StreamWriter(_clientToServerPipe.Writer.AsStream()); - var serverStdoutReader = new StreamReader(_serverToClientPipe.Reader.AsStream()); - - var serverConfig = new McpServerConfig() - { - Id = "TestServer", - Name = "TestServer", - TransportType = "ignored", - }; - return await McpClientFactory.CreateAsync( - serverConfig, - createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader, LoggerFactory), + new McpServerConfig() + { + Id = "TestServer", + Name = "TestServer", + TransportType = "ignored", + }, + createTransportFunc: (_, _) => new StreamClientTransport( + serverInput: _clientToServerPipe.Writer.AsStream(), + serverOutput: _serverToClientPipe.Reader.AsStream(), + LoggerFactory), loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); } @@ -75,7 +252,7 @@ public async Task ListToolsAsync_AllToolsReturned() { IMcpClient client = await CreateMcpClientForServer(); - var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.Equal(12, tools.Count); var echo = tools.Single(t => t.Name == "Method4"); var result = await echo.InvokeAsync(new Dictionary() { ["i"] = 42 }, TestContext.Current.CancellationToken); @@ -101,7 +278,7 @@ public async Task EnumerateToolsAsync_AllToolsReturned() { IMcpClient client = await CreateMcpClientForServer(); - await foreach (var tool in client.EnumerateToolsAsync(TestContext.Current.CancellationToken)) + await foreach (var tool in client.EnumerateToolsAsync(cancellationToken: TestContext.Current.CancellationToken)) { if (tool.Name == "Method4") { @@ -113,4 +290,90 @@ public async Task EnumerateToolsAsync_AllToolsReturned() Assert.Fail("Couldn't find target method"); } + + [Fact] + public async Task EnumerateToolsAsync_FlowsJsonSerializerOptions() + { + JsonSerializerOptions options = new(JsonSerializerOptions.Default); + IMcpClient client = await CreateMcpClientForServer(); + bool hasTools = false; + + await foreach (var tool in client.EnumerateToolsAsync(options, TestContext.Current.CancellationToken)) + { + Assert.Same(options, tool.JsonSerializerOptions); + hasTools = true; + } + + foreach (var tool in await client.ListToolsAsync(options, TestContext.Current.CancellationToken)) + { + Assert.Same(options, tool.JsonSerializerOptions); + } + + Assert.True(hasTools); + } + + [Fact] + public async Task EnumerateToolsAsync_HonorsJsonSerializerOptions() + { + JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; + IMcpClient client = await CreateMcpClientForServer(); + + var tool = (await client.ListToolsAsync(emptyOptions, TestContext.Current.CancellationToken)).First(); + await Assert.ThrowsAsync(() => tool.InvokeAsync(new Dictionary { ["i"] = 42 }, TestContext.Current.CancellationToken)); + } + + [Fact] + public async Task SendRequestAsync_HonorsJsonSerializerOptions() + { + JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; + IMcpClient client = await CreateMcpClientForServer(); + + await Assert.ThrowsAsync(() => client.SendRequestAsync("Method4", new() { Name = "tool" }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken)); + } + + [Fact] + public async Task SendNotificationAsync_HonorsJsonSerializerOptions() + { + JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; + IMcpClient client = await CreateMcpClientForServer(); + + await Assert.ThrowsAsync(() => client.SendNotificationAsync("Method4", new { Value = 42 }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken)); + } + + [Fact] + public async Task GetPromptsAsync_HonorsJsonSerializerOptions() + { + JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; + IMcpClient client = await CreateMcpClientForServer(); + + await Assert.ThrowsAsync(() => client.GetPromptAsync("Prompt", new Dictionary { ["i"] = 42 }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken)); + } + + [Fact] + public async Task WithName_ChangesToolName() + { + JsonSerializerOptions options = new(JsonSerializerOptions.Default); + IMcpClient client = await CreateMcpClientForServer(); + + var tool = (await client.ListToolsAsync(options, TestContext.Current.CancellationToken)).First(); + var originalName = tool.Name; + var renamedTool = tool.WithName("RenamedTool"); + + Assert.NotNull(renamedTool); + Assert.Equal("RenamedTool", renamedTool.Name); + Assert.Equal(originalName, tool?.Name); + } + + [Fact] + public async Task WithDescription_ChangesToolDescription() + { + JsonSerializerOptions options = new(JsonSerializerOptions.Default); + IMcpClient client = await CreateMcpClientForServer(); + var tool = (await client.ListToolsAsync(options, TestContext.Current.CancellationToken)).FirstOrDefault(); + var originalDescription = tool?.Description; + var redescribedTool = tool?.WithDescription("ToolWithNewDescription"); + Assert.NotNull(redescribedTool); + Assert.Equal("ToolWithNewDescription", redescribedTool.Description); + Assert.Equal(originalDescription, tool?.Description); + } } \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs index f6b8f2b1e..ae58023fa 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs @@ -1,8 +1,11 @@ -using System.Threading.Channels; +using Microsoft.Extensions.Logging; using ModelContextProtocol.Client; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; +using Moq; +using System.Text.Json; +using System.Threading.Channels; namespace ModelContextProtocol.Tests.Client; @@ -186,7 +189,68 @@ public async Task McpFactory_WithInvalidTransportOptions_ThrowsFormatException(s await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(config, _defaultOptions, cancellationToken: TestContext.Current.CancellationToken)); } - private sealed class NopTransport : ITransport, IClientTransport + [Theory] + [InlineData(typeof(NopTransport))] + [InlineData(typeof(FailureTransport))] + public async Task CreateAsync_WithCapabilitiesOptions(Type transportType) + { + // Arrange + var serverConfig = new McpServerConfig + { + Id = "TestServer", + Name = "TestServer", + TransportType = "stdio", + Location = "test-location" + }; + + var clientOptions = new McpClientOptions + { + ClientInfo = new Implementation + { + Name = "TestClient", + Version = "1.0.0.0" + }, + Capabilities = new ClientCapabilities + { + Sampling = new SamplingCapability + { + SamplingHandler = (c, p, t) => Task.FromResult( + new CreateMessageResult { + Content = new Content { Text = "result" }, + Model = "test-model", + Role = "test-role", + StopReason = "endTurn" + }), + }, + Roots = new RootsCapability + { + ListChanged = true, + RootsHandler = (t, r) => Task.FromResult(new ListRootsResult { Roots = [] }), + } + } + }; + + var clientTransport = (IClientTransport?)Activator.CreateInstance(transportType); + IMcpClient? client = null; + + var actionTask = McpClientFactory.CreateAsync(serverConfig, clientOptions, (config, logger) => clientTransport ?? new NopTransport(), new Mock().Object, CancellationToken.None); + + // Act + if (clientTransport is FailureTransport) + { + var exception = await Assert.ThrowsAsync(async() => await actionTask); + Assert.Equal(FailureTransport.ExpectedMessage, exception.Message); + } + else + { + client = await actionTask; + + // Assert + Assert.NotNull(client); + } + } + + private class NopTransport : ITransport, IClientTransport { private readonly Channel _channel = Channel.CreateUnbounded(); @@ -198,7 +262,7 @@ private sealed class NopTransport : ITransport, IClientTransport public ValueTask DisposeAsync() => default; - public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + public virtual Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) { switch (message) { @@ -206,16 +270,16 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella _channel.Writer.TryWrite(new JsonRpcResponse { Id = ((JsonRpcRequest)message).Id, - Result = new InitializeResult() + Result = JsonSerializer.SerializeToNode(new InitializeResult { Capabilities = new ServerCapabilities(), ProtocolVersion = "2024-11-05", - ServerInfo = new Implementation() + ServerInfo = new Implementation { Name = "NopTransport", Version = "1.0.0" }, - } + }), }); break; } @@ -223,4 +287,14 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella return Task.CompletedTask; } } + + private sealed class FailureTransport : NopTransport + { + public const string ExpectedMessage = "Something failed"; + + public override Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + { + throw new InvalidOperationException(ExpectedMessage); + } + } } diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index 9b598d332..a6d0a9b61 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -68,7 +68,7 @@ public async Task ListTools_Stdio(string clientId) // act await using var client = await _fixture.CreateClientAsync(clientId); - var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); // assert Assert.NotEmpty(tools); @@ -88,7 +88,7 @@ public async Task CallTool_Stdio_EchoServer(string clientId) { ["message"] = "Hello MCP!" }, - TestContext.Current.CancellationToken + cancellationToken: TestContext.Current.CancellationToken ); // assert @@ -106,7 +106,7 @@ public async Task CallTool_Stdio_ViaAIFunction_EchoServer(string clientId) // act await using var client = await _fixture.CreateClientAsync(clientId); - var aiFunctions = await client.ListToolsAsync(TestContext.Current.CancellationToken); + var aiFunctions = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); var echo = aiFunctions.Single(t => t.Name == "echo"); var result = await echo.InvokeAsync([new KeyValuePair("message", "Hello MCP!")], TestContext.Current.CancellationToken); @@ -140,7 +140,7 @@ public async Task GetPrompt_Stdio_SimplePrompt(string clientId) // act await using var client = await _fixture.CreateClientAsync(clientId); - var result = await client.GetPromptAsync("simple_prompt", null, TestContext.Current.CancellationToken); + var result = await client.GetPromptAsync("simple_prompt", null, cancellationToken: TestContext.Current.CancellationToken); // assert Assert.NotNull(result); @@ -160,7 +160,7 @@ public async Task GetPrompt_Stdio_ComplexPrompt(string clientId) { "temperature", "0.7" }, { "style", "formal" } }; - var result = await client.GetPromptAsync("complex_prompt", arguments, TestContext.Current.CancellationToken); + var result = await client.GetPromptAsync("complex_prompt", arguments, cancellationToken: TestContext.Current.CancellationToken); // assert Assert.NotNull(result); @@ -176,7 +176,7 @@ public async Task GetPrompt_NonExistent_ThrowsException(string clientId) // act await using var client = await _fixture.CreateClientAsync(clientId); await Assert.ThrowsAsync(() => - client.GetPromptAsync("non_existent_prompt", null, TestContext.Current.CancellationToken)); + client.GetPromptAsync("non_existent_prompt", null, cancellationToken: TestContext.Current.CancellationToken)); } [Theory] @@ -259,7 +259,7 @@ public async Task SubscribeResource_Stdio() await using var client = await _fixture.CreateClientAsync(clientId); client.AddNotificationHandler(NotificationMethods.ResourceUpdatedNotification, (notification) => { - var notificationParams = JsonSerializer.Deserialize(notification.Params!.ToString() ?? string.Empty); + var notificationParams = JsonSerializer.Deserialize(notification.Params); tcs.TrySetResult(true); return Task.CompletedTask; }); @@ -280,7 +280,7 @@ public async Task UnsubscribeResource_Stdio() await using var client = await _fixture.CreateClientAsync(clientId); client.AddNotificationHandler(NotificationMethods.ResourceUpdatedNotification, (notification) => { - var notificationParams = JsonSerializer.Deserialize(notification.Params!.ToString() ?? string.Empty); + var notificationParams = JsonSerializer.Deserialize(notification.Params); receivedNotification.TrySetResult(true); return Task.CompletedTask; }); @@ -355,7 +355,7 @@ public async Task Sampling_Stdio(string clientId) { Sampling = new() { - SamplingHandler = (_, _) => + SamplingHandler = (_, _, _) => { samplingHandlerCalls++; return Task.FromResult(new CreateMessageResult @@ -381,7 +381,7 @@ public async Task Sampling_Stdio(string clientId) ["prompt"] = "Test prompt", ["maxTokens"] = 100 }, - TestContext.Current.CancellationToken); + cancellationToken: TestContext.Current.CancellationToken); // assert Assert.NotNull(result); @@ -429,7 +429,7 @@ public async Task Notifications_Stdio(string clientId) // Verify we can send notifications without errors await client.SendNotificationAsync(NotificationMethods.RootsUpdatedNotification, cancellationToken: TestContext.Current.CancellationToken); - await client.SendNotificationAsync("test/notification", new { test = true }, TestContext.Current.CancellationToken); + await client.SendNotificationAsync("test/notification", new { test = true }, cancellationToken: TestContext.Current.CancellationToken); // assert // no response to check, if no exception is thrown, it's a success @@ -467,7 +467,7 @@ public async Task CallTool_Stdio_MemoryServer() var result = await client.CallToolAsync( "read_graph", new Dictionary(), - TestContext.Current.CancellationToken); + cancellationToken: TestContext.Current.CancellationToken); // assert Assert.NotNull(result); @@ -485,7 +485,7 @@ public async Task ListToolsAsync_UsingEverythingServer_ToolsAreProperlyCalled() _fixture.EverythingServerConfig, _fixture.DefaultOptions, cancellationToken: TestContext.Current.CancellationToken); - var mappedTools = await client.ListToolsAsync(TestContext.Current.CancellationToken); + var mappedTools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); // Create the chat client. using IChatClient chatClient = new OpenAIClient(s_openAIKey).AsChatClient("gpt-4o-mini") @@ -511,6 +511,9 @@ public async Task ListToolsAsync_UsingEverythingServer_ToolsAreProperlyCalled() [Fact(Skip = "Requires OpenAI API Key", SkipWhen = nameof(NoOpenAIKeySet))] public async Task SamplingViaChatClient_RequestResponseProperlyPropagated() { + var samplingHandler = new OpenAIClient(s_openAIKey) + .AsChatClient("gpt-4o-mini") + .CreateSamplingHandler(); await using var client = await McpClientFactory.CreateAsync(_fixture.EverythingServerConfig, new() { ClientInfo = new() { Name = nameof(SamplingViaChatClient_RequestResponseProperlyPropagated), Version = "1.0.0" }, @@ -518,7 +521,7 @@ public async Task SamplingViaChatClient_RequestResponseProperlyPropagated() { Sampling = new() { - SamplingHandler = new OpenAIClient(s_openAIKey).AsChatClient("gpt-4o-mini").CreateSamplingHandler(), + SamplingHandler = samplingHandler, }, }, }, cancellationToken: TestContext.Current.CancellationToken); @@ -526,7 +529,7 @@ public async Task SamplingViaChatClient_RequestResponseProperlyPropagated() var result = await client.CallToolAsync("sampleLLM", new Dictionary() { ["prompt"] = "In just a few words, what is the most famous tower in Paris?", - }, TestContext.Current.CancellationToken); + }, cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.NotEmpty(result.Content); @@ -539,22 +542,11 @@ public async Task SamplingViaChatClient_RequestResponseProperlyPropagated() [MemberData(nameof(GetClients))] public async Task SetLoggingLevel_ReceivesLoggingMessages(string clientId) { - // arrange - JsonSerializerOptions jsonSerializerOptions = new(JsonSerializerDefaults.Web) - { - TypeInfoResolver = new DefaultJsonTypeInfoResolver(), - Converters = { new JsonStringEnumConverter() }, - DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, - NumberHandling = JsonNumberHandling.AllowReadingFromString, - Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping, - }; - TaskCompletionSource receivedNotification = new(); await using var client = await _fixture.CreateClientAsync(clientId); client.AddNotificationHandler(NotificationMethods.LoggingMessageNotification, (notification) => { - var loggingMessageNotificationParameters = JsonSerializer.Deserialize(notification.Params!.ToString() ?? string.Empty, - jsonSerializerOptions); + var loggingMessageNotificationParameters = JsonSerializer.Deserialize(notification.Params); if (loggingMessageNotificationParameters is not null) { receivedNotification.TrySetResult(true); diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs index 846ceebcc..26d834b97 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs @@ -6,12 +6,13 @@ using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; -using ModelContextProtocol.Tests.Transport; using ModelContextProtocol.Tests.Utils; using System.ComponentModel; using System.IO.Pipelines; using System.Threading.Channels; +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + namespace ModelContextProtocol.Tests.Configuration; public class McpServerBuilderExtensionsPromptsTests : LoggedTest, IAsyncDisposable @@ -28,9 +29,72 @@ public McpServerBuilderExtensionsPromptsTests(ITestOutputHelper testOutputHelper { ServiceCollection sc = new(); sc.AddSingleton(LoggerFactory); - _builder = sc.AddMcpServer().WithStdioServerTransport().WithPrompts(); + _builder = sc + .AddMcpServer() + .WithStdioServerTransport() + .WithListPromptsHandler(async (request, cancellationToken) => + { + var cursor = request.Params?.Cursor; + switch (cursor) + { + case null: + return new() + { + NextCursor = "abc", + Prompts = [new() + { + Name = "FirstCustomPrompt", + Description = "First prompt returned by custom handler", + }], + }; + + case "abc": + return new() + { + NextCursor = "def", + Prompts = [new() + { + Name = "SecondCustomPrompt", + Description = "Second prompt returned by custom handler", + }], + }; + + case "def": + return new() + { + NextCursor = null, + Prompts = [new() + { + Name = "FinalCustomPrompt", + Description = "Final prompt returned by custom handler", + }], + }; + + default: + throw new Exception("Unexpected cursor"); + } + }) + .WithGetPromptHandler(async (request, cancellationToken) => + { + switch (request.Params?.Name) + { + case "FirstCustomPrompt": + case "SecondCustomPrompt": + case "FinalCustomPrompt": + return new GetPromptResult() + { + Messages = [new() { Role = Role.User, Content = new() { Text = $"hello from {request.Params.Name}", Type = "text" } }], + }; + + default: + throw new Exception($"Unknown prompt '{request.Params?.Name}'"); + } + }) + .WithPrompts(); + + // Call WithStdioServerTransport to get the IMcpServer registration, then overwrite default transport with a pipe transport. - sc.AddSingleton(new StdioServerTransport("TestServer", _clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream(), LoggerFactory)); + sc.AddSingleton(new StreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream(), loggerFactory: LoggerFactory)); sc.AddSingleton(new ObjectWithId()); _serviceProvider = sc.BuildServiceProvider(); @@ -55,19 +119,17 @@ public async ValueTask DisposeAsync() private async Task CreateMcpClientForServer() { - var serverStdinWriter = new StreamWriter(_clientToServerPipe.Writer.AsStream()); - var serverStdoutReader = new StreamReader(_serverToClientPipe.Reader.AsStream()); - - var serverConfig = new McpServerConfig() - { - Id = "TestServer", - Name = "TestServer", - TransportType = "ignored", - }; - return await McpClientFactory.CreateAsync( - serverConfig, - createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader, LoggerFactory), + new McpServerConfig() + { + Id = "TestServer", + Name = "TestServer", + TransportType = "ignored", + }, + createTransportFunc: (_, _) => new StreamClientTransport( + serverInput: _clientToServerPipe.Writer.AsStream(), + serverOutput: _serverToClientPipe.Reader.AsStream(), + LoggerFactory), loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); } @@ -87,12 +149,12 @@ public async Task Can_List_And_Call_Registered_Prompts() IMcpClient client = await CreateMcpClientForServer(); var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); - Assert.Equal(3, prompts.Count); + Assert.Equal(6, prompts.Count); var prompt = prompts.First(t => t.Name == nameof(SimplePrompts.ReturnsChatMessages)); Assert.Equal("Returns chat messages", prompt.Description); - var result = await prompt.GetAsync(new Dictionary() { ["message"] = "hello" }, TestContext.Current.CancellationToken); + var result = await prompt.GetAsync(new Dictionary() { ["message"] = "hello" }, cancellationToken: TestContext.Current.CancellationToken); var chatMessages = result.ToChatMessages(); Assert.NotNull(chatMessages); @@ -100,6 +162,14 @@ public async Task Can_List_And_Call_Registered_Prompts() Assert.Equal(2, chatMessages.Count); Assert.Equal("The prompt is: hello", chatMessages[0].Text); Assert.Equal("Summarize.", chatMessages[1].Text); + + prompt = prompts.First(t => t.Name == "SecondCustomPrompt"); + Assert.Equal("Second prompt returned by custom handler", prompt.Description); + result = await prompt.GetAsync(cancellationToken: TestContext.Current.CancellationToken); + chatMessages = result.ToChatMessages(); + Assert.NotNull(chatMessages); + Assert.Single(chatMessages); + Assert.Equal("hello from SecondCustomPrompt", chatMessages[0].Text); } [Fact] @@ -108,7 +178,7 @@ public async Task Can_Be_Notified_Of_Prompt_Changes() IMcpClient client = await CreateMcpClientForServer(); var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); - Assert.Equal(3, prompts.Count); + Assert.Equal(6, prompts.Count); Channel listChanged = Channel.CreateUnbounded(); client.AddNotificationHandler("notifications/prompts/list_changed", notification => @@ -129,7 +199,7 @@ public async Task Can_Be_Notified_Of_Prompt_Changes() await notificationRead; prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); - Assert.Equal(4, prompts.Count); + Assert.Equal(7, prompts.Count); Assert.Contains(prompts, t => t.Name == "NewPrompt"); notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken); @@ -138,7 +208,7 @@ public async Task Can_Be_Notified_Of_Prompt_Changes() await notificationRead; prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); - Assert.Equal(3, prompts.Count); + Assert.Equal(6, prompts.Count); Assert.DoesNotContain(prompts, t => t.Name == "NewPrompt"); } @@ -220,7 +290,7 @@ public void Register_Prompts_From_Multiple_Sources() Assert.Contains(services.GetServices(), t => t.ProtocolPrompt.Name == nameof(MorePrompts.AnotherPrompt)); } - [McpServerToolType] + [McpServerPromptType] public sealed class SimplePrompts(ObjectWithId? id = null) { [McpServerPrompt, Description("Returns chat messages")] diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 73fdeadef..3c8981b63 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -7,7 +7,6 @@ using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; -using ModelContextProtocol.Tests.Transport; using ModelContextProtocol.Tests.Utils; using System.Collections.Concurrent; using System.ComponentModel; @@ -16,6 +15,8 @@ using System.Text.RegularExpressions; using System.Threading.Channels; +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + namespace ModelContextProtocol.Tests.Configuration; public class McpServerBuilderExtensionsToolsTests : LoggedTest, IAsyncDisposable @@ -32,9 +33,92 @@ public McpServerBuilderExtensionsToolsTests(ITestOutputHelper testOutputHelper) { ServiceCollection sc = new(); sc.AddSingleton(LoggerFactory); - _builder = sc.AddMcpServer().WithStdioServerTransport().WithTools(); + _builder = sc + .AddMcpServer() + .WithStdioServerTransport() + .WithListToolsHandler(async (request, cancellationToken) => + { + var cursor = request.Params?.Cursor; + switch (cursor) + { + case null: + return new() + { + NextCursor = "abc", + Tools = [new() + { + Name = "FirstCustomTool", + Description = "First tool returned by custom handler", + InputSchema = JsonSerializer.Deserialize(""" + { + "type": "object", + "properties": {}, + "required": [] + } + """), + }], + }; + + case "abc": + return new() + { + NextCursor = "def", + Tools = [new() + { + Name = "SecondCustomTool", + Description = "Second tool returned by custom handler", + InputSchema = JsonSerializer.Deserialize(""" + { + "type": "object", + "properties": {}, + "required": [] + } + """), + }], + }; + + case "def": + return new() + { + NextCursor = null, + Tools = [new() + { + Name = "FinalCustomTool", + Description = "Third tool returned by custom handler", + InputSchema = JsonSerializer.Deserialize(""" + { + "type": "object", + "properties": {}, + "required": [] + } + """), + }], + }; + + default: + throw new Exception("Unexpected cursor"); + } + }) + .WithCallToolHandler(async (request, cancellationToken) => + { + switch (request.Params?.Name) + { + case "FirstCustomTool": + case "SecondCustomTool": + case "FinalCustomTool": + return new CallToolResponse() + { + Content = [new Content() { Text = $"{request.Params.Name}Result", Type = "text" }], + }; + + default: + throw new Exception($"Unknown tool '{request.Params?.Name}'"); + } + }) + .WithTools(); + // Call WithStdioServerTransport to get the IMcpServer registration, then overwrite default transport with a pipe transport. - sc.AddSingleton(new StdioServerTransport("TestServer", _clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream(), LoggerFactory)); + sc.AddSingleton(new StreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream(), loggerFactory: LoggerFactory)); sc.AddSingleton(new ObjectWithId()); _serviceProvider = sc.BuildServiceProvider(); @@ -59,19 +143,17 @@ public async ValueTask DisposeAsync() private async Task CreateMcpClientForServer() { - var serverStdinWriter = new StreamWriter(_clientToServerPipe.Writer.AsStream()); - var serverStdoutReader = new StreamReader(_serverToClientPipe.Reader.AsStream()); - - var serverConfig = new McpServerConfig() - { - Id = "TestServer", - Name = "TestServer", - TransportType = "ignored", - }; - return await McpClientFactory.CreateAsync( - serverConfig, - createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader, LoggerFactory), + new McpServerConfig() + { + Id = "TestServer", + Name = "TestServer", + TransportType = "ignored", + }, + createTransportFunc: (_, _) => new StreamClientTransport( + serverInput: _clientToServerPipe.Writer.AsStream(), + _serverToClientPipe.Reader.AsStream(), + LoggerFactory), loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); } @@ -90,8 +172,8 @@ public async Task Can_List_Registered_Tools() { IMcpClient client = await CreateMcpClientForServer(); - var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - Assert.Equal(13, tools.Count); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + Assert.Equal(16, tools.Count); McpClientTool echoTool = tools.First(t => t.Name == "Echo"); Assert.Equal("Echo", echoTool.Name); @@ -117,28 +199,26 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T var stdinPipe = new Pipe(); var stdoutPipe = new Pipe(); - await using var transport = new StdioServerTransport($"TestServer_{i}", stdinPipe.Reader.AsStream(), stdoutPipe.Writer.AsStream()); + await using var transport = new StreamServerTransport(stdinPipe.Reader.AsStream(), stdoutPipe.Writer.AsStream()); await using var server = McpServerFactory.Create(transport, options, loggerFactory, _serviceProvider); var serverRunTask = server.RunAsync(TestContext.Current.CancellationToken); - using var serverStdinWriter = new StreamWriter(stdinPipe.Writer.AsStream()); - using var serverStdoutReader = new StreamReader(stdoutPipe.Reader.AsStream()); - - var serverConfig = new McpServerConfig() - { - Id = $"TestServer_{i}", - Name = $"TestServer_{i}", - TransportType = "ignored", - }; - await using (var client = await McpClientFactory.CreateAsync( - serverConfig, - createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader, LoggerFactory), + new McpServerConfig() + { + Id = $"TestServer_{i}", + Name = $"TestServer_{i}", + TransportType = "ignored", + }, + createTransportFunc: (_, _) => new StreamClientTransport( + serverInput: stdinPipe.Writer.AsStream(), + serverOutput: stdoutPipe.Reader.AsStream(), + LoggerFactory), loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)) { - var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - Assert.Equal(13, tools.Count); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + Assert.Equal(16, tools.Count); McpClientTool echoTool = tools.First(t => t.Name == "Echo"); Assert.Equal("Echo", echoTool.Name); @@ -164,8 +244,8 @@ public async Task Can_Be_Notified_Of_Tool_Changes() { IMcpClient client = await CreateMcpClientForServer(); - var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - Assert.Equal(13, tools.Count); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + Assert.Equal(16, tools.Count); Channel listChanged = Channel.CreateUnbounded(); client.AddNotificationHandler(NotificationMethods.ToolListChangedNotification, notification => @@ -185,8 +265,8 @@ public async Task Can_Be_Notified_Of_Tool_Changes() serverTools.Add(newTool); await notificationRead; - tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - Assert.Equal(14, tools.Count); + tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + Assert.Equal(17, tools.Count); Assert.Contains(tools, t => t.Name == "NewTool"); notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken); @@ -194,8 +274,8 @@ public async Task Can_Be_Notified_Of_Tool_Changes() serverTools.Remove(newTool); await notificationRead; - tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - Assert.Equal(13, tools.Count); + tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + Assert.Equal(16, tools.Count); Assert.DoesNotContain(tools, t => t.Name == "NewTool"); } @@ -207,7 +287,7 @@ public async Task Can_Call_Registered_Tool() var result = await client.CallToolAsync( "Echo", new Dictionary() { ["message"] = "Peter" }, - TestContext.Current.CancellationToken); + cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.NotNull(result.Content); @@ -225,13 +305,20 @@ public async Task Can_Call_Registered_Tool_With_Array_Result() var result = await client.CallToolAsync( "EchoArray", new Dictionary() { ["message"] = "Peter" }, - TestContext.Current.CancellationToken); + cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(result.Content); Assert.NotEmpty(result.Content); - Assert.Equal("hello Peter", result.Content[0].Text); Assert.Equal("hello2 Peter", result.Content[1].Text); + + result = await client.CallToolAsync( + "SecondCustomTool", + cancellationToken: TestContext.Current.CancellationToken); + Assert.NotNull(result); + Assert.NotNull(result.Content); + Assert.NotEmpty(result.Content); + Assert.Equal("SecondCustomToolResult", result.Content[0].Text); } [Fact] @@ -421,20 +508,33 @@ public void WithTools_Parameters_Satisfiable_From_DI(bool parameterInServices) } [Theory] - [InlineData(false)] - [InlineData(true)] - public void WithToolsFromAssembly_Parameters_Satisfiable_From_DI(bool parameterInServices) + [InlineData(ServiceLifetime.Singleton)] + [InlineData(ServiceLifetime.Scoped)] + [InlineData(ServiceLifetime.Transient)] + [InlineData(null)] + public void WithToolsFromAssembly_Parameters_Satisfiable_From_DI(ServiceLifetime? lifetime) { ServiceCollection sc = new(); - if (parameterInServices) + switch (lifetime) { - sc.AddSingleton(new ComplexObject()); + case ServiceLifetime.Singleton: + sc.AddSingleton(new ComplexObject()); + break; + + case ServiceLifetime.Scoped: + sc.AddScoped(_ => new ComplexObject()); + break; + + case ServiceLifetime.Transient: + sc.AddTransient(_ => new ComplexObject()); + break; } + sc.AddMcpServer().WithToolsFromAssembly(); IServiceProvider services = sc.BuildServiceProvider(); McpServerTool tool = services.GetServices().First(t => t.ProtocolTool.Name == "EchoComplex"); - if (parameterInServices) + if (lifetime is not null) { Assert.DoesNotContain("\"complex\"", JsonSerializer.Serialize(tool.ProtocolTool.InputSchema)); } @@ -449,7 +549,7 @@ public async Task Recognizes_Parameter_Types() { IMcpClient client = await CreateMcpClientForServer(); - var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(tools); Assert.NotEmpty(tools); @@ -525,26 +625,25 @@ public async Task HandlesIProgressParameter() IMcpClient client = await CreateMcpClientForServer(); client.AddNotificationHandler(NotificationMethods.ProgressNotification, notification => { - ProgressNotification pn = JsonSerializer.Deserialize((JsonElement)notification.Params!)!; + ProgressNotification pn = JsonSerializer.Deserialize(notification.Params)!; notifications.Enqueue(pn); return Task.CompletedTask; }); - var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(tools); Assert.NotEmpty(tools); McpClientTool progressTool = tools.First(t => t.Name == nameof(EchoTool.SendsProgressNotifications)); - var result = await client.SendRequestAsync(new JsonRpcRequest() - { - Method = RequestMethods.ToolsCall, - Params = new CallToolRequestParams() + var result = await client.SendRequestAsync( + RequestMethods.ToolsCall, + new CallToolRequestParams { Name = progressTool.ProtocolTool.Name, Meta = new() { ProgressToken = new("abc123") }, }, - }, TestContext.Current.CancellationToken); + cancellationToken: TestContext.Current.CancellationToken); Assert.Contains("done", JsonSerializer.Serialize(result)); SpinWait.SpinUntil(() => notifications.Count == 10, TimeSpan.FromSeconds(10)); @@ -553,7 +652,7 @@ public async Task HandlesIProgressParameter() Assert.Equal(10, array.Length); for (int i = 0; i < array.Length; i++) { - Assert.Equal("\"abc123\"", array[i].ProgressToken.ToString()); + Assert.Equal("abc123", array[i].ProgressToken.ToString()); Assert.Equal(i, array[i].Progress.Progress); Assert.Equal(10, array[i].Progress.Total); Assert.Equal($"Progress {i}", array[i].Progress.Message); @@ -565,18 +664,17 @@ public async Task CancellationNotificationsPropagateToToolTokens() { IMcpClient client = await CreateMcpClientForServer(); - var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(tools); Assert.NotEmpty(tools); McpClientTool cancelableTool = tools.First(t => t.Name == nameof(EchoTool.InfiniteCancelableOperation)); var requestId = new RequestId(Guid.NewGuid().ToString()); - var invokeTask = client.SendRequestAsync(new JsonRpcRequest() - { - Method = RequestMethods.ToolsCall, - Id = requestId, - Params = new CallToolRequestParams() { Name = cancelableTool.ProtocolTool.Name }, - }, TestContext.Current.CancellationToken); + var invokeTask = client.SendRequestAsync( + RequestMethods.ToolsCall, + new CallToolRequestParams { Name = cancelableTool.ProtocolTool.Name }, + requestId: requestId, + cancellationToken: TestContext.Current.CancellationToken); await client.SendNotificationAsync( NotificationMethods.CancelledNotification, diff --git a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs new file mode 100644 index 000000000..583ae2743 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs @@ -0,0 +1,79 @@ +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Server; +using OpenTelemetry.Trace; +using System.Diagnostics; +using System.IO.Pipelines; + +namespace ModelContextProtocol.Tests; + +[Collection(nameof(DisableParallelization))] +public class DiagnosticTests +{ + [Fact] + public async Task Session_TracksActivities() + { + var activities = new List(); + + using (var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder() + .AddSource("Experimental.ModelContextProtocol") + .AddInMemoryExporter(activities) + .Build()) + { + await RunConnected(async (client, server) => + { + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + Assert.NotNull(tools); + Assert.NotEmpty(tools); + + var tool = tools.First(t => t.Name == "DoubleValue"); + await tool.InvokeAsync(new Dictionary() { ["amount"] = 42 }, TestContext.Current.CancellationToken); + }); + } + + Assert.NotEmpty(activities); + + Activity toolCallActivity = activities.First(a => + a.Tags.Any(t => t.Key == "rpc.method" && t.Value == "tools/call")); + Assert.Equal("DoubleValue", toolCallActivity.Tags.First(t => t.Key == "mcp.request.params.name").Value); + } + + private static async Task RunConnected(Func action) + { + Pipe clientToServerPipe = new(), serverToClientPipe = new(); + StreamServerTransport serverTransport = new(clientToServerPipe.Reader.AsStream(), serverToClientPipe.Writer.AsStream()); + StreamClientTransport clientTransport = new(clientToServerPipe.Writer.AsStream(), serverToClientPipe.Reader.AsStream()); + + Task serverTask; + + await using (IMcpServer server = McpServerFactory.Create(serverTransport, new() + { + ServerInfo = new Implementation { Name = "TestServer", Version = "1.0.0" }, + Capabilities = new() + { + Tools = new() + { + ToolCollection = [McpServerTool.Create((int amount) => amount * 2, new() { Name = "DoubleValue", Description = "Doubles the value." })], + } + } + })) + { + serverTask = server.RunAsync(TestContext.Current.CancellationToken); + + await using (IMcpClient client = await McpClientFactory.CreateAsync(new() + { + Id = "TestServer", + Name = "TestServer", + TransportType = TransportTypes.StdIo, + }, + createTransportFunc: (_, __) => clientTransport, + cancellationToken: TestContext.Current.CancellationToken)) + { + await action(client, server); + } + } + + await serverTask; + } +} diff --git a/tests/ModelContextProtocol.Tests/McpJsonUtilitiesTests.cs b/tests/ModelContextProtocol.Tests/McpJsonUtilitiesTests.cs new file mode 100644 index 000000000..29385fb80 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/McpJsonUtilitiesTests.cs @@ -0,0 +1,28 @@ +using ModelContextProtocol.Utils.Json; +using System.Text.Json; + +namespace ModelContextProtocol.Tests; + +public static class McpJsonUtilitiesTests +{ + [Fact] + public static void DefaultOptions_IsSingleton() + { + var options = McpJsonUtilities.DefaultOptions; + + Assert.NotNull(options); + Assert.True(options.IsReadOnly); + Assert.Same(options, McpJsonUtilities.DefaultOptions); + } + + [Fact] + public static void DefaultOptions_UseReflectionWhenEnabled() + { + var options = McpJsonUtilities.DefaultOptions; + bool isReflectionEnabled = JsonSerializer.IsReflectionEnabledByDefault; + Type anonType = new { Id = 42 }.GetType(); + + Assert.True(isReflectionEnabled); // To be disabled once https://github.com/dotnet/extensions/pull/6241 is incorporated + Assert.Equal(isReflectionEnabled, options.TryGetTypeInfo(anonType, out _)); + } +} diff --git a/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj b/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj index 7a239ef29..c8dc1ca90 100644 --- a/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj +++ b/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj @@ -1,7 +1,7 @@  - net8.0 + net9.0;net8.0 enable enable Latest @@ -20,11 +20,14 @@ runtime; build; native; contentfiles; analyzers; buildtransitive all + + + diff --git a/tests/ModelContextProtocol.Tests/Protocol/RequestIdTests.cs b/tests/ModelContextProtocol.Tests/Protocol/RequestIdTests.cs new file mode 100644 index 000000000..1df5ccb73 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Protocol/RequestIdTests.cs @@ -0,0 +1,38 @@ +using ModelContextProtocol.Protocol.Messages; +using System.Text.Json; + +namespace ModelContextProtocol.Tests.Protocol; + +public class RequestIdTests +{ + [Fact] + public void StringCtor_Roundtrips() + { + RequestId id = new("test-id"); + Assert.Equal("test-id", id.ToString()); + Assert.Equal("\"test-id\"", JsonSerializer.Serialize(id)); + Assert.Same("test-id", id.Id); + + Assert.True(id.Equals(new("test-id"))); + Assert.False(id.Equals(new("tEst-id"))); + Assert.Equal("test-id".GetHashCode(), id.GetHashCode()); + + Assert.Equal(id, JsonSerializer.Deserialize(JsonSerializer.Serialize(id))); + } + + [Fact] + public void Int64Ctor_Roundtrips() + { + RequestId id = new(42); + Assert.Equal("42", id.ToString()); + Assert.Equal("42", JsonSerializer.Serialize(id)); + Assert.Equal(42, Assert.IsType(id.Id)); + + Assert.True(id.Equals(new(42))); + Assert.False(id.Equals(new(43))); + Assert.False(id.Equals(new("42"))); + Assert.Equal(42L.GetHashCode(), id.GetHashCode()); + + Assert.Equal(id, JsonSerializer.Deserialize(JsonSerializer.Serialize(id))); + } +} diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs index 25c5123ef..ae640eccd 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs @@ -1,8 +1,6 @@ -using ModelContextProtocol.Protocol.Transport; -using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; using ModelContextProtocol.Tests.Utils; -using Moq; namespace ModelContextProtocol.Tests.Server; @@ -25,7 +23,8 @@ public McpServerFactoryTests(ITestOutputHelper testOutputHelper) public async Task Create_Should_Initialize_With_Valid_Parameters() { // Arrange & Act - await using IMcpServer server = McpServerFactory.Create(Mock.Of(), _options, LoggerFactory); + await using var transport = new TestServerTransport(); + await using IMcpServer server = McpServerFactory.Create(transport, _options, LoggerFactory); // Assert Assert.NotNull(server); @@ -39,9 +38,10 @@ public void Create_Throws_For_Null_ServerTransport() } [Fact] - public void Create_Throws_For_Null_Options() + public async Task Create_Throws_For_Null_Options() { // Arrange, Act & Assert - Assert.Throws("serverOptions", () => McpServerFactory.Create(Mock.Of(), null!, LoggerFactory)); + await using var transport = new TestServerTransport(); + Assert.Throws("serverOptions", () => McpServerFactory.Create(transport, null!, LoggerFactory)); } } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index 65ed07dff..e0a4a6e49 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -1,28 +1,23 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; -using ModelContextProtocol.Client; using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; using ModelContextProtocol.Tests.Utils; -using Moq; using System.Reflection; +using System.Text.Json; +using System.Text.Json.Nodes; namespace ModelContextProtocol.Tests.Server; public class McpServerTests : LoggedTest { - private readonly Mock _serverTransport; private readonly McpServerOptions _options; - private readonly IServiceProvider _serviceProvider; public McpServerTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper) { - _serverTransport = new Mock(); _options = CreateOptions(); - _serviceProvider = new ServiceCollection().BuildServiceProvider(); } private static McpServerOptions CreateOptions(ServerCapabilities? capabilities = null) @@ -40,7 +35,8 @@ private static McpServerOptions CreateOptions(ServerCapabilities? capabilities = public async Task Constructor_Should_Initialize_With_Valid_Parameters() { // Arrange & Act - await using var server = McpServerFactory.Create(_serverTransport.Object, _options, LoggerFactory, _serviceProvider); + await using var transport = new TestServerTransport(); + await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); // Assert Assert.NotNull(server); @@ -50,21 +46,23 @@ public async Task Constructor_Should_Initialize_With_Valid_Parameters() public void Constructor_Throws_For_Null_Transport() { // Arrange, Act & Assert - Assert.Throws(() => McpServerFactory.Create(null!, _options, LoggerFactory, _serviceProvider)); + Assert.Throws(() => McpServerFactory.Create(null!, _options, LoggerFactory)); } [Fact] - public void Constructor_Throws_For_Null_Options() + public async Task Constructor_Throws_For_Null_Options() { // Arrange, Act & Assert - Assert.Throws(() => McpServerFactory.Create(_serverTransport.Object, null!, LoggerFactory, _serviceProvider)); + await using var transport = new TestServerTransport(); + Assert.Throws(() => McpServerFactory.Create(transport, null!, LoggerFactory)); } [Fact] public async Task Constructor_Does_Not_Throw_For_Null_Logger() { // Arrange & Act - await using var server = McpServerFactory.Create(_serverTransport.Object, _options, null, _serviceProvider); + await using var transport = new TestServerTransport(); + await using var server = McpServerFactory.Create(transport, _options, null); // Assert Assert.NotNull(server); @@ -74,7 +72,8 @@ public async Task Constructor_Does_Not_Throw_For_Null_Logger() public async Task Constructor_Does_Not_Throw_For_Null_ServiceProvider() { // Arrange & Act - await using var server = McpServerFactory.Create(_serverTransport.Object, _options, LoggerFactory, null); + await using var transport = new TestServerTransport(); + await using var server = McpServerFactory.Create(transport, _options, LoggerFactory, null); // Assert Assert.NotNull(server); @@ -84,27 +83,23 @@ public async Task Constructor_Does_Not_Throw_For_Null_ServiceProvider() public async Task RunAsync_Should_Throw_InvalidOperationException_If_Already_Running() { // Arrange - await using var server = McpServerFactory.Create(_serverTransport.Object, _options, LoggerFactory, _serviceProvider); + await using var transport = new TestServerTransport(); + await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); var runTask = server.RunAsync(TestContext.Current.CancellationToken); // Act & Assert await Assert.ThrowsAsync(() => server.RunAsync(TestContext.Current.CancellationToken)); - try - { - await runTask; - } - catch (NullReferenceException) - { - // _serverTransport.Object returns a null MessageReader - } + await transport.DisposeAsync(); + await runTask; } [Fact] public async Task RequestSamplingAsync_Should_Throw_McpServerException_If_Client_Does_Not_Support_Sampling() { // Arrange - await using var server = McpServerFactory.Create(_serverTransport.Object, _options, LoggerFactory, _serviceProvider); + await using var transport = new TestServerTransport(); + await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); SetClientCapabilities(server, new ClientCapabilities()); var action = () => server.RequestSamplingAsync(new CreateMessageRequestParams { Messages = [] }, CancellationToken.None); @@ -118,7 +113,7 @@ public async Task RequestSamplingAsync_Should_SendRequest() { // Arrange await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory, _serviceProvider); + await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); SetClientCapabilities(server, new ClientCapabilities { Sampling = new SamplingCapability() }); var runTask = server.RunAsync(TestContext.Current.CancellationToken); @@ -139,7 +134,8 @@ public async Task RequestSamplingAsync_Should_SendRequest() public async Task RequestRootsAsync_Should_Throw_McpServerException_If_Client_Does_Not_Support_Roots() { // Arrange - await using var server = McpServerFactory.Create(_serverTransport.Object, _options, LoggerFactory, _serviceProvider); + await using var transport = new TestServerTransport(); + await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); SetClientCapabilities(server, new ClientCapabilities()); // Act & Assert @@ -151,7 +147,7 @@ public async Task RequestRootsAsync_Should_SendRequest() { // Arrange await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory, _serviceProvider); + await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); SetClientCapabilities(server, new ClientCapabilities { Roots = new RootsCapability() }); var runTask = server.RunAsync(TestContext.Current.CancellationToken); @@ -177,7 +173,8 @@ await Can_Handle_Requests( configureOptions: null, assertResult: response => { - Assert.IsType(response); + JsonObject jObj = Assert.IsType(response); + Assert.Empty(jObj); }); } @@ -190,9 +187,8 @@ await Can_Handle_Requests( configureOptions: null, assertResult: response => { - Assert.IsType(response); - - var result = (InitializeResult)response; + var result = JsonSerializer.Deserialize(response); + Assert.NotNull(result); Assert.Equal("TestServer", result.ServerInfo.Name); Assert.Equal("1.0", result.ServerInfo.Version); Assert.Equal("2024", result.ProtocolVersion); @@ -208,10 +204,8 @@ await Can_Handle_Requests( configureOptions: null, assertResult: response => { - Assert.IsType(response); - - var result = (CompleteResult)response; - Assert.NotNull(result.Completion); + var result = JsonSerializer.Deserialize(response); + Assert.NotNull(result?.Completion); Assert.Empty(result.Completion.Values); Assert.Equal(0, result.Completion.Total); Assert.False(result.Completion.HasMore); @@ -239,10 +233,8 @@ await Can_Handle_Requests( }, assertResult: response => { - Assert.IsType(response); - - var result = (CompleteResult)response; - Assert.NotNull(result.Completion); + CompleteResult? result = JsonSerializer.Deserialize(response); + Assert.NotNull(result?.Completion); Assert.NotEmpty(result.Completion.Values); Assert.Equal("test", result.Completion.Values[0]); Assert.Equal(2, result.Completion.Total); @@ -279,10 +271,8 @@ await Can_Handle_Requests( configureOptions: null, assertResult: response => { - Assert.IsType(response); - - var result = (ListResourceTemplatesResult)response; - Assert.NotNull(result.ResourceTemplates); + var result = JsonSerializer.Deserialize(response); + Assert.NotNull(result?.ResourceTemplates); Assert.NotEmpty(result.ResourceTemplates); Assert.Equal("test", result.ResourceTemplates[0].UriTemplate); }); @@ -310,10 +300,8 @@ await Can_Handle_Requests( configureOptions: null, assertResult: response => { - Assert.IsType(response); - - var result = (ListResourcesResult)response; - Assert.NotNull(result.Resources); + var result = JsonSerializer.Deserialize(response); + Assert.NotNull(result?.Resources); Assert.NotEmpty(result.Resources); Assert.Equal("test", result.Resources[0].Uri); }); @@ -337,7 +325,7 @@ await Can_Handle_Requests( { return Task.FromResult(new ReadResourceResult { - Contents = [new TextResourceContents() { Text = "test" }] + Contents = [new TextResourceContents { Text = "test" }] }); }, ListResourcesHandler = (request, ct) => throw new NotImplementedException(), @@ -347,10 +335,8 @@ await Can_Handle_Requests( configureOptions: null, assertResult: response => { - Assert.IsType(response); - - var result = (ReadResourceResult)response; - Assert.NotNull(result.Contents); + var result = JsonSerializer.Deserialize(response); + Assert.NotNull(result?.Contents); Assert.NotEmpty(result.Contents); TextResourceContents textResource = Assert.IsType(result.Contents[0]); @@ -386,10 +372,8 @@ await Can_Handle_Requests( configureOptions: null, assertResult: response => { - Assert.IsType(response); - - var result = (ListPromptsResult)response; - Assert.NotNull(result.Prompts); + var result = JsonSerializer.Deserialize(response); + Assert.NotNull(result?.Prompts); Assert.NotEmpty(result.Prompts); Assert.Equal("test", result.Prompts[0].Name); }); @@ -417,9 +401,8 @@ await Can_Handle_Requests( configureOptions: null, assertResult: response => { - Assert.IsType(response); - - var result = (GetPromptResult)response; + var result = JsonSerializer.Deserialize(response); + Assert.NotNull(result); Assert.Equal("test", result.Description); }); } @@ -452,9 +435,8 @@ await Can_Handle_Requests( configureOptions: null, assertResult: response => { - Assert.IsType(response); - - var result = (ListToolsResult)response; + var result = JsonSerializer.Deserialize(response); + Assert.NotNull(result); Assert.NotEmpty(result.Tools); Assert.Equal("test", result.Tools[0].Name); }); @@ -488,9 +470,8 @@ await Can_Handle_Requests( configureOptions: null, assertResult: response => { - Assert.IsType(response); - - var result = (CallToolResponse)response; + var result = JsonSerializer.Deserialize(response); + Assert.NotNull(result); Assert.NotEmpty(result.Content); Assert.Equal("test", result.Content[0].Text); }); @@ -502,13 +483,13 @@ public async Task Can_Handle_Call_Tool_Requests_Throws_Exception_If_No_Handler_A await Throws_Exception_If_No_Handler_Assigned(new ServerCapabilities { Tools = new() }, RequestMethods.ToolsCall, "CallTool handler not configured"); } - private async Task Can_Handle_Requests(ServerCapabilities? serverCapabilities, string method, Action? configureOptions, Action assertResult) + private async Task Can_Handle_Requests(ServerCapabilities? serverCapabilities, string method, Action? configureOptions, Action assertResult) { await using var transport = new TestServerTransport(); var options = CreateOptions(serverCapabilities); configureOptions?.Invoke(options); - await using var server = McpServerFactory.Create(transport, options, LoggerFactory, _serviceProvider); + await using var server = McpServerFactory.Create(transport, options, LoggerFactory); var runTask = server.RunAsync(TestContext.Current.CancellationToken); @@ -530,7 +511,6 @@ await transport.SendMessageAsync( var response = await receivedMessage.Task.WaitAsync(TimeSpan.FromSeconds(5)); Assert.NotNull(response); - Assert.NotNull(response.Result); assertResult(response.Result); @@ -543,7 +523,7 @@ private async Task Throws_Exception_If_No_Handler_Assigned(ServerCapabilities se await using var transport = new TestServerTransport(); var options = CreateOptions(serverCapabilities); - Assert.Throws(() => McpServerFactory.Create(transport, options, LoggerFactory, _serviceProvider)); + Assert.Throws(() => McpServerFactory.Create(transport, options, LoggerFactory)); } [Fact] @@ -554,7 +534,6 @@ public async Task AsSamplingChatClient_NoSamplingSupport_Throws() Assert.Throws("server", () => server.AsSamplingChatClient()); } - [Fact] public async Task AsSamplingChatClient_HandlesRequestResponse() { @@ -584,6 +563,26 @@ public async Task AsSamplingChatClient_HandlesRequestResponse() Assert.Equal(ChatRole.Assistant, response.Messages[0].Role); } + [Fact] + public async Task Can_SendMessage_Before_RunAsync() + { + await using var transport = new TestServerTransport(); + await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + + var logNotification = new JsonRpcNotification() + { + Method = NotificationMethods.LoggingMessageNotification + }; + await server.SendMessageAsync(logNotification, TestContext.Current.CancellationToken); + + var runTask = server.RunAsync(TestContext.Current.CancellationToken); + await transport.DisposeAsync(); + await runTask; + + Assert.NotEmpty(transport.SentMessages); + Assert.Same(logNotification, transport.SentMessages[0]); + } + private static void SetClientCapabilities(IMcpServer server, ClientCapabilities capabilities) { PropertyInfo? property = server.GetType().GetProperty("ClientCapabilities", BindingFlags.Public | BindingFlags.Instance); @@ -597,10 +596,11 @@ private sealed class TestServerForIChatClient(bool supportsSampling) : IMcpServe supportsSampling ? new ClientCapabilities { Sampling = new SamplingCapability() } : null; - public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken) where T : class + public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken) { - CreateMessageRequestParams rp = Assert.IsType(request.Params); + CreateMessageRequestParams? rp = JsonSerializer.Deserialize(request.Params); + Assert.NotNull(rp); Assert.Equal(0.75f, rp.Temperature); Assert.Equal(42, rp.MaxTokens); Assert.Equal(["."], rp.StopSequences); @@ -621,7 +621,12 @@ public Task SendRequestAsync(JsonRpcRequest request, CancellationToken can Role = "assistant", StopReason = "endTurn", }; - return Task.FromResult((T)(object)result); + + return Task.FromResult(new JsonRpcResponse + { + Id = new RequestId("0"), + Result = JsonSerializer.SerializeToNode(result), + }); } public ValueTask DisposeAsync() => default; @@ -636,4 +641,48 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella public Task RunAsync(CancellationToken cancellationToken = default) => throw new NotImplementedException(); } + + [Fact] + public async Task NotifyProgress_Should_Be_Handled() + { + await using TestServerTransport transport = new(); + var options = CreateOptions(); + + var notificationReceived = new TaskCompletionSource(); + + var server = McpServerFactory.Create(transport, options, LoggerFactory); + server.AddNotificationHandler(NotificationMethods.ProgressNotification, notification => + { + notificationReceived.SetResult(notification); + return Task.CompletedTask; + }); + + Task serverTask = server.RunAsync(TestContext.Current.CancellationToken); + + await transport.SendMessageAsync(new JsonRpcNotification + { + Method = NotificationMethods.ProgressNotification, + Params = JsonSerializer.SerializeToNode(new ProgressNotification + { + ProgressToken = new("abc"), + Progress = new() + { + Progress = 50, + Total = 100, + Message = "Progress message", + }, + }), + }, TestContext.Current.CancellationToken); + + var notification = await notificationReceived.Task; + var progress = JsonSerializer.Deserialize(notification.Params); + Assert.NotNull(progress); + Assert.Equal("abc", progress.ProgressToken.ToString()); + Assert.Equal(50, progress.Progress.Progress); + Assert.Equal(100, progress.Progress.Total); + Assert.Equal("Progress message", progress.Progress.Message); + + await server.DisposeAsync(); + await serverTask; + } } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs index 2d747d9dc..30d13fd9f 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs @@ -45,20 +45,47 @@ public async Task SupportsIMcpServer() Assert.Equal("42", result.Content[0].Text); } - [Fact] - public async Task SupportsServiceFromDI() + [Theory] + [InlineData(ServiceLifetime.Singleton)] + [InlineData(ServiceLifetime.Scoped)] + [InlineData(ServiceLifetime.Transient)] + public async Task SupportsServiceFromDI(ServiceLifetime injectedArgumentLifetime) { - MyService expectedMyService = new(); + MyService singletonService = new(); ServiceCollection sc = new(); - sc.AddSingleton(expectedMyService); - IServiceProvider services = sc.BuildServiceProvider(); + switch (injectedArgumentLifetime) + { + case ServiceLifetime.Singleton: + sc.AddSingleton(singletonService); + break; + + case ServiceLifetime.Scoped: + sc.AddScoped(_ => new MyService()); + break; + + case ServiceLifetime.Transient: + sc.AddTransient(_ => new MyService()); + break; + } - McpServerTool tool = McpServerTool.Create((MyService actualMyService) => + sc.AddSingleton(services => { - Assert.Same(expectedMyService, actualMyService); - return "42"; - }, new() { Services = services }); + return McpServerTool.Create((MyService actualMyService) => + { + Assert.NotNull(actualMyService); + if (injectedArgumentLifetime == ServiceLifetime.Singleton) + { + Assert.Same(singletonService, actualMyService); + } + + return "42"; + }, new() { Services = services }); + }); + + IServiceProvider services = sc.BuildServiceProvider(); + + McpServerTool tool = services.GetRequiredService(); Assert.DoesNotContain("actualMyService", JsonSerializer.Serialize(tool.ProtocolTool.InputSchema)); diff --git a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs index 1874953d9..0622e656e 100644 --- a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs @@ -3,17 +3,30 @@ using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Tests.Utils; -using System.Text.Json; namespace ModelContextProtocol.Tests; public class SseIntegrationTests(ITestOutputHelper outputHelper) : LoggedTest(outputHelper) { + /// Port number to be grabbed by the next test. + private static int s_nextPort = 3000; + + // If the tests run concurrently against different versions of the runtime, tests can conflict with + // each other in the ports set up for interacting with containers. Ensure that such suites running + // against different TFMs use different port numbers. + private static readonly int s_portOffset = 1000 * (Environment.Version.Major switch + { + int v when v >= 8 => Environment.Version.Major - 7, + _ => 0, + }); + + private static int CreatePortNumber() => Interlocked.Increment(ref s_nextPort) + s_portOffset; + [Fact] public async Task ConnectAndReceiveMessage_InMemoryServer() { // Arrange - await using InMemoryTestSseServer server = new(logger: LoggerFactory.CreateLogger()); + await using InMemoryTestSseServer server = new(CreatePortNumber(), LoggerFactory.CreateLogger()); await server.StartAsync(); var defaultOptions = new McpClientOptions @@ -27,7 +40,7 @@ public async Task ConnectAndReceiveMessage_InMemoryServer() Name = "In-memory Test Server", TransportType = TransportTypes.Sse, TransportOptions = [], - Location = "http://localhost:5000/sse" + Location = $"http://localhost:{server.Port}/sse" }; // Act @@ -41,7 +54,7 @@ public async Task ConnectAndReceiveMessage_InMemoryServer() await server.WaitForConnectionAsync(TimeSpan.FromSeconds(10)); // Send a test message through POST endpoint - await client.SendNotificationAsync("test/message", new { message = "Hello, SSE!" }, TestContext.Current.CancellationToken); + await client.SendNotificationAsync("test/message", new { message = "Hello, SSE!" }, cancellationToken: TestContext.Current.CancellationToken); // Assert Assert.True(true); @@ -53,7 +66,7 @@ public async Task ConnectAndReceiveMessage_EverythingServerWithSse() { Assert.SkipWhen(!EverythingSseServerFixture.IsDockerAvailable, "docker is not available"); - int port = 3001; + int port = CreatePortNumber(); await using var fixture = new EverythingSseServerFixture(port); await fixture.StartAsync(); @@ -78,7 +91,7 @@ public async Task ConnectAndReceiveMessage_EverythingServerWithSse() defaultOptions, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); - var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); // assert Assert.NotEmpty(tools); @@ -90,7 +103,7 @@ public async Task Sampling_Sse_EverythingServer() { Assert.SkipWhen(!EverythingSseServerFixture.IsDockerAvailable, "docker is not available"); - int port = 3002; + int port = CreatePortNumber(); await using var fixture = new EverythingSseServerFixture(port); await fixture.StartAsync(); @@ -116,7 +129,7 @@ public async Task Sampling_Sse_EverythingServer() { Sampling = new() { - SamplingHandler = (_, _) => + SamplingHandler = (_, _, _) => { samplingHandlerCalls++; return Task.FromResult(new CreateMessageResult @@ -145,7 +158,7 @@ public async Task Sampling_Sse_EverythingServer() { ["prompt"] = "Test prompt", ["maxTokens"] = 100 - }, TestContext.Current.CancellationToken); + }, cancellationToken: TestContext.Current.CancellationToken); // assert Assert.NotNull(result); @@ -158,11 +171,10 @@ public async Task Sampling_Sse_EverythingServer() public async Task ConnectAndReceiveMessage_InMemoryServer_WithFullEndpointEventUri() { // Arrange - await using InMemoryTestSseServer server = new(logger: LoggerFactory.CreateLogger()); + await using InMemoryTestSseServer server = new(CreatePortNumber(), LoggerFactory.CreateLogger()); server.UseFullUrlForEndpointEvent = true; await server.StartAsync(); - var defaultOptions = new McpClientOptions { ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" } @@ -174,7 +186,7 @@ public async Task ConnectAndReceiveMessage_InMemoryServer_WithFullEndpointEventU Name = "In-memory Test Server", TransportType = TransportTypes.Sse, TransportOptions = [], - Location = "http://localhost:5000/sse" + Location = $"http://localhost:{server.Port}/sse" }; // Act @@ -188,7 +200,7 @@ public async Task ConnectAndReceiveMessage_InMemoryServer_WithFullEndpointEventU await server.WaitForConnectionAsync(TimeSpan.FromSeconds(10)); // Send a test message through POST endpoint - await client.SendNotificationAsync("test/message", new { message = "Hello, SSE!" }, TestContext.Current.CancellationToken); + await client.SendNotificationAsync("test/message", new { message = "Hello, SSE!" }, cancellationToken: TestContext.Current.CancellationToken); // Assert Assert.True(true); @@ -198,7 +210,7 @@ public async Task ConnectAndReceiveMessage_InMemoryServer_WithFullEndpointEventU public async Task ConnectAndReceiveNotification_InMemoryServer() { // Arrange - await using InMemoryTestSseServer server = new(logger: LoggerFactory.CreateLogger()); + await using InMemoryTestSseServer server = new(CreatePortNumber(), LoggerFactory.CreateLogger()); await server.StartAsync(); @@ -213,7 +225,7 @@ public async Task ConnectAndReceiveNotification_InMemoryServer() Name = "In-memory Test Server", TransportType = TransportTypes.Sse, TransportOptions = [], - Location = "http://localhost:5000/sse" + Location = $"http://localhost:{server.Port}/sse" }; // Act @@ -229,7 +241,7 @@ public async Task ConnectAndReceiveNotification_InMemoryServer() var receivedNotification = new TaskCompletionSource(); client.AddNotificationHandler("test/notification", (args) => { - var msg = ((JsonElement?)args.Params)?.GetProperty("message").GetString(); + var msg = args.Params?["message"]?.GetValue(); receivedNotification.SetResult(msg); return Task.CompletedTask; diff --git a/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs b/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs index 75aefc086..238c7747c 100644 --- a/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs @@ -1,5 +1,4 @@ -using Microsoft.Extensions.Logging; -using ModelContextProtocol.Client; +using ModelContextProtocol.Client; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Test.Utils; using ModelContextProtocol.Tests.Utils; @@ -18,16 +17,19 @@ public class SseServerIntegrationTestFixture : IAsyncDisposable public SseServerIntegrationTestFixture() { + // Ensure that test suites running against different TFMs and possibly concurrently use different port numbers. + int port = 3001 + Environment.Version.Major; + DefaultConfig = new McpServerConfig { Id = "test_server", Name = "TestServer", TransportType = TransportTypes.Sse, TransportOptions = [], - Location = "http://localhost:3001/sse" + Location = $"http://localhost:{port}/sse" }; - _serverTask = Program.MainAsync([], new XunitLoggerProvider(_delegatingTestOutputHelper), _stopCts.Token); + _serverTask = Program.MainAsync([port.ToString()], new XunitLoggerProvider(_delegatingTestOutputHelper), _stopCts.Token); } public static McpClientOptions CreateDefaultClientOptions() => new() diff --git a/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs index 44befcd10..b73a9c06e 100644 --- a/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseServerIntegrationTests.cs @@ -1,6 +1,8 @@ using ModelContextProtocol.Client; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Tests.Utils; +using System.Net; +using System.Text; namespace ModelContextProtocol.Tests; @@ -30,6 +32,12 @@ private Task GetClientAsync(McpClientOptions? options = null) cancellationToken: TestContext.Current.CancellationToken); } + private HttpClient GetHttpClient() => + new() + { + BaseAddress = new(_fixture.DefaultConfig.Location!), + }; + [Fact] public async Task ConnectAndPing_Sse_TestServer() { @@ -63,7 +71,7 @@ public async Task ListTools_Sse_TestServer() // act await using var client = await GetClientAsync(); - var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); // assert Assert.NotNull(tools); @@ -82,7 +90,7 @@ public async Task CallTool_Sse_EchoServer() { ["message"] = "Hello MCP!" }, - TestContext.Current.CancellationToken + cancellationToken: TestContext.Current.CancellationToken ); // assert @@ -168,7 +176,7 @@ public async Task GetPrompt_Sse_SimplePrompt() // act await using var client = await GetClientAsync(); - var result = await client.GetPromptAsync("simple_prompt", null, TestContext.Current.CancellationToken); + var result = await client.GetPromptAsync("simple_prompt", null, cancellationToken: TestContext.Current.CancellationToken); // assert Assert.NotNull(result); @@ -187,7 +195,7 @@ public async Task GetPrompt_Sse_ComplexPrompt() { "temperature", "0.7" }, { "style", "formal" } }; - var result = await client.GetPromptAsync("complex_prompt", arguments, TestContext.Current.CancellationToken); + var result = await client.GetPromptAsync("complex_prompt", arguments, cancellationToken: TestContext.Current.CancellationToken); // assert Assert.NotNull(result); @@ -202,7 +210,7 @@ public async Task GetPrompt_Sse_NonExistent_ThrowsException() // act await using var client = await GetClientAsync(); await Assert.ThrowsAsync(() => - client.GetPromptAsync("non_existent_prompt", null, TestContext.Current.CancellationToken)); + client.GetPromptAsync("non_existent_prompt", null, cancellationToken: TestContext.Current.CancellationToken)); } [Fact] @@ -215,7 +223,7 @@ public async Task Sampling_Sse_TestServer() var options = SseServerIntegrationTestFixture.CreateDefaultClientOptions(); options.Capabilities ??= new(); options.Capabilities.Sampling ??= new(); - options.Capabilities.Sampling.SamplingHandler = async (_, _) => + options.Capabilities.Sampling.SamplingHandler = async (_, _, _) => { samplingHandlerCalls++; return new CreateMessageResult @@ -238,7 +246,7 @@ public async Task Sampling_Sse_TestServer() ["prompt"] = "Test prompt", ["maxTokens"] = 100 }, - TestContext.Current.CancellationToken); + cancellationToken: TestContext.Current.CancellationToken); // assert Assert.NotNull(result); @@ -262,7 +270,7 @@ public async Task CallTool_Sse_EchoServer_Concurrently() { ["message"] = $"Hello MCP! {i}" }, - TestContext.Current.CancellationToken + cancellationToken: TestContext.Current.CancellationToken ); Assert.NotNull(result); @@ -271,4 +279,44 @@ public async Task CallTool_Sse_EchoServer_Concurrently() Assert.Equal($"Echo: Hello MCP! {i}", textContent.Text); } } + + [Fact] + public async Task EventSourceStream_Includes_MessageEventType() + { + // Simulate our own MCP client handshake using a plain HttpClient so we can look for "event: message" + // in the raw SSE response stream which is not exposed by the real MCP client. + using var httpClient = GetHttpClient(); + await using var sseResponse = await httpClient.GetStreamAsync("", TestContext.Current.CancellationToken); + using var streamReader = new StreamReader(sseResponse); + + var endpointEvent = await streamReader.ReadLineAsync(TestContext.Current.CancellationToken); + Assert.Equal("event: endpoint", endpointEvent); + + var endpointData = await streamReader.ReadLineAsync(TestContext.Current.CancellationToken); + Assert.NotNull(endpointData); + Assert.StartsWith("data: ", endpointData); + var messageEndpoint = endpointData["data: ".Length..]; + + const string initializeRequest = """ + {"jsonrpc":"2.0","id":"1","method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"IntegrationTestClient","version":"1.0.0"}}} + """; + using (var initializeRequestBody = new StringContent(initializeRequest, Encoding.UTF8, "application/json")) + { + var response = await httpClient.PostAsync(messageEndpoint, initializeRequestBody, TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.Accepted, response.StatusCode); + } + + const string initializedNotification = """ + {"jsonrpc":"2.0","method":"notifications/initialized"} + """; + using (var initializedNotificationBody = new StringContent(initializedNotification, Encoding.UTF8, "application/json")) + { + var response = await httpClient.PostAsync(messageEndpoint, initializedNotificationBody, TestContext.Current.CancellationToken); + Assert.Equal(HttpStatusCode.Accepted, response.StatusCode); + } + + Assert.Equal("", await streamReader.ReadLineAsync(TestContext.Current.CancellationToken)); + var messageEvent = await streamReader.ReadLineAsync(TestContext.Current.CancellationToken); + Assert.Equal("event: message", messageEvent); + } } diff --git a/tests/ModelContextProtocol.Tests/TestAttributes.cs b/tests/ModelContextProtocol.Tests/TestAttributes.cs index 4edbce6ec..8a0140db8 100644 --- a/tests/ModelContextProtocol.Tests/TestAttributes.cs +++ b/tests/ModelContextProtocol.Tests/TestAttributes.cs @@ -1,2 +1,9 @@ -// Uncomment to disable parallel test execution -//[assembly: CollectionBehavior(DisableTestParallelization = true)] \ No newline at end of file +// Uncomment to disable parallel test execution for the whole assembly +//[assembly: CollectionBehavior(DisableTestParallelization = true)] + +/// +/// Enables test classes to individually be attributed as [Collection(nameof(DisableParallelization))] +/// to have those tests run non-concurrently with any other tests. +/// +[CollectionDefinition(nameof(DisableParallelization), DisableParallelization = true)] +public sealed class DisableParallelization; \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs index 23061cd9c..8fe8e91c1 100644 --- a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs @@ -213,7 +213,7 @@ public async Task ReceiveMessagesAsync_Handles_Messages() Assert.True(session.MessageReader.TryRead(out var message)); Assert.NotNull(message); Assert.IsType(message); - Assert.Equal("\"44\"", ((JsonRpcRequest)message).Id.ToString()); + Assert.Equal("44", ((JsonRpcRequest)message).Id.ToString()); } [Fact] diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs index 5857f3c4d..33793555d 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs @@ -8,6 +8,7 @@ using System.IO.Pipelines; using System.Text; using System.Text.Json; +using System.Text.Json.Nodes; namespace ModelContextProtocol.Tests.Transport; @@ -55,7 +56,7 @@ public void Constructor_Throws_For_Null_Options() [Fact] public async Task Should_Start_In_Connected_State() { - await using var transport = new StdioServerTransport(_serverOptions.ServerInfo.Name, new Pipe().Reader.AsStream(), Stream.Null, LoggerFactory); + await using var transport = new StreamServerTransport(new Pipe().Reader.AsStream(), Stream.Null, loggerFactory: LoggerFactory); Assert.True(transport.IsConnected); } @@ -65,11 +66,10 @@ public async Task SendMessageAsync_Should_Send_Message() { using var output = new MemoryStream(); - await using var transport = new StdioServerTransport( - _serverOptions.ServerInfo.Name, + await using var transport = new StreamServerTransport( new Pipe().Reader.AsStream(), output, - LoggerFactory); + loggerFactory: LoggerFactory); // Verify transport is connected Assert.True(transport.IsConnected, "Transport should be connected after StartListeningAsync"); @@ -87,7 +87,7 @@ public async Task SendMessageAsync_Should_Send_Message() [Fact] public async Task DisposeAsync_Should_Dispose_Resources() { - await using var transport = new StdioServerTransport(_serverOptions.ServerInfo.Name, Stream.Null, Stream.Null, LoggerFactory); + await using var transport = new StreamServerTransport(Stream.Null, Stream.Null, loggerFactory: LoggerFactory); await transport.DisposeAsync(); @@ -104,11 +104,10 @@ public async Task ReadMessagesAsync_Should_Read_Messages() Pipe pipe = new(); using var input = pipe.Reader.AsStream(); - await using var transport = new StdioServerTransport( - _serverOptions.ServerInfo.Name, + await using var transport = new StreamServerTransport( input, Stream.Null, - LoggerFactory); + loggerFactory: LoggerFactory); // Verify transport is connected Assert.True(transport.IsConnected, "Transport should be connected after StartListeningAsync"); @@ -128,7 +127,7 @@ public async Task ReadMessagesAsync_Should_Read_Messages() [Fact] public async Task CleanupAsync_Should_Cleanup_Resources() { - var transport = new StdioServerTransport(_serverOptions.ServerInfo.Name, Stream.Null, Stream.Null, LoggerFactory); + var transport = new StreamServerTransport(Stream.Null, Stream.Null, loggerFactory: LoggerFactory); await transport.DisposeAsync(); @@ -141,11 +140,10 @@ public async Task SendMessageAsync_Should_Preserve_Unicode_Characters() // Use a reader that won't terminate using var output = new MemoryStream(); - await using var transport = new StdioServerTransport( - _serverOptions.ServerInfo.Name, + await using var transport = new StreamServerTransport( new Pipe().Reader.AsStream(), output, - LoggerFactory); + loggerFactory: LoggerFactory); // Verify transport is connected Assert.True(transport.IsConnected, "Transport should be connected after StartListeningAsync"); @@ -156,10 +154,10 @@ public async Task SendMessageAsync_Should_Preserve_Unicode_Characters() { Method = "test", Id = new RequestId(44), - Params = new Dictionary + Params = new JsonObject { - ["text"] = JsonSerializer.SerializeToElement(chineseText) - } + ["text"] = chineseText + }, }; // Clear output and send message @@ -178,10 +176,10 @@ public async Task SendMessageAsync_Should_Preserve_Unicode_Characters() { Method = "test", Id = new RequestId(45), - Params = new Dictionary + Params = new JsonObject { - ["text"] = JsonSerializer.SerializeToElement(emojiText) - } + ["text"] = emojiText + }, }; // Clear output and send message diff --git a/tests/ModelContextProtocol.Tests/Transport/StreamClientTransport.cs b/tests/ModelContextProtocol.Tests/Transport/StreamClientTransport.cs deleted file mode 100644 index d41f0b979..000000000 --- a/tests/ModelContextProtocol.Tests/Transport/StreamClientTransport.cs +++ /dev/null @@ -1,79 +0,0 @@ -using Microsoft.Extensions.Logging; -using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Protocol.Transport; -using ModelContextProtocol.Utils.Json; -using System.Text.Json; - -namespace ModelContextProtocol.Tests.Transport; - -internal sealed class StreamClientTransport : TransportBase, IClientTransport -{ - private readonly JsonSerializerOptions _jsonOptions = McpJsonUtilities.DefaultOptions; - private readonly Task? _readTask; - private readonly CancellationTokenSource _shutdownCts = new CancellationTokenSource(); - private readonly TextReader _serverStdoutReader; - private readonly TextWriter _serverStdinWriter; - - public StreamClientTransport(TextWriter serverStdinWriter, TextReader serverStdoutReader, ILoggerFactory loggerFactory) - : base(loggerFactory) - { - _serverStdoutReader = serverStdoutReader; - _serverStdinWriter = serverStdinWriter; - _readTask = Task.Run(() => ReadMessagesAsync(_shutdownCts.Token), CancellationToken.None); - SetConnected(true); - } - - public Task ConnectAsync(CancellationToken cancellationToken = default) => Task.FromResult(this); - - public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) - { - string id = message is IJsonRpcMessageWithId messageWithId ? - messageWithId.Id.ToString() : - "(no id)"; - - await _serverStdinWriter.WriteLineAsync(JsonSerializer.Serialize(message)).ConfigureAwait(false); - await _serverStdinWriter.FlushAsync(cancellationToken).ConfigureAwait(false); - } - - private async Task ReadMessagesAsync(CancellationToken cancellationToken) - { - try - { - while (await _serverStdoutReader.ReadLineAsync(cancellationToken).ConfigureAwait(false) is string line) - { - if (!string.IsNullOrWhiteSpace(line)) - { - try - { - if (JsonSerializer.Deserialize(line.Trim(), _jsonOptions) is { } message) - { - await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); - } - } - catch (JsonException) - { - } - } - } - } - catch (OperationCanceledException) - { - } - } - - public override async ValueTask DisposeAsync() - { - if (_shutdownCts is { } shutdownCts) - { - await shutdownCts.CancelAsync().ConfigureAwait(false); - shutdownCts.Dispose(); - } - - if (_readTask is Task readTask) - { - await readTask.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(false); - } - - SetConnected(false); - } -} diff --git a/tests/ModelContextProtocol.Tests/Utils/InMemoryTestSseServer.cs b/tests/ModelContextProtocol.Tests/Utils/InMemoryTestSseServer.cs index 0bdfde192..7d7122a8e 100644 --- a/tests/ModelContextProtocol.Tests/Utils/InMemoryTestSseServer.cs +++ b/tests/ModelContextProtocol.Tests/Utils/InMemoryTestSseServer.cs @@ -26,6 +26,8 @@ public sealed class InMemoryTestSseServer : IAsyncDisposable public InMemoryTestSseServer(int port = 5000, ILogger? logger = null) { + Port = port; + _listener = new HttpListener(); _listener.Prefixes.Add($"http://localhost:{port}/"); _cts = new CancellationTokenSource(); @@ -35,6 +37,8 @@ public InMemoryTestSseServer(int port = 5000, ILogger? lo _messagePath = "/message"; } + public int Port { get; } + /// /// This is to be able to use the full URL for the endpoint event. /// @@ -289,7 +293,7 @@ private static async Task HandleInitializationRequest(HttpListenerResponse respo var jsonRpcResponse = new JsonRpcResponse() { Id = jsonRpcRequest.Id, - Result = new InitializeResult() + Result = JsonSerializer.SerializeToNode(new InitializeResult { ProtocolVersion = "2024-11-05", Capabilities = new(), @@ -298,7 +302,7 @@ private static async Task HandleInitializationRequest(HttpListenerResponse respo Name = "ExampleServer", Version = "1.0.0" } - } + }) }; // Echo back to the HTTP response @@ -369,7 +373,7 @@ public async Task SendTestNotificationAsync(string content) { JsonRpc = "2.0", Method = "test/notification", - Params = new { message = content } + Params = JsonSerializer.SerializeToNode(new { message = content }), }; var serialized = JsonSerializer.Serialize(notification); diff --git a/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs b/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs index 33a133616..f21660143 100644 --- a/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs +++ b/tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs @@ -1,7 +1,8 @@ -using System.Threading.Channels; -using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; +using System.Text.Json; +using System.Threading.Channels; namespace ModelContextProtocol.Tests.Utils; @@ -59,10 +60,10 @@ private async Task ListRoots(JsonRpcRequest request, CancellationToken cancellat await WriteMessageAsync(new JsonRpcResponse { Id = request.Id, - Result = new ModelContextProtocol.Protocol.Types.ListRootsResult + Result = JsonSerializer.SerializeToNode(new ListRootsResult { Roots = [] - } + }), }, cancellationToken); } @@ -71,7 +72,7 @@ private async Task Sampling(JsonRpcRequest request, CancellationToken cancellati await WriteMessageAsync(new JsonRpcResponse { Id = request.Id, - Result = new CreateMessageResult { Content = new(), Model = "model", Role = "role" } + Result = JsonSerializer.SerializeToNode(new CreateMessageResult { Content = new(), Model = "model", Role = "role" }), }, cancellationToken); }