Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/Cellm/Models/ModelRequestBehavior/ToolBehavior.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using Cellm.Models.OpenAi;
using Cellm.Prompts;
using Cellm.Prompts;
using Cellm.Tools;
using MediatR;

Expand All @@ -10,12 +9,12 @@ internal class ToolBehavior<TRequest, TResponse> : IPipelineBehavior<TRequest, T
where TResponse : IModelResponse
{
private readonly ISender _sender;
private readonly ITools _tools;
private readonly ToolRunner _toolRunner;

public ToolBehavior(ISender sender, ITools tools)
public ToolBehavior(ISender sender, ToolRunner toolRunner)
{
_sender = sender;
_tools = tools;
_toolRunner = toolRunner;
}

public async Task<TResponse> Handle(TRequest request, RequestHandlerDelegate<TResponse> next, CancellationToken cancellationToken)
Expand All @@ -27,7 +26,8 @@ public async Task<TResponse> Handle(TRequest request, RequestHandlerDelegate<TRe
if (toolCalls is not null)
{
// Model called tools, run tools and call model again
request.Prompt.Messages.Add(await RunTools(toolCalls));
var message = await RunTools(toolCalls);
request.Prompt.Messages.Add(message);
response = await _sender.Send(request, cancellationToken);
}

Expand All @@ -36,7 +36,7 @@ public async Task<TResponse> Handle(TRequest request, RequestHandlerDelegate<TRe

private async Task<Message> RunTools(List<ToolCall> toolCalls)
{
var toolResults = await Task.WhenAll(toolCalls.Select(x => _tools.Run(x)));
var toolResults = await Task.WhenAll(toolCalls.Select(x => _toolRunner.Run(x)));
var toolCallsWithResults = toolCalls
.Zip(toolResults, (toolCall, toolResult) => toolCall with { Result = toolResult })
.ToList();
Expand Down
4 changes: 2 additions & 2 deletions src/Cellm/Models/OpenAi/Extensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ private static List<OpenAiMessage> ToOpenAiToolResults(Message message)
).ToList();
}

public static List<OpenAiTool> ToOpenAiTools(this ITools tools)
public static List<OpenAiTool> ToOpenAiTools(this ToolRunner toolRunner)
{
return tools.GetTools()
return toolRunner.GetTools()
.Select(x => new OpenAiTool("function", new OpenAiFunction(x.Name, x.Description, x.Parameters)))
.ToList();
}
Expand Down
8 changes: 4 additions & 4 deletions src/Cellm/Models/OpenAi/OpenAiRequestHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,20 @@ internal class OpenAiRequestHandler : IModelRequestHandler<OpenAiRequest, OpenAi
private readonly OpenAiConfiguration _openAiConfiguration;
private readonly CellmConfiguration _cellmConfiguration;
private readonly HttpClient _httpClient;
private readonly ITools _tools;
private readonly ToolRunner _toolRunner;
private readonly ISerde _serde;

public OpenAiRequestHandler(
IOptions<OpenAiConfiguration> openAiConfiguration,
IOptions<CellmConfiguration> cellmConfiguration,
HttpClient httpClient,
ITools tools,
ToolRunner toolRunner,
ISerde serde)
{
_openAiConfiguration = openAiConfiguration.Value;
_cellmConfiguration = cellmConfiguration.Value;
_httpClient = httpClient;
_tools = tools;
_toolRunner = toolRunner;
_serde = serde;
}

Expand Down Expand Up @@ -63,7 +63,7 @@ public string Serialize(OpenAiRequest request)
openAiPrompt.ToOpenAiMessages(),
_cellmConfiguration.MaxOutputTokens,
openAiPrompt.Temperature,
_tools.ToOpenAiTools(),
_toolRunner.ToOpenAiTools(),
"auto");

return _serde.Serialize(chatCompletionRequest, new JsonSerializerOptions
Expand Down
27 changes: 18 additions & 9 deletions src/Cellm/Services/ServiceLocator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using Cellm.Models.PipelineBehavior;
using Cellm.Services.Configuration;
using Cellm.Tools;
using Cellm.Tools.Glob;
using ExcelDna.Integration;
using MediatR;
using Microsoft.Extensions.Caching.Memory;
Expand Down Expand Up @@ -82,22 +83,28 @@ private static IServiceCollection ConfigureServices(IServiceCollection services)
sentryLoggingOptions.ExperimentalMetrics = new ExperimentalMetricsOptions { EnableCodeLocations = true };
sentryLoggingOptions.AddIntegration(new ProfilingIntegration());
});
});
})
.AddSingleton(typeof(IPipelineBehavior<,>), typeof(SentryBehavior<,>));

