Skip to content

Fix pagination handling in McpServer #177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Shared;
using ModelContextProtocol.Utils;
using ModelContextProtocol.Utils.Json;
using System.Diagnostics.CodeAnalysis;
Expand Down
61 changes: 15 additions & 46 deletions src/ModelContextProtocol/Server/McpServer.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using Microsoft.Extensions.Logging;
using ModelContextProtocol.Logging;
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Protocol.Types;
Expand Down Expand Up @@ -214,41 +213,26 @@ private void SetPromptsHandler(McpServerOptions options)
throw new McpServerException("ListPrompts and GetPrompt handlers should be specified together.");
}

// Handle tools provided via DI.
// Handle prompts provided via DI.
if (prompts is { IsEmpty: false })
{
// Synthesize the handlers, making sure a PromptsCapability is specified.
var originalListPromptsHandler = listPromptsHandler;
var originalGetPromptHandler = getPromptHandler;

// Synthesize the handlers, making sure a ToolsCapability is specified.
listPromptsHandler = async (request, cancellationToken) =>
{
ListPromptsResult result = new();
foreach (McpServerPrompt prompt in prompts)
{
result.Prompts.Add(prompt.ProtocolPrompt);
}
ListPromptsResult result = originalListPromptsHandler is not null ?
await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(false) :
new();

if (originalListPromptsHandler is not null)
if (request.Params?.Cursor is null)
{
string? nextCursor = null;
do
{
ListPromptsResult extraResults = await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(false);
result.Prompts.AddRange(extraResults.Prompts);

nextCursor = extraResults.NextCursor;
if (nextCursor is not null)
{
request = request with { Params = new() { Cursor = nextCursor } };
}
}
while (nextCursor is not null);
result.Prompts.AddRange(prompts.Select(t => t.ProtocolPrompt));
}

return result;
};

var originalGetPromptHandler = getPromptHandler;
getPromptHandler = (request, cancellationToken) =>
{
if (request.Params is null ||
Expand Down Expand Up @@ -316,38 +300,23 @@ private void SetToolsHandler(McpServerOptions options)
// Handle tools provided via DI.
if (tools is { IsEmpty: false })
{
var originalListToolsHandler = listToolsHandler;
var originalCallToolHandler = callToolHandler;

// Synthesize the handlers, making sure a ToolsCapability is specified.
var originalListToolsHandler = listToolsHandler;
listToolsHandler = async (request, cancellationToken) =>
{
ListToolsResult result = new();
foreach (McpServerTool tool in tools)
{
result.Tools.Add(tool.ProtocolTool);
}
ListToolsResult result = originalListToolsHandler is not null ?
await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) :
new();

if (originalListToolsHandler is not null)
if (request.Params?.Cursor is null)
{
string? nextCursor = null;
do
{
ListToolsResult extraResults = await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false);
result.Tools.AddRange(extraResults.Tools);

nextCursor = extraResults.NextCursor;
if (nextCursor is not null)
{
request = request with { Params = new() { Cursor = nextCursor } };
}
}
while (nextCursor is not null);
result.Tools.AddRange(tools.Select(t => t.ProtocolTool));
}

return result;
};

var originalCallToolHandler = callToolHandler;
callToolHandler = (request, cancellationToken) =>
{
if (request.Params is null ||
Expand Down
1 change: 0 additions & 1 deletion src/ModelContextProtocol/TokenProgress.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Server;
using ModelContextProtocol.Shared;

namespace ModelContextProtocol;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
using ModelContextProtocol.Client;
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Server;
using ModelContextProtocol.Tests.Transport;
using ModelContextProtocol.Tests.Utils;
using System.IO.Pipelines;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Server;
using ModelContextProtocol.Tests.Transport;
using ModelContextProtocol.Tests.Utils;
using System.ComponentModel;
using System.IO.Pipelines;
using System.Threading.Channels;

#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously

namespace ModelContextProtocol.Tests.Configuration;

public class McpServerBuilderExtensionsPromptsTests : LoggedTest, IAsyncDisposable
Expand All @@ -28,7 +29,70 @@ public McpServerBuilderExtensionsPromptsTests(ITestOutputHelper testOutputHelper
{
ServiceCollection sc = new();
sc.AddSingleton(LoggerFactory);
_builder = sc.AddMcpServer().WithStdioServerTransport().WithPrompts<SimplePrompts>();
_builder = sc
.AddMcpServer()
.WithStdioServerTransport()
.WithListPromptsHandler(async (request, cancellationToken) =>
{
var cursor = request.Params?.Cursor;
switch (cursor)
{
case null:
return new()
{
NextCursor = "abc",
Prompts = [new()
{
Name = "FirstCustomPrompt",
Description = "First prompt returned by custom handler",
}],
};

case "abc":
return new()
{
NextCursor = "def",
Prompts = [new()
{
Name = "SecondCustomPrompt",
Description = "Second prompt returned by custom handler",
}],
};

case "def":
return new()
{
NextCursor = null,
Prompts = [new()
{
Name = "FinalCustomPrompt",
Description = "Final prompt returned by custom handler",
}],
};

default:
throw new Exception("Unexpected cursor");
}
})
.WithGetPromptHandler(async (request, cancellationToken) =>
{
switch (request.Params?.Name)
{
case "FirstCustomPrompt":
case "SecondCustomPrompt":
case "FinalCustomPrompt":
return new GetPromptResult()
{
Messages = [new() { Role = Role.User, Content = new() { Text = $"hello from {request.Params.Name}", Type = "text" } }],
};

default:
throw new Exception($"Unknown prompt '{request.Params?.Name}'");
}
})
.WithPrompts<SimplePrompts>();


// Call WithStdioServerTransport to get the IMcpServer registration, then overwrite default transport with a pipe transport.
sc.AddSingleton<ITransport>(new StreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream(), loggerFactory: LoggerFactory));
sc.AddSingleton(new ObjectWithId());
Expand Down Expand Up @@ -85,7 +149,7 @@ public async Task Can_List_And_Call_Registered_Prompts()
IMcpClient client = await CreateMcpClientForServer();

var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken);
Assert.Equal(3, prompts.Count);
Assert.Equal(6, prompts.Count);

var prompt = prompts.First(t => t.Name == nameof(SimplePrompts.ReturnsChatMessages));
Assert.Equal("Returns chat messages", prompt.Description);
Expand All @@ -98,6 +162,14 @@ public async Task Can_List_And_Call_Registered_Prompts()
Assert.Equal(2, chatMessages.Count);
Assert.Equal("The prompt is: hello", chatMessages[0].Text);
Assert.Equal("Summarize.", chatMessages[1].Text);

prompt = prompts.First(t => t.Name == "SecondCustomPrompt");
Assert.Equal("Second prompt returned by custom handler", prompt.Description);
result = await prompt.GetAsync(cancellationToken: TestContext.Current.CancellationToken);
chatMessages = result.ToChatMessages();
Assert.NotNull(chatMessages);
Assert.Single(chatMessages);
Assert.Equal("hello from SecondCustomPrompt", chatMessages[0].Text);
}

[Fact]
Expand All @@ -106,7 +178,7 @@ public async Task Can_Be_Notified_Of_Prompt_Changes()
IMcpClient client = await CreateMcpClientForServer();

var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken);
Assert.Equal(3, prompts.Count);
Assert.Equal(6, prompts.Count);

Channel<JsonRpcNotification> listChanged = Channel.CreateUnbounded<JsonRpcNotification>();
client.AddNotificationHandler("notifications/prompts/list_changed", notification =>
Expand All @@ -127,7 +199,7 @@ public async Task Can_Be_Notified_Of_Prompt_Changes()
await notificationRead;

prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken);
Assert.Equal(4, prompts.Count);
Assert.Equal(7, prompts.Count);
Assert.Contains(prompts, t => t.Name == "NewPrompt");

notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken);
Expand All @@ -136,7 +208,7 @@ public async Task Can_Be_Notified_Of_Prompt_Changes()
await notificationRead;

prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken);
Assert.Equal(3, prompts.Count);
Assert.Equal(6, prompts.Count);
Assert.DoesNotContain(prompts, t => t.Name == "NewPrompt");
}

Expand Down
Loading