Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Cellm/AddIn/CellmFunctions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ public static object PromptWith(
private static async Task<string> CallModelAsync(Prompt prompt, string? provider = null, Uri? baseAddress = null)
{
var client = ServiceLocator.Get<Client>();
var response = await client.Send(prompt, provider, baseAddress);
var response = await client.Send(prompt, provider, baseAddress, CancellationToken.None);
return response.Messages.Last().Text ?? throw new NullReferenceException("No text response");
}
}
23 changes: 9 additions & 14 deletions src/Cellm/Models/Client.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using Cellm.Models.Llamafile;
using Cellm.Models.Ollama;
using Cellm.Models.OpenAi;
using Cellm.Models.OpenAiCompatible;
using Cellm.Prompts;
using Cellm.Services.Configuration;
using MediatR;
Expand All @@ -12,18 +13,11 @@

namespace Cellm.Models;

internal class Client
internal class Client(ISender sender, IOptions<CellmConfiguration> cellmConfiguration)
{
private readonly CellmConfiguration _cellmConfiguration;
private readonly ISender _sender;
private readonly CellmConfiguration _cellmConfiguration = cellmConfiguration.Value;

public Client(IOptions<CellmConfiguration> cellmConfiguration, ISender sender)
{
_cellmConfiguration = cellmConfiguration.Value;
_sender = sender;
}

public async Task<Prompt> Send(Prompt prompt, string? provider, Uri? baseAddress)
public async Task<Prompt> Send(Prompt prompt, string? provider, Uri? baseAddress, CancellationToken cancellationToken)
{
try
{
Expand All @@ -36,10 +30,11 @@ public async Task<Prompt> Send(Prompt prompt, string? provider, Uri? baseAddress

IModelResponse response = parsedProvider switch
{
Providers.Anthropic => await _sender.Send(new AnthropicRequest(prompt, provider, baseAddress)),
Providers.Llamafile => await _sender.Send(new LlamafileRequest(prompt)),
Providers.Ollama => await _sender.Send(new OllamaRequest(prompt)),
Providers.OpenAi => await _sender.Send(new OpenAiRequest(prompt)),
Providers.Anthropic => await sender.Send(new AnthropicRequest(prompt, provider, baseAddress), cancellationToken),
Providers.Llamafile => await sender.Send(new LlamafileRequest(prompt), cancellationToken),
Providers.Ollama => await sender.Send(new OllamaRequest(prompt), cancellationToken),
Providers.OpenAi => await sender.Send(new OpenAiRequest(prompt), cancellationToken),
Providers.OpenAiCompatible => await sender.Send(new OpenAiCompatibleRequest(prompt, baseAddress), cancellationToken),
_ => throw new InvalidOperationException($"Provider {parsedProvider} is defined but not implemented")
};

Expand Down
2 changes: 1 addition & 1 deletion src/Cellm/Models/Llamafile/LlamafileRequestHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public async Task<LlamafileResponse> Handle(LlamafileRequest request, Cancellati
// Start server on first call
var llamafile = await _llamafiles[request.Prompt.Options.ModelId ?? _llamafileConfiguration.DefaultModel];

var openAiResponse = await _sender.Send(new OpenAiCompatibleRequest(request.Prompt, nameof(Llamafile), llamafile.BaseAddress), cancellationToken);
var openAiResponse = await _sender.Send(new OpenAiCompatibleRequest(request.Prompt, llamafile.BaseAddress), cancellationToken);

return new LlamafileResponse(openAiResponse.Prompt);
}
Expand Down
10 changes: 10 additions & 0 deletions src/Cellm/Models/OpenAiCompatible/OpenAiCompatibleConfiguration.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
namespace Cellm.Models.OpenAiCompatible;

internal class OpenAiCompatibleConfiguration
{
public Uri BaseAddress { get; set; } = default!;

public string DefaultModel { get; init; } = string.Empty;

public string ApiKey { get; init; } = string.Empty;
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

namespace Cellm.Models.OpenAiCompatible;

internal record OpenAiCompatibleRequest(Prompt Prompt, string Provider, Uri BaseAddress) : IModelRequest<OpenAiCompatibleResponse>;
internal record OpenAiCompatibleRequest(Prompt Prompt, Uri? BaseAddress) : IModelRequest<OpenAiCompatibleResponse>;
Original file line number Diff line number Diff line change
@@ -1,10 +1,43 @@

using System.ClientModel;
using System.ClientModel.Primitives;
using Cellm.Prompts;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Options;
using OpenAI;

namespace Cellm.Models.OpenAiCompatible;

internal class OpenAiCompatibleRequestHandler : IModelRequestHandler<OpenAiCompatibleRequest, OpenAiCompatibleResponse>
internal class OpenAiCompatibleRequestHandler(HttpClient httpClient, IOptions<OpenAiCompatibleConfiguration> openAiCompatibleConfiguration)
: IModelRequestHandler<OpenAiCompatibleRequest, OpenAiCompatibleResponse>
{
public Task<OpenAiCompatibleResponse> Handle(OpenAiCompatibleRequest request, CancellationToken cancellationToken)
private readonly OpenAiCompatibleConfiguration _openAiCompatibleConfiguration = openAiCompatibleConfiguration.Value;

public async Task<OpenAiCompatibleResponse> Handle(OpenAiCompatibleRequest request, CancellationToken cancellationToken)
{
throw new NotImplementedException();
var chatClient = CreateChatClient(request.BaseAddress);

var chatCompletion = await chatClient.CompleteAsync(request.Prompt.Messages, request.Prompt.Options, cancellationToken);

var prompt = new PromptBuilder(request.Prompt)
.AddMessage(chatCompletion.Message)
.Build();

return new OpenAiCompatibleResponse(prompt);
}

private IChatClient CreateChatClient(Uri? baseAddress)
{
var openAiClient = new OpenAIClient(
new ApiKeyCredential(_openAiCompatibleConfiguration.ApiKey),
new OpenAIClientOptions
{
Transport = new HttpClientPipelineTransport(httpClient),
Endpoint = baseAddress ?? _openAiCompatibleConfiguration.BaseAddress
});

return new ChatClientBuilder(openAiClient.AsChatClient(_openAiCompatibleConfiguration.DefaultModel))
.UseFunctionInvocation()
.Build();
}
}
32 changes: 6 additions & 26 deletions src/Cellm/Models/OpenAiCompatible/SerrviceCollectionExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,43 +1,23 @@
using Cellm.Services.Configuration;
using Cellm.Models.Anthropic;
using Cellm.Services.Configuration;
using MediatR;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;

namespace Cellm.Models.OpenAiCompatible;

internal static class ServiceCollectionExtensions
{
public static IServiceCollection AddOpenAiCompatibleChatClient(this IServiceCollection services, string provider, IConfiguration configuration)
public static IServiceCollection AddOpenAiCompatibleChatClient(this IServiceCollection services, IConfiguration configuration)
{
var resiliencePipelineConfigurator = new ResiliencePipelineConfigurator(configuration);

services
.AddHttpClient(provider, openAiCompatibleHttpClient =>
.AddHttpClient<IRequestHandler<OpenAiCompatibleRequest, OpenAiCompatibleResponse>, OpenAiCompatibleRequestHandler>(openAiCompatibleHttpClient =>
{
openAiCompatibleHttpClient.Timeout = TimeSpan.FromHours(1);
})
.AddResilienceHandler($"{nameof(OpenAiCompatibleRequestHandler)}ResiliencePipeline", resiliencePipelineConfigurator.ConfigureResiliencePipeline);

// This is probably not needed, because we would send a OpenAiCompatibleRequestHandler(Prompt prompt, Uri BaseAddress) and instantiate a client on each call
//var openAiCompatibleConfiguration = configuration.GetRequiredSection($"{provider}Configuration").Get<OpenAiCompatibleConfiguration>()
// ?? throw new NullReferenceException(nameof(provider));

//services
// .AddKeyedChatClient(Providers.OpenAiCompatible, serviceProvider =>
// {
// var openAiCompatibleHttpClient = serviceProvider
// .GetRequiredService<IHttpClientFactory>()
// .CreateClient(provider);

// var openAiClient = new OpenAIClient(
// new ApiKeyCredential(openAiCompatibleConfiguration.ApiKey),
// new OpenAIClientOptions
// {
// Transport = new HttpClientPipelineTransport(openAiCompatibleHttpClient),
// });

// return openAiClient.AsChatClient(openAiCompatibleConfiguration.DefaultModel);
// })
// .UseFunctionInvocation();
.AddResilienceHandler(nameof(OpenAiCompatibleRequestHandler), resiliencePipelineConfigurator.ConfigureResiliencePipeline);

return services;
}
Expand Down
7 changes: 5 additions & 2 deletions src/Cellm/Services/ServiceLocator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using Cellm.Models.ModelRequestBehavior;
using Cellm.Models.Ollama;
using Cellm.Models.OpenAi;
using Cellm.Models.OpenAiCompatible;
using Cellm.Services.Configuration;
using Cellm.Tools;
using Cellm.Tools.FileReader;
Expand Down Expand Up @@ -50,6 +51,7 @@ private static IServiceCollection ConfigureServices(IServiceCollection services)
.Configure<AnthropicConfiguration>(configuration.GetRequiredSection(nameof(AnthropicConfiguration)))
.Configure<OllamaConfiguration>(configuration.GetRequiredSection(nameof(OllamaConfiguration)))
.Configure<OpenAiConfiguration>(configuration.GetRequiredSection(nameof(OpenAiConfiguration)))
.Configure<OpenAiCompatibleConfiguration>(configuration.GetRequiredSection(nameof(OpenAiCompatibleConfiguration)))
.Configure<LlamafileConfiguration>(configuration.GetRequiredSection(nameof(LlamafileConfiguration)))
.Configure<RateLimiterConfiguration>(configuration.GetRequiredSection(nameof(RateLimiterConfiguration)))
.Configure<CircuitBreakerConfiguration>(configuration.GetRequiredSection(nameof(CircuitBreakerConfiguration)))
Expand Down Expand Up @@ -124,8 +126,9 @@ private static IServiceCollection ConfigureServices(IServiceCollection services)
.AddResilienceHandler($"{nameof(AnthropicRequestHandler)}ResiliencePipeline", resiliencePipelineConfigurator.ConfigureResiliencePipeline);

services
.AddOpenOllamaChatClient(configuration)
.AddOpenAiChatClient(configuration);
.AddOpenAiChatClient(configuration)
.AddOpenAiCompatibleChatClient(configuration)
.AddOpenOllamaChatClient(configuration);

// Model request pipeline
services
Expand Down
2 changes: 1 addition & 1 deletion src/Cellm/appsettings.Local.Google.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"OpenAiCompatibleConfiguration": {
"BaseAddress": "https://generativelanguage.googleapis.com/v1beta/openai/v1",
"BaseAddress": "https://generativelanguage.googleapis.com/v1beta/openai",
"DefaultModel": "gemini-1.5-flash",
"ApiKey": "YOUR_GEMINI_API_KEY"

Expand Down
2 changes: 1 addition & 1 deletion src/Cellm/appsettings.Local.Mistral.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"OpenAiCompatibleConfiguration": {
"BaseAddress": "https://api.mistral.ai",
"BaseAddress": "https://api.mistral.ai/v1",
"DefaultModel": "mistral-small-latest",
"ApiKey": "YOUR_MISTRAL_API_KEY"
},
Expand Down
2 changes: 1 addition & 1 deletion src/Cellm/appsettings.Local.vLLM.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"OpenAiCompatibleConfiguration": {
"BaseAddress": "http://localhost:8000"
"BaseAddress": "http://localhost:8000/v1"
},
"CellmConfiguration": {
"DefaultProvider": "OpenAiCompatible"
Expand Down