Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 56 additions & 5 deletions src/Cellm.Models/Behaviors/ToolBehavior.cs
Original file line number Diff line number Diff line change
@@ -1,22 +1,73 @@
using Cellm.Models.Providers;
using System.Runtime.CompilerServices;
using Cellm.Models.Providers;
using Cellm.Tools.ModelContextProtocol;
using MediatR;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Options;
using ModelContextProtocol.Client;

namespace Cellm.Models.Tools;

internal class ToolBehavior<TRequest, TResponse>(IOptionsMonitor<ProviderConfiguration> providerConfiguration, IEnumerable<AIFunction> functions)
internal class ToolBehavior<TRequest, TResponse>(
IOptionsMonitor<ProviderConfiguration> providerConfiguration,
IOptionsMonitor<ModelContextProtocolConfiguration> modelContextProtocolConfiguration,
IEnumerable<AIFunction> functions)
: IPipelineBehavior<TRequest, TResponse> where TRequest : IModelRequest<TResponse>
{
public async Task<TResponse> Handle(TRequest request, RequestHandlerDelegate<TResponse> next, CancellationToken cancellationToken)
{
if (providerConfiguration.CurrentValue.EnableTools.Any(t => t.Value))
{
request.Prompt.Options.Tools = functions
.Where(f => providerConfiguration.CurrentValue.EnableTools[f.Name])
.ToList<AITool>();
request.Prompt.Options.Tools = GetNativeTools();
}

if (providerConfiguration.CurrentValue.EnableModelContextProtocolServers.Any(t => t.Value))
{
request.Prompt.Options.Tools ??= [];

await foreach (var tool in GetModelContextProtocolTools(cancellationToken))
{
request.Prompt.Options.Tools.Add(tool);
}
}

return await next();
}

private List<AITool> GetNativeTools()
{
return functions
.Where(f => providerConfiguration.CurrentValue.EnableTools[f.Name])
.ToList<AITool>();
}

// TODO:
// - Cache capabilities on a per-server basis.
// - Query servers in parallel.
//
// Note: We cannot get list of tools only on startup because user can add/delete/enable/disable servers. But
// with this solution we query servers for capabilities on every model call which is hardly ideal.
// We need to cache on a per-server basis so servers added at runtime will be queried
private async IAsyncEnumerable<AITool> GetModelContextProtocolTools([EnumeratorCancellation] CancellationToken cancellationToken)
{
foreach (var server in modelContextProtocolConfiguration.CurrentValue.Servers)
{
if (!providerConfiguration.CurrentValue.EnableModelContextProtocolServers.TryGetValue(server.Name, out var isEnabled) || !isEnabled)
{
continue;
}

var client = await McpClientFactory.CreateAsync(server, cancellationToken: cancellationToken);

if (client is null)
{
continue;
}

foreach (var tool in await client.ListToolsAsync(cancellationToken: cancellationToken))
{
yield return tool;
}
}
}
}
8 changes: 7 additions & 1 deletion src/Cellm.Models/Providers/ProviderConfiguration.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
namespace Cellm.Models.Providers;
using ModelContextProtocol;

namespace Cellm.Models.Providers;