// Internals
services
.AddSingleton(configuration)
.AddMediatR(cfg => cfg.RegisterServicesFromAssembly(Assembly.GetExecutingAssembly()))
.AddTransient(typeof(IPipelineBehavior<,>), typeof(SentryBehavior<,>))
.AddTransient(typeof(IPipelineBehavior<,>), typeof(CachingBehavior<,>))
.AddTransient(typeof(IPipelineBehavior<,>), typeof(ToolBehavior<,>))
.AddMemoryCache()
.AddTransient<ArgumentParser>()
.AddSingleton<IClient, Client>()
.AddSingleton<ITools, Tools.Tools>()
.AddSingleton<ISerde, Serde>();

// Cache
services
.AddMemoryCache()
.AddSingleton<ICache, Cache>()
.AddSingleton<ISerde, Serde>()
.AddSingleton<LLamafileProcessManager>();
.AddSingleton(typeof(IPipelineBehavior<,>), typeof(CachingBehavior<,>));

// Tools
services
.AddSingleton<ToolRunner>()
.AddSingleton<ToolFactory>()
.AddSingleton(typeof(IPipelineBehavior<,>), typeof(ToolBehavior<,>));

// Model Providers
var rateLimiterConfiguration = configuration.GetRequiredSection(nameof(RateLimiterConfiguration)).Get<RateLimiterConfiguration>()
Expand Down Expand Up @@ -139,7 +146,9 @@ private static IServiceCollection ConfigureServices(IServiceCollection services)
openAiHttpClient.DefaultRequestHeaders.Add("Authorization", $"Bearer {openAiConfiguration.ApiKey}");
}).AddResilienceHandler($"{nameof(OpenAiRequestHandler)}ResiliencePipeline", resiliencePipelineConfigurator.ConfigureResiliencePipeline);

services.AddSingleton<LlamafileRequestHandler>();
services
.AddSingleton<LlamafileRequestHandler>()
.AddSingleton<LLamafileProcessManager>();

return services;
}
Expand Down
48 changes: 0 additions & 48 deletions src/Cellm/Tools/Glob.cs

This file was deleted.

10 changes: 10 additions & 0 deletions src/Cellm/Tools/Glob/GlobRequest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
using System.ComponentModel;
using MediatR;

namespace Cellm.Tools.Glob;

[Description("Search for files on the user's disk using glob patterns. Useful when user asks you to find files.")]
internal record GlobRequest(
[Description("The root directory to start the glob search from")] string RootPath,
[Description("List of patterns to include in the search")] List<string> IncludePatterns,
[Description("Optional list of patterns to exclude from the search")] List<string>? ExcludePatterns) : ITool, IRequest<GlobResponse>;
17 changes: 17 additions & 0 deletions src/Cellm/Tools/Glob/GlobRequestHandler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using MediatR;
using Microsoft.Extensions.FileSystemGlobbing;

namespace Cellm.Tools.Glob;

internal class GlobRequestHandler : IRequestHandler<GlobRequest, GlobResponse>
{
public Task<GlobResponse> Handle(GlobRequest request, CancellationToken cancellationToken)
{
var matcher = new Matcher();
matcher.AddIncludePatterns(request.IncludePatterns);
matcher.AddExcludePatterns(request.ExcludePatterns ?? new List<string>());
var fileNames = matcher.GetResultsInFullPath(request.RootPath);

return Task.FromResult(new GlobResponse(fileNames.ToList()));
}
}
6 changes: 6 additions & 0 deletions src/Cellm/Tools/Glob/GlobResponse.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
using System.ComponentModel;

