Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/Cellm.Models/Prompts/PromptBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ namespace Cellm.Models.Prompts;

public class PromptBuilder
{
private List<ChatMessage> _messages = new();
private ChatOptions _options = new();
private readonly List<ChatMessage> _messages = [];
private readonly ChatOptions _options = new();

public PromptBuilder()
{
Expand All @@ -30,6 +30,12 @@ public PromptBuilder SetTemperature(double temperature)
return this;
}

public PromptBuilder SetMaxOutputTokens(int maxOutputTokens)
{
_options.MaxOutputTokens = maxOutputTokens;
return this;
}

public PromptBuilder AddSystemMessage(string content)
{
_messages.Add(new ChatMessage(ChatRole.System, content));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
using Cellm.Models.Prompts;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Options;

namespace Cellm.Models.Providers.Anthropic;

internal class AnthropicRequestHandler(
[FromKeyedServices(Provider.Anthropic)] IChatClient chatClient,
IOptionsMonitor<ProviderConfiguration> providerConfiguration)
[FromKeyedServices(Provider.Anthropic)] IChatClient chatClient)
: IModelRequestHandler<AnthropicRequest, AnthropicResponse>
{
public async Task<AnthropicResponse> Handle(AnthropicRequest request, CancellationToken cancellationToken)
{
// Required by Anthropic API
request.Prompt.Options.MaxOutputTokens ??= providerConfiguration.CurrentValue.MaxOutputTokens;

var chatResponse = await chatClient.GetResponseAsync(
request.Prompt.Messages,
request.Prompt.Options,
Expand Down
54 changes: 25 additions & 29 deletions src/Cellm.Models/ServiceCollectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Options;
using Mistral.SDK;
using OpenAI;
using Polly;
using Polly.CircuitBreaker;
Expand Down Expand Up @@ -164,27 +165,6 @@ public static IServiceCollection AddCellmChatClient(this IServiceCollection serv
return services;
}

public static IServiceCollection AddOllamaChatClient(this IServiceCollection services)
{
services
.AddKeyedChatClient(Provider.Ollama, serviceProvider =>
{
var account = ServiceLocator.ServiceProvider.GetRequiredService<Account>();
account.RequireEntitlement(Entitlement.EnableOllamaProvider);

var ollamaConfiguration = serviceProvider.GetRequiredService<IOptionsMonitor<OllamaConfiguration>>();
var resilientHttpClient = serviceProvider.GetKeyedService<HttpClient>("ResilientHttpClient") ?? throw new NullReferenceException("ResilientHttpClient");

return new OllamaChatClient(
ollamaConfiguration.CurrentValue.BaseAddress,
ollamaConfiguration.CurrentValue.DefaultModel,
resilientHttpClient);
}, ServiceLifetime.Transient)
.UseFunctionInvocation();

return services;
}

public static IServiceCollection AddDeepSeekChatClient(this IServiceCollection services)
{
services
Expand Down Expand Up @@ -222,15 +202,31 @@ public static IServiceCollection AddMistralChatClient(this IServiceCollection se
var mistralConfiguration = serviceProvider.GetRequiredService<IOptionsMonitor<MistralConfiguration>>();
var resilientHttpClient = serviceProvider.GetKeyedService<HttpClient>("ResilientHttpClient") ?? throw new NullReferenceException("ResilientHttpClient");

var openAiClient = new OpenAIClient(
new ApiKeyCredential(mistralConfiguration.CurrentValue.ApiKey),
new OpenAIClientOptions
{
Transport = new HttpClientPipelineTransport(resilientHttpClient),
Endpoint = mistralConfiguration.CurrentValue.BaseAddress
});
return new MistralClient(mistralConfiguration.CurrentValue.ApiKey, resilientHttpClient)
.Completions
.AsBuilder()
.Build();
}, ServiceLifetime.Transient)
.UseFunctionInvocation();

return services;
}

public static IServiceCollection AddOllamaChatClient(this IServiceCollection services)
{
services
.AddKeyedChatClient(Provider.Ollama, serviceProvider =>
{
var account = ServiceLocator.ServiceProvider.GetRequiredService<Account>();
account.RequireEntitlement(Entitlement.EnableOllamaProvider);

var ollamaConfiguration = serviceProvider.GetRequiredService<IOptionsMonitor<OllamaConfiguration>>();
var resilientHttpClient = serviceProvider.GetKeyedService<HttpClient>("ResilientHttpClient") ?? throw new NullReferenceException("ResilientHttpClient");

return openAiClient.GetChatClient(mistralConfiguration.CurrentValue.DefaultModel).AsIChatClient();
return new OllamaChatClient(
ollamaConfiguration.CurrentValue.BaseAddress,
ollamaConfiguration.CurrentValue.DefaultModel,
resilientHttpClient);
}, ServiceLifetime.Transient)
.UseFunctionInvocation();

Expand Down
4 changes: 1 addition & 3 deletions src/Cellm/AddIn/ArgumentParser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

namespace Cellm.AddIn;

public record Arguments(Provider Provider, string Model, string Context, string Instructions, double Temperature);

public class ArgumentParser
{
private string? _provider;
Expand Down Expand Up @@ -231,7 +229,7 @@ private static string RenderInstructions(string instructions)
.ToString();
}

private double ParseTemperature(double temperature)
private static double ParseTemperature(double temperature)
{
if (temperature < 0 || temperature > 1)
{
Expand Down
5 changes: 5 additions & 0 deletions src/Cellm/AddIn/Arguments.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
using Cellm.Models.Providers;

namespace Cellm.AddIn;

public record Arguments(Provider Provider, string Model, string Context, string Instructions, double Temperature);
6 changes: 3 additions & 3 deletions src/Cellm/AddIn/ExcelAddin.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ public void AutoOpen()
{
ExcelIntegration.RegisterUnhandledExceptionHandler(obj =>
{
var ex = (Exception)obj;
SentrySdk.CaptureException(ex);
return ex.Message;
var e = (Exception)obj;
SentrySdk.CaptureException(e);
return e.Message;
});

_ = ServiceLocator.ServiceProvider;
Expand Down
15 changes: 10 additions & 5 deletions src/Cellm/AddIn/ExcelFunctions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using ExcelDna.Integration;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Options;

namespace Cellm.AddIn;

Expand Down Expand Up @@ -74,7 +75,10 @@ public static object PromptWith(
{
try
{
var arguments = ServiceLocator.ServiceProvider.GetRequiredService<ArgumentParser>()
var argumentParser = ServiceLocator.ServiceProvider.GetRequiredService<ArgumentParser>();
var providerConfiguration = ServiceLocator.ServiceProvider.GetRequiredService<IOptionsMonitor<ProviderConfiguration>>();

var arguments = argumentParser
.AddProvider(providerAndModel)
.AddModel(providerAndModel)
.AddInstructionsOrContext(instructionsOrContext)
Expand All @@ -90,6 +94,7 @@ public static object PromptWith(
var prompt = new PromptBuilder()
.SetModel(arguments.Model)
.SetTemperature(arguments.Temperature)
.SetMaxOutputTokens(providerConfiguration.CurrentValue.MaxOutputTokens)
.AddSystemMessage(SystemMessages.SystemMessage)
.AddUserMessage(userMessage)
.Build();
Expand All @@ -101,11 +106,11 @@ public static object PromptWith(
});

}
catch (CellmException ex)
catch (CellmException e)
{
SentrySdk.CaptureException(ex);
Debug.WriteLine(ex);
return ex.Message;
SentrySdk.CaptureException(e);
Debug.WriteLine(e);
return e.Message;
}
}

Expand Down
1 change: 1 addition & 0 deletions src/Cellm/Cellm.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
<PackageReference Include="Microsoft.Extensions.Logging.Debug" Version="9.0.4" />
<PackageReference Include="Microsoft.Extensions.Options" Version="9.0.4" />
<PackageReference Include="Microsoft.Extensions.Options.ConfigurationExtensions" Version="9.0.4" />
<PackageReference Include="Mistral.SDK" Version="2.1.1" />
<PackageReference Include="ModelContextProtocol" Version="0.1.0-preview.7" />
<PackageReference Include="PdfPig" Version="0.1.10" />
<PackageReference Include="Sentry.Extensions.Logging" Version="5.5.1" />
Expand Down
10 changes: 10 additions & 0 deletions src/Cellm/packages.lock.json
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,16 @@
"Microsoft.Extensions.Primitives": "9.0.4"
}
},
"Mistral.SDK": {
"type": "Direct",
"requested": "[2.1.1, )",
"resolved": "2.1.1",
"contentHash": "dBTLqmtfj7C62meCEB9l7VKDtRDDFQgYbx8a5+8uTLtU9bUmw9xYMmyTixHcfOGAQ8LlrXOWNOS6dEFMEgFHhQ==",
"dependencies": {
"Microsoft.Bcl.AsyncInterfaces": "8.0.0",
"Microsoft.Extensions.AI.Abstractions": "9.3.0-preview.1.25161.3"
}
},
"ModelContextProtocol": {
"type": "Direct",
"requested": "[0.1.0-preview.7, )",
Expand Down