public class ProviderConfiguration : IProviderConfiguration
{
Expand All @@ -12,6 +14,10 @@ public class ProviderConfiguration : IProviderConfiguration

public Dictionary<string, bool> EnableTools { get; init; } = [];

public Dictionary<string, bool> EnableModelContextProtocolServers { get; init; } = [];

public Dictionary<string, McpServerConfig> ModelContextProtocolServers { get; init; } = [];

public bool EnableCache { get; init; } = true;

public int CacheTimeoutInSeconds { get; init; } = 3600;
Expand Down
4 changes: 2 additions & 2 deletions src/Cellm/AddIn/ExcelRibbon.cs
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,12 @@ public bool OnGetCachePressed(IRibbonControl control)

public void OnFileSearchToggled(IRibbonControl control, bool pressed)
{
SetValue("ProviderConfiguration:EnableTools:GlobRequest", pressed.ToString());
SetValue("ProviderConfiguration:EnableTools:FileSearchRequest", pressed.ToString());
}

public bool OnGetFileSearchPressed(IRibbonControl control)
{
var value = GetValue("ProviderConfiguration:EnableTools:GlobRequest");
var value = GetValue("ProviderConfiguration:EnableTools:FileSearchRequest");
return bool.Parse(value);
}

Expand Down
1 change: 1 addition & 0 deletions src/Cellm/Cellm.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
<PackageReference Include="Microsoft.Extensions.Logging.Debug" Version="9.0.3" />
<PackageReference Include="Microsoft.Extensions.Options" Version="9.0.3" />
<PackageReference Include="Microsoft.Extensions.Options.ConfigurationExtensions" Version="9.0.3" />
<PackageReference Include="ModelContextProtocol" Version="0.1.0-preview.6" />
<PackageReference Include="PdfPig" Version="0.1.10" />
<PackageReference Include="Sentry.Extensions.Logging" Version="5.5.0" />
<PackageReference Include="Sentry.Profiling" Version="5.5.0" />
Expand Down
12 changes: 8 additions & 4 deletions src/Cellm/Services/ServiceLocator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
using Cellm.Services.Configuration;
using Cellm.Tools;
using Cellm.Tools.FileReader;
using Cellm.Tools.ModelContextProtocol;
using ExcelDna.Integration;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Caching.Memory;
Expand Down Expand Up @@ -48,8 +49,9 @@ private static IServiceCollection ConfigureServices(IServiceCollection services)
.Build();

services
.Configure<CellmConfiguration>(configuration.GetRequiredSection(nameof(CellmConfiguration)))
.Configure<ProviderConfiguration>(configuration.GetRequiredSection(nameof(ProviderConfiguration)))
.Configure<ModelContextProtocolConfiguration>(configuration.GetRequiredSection(nameof(ModelContextProtocolConfiguration)))
.Configure<CellmConfiguration>(configuration.GetRequiredSection(nameof(CellmConfiguration)))
.Configure<AnthropicConfiguration>(configuration.GetRequiredSection(nameof(AnthropicConfiguration)))
.Configure<DeepSeekConfiguration>(configuration.GetRequiredSection(nameof(DeepSeekConfiguration)))
.Configure<LlamafileConfiguration>(configuration.GetRequiredSection(nameof(LlamafileConfiguration)))
Expand Down Expand Up @@ -129,10 +131,12 @@ private static IServiceCollection ConfigureServices(IServiceCollection services)
.AddSingleton<FileReaderFactory>()
.AddSingleton<IFileReader, PdfReader>()
.AddSingleton<IFileReader, TextReader>()
.AddSingleton<Functions>()
.AddSingleton<NativeTools>()
.AddTools(
serviceProvider => AIFunctionFactory.Create(serviceProvider.GetRequiredService<Functions>().GlobRequest),
serviceProvider => AIFunctionFactory.Create(serviceProvider.GetRequiredService<Functions>().FileReaderRequest));
serviceProvider => AIFunctionFactory.Create(serviceProvider.GetRequiredService<NativeTools>().FileSearchRequest),
serviceProvider => AIFunctionFactory.Create(serviceProvider.GetRequiredService<NativeTools>().FileReaderRequest));

// Workarounds

