Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Cellm/Cellm.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
<PackageReference Include="Microsoft.Extensions.Options" Version="9.0.5" />
<PackageReference Include="Microsoft.Extensions.Options.ConfigurationExtensions" Version="9.0.5" />
<PackageReference Include="Mistral.SDK" Version="2.2.0" />
<PackageReference Include="ModelContextProtocol" Version="0.1.0-preview.11" />
<PackageReference Include="ModelContextProtocol" Version="0.2.0-preview.3" />
<PackageReference Include="OllamaSharp" Version="5.2.2" />
<PackageReference Include="PdfPig" Version="0.1.10" />
<PackageReference Include="Sentry.Extensions.Logging" Version="5.5.1" />
Expand Down
6 changes: 3 additions & 3 deletions src/Cellm/Models/Behaviors/CacheBehavior.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ public async Task<TResponse> Handle(TRequest request, RequestHandlerDelegate<TRe
// Tools are explicitly [JsonIgnore]'d, but we want to send prompt if user added/removed tools
var toolsAsJson = JsonSerializer.Serialize(request.Prompt.Options.Tools);

byte[] hashBytes = SHA256.HashData(Encoding.UTF8.GetBytes(promptAsJson + toolsAsJson));
var key = Convert.ToBase64String(hashBytes);
var hash = SHA256.HashData(Encoding.UTF8.GetBytes(promptAsJson + toolsAsJson));
var key = Convert.ToBase64String(hash);

return await cache.GetOrCreateAsync(
key,
async innerCancellationToken => await next(),
async innerCancellationToken => await next().ConfigureAwait(false),
options: _cacheEntryOptions,
Tags,
cancellationToken
Expand Down
2 changes: 1 addition & 1 deletion src/Cellm/Models/Behaviors/ProviderBehavior.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public async Task<TResponse> Handle(TRequest request, RequestHandlerDelegate<TRe
providerBehavior.Before(request.Prompt);
}

var response = await next();
var response = await next().ConfigureAwait(false);

foreach (var providerBehavior in enabledProviderBehaviors)
{
Expand Down
175 changes: 68 additions & 107 deletions src/Cellm/Models/Behaviors/ToolBehavior.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,152 +9,113 @@
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using ModelContextProtocol.Client;
using ModelContextProtocol.Protocol.Transport;

namespace Cellm.Models.Behaviors;

internal class ToolBehavior<TRequest, TResponse>(
Account account,
IOptionsMonitor<CellmAddInConfiguration> providerConfiguration,
IOptionsMonitor<ModelContextProtocolConfiguration> modelContextProtocolConfiguration,
IEnumerable<AIFunction> functions,
ILogger<ToolBehavior<TRequest, TResponse>> logger,
ILoggerFactory loggerFactory)
: IPipelineBehavior<TRequest, TResponse>
where TRequest : IPrompt
Account account,
IOptionsMonitor<CellmAddInConfiguration> cellmAddInConfiguration,
IOptionsMonitor<ModelContextProtocolConfiguration> modelContextProtocolConfiguration,
IEnumerable<AIFunction> functions,
ILogger<ToolBehavior<TRequest, TResponse>> logger,
ILoggerFactory loggerFactory)
: IPipelineBehavior<TRequest, TResponse>
where TRequest : IPrompt
{
// TODO: Use HybridCache
private readonly ConcurrentDictionary<string, IMcpClient> _mcpClientCache = [];
private readonly ConcurrentDictionary<string, IList<McpClientTool>> _mcpClientToolCache = [];
private readonly SemaphoreSlim _asyncLock = new(1, 1);
// TODO: Cannot use HybridCache because McpClientTool instances can be serialized
private readonly ConcurrentDictionary<string, IList<McpClientTool>> _cache = new();

public async Task<TResponse> Handle(TRequest request, RequestHandlerDelegate<TResponse> next, CancellationToken cancellationToken)
{
if (providerConfiguration.CurrentValue.EnableTools.Any(t => t.Value))
if (cellmAddInConfiguration.CurrentValue.EnableTools.Any(t => t.Value))
{
logger.LogDebug("Native tools enabled");

request.Prompt.Options.Tools = GetNativeTools();
}
else
{
logger.LogDebug("Native tools disabled");
request.Prompt.Options.Tools = [.. functions.Where(f => cellmAddInConfiguration.CurrentValue.EnableTools[f.Name])];
}

var enableModelContextProtocol = await account.HasEntitlementAsync(Entitlement.EnableModelContextProtocol);
var enableModelContextProtocol = await account.HasEntitlementAsync(Entitlement.EnableModelContextProtocol).ConfigureAwait(false);

if (providerConfiguration.CurrentValue.EnableModelContextProtocolServers.Any(t => t.Value) && enableModelContextProtocol)
if (cellmAddInConfiguration.CurrentValue.EnableModelContextProtocolServers.Any(t => t.Value) && enableModelContextProtocol)
{
logger.LogDebug("MCP tools enabled");

request.Prompt.Options.Tools ??= [];

await foreach (var tool in GetModelContextProtocolToolsAsync(cancellationToken))
await foreach (var tool in GetMcpToolsAsync(cancellationToken))
{
request.Prompt.Options.Tools ??= [];
request.Prompt.Options.Tools.Add(tool);
}
}
else
{
logger.LogDebug("MCP tools disabled");
}

if (request.Prompt.Options.Tools?.Any() ?? false)
{
logger.LogDebug("Tools: {tools}", request.Prompt.Options.Tools);
logger.LogDebug("Tools enabled: {tools}", request.Prompt.Options.Tools);
}
else
{
logger.LogDebug("Tools disabled");
}

return await next().ConfigureAwait(false);
}

private List<AITool> GetNativeTools()
private async IAsyncEnumerable<McpClientTool> GetMcpToolsAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
return [.. functions.Where(f => providerConfiguration.CurrentValue.EnableTools[f.Name])];
}
var stdioToolTasks = modelContextProtocolConfiguration.CurrentValue.StdioServers
.Where(stdioClientTransportOptions => cellmAddInConfiguration.CurrentValue.EnableModelContextProtocolServers
.TryGetValue(stdioClientTransportOptions.Name ?? throw new NullReferenceException(nameof(stdioClientTransportOptions.Name)), out var isEnabled) && isEnabled)
.Select(stdioClientTransportOptions => GetOrFetchServerToolsAsync(stdioClientTransportOptions, cancellationToken))
.ToList();

// TODO: Query servers in parallel
private async IAsyncEnumerable<AITool> GetModelContextProtocolToolsAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
foreach (var serverConfiguration in modelContextProtocolConfiguration.CurrentValue.StdioServers)
{
var serverName = serverConfiguration.Name ?? throw new NullReferenceException(nameof(serverConfiguration.Name));
var sseToolTasks = modelContextProtocolConfiguration.CurrentValue.SseServers
.Where(sseClientTransportOptions => cellmAddInConfiguration.CurrentValue.EnableModelContextProtocolServers
.TryGetValue(sseClientTransportOptions.Name ?? throw new NullReferenceException(nameof(sseClientTransportOptions.Name)), out var isEnabled) && isEnabled)
.Select(sseClientTransportOptions => GetOrFetchServerToolsAsync(sseClientTransportOptions, cancellationToken))
.ToList();

if (!providerConfiguration.CurrentValue.EnableModelContextProtocolServers.TryGetValue(serverName, out var isEnabled) || !isEnabled)
{
continue;
}
List<Task<IList<McpClientTool>>> pendingTasks = [.. stdioToolTasks, .. sseToolTasks];

_mcpClientToolCache.TryGetValue(serverName, out var serverTools);

if (serverTools is null)
{
await _asyncLock.WaitAsync(cancellationToken).ConfigureAwait(false);

try
{
_mcpClientCache.TryGetValue(serverName, out var StdioMcpClient);

if (StdioMcpClient is null)
{
var clientTransport = new StdioClientTransport(serverConfiguration);
StdioMcpClient = await McpClientFactory.CreateAsync(clientTransport, loggerFactory: loggerFactory, cancellationToken: cancellationToken).ConfigureAwait(false);
_mcpClientCache[serverName] = StdioMcpClient;
}

serverTools = await StdioMcpClient.ListToolsAsync(cancellationToken: cancellationToken).ConfigureAwait(false);
_mcpClientToolCache[serverName] = serverTools;
}
finally
{
_asyncLock.Release();
}
}
while (pendingTasks.Count > 0)
{
var completedTask = await Task.WhenAny(pendingTasks).ConfigureAwait(false);
pendingTasks.Remove(completedTask);

foreach (var serverTool in serverTools)
foreach (var tool in await completedTask)
{
yield return serverTool;
cancellationToken.ThrowIfCancellationRequested();
yield return tool;
}
}
}

foreach (var serverConfiguration in modelContextProtocolConfiguration.CurrentValue.SseServers)
private async Task<IList<McpClientTool>> GetOrFetchServerToolsAsync(StdioClientTransportOptions stdioClientTransportOptions, CancellationToken cancellationToken)
{
if (_cache.ContainsKey(stdioClientTransportOptions.Name ?? throw new NullReferenceException(nameof(stdioClientTransportOptions))) && _cache[stdioClientTransportOptions.Name] is IList<McpClientTool> cachedTools)
{
var serverName = serverConfiguration.Name ?? throw new NullReferenceException(nameof(serverConfiguration.Name));
logger.LogDebug("Using cached tools for {ServerName}", stdioClientTransportOptions.Name);
return cachedTools;
}

if (!providerConfiguration.CurrentValue.EnableModelContextProtocolServers.TryGetValue(serverName, out var isEnabled) || !isEnabled)
{
continue;
}
var clientTransport = new StdioClientTransport(stdioClientTransportOptions);
var mcpClient = await McpClientFactory.CreateAsync(clientTransport, loggerFactory: loggerFactory, cancellationToken: cancellationToken).ConfigureAwait(false);
var tools = await mcpClient.ListToolsAsync(cancellationToken: cancellationToken).ConfigureAwait(false);

_mcpClientToolCache.TryGetValue(serverName, out var serverTools);
_cache[stdioClientTransportOptions.Name ?? throw new NullReferenceException(nameof(stdioClientTransportOptions))] = tools;

if (serverTools is null)
{
await _asyncLock.WaitAsync(cancellationToken).ConfigureAwait(false);

try
{
_mcpClientCache.TryGetValue(serverName, out var SseMcpClient);

if (SseMcpClient is null)
{
var clientTransport = new SseClientTransport(serverConfiguration);
SseMcpClient = await McpClientFactory.CreateAsync(clientTransport, loggerFactory: loggerFactory, cancellationToken: cancellationToken).ConfigureAwait(false);
_mcpClientCache[serverName] = SseMcpClient;
}

serverTools = await SseMcpClient.ListToolsAsync(cancellationToken: cancellationToken).ConfigureAwait(false);
_mcpClientToolCache[serverName] = serverTools;
}
finally
{
_asyncLock.Release();
}
}
return tools;
}

foreach (var serverTool in serverTools)
{
yield return serverTool;
}
private async Task<IList<McpClientTool>> GetOrFetchServerToolsAsync(SseClientTransportOptions sseClientTransportOptions, CancellationToken cancellationToken)
{
if (_cache.ContainsKey(sseClientTransportOptions.Name ?? throw new NullReferenceException(nameof(sseClientTransportOptions))) && _cache[sseClientTransportOptions.Name] is IList<McpClientTool> cachedTools)
{
logger.LogDebug("Using cached tools for {ServerName}", sseClientTransportOptions.Name);
return cachedTools;
}

var clientTransport = new SseClientTransport(sseClientTransportOptions);
var mcpClient = await McpClientFactory.CreateAsync(clientTransport, loggerFactory: loggerFactory, cancellationToken: cancellationToken).ConfigureAwait(false);
var tools = await mcpClient.ListToolsAsync(cancellationToken: cancellationToken).ConfigureAwait(false);

_cache[sseClientTransportOptions.Name ?? throw new NullReferenceException(nameof(sseClientTransportOptions))] = tools;

return tools;
}
}

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Client;

namespace Cellm.Tools.ModelContextProtocol;

Expand Down
2 changes: 1 addition & 1 deletion src/Cellm/appsettings.json
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
"StdioServers": [
{
"Command": "npx",
"Arguments": [ "-y", "@playwright/mcp@latest" ],
"Arguments": [ "-y", "@playwright/mcp@latest", "--isolated", "--headless", "--image-responses", "omit", "--caps", "wait,history,pdf" ],
"Name": "Playwright"
}
],
Expand Down
27 changes: 17 additions & 10 deletions src/Cellm/packages.lock.json
Original file line number Diff line number Diff line change
Expand Up @@ -257,15 +257,12 @@
},
"ModelContextProtocol": {
"type": "Direct",
"requested": "[0.1.0-preview.11, )",
"resolved": "0.1.0-preview.11",
"contentHash": "TWSeMx0mcdX4pZVB0gS2ii06Ll9pz+PAG8OieTI0MG9rD4YmFBrCMKoRXf381I/S0jVjktJeFHdrJVKeSiH7BA==",
"requested": "[0.2.0-preview.3, )",
"resolved": "0.2.0-preview.3",
"contentHash": "logBVqTpZuYc3Exkx+9AmjHENVYT7+LkeTbXBDK5ghiUSDpGuCoxndeJOHRj0c9K9JqFZJweNXqlu4hihK2CKQ==",
"dependencies": {
"Microsoft.Extensions.AI": "9.4.0-preview.1.25207.5",
"Microsoft.Extensions.AI.Abstractions": "9.4.0-preview.1.25207.5",
"Microsoft.Extensions.Hosting.Abstractions": "9.0.4",
"Microsoft.Extensions.Logging.Abstractions": "9.0.4",
"System.Net.ServerSentEvents": "10.0.0-preview.2.25163.2"
"Microsoft.Extensions.Hosting.Abstractions": "9.0.5",
"ModelContextProtocol.Core": "0.2.0-preview.3"
}
},
"OllamaSharp": {
Expand Down Expand Up @@ -609,6 +606,16 @@
"Microsoft.NETCore.Platforms": "5.0.0"
}
},
"ModelContextProtocol.Core": {
"type": "Transitive",
"resolved": "0.2.0-preview.3",
"contentHash": "90zVR7ZcGN6cyDZM2EjV48n/H17AdA5fYt3jjTL0PY275BquFhESyRAcorAAG3BCqV/vXo6HuQpFUuPam/9xOw==",
"dependencies": {
"Microsoft.Extensions.AI.Abstractions": "9.5.0",
"Microsoft.Extensions.Logging.Abstractions": "9.0.5",
"System.Net.ServerSentEvents": "10.0.0-preview.4.25258.110"
}
},
"OpenAI": {
"type": "Transitive",
"resolved": "2.2.0-beta.4",
Expand Down Expand Up @@ -681,8 +688,8 @@
},
"System.Net.ServerSentEvents": {
"type": "Transitive",
"resolved": "10.0.0-preview.2.25163.2",
"contentHash": "XHyvtQSgco0Sv0kz9yNBv93k3QOoAVzIVd5XbQoTqjV9sqkzWHsToNknyxtNjcXQwb+O9TfzSlNobsBWwnKD3Q=="
"resolved": "10.0.0-preview.4.25258.110",
"contentHash": "5SG8yFN0e6y3VVaCMWSRYdDW15Apzg2LQoYUtl+jFrGb/+RLSMYmIdU9hO1hbHqpU0JD/lfBlCox0FebdAGGOQ=="
},
"System.Numerics.Vectors": {
"type": "Transitive",
Expand Down
Loading