Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions samples/AspNetCoreSseServer/Attributes/LimitCalls.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
using ModelContextProtocol.Core;
using ModelContextProtocol.Protocol;
using ModelContextProtocol.Server;

namespace AspNetCoreSseServer.Attributes;

public class LimitCallsAttribute(int maxCalls) : ToolFilterAttribute
{
private int _callCount;

public override ValueTask<CallToolResult>? OnToolCalling(Tool tool, RequestContext<CallToolRequestParams> context)
{
//Thread-safe increment
var currentCount = Interlocked.Add(ref _callCount, 1);

//Log count
Console.Out.WriteLine($"Tool: {tool.Name} called {currentCount} time(s)");

//If under threshold, do nothing
if (currentCount <= maxCalls)
return null; //do nothing

//If above threshold, return error message
return new ValueTask<CallToolResult>(new CallToolResult
{
Content = [new TextContentBlock { Text = $"This tool can only be called {maxCalls} time(s)" }]
});
}

public override bool OnToolListed(Tool tool, RequestContext<ListToolsRequestParams> context)
{
//With the provided request context, you can access the dependency injection
var configuration = context.Services?.GetService<IConfiguration>();
var hide = configuration?["hide-tools-above-limit"] == "True";

//Prevent the tool being listed (return false)
//if the hide flag is true and the call count is above the threshold
return _callCount <= maxCalls || !hide;
}
}
2 changes: 2 additions & 0 deletions samples/AspNetCoreSseServer/Tools/EchoTool.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
using ModelContextProtocol.Server;
using System.ComponentModel;
using AspNetCoreSseServer.Attributes;

namespace TestServerWithHosting.Tools;

[McpServerToolType]
public sealed class EchoTool
{
[McpServerTool, Description("Echoes the input back to the client.")]
[LimitCalls(maxCalls: 10)]
public static string Echo(string message)
{
return "hello " + message;
Expand Down
3 changes: 2 additions & 1 deletion samples/AspNetCoreSseServer/appsettings.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@
"Microsoft.AspNetCore": "Warning"
}
},
"AllowedHosts": "*"
"AllowedHosts": "*",
"hide-tools-above-limit": true
}
12 changes: 10 additions & 2 deletions src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.RegularExpressions;
using ModelContextProtocol.Core;

namespace ModelContextProtocol.Server;

Expand Down Expand Up @@ -146,7 +147,7 @@ options.OpenWorld is not null ||
}
}

return new AIFunctionMcpServerTool(function, tool, options?.Services, structuredOutputRequiresWrapping);
return new AIFunctionMcpServerTool(function, tool, options?.Services, structuredOutputRequiresWrapping, options?.Filters ?? []);
}

private static McpServerToolCreateOptions DeriveOptions(MethodInfo method, McpServerToolCreateOptions? options)
Expand Down Expand Up @@ -185,6 +186,9 @@ private static McpServerToolCreateOptions DeriveOptions(MethodInfo method, McpSe
{
newOptions.Description ??= descAttr.Description;
}

var filters = method.GetCustomAttributes<ToolFilterAttribute>().OrderBy(f => f.Order).ToArray<IToolFilter>();
newOptions.Filters = filters;

return newOptions;
}
Expand All @@ -193,17 +197,21 @@ private static McpServerToolCreateOptions DeriveOptions(MethodInfo method, McpSe
internal AIFunction AIFunction { get; }

/// <summary>Initializes a new instance of the <see cref="McpServerTool"/> class.</summary>
private AIFunctionMcpServerTool(AIFunction function, Tool tool, IServiceProvider? serviceProvider, bool structuredOutputRequiresWrapping)
private AIFunctionMcpServerTool(AIFunction function, Tool tool, IServiceProvider? serviceProvider, bool structuredOutputRequiresWrapping, IToolFilter[] filters)
{
AIFunction = function;
ProtocolTool = tool;
Filters = filters;
_logger = serviceProvider?.GetService<ILoggerFactory>()?.CreateLogger<McpServerTool>() ?? (ILogger)NullLogger.Instance;
_structuredOutputRequiresWrapping = structuredOutputRequiresWrapping;
}

/// <inheritdoc />
public override Tool ProtocolTool { get; }

/// <inheritdoc />
public override IToolFilter[] Filters { get; }

/// <inheritdoc />
public override async ValueTask<CallToolResult> InvokeAsync(
RequestContext<CallToolRequestParams> request, CancellationToken cancellationToken = default)
Expand Down
22 changes: 21 additions & 1 deletion src/ModelContextProtocol.Core/Server/McpServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,11 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false)
{
foreach (var t in tools)
{
if (t.Filters.Any(f => !f.OnToolListed(t.ProtocolTool,request)))
{
continue;
}

result.Tools.Add(t.ProtocolTool);
}
}
Expand All @@ -461,7 +466,22 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false)
if (request.Params is not null &&
tools.TryGetPrimitive(request.Params.Name, out var tool))
{
return tool.InvokeAsync(request, cancellationToken);
foreach (var filter in tool.Filters)
{
var filterResult = filter.OnToolCalling(tool.ProtocolTool, request);
if(filterResult != null)
return filterResult.Value;
}

var result = tool.InvokeAsync(request, cancellationToken);

foreach (var filter in tool.Filters)
{
var filterResult = filter.OnToolCalled(tool.ProtocolTool, request, result);
if(filterResult != null)
return filterResult.Value;
}
return result;
}