// https://github.com/openai/openai-dotnet/issues/297
var metricsFilterDescriptor = services.FirstOrDefault(descriptor =>
Expand Down
5 changes: 5 additions & 0 deletions src/Cellm/Tools/FileSearch/FileSearchRequest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
using MediatR;

namespace Cellm.Tools.FileSearch;

internal record FileSearchRequest(string RootPath, List<string> IncludePatterns, List<string>? ExcludePatterns) : IRequest<FileSearchResponse>;
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
using MediatR;
using Microsoft.Extensions.FileSystemGlobbing;

namespace Cellm.Tools.Glob;
namespace Cellm.Tools.FileSearch;

internal class GlobRequestHandler : IRequestHandler<GlobRequest, GlobResponse>
internal class FileSearchRequestHandler : IRequestHandler<FileSearchRequest, FileSearchResponse>
{
public Task<GlobResponse> Handle(GlobRequest request, CancellationToken cancellationToken)
public Task<FileSearchResponse> Handle(FileSearchRequest request, CancellationToken cancellationToken)
{
var matcher = new Matcher();
matcher.AddIncludePatterns(request.IncludePatterns);
matcher.AddExcludePatterns(request.ExcludePatterns ?? []);
var fileNames = matcher.GetResultsInFullPath(request.RootPath);

return Task.FromResult(new GlobResponse(fileNames.ToList()));
return Task.FromResult(new FileSearchResponse(fileNames.ToList()));
}
}
3 changes: 3 additions & 0 deletions src/Cellm/Tools/FileSearch/FileSearchResponse.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
namespace Cellm.Tools.FileSearch;

internal record FileSearchResponse(List<string> FilePaths);
5 changes: 0 additions & 5 deletions src/Cellm/Tools/Glob/GlobRequest.cs

This file was deleted.

3 changes: 0 additions & 3 deletions src/Cellm/Tools/Glob/GlobResponse.cs

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
using ModelContextProtocol;

namespace Cellm.Tools.ModelContextProtocol;

internal class ModelContextProtocolConfiguration
{
public List<McpServerConfig> Servers { get; init; } = [];
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using System.ComponentModel;
using Cellm.Tools.FileReader;
using Cellm.Tools.Glob;
using Cellm.Tools.FileSearch;
using MediatR;

namespace Cellm.Tools;
Expand All @@ -9,17 +9,17 @@ namespace Cellm.Tools;
/// Provides an adapter between MediatR and Microsoft.Extensions.AI by wrapping
/// request handlers in function definitions suitable for the AIFunctionFactory.
/// </summary>
internal class Functions(ISender sender)
internal class NativeTools(ISender sender)
{
[Description("Uses glob patterns to search for files on the user's disk and returns matching file paths.")]
[return: Description($"The list of file paths that matches {nameof(includePatterns)} and do not match {nameof(excludePatterns)}")]
public async Task<GlobResponse> GlobRequest(
public async Task<FileSearchResponse> FileSearchRequest(
[Description("The root directory to start the glob search from")] string rootPath,
[Description("The list of glob patterns whose matches will be included in the result")] List<string> includePatterns,
[Description("An optional list of glob patterns whose matches will be excluded from the result")] List<string>? excludePatterns,
CancellationToken cancellationToken)
{
return await sender.Send(new GlobRequest(rootPath, includePatterns, excludePatterns), cancellationToken);
return await sender.Send(new FileSearchRequest(rootPath, includePatterns, excludePatterns), cancellationToken);
}

[Description("Reads a file and returns its content as plain text.")]
Expand Down
19 changes: 18 additions & 1 deletion src/Cellm/appsettings.json
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,27 @@
"CacheTimeoutInSeconds": 3600,
"EnableCache": true,
"EnableTools": {
"GlobRequest": false,
"FileSearchRequest": false,
"FileReaderRequest": false
},
"EnableModelContextProtocolServers": {
"Everything": true
}
},
"ModelContextProtocolConfiguration": {
"Servers": [
{
"Id": "everything",
"Name": "Everything",
"TransportType": "stdio",
"Location": "path",
"TransportOptions": {
"command": "npx",
"arguments": "-y @modelcontextprotocol/server-everything"
}
}
]
},
"ResilienceConfiguration": {
"RateLimiterConfiguration": {
"QueueLimit": 1048576,
Expand Down
17 changes: 17 additions & 0 deletions src/Cellm/packages.lock.json
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,18 @@
"Microsoft.Extensions.Primitives": "9.0.3"
}
},
"ModelContextProtocol": {
"type": "Direct",
"requested": "[0.1.0-preview.6, )",
"resolved": "0.1.0-preview.6",
"contentHash": "gq6mQYvtaGC8lhWHBS4X5Ck53+HNWZPiqO7hOEOFRRLO30OlrXX9I+Uz9ShvfTAnjwvugV3TfMH2UbpBXusBtw==",
"dependencies": {
"Microsoft.Extensions.AI.Abstractions": "9.3.0-preview.1.25161.3",
"Microsoft.Extensions.Hosting.Abstractions": "9.0.0",
"Microsoft.Extensions.Logging.Abstractions": "9.0.0",
"System.Net.ServerSentEvents": "10.0.0-preview.2.25163.2"
}
},
"PdfPig": {
"type": "Direct",
"requested": "[0.1.10, )",
Expand Down Expand Up @@ -566,6 +578,11 @@
"resolved": "9.0.3",
"contentHash": "QH23aqk1Cr1oSP9zEbjsJ60M7nbYOSEQLXszzxK12VXjEOXasnI8pnF7WeME66+z8OoecHfIL8iGxCRxjFQXFQ=="
},
"System.Net.ServerSentEvents": {
"type": "Transitive",
"resolved": "10.0.0-preview.2.25163.2",
"contentHash": "XHyvtQSgco0Sv0kz9yNBv93k3QOoAVzIVd5XbQoTqjV9sqkzWHsToNknyxtNjcXQwb+O9TfzSlNobsBWwnKD3Q=="
},
"System.Text.Json": {
"type": "Transitive",
"resolved": "9.0.3",
Expand Down