Skip to content

Commit b12d728

Browse files
authored
Fix pagination handling in McpServer (#177)
- We were adding tools/prompts from the collections on every request. If multiple requests came in with different cursors, we'd re-add the same tools each time. - We were defeating the purpose of pagination by doing all of the aggregation in the server. If a custom handler returns a paginated result, we should instead propagate that back to the client, who can choose to get more results when needed.
1 parent 8dc1f5d commit b12d728

File tree

8 files changed

+193
-66
lines changed

8 files changed

+193
-66
lines changed

src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
using Microsoft.Extensions.AI;
22
using Microsoft.Extensions.DependencyInjection;
33
using ModelContextProtocol.Protocol.Types;
4-
using ModelContextProtocol.Shared;
54
using ModelContextProtocol.Utils;
65
using ModelContextProtocol.Utils.Json;
76
using System.Diagnostics.CodeAnalysis;

src/ModelContextProtocol/Server/McpServer.cs

Lines changed: 15 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using Microsoft.Extensions.Logging;
2-
using ModelContextProtocol.Logging;
32
using ModelContextProtocol.Protocol.Messages;
43
using ModelContextProtocol.Protocol.Transport;
54
using ModelContextProtocol.Protocol.Types;
@@ -214,41 +213,26 @@ private void SetPromptsHandler(McpServerOptions options)
214213
throw new McpServerException("ListPrompts and GetPrompt handlers should be specified together.");
215214
}
216215