return originalCallToolHandler(request, cancellationToken);
Expand Down
4 changes: 4 additions & 0 deletions src/ModelContextProtocol.Core/Server/McpServerTool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using ModelContextProtocol.Protocol;
using System.Reflection;
using System.Text.Json;
using ModelContextProtocol.Core;

namespace ModelContextProtocol.Server;

Expand Down Expand Up @@ -140,6 +141,9 @@ protected McpServerTool()

/// <summary>Gets the protocol <see cref="Tool"/> type for this instance.</summary>
public abstract Tool ProtocolTool { get; }

/// <summary>Gets the filters (<see cref="IToolFilter"/>) associated to this tool.</summary>
public abstract IToolFilter[] Filters { get; }

/// <summary>Invokes the <see cref="McpServerTool"/>.</summary>
/// <param name="request">The request information resulting in the invocation of this tool.</param>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using ModelContextProtocol.Protocol;
using System.ComponentModel;
using System.Text.Json;
using ModelContextProtocol.Core;

namespace ModelContextProtocol.Server;

Expand Down Expand Up @@ -155,6 +156,9 @@ public sealed class McpServerToolCreateOptions
/// </remarks>
public AIJsonSchemaCreateOptions? SchemaCreateOptions { get; set; }

/// TODO
public IToolFilter[] Filters { get; set; } = [];

/// <summary>
/// Creates a shallow clone of the current <see cref="McpServerToolCreateOptions"/> instance.
/// </summary>
Expand Down
39 changes: 39 additions & 0 deletions src/ModelContextProtocol.Core/ToolFilter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
using ModelContextProtocol.Protocol;
using ModelContextProtocol.Server;

namespace ModelContextProtocol.Core;

/// TODO:
public interface IToolFilter
{
/// TODO:
bool OnToolListed(Tool tool, RequestContext<ListToolsRequestParams> context);

/// TODO:
ValueTask<CallToolResult>? OnToolCalling(Tool tool, RequestContext<CallToolRequestParams> context);

/// TODO:
ValueTask<CallToolResult>? OnToolCalled(Tool tool, RequestContext<CallToolRequestParams> context, ValueTask<CallToolResult> callResult);
}

/// TODO:
[AttributeUsage(AttributeTargets.Method, AllowMultiple = true)]
public abstract class ToolFilterAttribute(int order = 0) : Attribute, IToolFilter
{
/// <summary>
/// Gets the order value for determining the order of execution of filters. Filters execute in
/// ascending numeric value of the <see cref="Order"/> property.
/// </summary>
public int Order { get; } = order;

/// <inheritdoc />
public virtual bool OnToolListed(Tool tool, RequestContext<ListToolsRequestParams> context) => true;

/// <inheritdoc />
public virtual ValueTask<CallToolResult>? OnToolCalling(Tool tool, RequestContext<CallToolRequestParams> context) =>
null;

/// <inheritdoc />
public virtual ValueTask<CallToolResult>? OnToolCalled(Tool tool, RequestContext<CallToolRequestParams> context,
ValueTask<CallToolResult> callResult) => null;
}