namespace Cellm.Tools.Glob;

internal record GlobResponse(
[Description("List of file paths matching the glob patterns")] List<string> FilePaths);
5 changes: 5 additions & 0 deletions src/Cellm/Tools/ITool.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
namespace Cellm.Tools;

public interface ITool
{
}
10 changes: 0 additions & 10 deletions src/Cellm/Tools/ITools.cs

This file was deleted.

92 changes: 92 additions & 0 deletions src/Cellm/Tools/ToolFactory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
using System.Reflection;
using Cellm.AddIn.Exceptions;
using Cellm.Prompts;
using Json.More;
using Json.Schema;
using Json.Schema.Generation;

namespace Cellm.Tools;

public class ToolFactory
{
public static Tool CreateTool(Type type)
{
var description = GetDescriptionForType(type);

var parameterSchemaBuilder = new JsonSchemaBuilder()
.FromType(type)
.Properties(GetPropertiesForType(type))
.Required(GetRequiredForType(type))
.AdditionalProperties(false);

var parameterSchema = parameterSchemaBuilder.Build();
var parameters = parameterSchema.ToJsonDocument();

return new Tool(
type.Name,
description,
parameters
);
}
private static IReadOnlyDictionary<string, JsonSchema> GetPropertiesForType(Type type)
{
return type
.GetProperties()
.ToDictionary(
property => property.Name,
property => new JsonSchemaBuilder()
.FromType(property.PropertyType)
.Description(GetPropertyDescriptionsForType(type, property.Name))
.Build()
);
}

private static string GetPropertyDescriptionsForType(Type type, string propertyName)
{
return type
.GetConstructors()
.First() // Records have a single constructor
.GetParameters()
.Where(property => property.Name == propertyName)
.Select(property => property.GetCustomAttribute<System.ComponentModel.DescriptionAttribute>()?.Description)
.FirstOrDefault() ?? throw new CellmException($"Cannot get description of {type.Name} property {propertyName}.");
}

public static List<string> GetRequiredForType(Type type)
{
return type
.GetConstructors()
.First() // Records have a single constructor
.GetParameters()
.Where(p => !IsNullableType(p.ParameterType))
.Select(p => p.Name!)
.ToList();
}

private static bool IsNullableType(Type type)
{
// Check if it's a nullable reference type (marked with ?)
if (type.IsValueType == false)
{
var nullabilityInfo = type.GetCustomAttribute<NullableAttribute>();
if (nullabilityInfo != null)
{
return true;
}
}

// Check if it's Nullable<T>
if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>))
{
return true;
}

return false;
}

private static string GetDescriptionForType(Type type)
{
var descriptionAttribute = type.GetCustomAttribute<DescriptionAttribute>();
return descriptionAttribute?.Description ?? type.Name;
}
}
44 changes: 44 additions & 0 deletions src/Cellm/Tools/ToolRunner.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
using Cellm.Models;
using Cellm.Prompts;
using Cellm.Tools.Glob;
using MediatR;

namespace Cellm.Tools;

internal class ToolRunner
{
private readonly ISender _sender;
private readonly ISerde _serde;
private readonly ToolFactory _toolFactory;
private readonly IEnumerable<Type> _toolTypes;

public ToolRunner(ISender sender, ISerde serde, ToolFactory toolFactory)
{
_sender = sender;
_serde = serde;
_toolFactory = toolFactory;
_toolTypes = new List<Type>() { typeof(GlobRequest) };
}

public List<Tool> GetTools()
{
return _toolTypes.Select(ToolFactory.CreateTool).ToList();
}

public async Task<string> Run(ToolCall toolCall)
{
return toolCall.Name switch
{
nameof(GlobRequest) => await Run<GlobRequest>(toolCall.Arguments),
_ => throw new ArgumentException($"Unsupported tool: {toolCall.Name}")
};
}

private async Task<string> Run<T>(string arguments)
where T : notnull
{
var request = _serde.Deserialize<T>(arguments);
var response = await _sender.Send(request);
return _serde.Serialize(response);
}
}
Loading