217-
// Handle tools provided via DI.
216+
// Handle prompts provided via DI.
218217
if (prompts is { IsEmpty: false })
219218
{
219+
// Synthesize the handlers, making sure a PromptsCapability is specified.
220220
var originalListPromptsHandler = listPromptsHandler;
221-
var originalGetPromptHandler = getPromptHandler;
222-
223-
// Synthesize the handlers, making sure a ToolsCapability is specified.
224221
listPromptsHandler = async (request, cancellationToken) =>
225222
{
226-
ListPromptsResult result = new();
227-
foreach (McpServerPrompt prompt in prompts)
228-
{
229-
result.Prompts.Add(prompt.ProtocolPrompt);
230-
}
223+
ListPromptsResult result = originalListPromptsHandler is not null ?
224+
await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(false) :
225+
new();
231226

232-
if (originalListPromptsHandler is not null)
227+
if (request.Params?.Cursor is null)
233228
{
234-
string? nextCursor = null;
235-
do
236-
{
237-
ListPromptsResult extraResults = await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(false);
238-
result.Prompts.AddRange(extraResults.Prompts);
239-
240-
nextCursor = extraResults.NextCursor;
241-
if (nextCursor is not null)
242-
{
243-
request = request with { Params = new() { Cursor = nextCursor } };
244-
}
245-
}
246-
while (nextCursor is not null);
229+
result.Prompts.AddRange(prompts.Select(t => t.ProtocolPrompt));
247230
}
248231

249232
return result;
250233
};
251234

235+
var originalGetPromptHandler = getPromptHandler;
252236
getPromptHandler = (request, cancellationToken) =>
253237
{
254238
if (request.Params is null ||
@@ -316,38 +300,23 @@ private void SetToolsHandler(McpServerOptions options)
316300
// Handle tools provided via DI.
317301
if (tools is { IsEmpty: false })
318302
{
319-
var originalListToolsHandler = listToolsHandler;
320-
var originalCallToolHandler = callToolHandler;
321-
322303
// Synthesize the handlers, making sure a ToolsCapability is specified.
304+
var originalListToolsHandler = listToolsHandler;
323305
listToolsHandler = async (request, cancellationToken) =>
324306
{
325-
ListToolsResult result = new();
326-
foreach (McpServerTool tool in tools)
327-
{
328-
result.Tools.Add(tool.ProtocolTool);
329-
}
307+
ListToolsResult result = originalListToolsHandler is not null ?
308+
await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) :
309+
new();
330310

331-
if (originalListToolsHandler is not null)
311+
if (request.Params?.Cursor is null)
332312
{
333-
string? nextCursor = null;
334-
do
335-
{
336-
ListToolsResult extraResults = await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false);
337-
result.Tools.AddRange(extraResults.Tools);
338-
339-
nextCursor = extraResults.NextCursor;
340-
if (nextCursor is not null)
341-
{
342-
request = request with { Params = new() { Cursor = nextCursor } };
343-
}
344-
}
345-
while (nextCursor is not null);
313+
result.Tools.AddRange(tools.Select(t => t.ProtocolTool));
346314
}
347315

348316
return result;
349317
};
350318

319+
var originalCallToolHandler = callToolHandler;
351320
callToolHandler = (request, cancellationToken) =>
352321
{
353322
if (request.Params is null ||

src/ModelContextProtocol/TokenProgress.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
using ModelContextProtocol.Protocol.Messages;
22
using ModelContextProtocol.Server;
3-
using ModelContextProtocol.Shared;
43

54
namespace ModelContextProtocol;
65

tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
using ModelContextProtocol.Client;
33
using ModelContextProtocol.Protocol.Transport;
44
using ModelContextProtocol.Server;
5-
using ModelContextProtocol.Tests.Transport;
65
using ModelContextProtocol.Tests.Utils;
76
using System.IO.Pipelines;
87

tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs

Lines changed: 78 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
using ModelContextProtocol.Protocol.Transport;
77
using ModelContextProtocol.Protocol.Types;
88
using ModelContextProtocol.Server;
9-
using ModelContextProtocol.Tests.Transport;
109
using ModelContextProtocol.Tests.Utils;
1110
using System.ComponentModel;
1211
using System.IO.Pipelines;
1312
using System.Threading.Channels;
1413

14+
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
15+
1516
namespace ModelContextProtocol.Tests.Configuration;
1617

1718
public class McpServerBuilderExtensionsPromptsTests : LoggedTest, IAsyncDisposable
@@ -28,7 +29,70 @@ public McpServerBuilderExtensionsPromptsTests(ITestOutputHelper testOutputHelper
2829
{
2930
ServiceCollection sc = new();
3031
sc.AddSingleton(LoggerFactory);
31-
_builder = sc.AddMcpServer().WithStdioServerTransport().WithPrompts<SimplePrompts>();
32+
_builder = sc
33+
.AddMcpServer()
34+
.WithStdioServerTransport()
35+
.WithListPromptsHandler(async (request, cancellationToken) =>
36+
{
37+
var cursor = request.Params?.Cursor;
38+
switch (cursor)
39+
{
40+
case null:
41+
return new()
42+
{
43+
NextCursor = "abc",
44+
Prompts = [new()
45+
{
46+
Name = "FirstCustomPrompt",
47+
Description = "First prompt returned by custom handler",
48+
}],
49+
};
50+
51+
case "abc":
52+
return new()
53+
{
54+
NextCursor = "def",
55+
Prompts = [new()
56+
{
57+
Name = "SecondCustomPrompt",
58+
Description = "Second prompt returned by custom handler",
59+
}],
60+
};
61+
62+
case "def":
63+
return new()
64+
{
65+
NextCursor = null,
66+
Prompts = [new()
67+
{
68+
Name = "FinalCustomPrompt",
69+
Description = "Final prompt returned by custom handler",
70+
}],
71+
};
72+
73+
default:
74+
throw new Exception("Unexpected cursor");
75+
}
76+
})
77+
.WithGetPromptHandler(async (request, cancellationToken) =>
78+
{
79+
switch (request.Params?.Name)
80+
{
81+
case "FirstCustomPrompt":
82+
case "SecondCustomPrompt":
83+
case "FinalCustomPrompt":
84+
return new GetPromptResult()
85+
{
86+
Messages = [new() { Role = Role.User, Content = new() { Text = $"hello from {request.Params.Name}", Type = "text" } }],
87+
};
88+
89+
default:
90+
throw new Exception($"Unknown prompt '{request.Params?.Name}'");
91+
}
92+
})
93+
.WithPrompts<SimplePrompts>();
94+
95+
3296
// Call WithStdioServerTransport to get the IMcpServer registration, then overwrite default transport with a pipe transport.
3397
sc.AddSingleton<ITransport>(new StreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream(), loggerFactory: LoggerFactory));
3498
sc.AddSingleton(new ObjectWithId());
@@ -85,7 +149,7 @@ public async Task Can_List_And_Call_Registered_Prompts()
85149
IMcpClient client = await CreateMcpClientForServer();
86150

87151
var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken);
88-
Assert.Equal(3, prompts.Count);
152+
Assert.Equal(6, prompts.Count);
89153

90154
var prompt = prompts.First(t => t.Name == nameof(SimplePrompts.ReturnsChatMessages));
91155
Assert.Equal("Returns chat messages", prompt.Description);
@@ -98,6 +162,14 @@ public async Task Can_List_And_Call_Registered_Prompts()
98162
Assert.Equal(2, chatMessages.Count);
99163
Assert.Equal("The prompt is: hello", chatMessages[0].Text);
100164
Assert.Equal("Summarize.", chatMessages[1].Text);
165+
166+
prompt = prompts.First(t => t.Name == "SecondCustomPrompt");
167+
Assert.Equal("Second prompt returned by custom handler", prompt.Description);
168+
result = await prompt.GetAsync(cancellationToken: TestContext.Current.CancellationToken);
169+
chatMessages = result.ToChatMessages();
170+
Assert.NotNull(chatMessages);
171+
Assert.Single(chatMessages);
172+
Assert.Equal("hello from SecondCustomPrompt", chatMessages[0].Text);
101173
}
102174

103175
[Fact]
@@ -106,7 +178,7 @@ public async Task Can_Be_Notified_Of_Prompt_Changes()
106178
IMcpClient client = await CreateMcpClientForServer();
107179

108180
var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken);
109-
Assert.Equal(3, prompts.Count);
181+
Assert.Equal(6, prompts.Count);
110182

111183
Channel<JsonRpcNotification> listChanged = Channel.CreateUnbounded<JsonRpcNotification>();
112184
client.AddNotificationHandler("notifications/prompts/list_changed", notification =>
@@ -127,7 +199,7 @@ public async Task Can_Be_Notified_Of_Prompt_Changes()
127199
await notificationRead;
128200

129201
prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken);
130-
Assert.Equal(4, prompts.Count);
202+
Assert.Equal(7, prompts.Count);
131203
Assert.Contains(prompts, t => t.Name == "NewPrompt");
132204

133205
notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken);
@@ -136,7 +208,7 @@ public async Task Can_Be_Notified_Of_Prompt_Changes()
136208
await notificationRead;
137209

138210
prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken);
139-
Assert.Equal(3, prompts.Count);
211+
Assert.Equal(6, prompts.Count);
140212
Assert.DoesNotContain(prompts, t => t.Name == "NewPrompt");
141213
}
142214

0 commit comments

Comments
 (0)