Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 34 additions & 10 deletions src/Cellm/AddIn/ArgumentParser.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
using System.Text;
using Cellm.AddIn.Exceptions;
using Cellm.AddIn.Prompts;
using Cellm.Services.Configuration;
using ExcelDna.Integration;
using Microsoft.Extensions.Options;
using Microsoft.Extensions.Configuration;
using Microsoft.Office.Interop.Excel;

namespace Cellm.AddIn;
Expand All @@ -11,17 +12,17 @@ public record Arguments(string Provider, string Model, string Context, string In

public class ArgumentParser
{
private string _provider;
private string _model;
private string? _provider;
private string? _model;
private string? _context;
private string? _instructions;
private double _temperature;
private double? _temperature;

public ArgumentParser(IOptions<CellmConfiguration> cellmConfiguration)
private readonly IConfiguration _configuration;

public ArgumentParser(IConfiguration configuration)
{
_provider = cellmConfiguration.Value.DefaultModelProvider;
_model = cellmConfiguration.Value.DefaultModel;
_temperature = cellmConfiguration.Value.DefaultTemperature;
_configuration = configuration;
}

public ArgumentParser AddProvider(object providerAndModel)
Expand Down Expand Up @@ -138,7 +139,23 @@ public ArgumentParser AddTemperature(object temperature)

public Arguments Parse()
{
if (_context == null)
var provider = _configuration.GetSection(nameof(CellmConfiguration)).GetValue<string>(nameof(CellmConfiguration.DefaultProvider))
?? throw new ArgumentException(nameof(CellmConfiguration.DefaultProvider));

if (!string.IsNullOrEmpty(_provider))
{
provider = _provider;
}

var model = _configuration.GetSection($"{provider}Configuration").GetValue<string>(nameof(IProviderConfiguration.DefaultModel))
?? throw new ArgumentException(nameof(IProviderConfiguration.DefaultModel));

if (!string.IsNullOrEmpty(_model))
{
model = _model;
}

if (_context is null)
{
throw new InvalidOperationException("Context argument is required");
}
Expand All @@ -164,7 +181,14 @@ public Arguments Parse()

instructionsBuilder.AppendLine("</instructions>");

return new Arguments(_provider, _model, contextBuilder.ToString(), instructionsBuilder.ToString(), _temperature);
var temperature = _configuration.GetSection(nameof(CellmConfiguration)).GetValue<double>(nameof(CellmConfiguration.DefaultTemperature));

if (_temperature is not null)
{
temperature = Convert.ToDouble(_temperature);
}

return new Arguments(provider, model, contextBuilder.ToString(), instructionsBuilder.ToString(), temperature);
}

private static string GetProvider(string providerAndModel)
Expand Down
2 changes: 1 addition & 1 deletion src/Cellm/AddIn/CellmConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ public class CellmConfiguration
{
public bool Debug { get; init; }

public string DefaultModelProvider { get; init; } = string.Empty;
public string DefaultProvider { get; init; } = string.Empty;

public string DefaultModel { get; init; } = string.Empty;

Expand Down
8 changes: 4 additions & 4 deletions src/Cellm/AddIn/CellmFunctions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public static object Prompt(
var cellmConfiguration = ServiceLocator.Get<IOptions<CellmConfiguration>>().Value;

return PromptWith(
$"{cellmConfiguration.DefaultModelProvider}/{cellmConfiguration.DefaultModel}",
$"{cellmConfiguration.DefaultProvider}/{cellmConfiguration.DefaultModel}",
context,
instructionsOrTemperature,
temperature);
Expand Down Expand Up @@ -87,7 +87,7 @@ public static object PromptWith(
// ExcelAsyncUtil yields Excel's main thread, Task.Run enables async/await in inner code
return ExcelAsyncUtil.Run(nameof(Prompt), new object[] { context, instructionsOrTemperature, temperature }, () =>
{
return Task.Run(async () => await CallModelAsync(prompt)).GetAwaiter().GetResult();
return Task.Run(async () => await CallModelAsync(prompt, arguments.Provider, arguments.Model)).GetAwaiter().GetResult();
});
}
catch (CellmException ex)
Expand All @@ -106,12 +106,12 @@ public static object PromptWith(
/// <returns>A task that represents the asynchronous operation. The task result contains the model's response as a string.</returns>
/// <exception cref="CellmException">Thrown when an unexpected error occurs during the operation.</exception>

private static async Task<string> CallModelAsync(Prompt prompt, string? provider = null, string? model = null)
private static async Task<string> CallModelAsync(Prompt prompt, string? provider = null, string? model = null, Uri? baseAddress = null)
{
try
{
var client = ServiceLocator.Get<IClient>();
var response = await client.Send(prompt, provider, model);
var response = await client.Send(prompt, provider, model, baseAddress);
return response.Messages.Last().Content;
}
catch (CellmException)
Expand Down
11 changes: 7 additions & 4 deletions src/Cellm/Models/Anthropic/AnthropicClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public AnthropicClient(
_serde = serde;
}

public async Task<Prompt> Send(Prompt prompt, string? provider, string? model)
public async Task<Prompt> Send(Prompt prompt, string? provider, string? model, Uri? baseAddress)
{
var transaction = SentrySdk.StartTransaction(typeof(AnthropicClient).Name, nameof(Send));
SentrySdk.ConfigureScope(scope => scope.Transaction = transaction);
Expand All @@ -51,7 +51,10 @@ public async Task<Prompt> Send(Prompt prompt, string? provider, string? model)
var json = _serde.Serialize(requestBody);
var jsonAsString = new StringContent(json, Encoding.UTF8, "application/json");

var response = await _httpClient.PostAsync("/v1/messages", jsonAsString);
const string path = "/v1/messages";
var address = baseAddress is null ? new Uri(path, UriKind.Relative) : new Uri(baseAddress, path);

var response = await _httpClient.PostAsync(address, jsonAsString);
var responseBodyAsString = await response.Content.ReadAsStringAsync();

if (!response.IsSuccessStatusCode)
Expand Down Expand Up @@ -80,7 +83,7 @@ public async Task<Prompt> Send(Prompt prompt, string? provider, string? model)
inputTokens,
unit: MeasurementUnit.Custom("token"),
tags: new Dictionary<string, string> {
{ nameof(provider), provider?.ToLower() ?? _cellmConfiguration.DefaultModelProvider },
{ nameof(provider), provider?.ToLower() ?? _cellmConfiguration.DefaultProvider },
{ nameof(model), model?.ToLower() ?? _anthropicConfiguration.DefaultModel },
{ nameof(_httpClient.BaseAddress), _httpClient.BaseAddress?.ToString() ?? string.Empty }
}
Expand All @@ -94,7 +97,7 @@ public async Task<Prompt> Send(Prompt prompt, string? provider, string? model)
outputTokens,
unit: MeasurementUnit.Custom("token"),
tags: new Dictionary<string, string> {
{ nameof(provider), provider?.ToLower() ?? _cellmConfiguration.DefaultModelProvider },
{ nameof(provider), provider?.ToLower() ?? _cellmConfiguration.DefaultProvider },
{ nameof(model), model?.ToLower() ?? _anthropicConfiguration.DefaultModel },
{ nameof(_httpClient.BaseAddress), _httpClient.BaseAddress?.ToString() ?? string.Empty }
}
Expand Down
6 changes: 4 additions & 2 deletions src/Cellm/Models/Anthropic/AnthropicConfiguration.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
namespace Cellm.Models.Anthropic;
using Cellm.Services.Configuration;

internal class AnthropicConfiguration
namespace Cellm.Models.Anthropic;

internal class AnthropicConfiguration : IProviderConfiguration
{
public Uri BaseAddress { get; init; }

Expand Down
6 changes: 3 additions & 3 deletions src/Cellm/Models/Client.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ public Client(IClientFactory clientFactory, IOptions<CellmConfiguration> cellmCo
_cellmConfiguration = cellmConfiguration.Value;
}

public async Task<Prompt> Send(Prompt prompt, string? provider, string? model)
public async Task<Prompt> Send(Prompt prompt, string? provider, string? model, Uri? baseAddress)
{
try
{
var client = _clientFactory.GetClient(provider ?? _cellmConfiguration.DefaultModelProvider);
return await client.Send(prompt, provider, model);
var client = _clientFactory.GetClient(provider ?? _cellmConfiguration.DefaultProvider);
return await client.Send(prompt, provider, model, baseAddress);
}
catch (HttpRequestException ex)
{
Expand Down
11 changes: 7 additions & 4 deletions src/Cellm/Models/GoogleAi/GoogleAiClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public GoogleAiClient(
_serde = serde;
}

public async Task<Prompt> Send(Prompt prompt, string? provider, string? model)
public async Task<Prompt> Send(Prompt prompt, string? provider, string? model, Uri? baseAddress)
{
var transaction = SentrySdk.StartTransaction(typeof(GoogleAiClient).Name, nameof(Send));
SentrySdk.ConfigureScope(scope => scope.Transaction = transaction);
Expand Down Expand Up @@ -57,7 +57,10 @@ public async Task<Prompt> Send(Prompt prompt, string? provider, string? model)
var json = _serde.Serialize(requestBody);
var jsonAsString = new StringContent(json, Encoding.UTF8, "application/json");

var response = await _httpClient.PostAsync($"/v1beta/models/{model ?? _googleAiConfiguration.DefaultModel}:generateContent?key={_googleAiConfiguration.ApiKey}", jsonAsString);
string path = $"/v1beta/models/{model ?? _googleAiConfiguration.DefaultModel}:generateContent?key={_googleAiConfiguration.ApiKey}";
var address = baseAddress is null ? new Uri(path, UriKind.Relative) : new Uri(baseAddress, path);

var response = await _httpClient.PostAsync(address, jsonAsString);
var responseBodyAsString = await response.Content.ReadAsStringAsync();

if (!response.IsSuccessStatusCode)
Expand Down Expand Up @@ -86,7 +89,7 @@ public async Task<Prompt> Send(Prompt prompt, string? provider, string? model)
inputTokens,
unit: MeasurementUnit.Custom("token"),
tags: new Dictionary<string, string> {
{ nameof(provider), provider?.ToLower() ?? _cellmConfiguration.DefaultModelProvider },
{ nameof(provider), provider?.ToLower() ?? _cellmConfiguration.DefaultProvider },
{ nameof(model), model?.ToLower() ?? _googleAiConfiguration.DefaultModel },
{ nameof(_httpClient.BaseAddress), _httpClient.BaseAddress?.ToString() ?? string.Empty }
}
Expand All @@ -100,7 +103,7 @@ public async Task<Prompt> Send(Prompt prompt, string? provider, string? model)
outputTokens,
unit: MeasurementUnit.Custom("token"),
tags: new Dictionary<string, string> {
{ nameof(provider), provider?.ToLower() ?? _cellmConfiguration.DefaultModelProvider },
{ nameof(provider), provider?.ToLower() ?? _cellmConfiguration.DefaultProvider },
{ nameof(model), model?.ToLower() ?? _googleAiConfiguration.DefaultModel },
{ nameof(_httpClient.BaseAddress), _httpClient.BaseAddress?.ToString() ?? string.Empty }
}
Expand Down
6 changes: 4 additions & 2 deletions src/Cellm/Models/GoogleAi/GoogleAiConfiguration.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
namespace Cellm.Models.GoogleAi;
using Cellm.Services.Configuration;

internal class GoogleAiConfiguration
namespace Cellm.Models.GoogleAi;

internal class GoogleAiConfiguration : IProviderConfiguration
{
public Uri BaseAddress { get; init; }

Expand Down
2 changes: 1 addition & 1 deletion src/Cellm/Models/IClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ namespace Cellm.Models;

internal interface IClient
{
public Task<Prompt> Send(Prompt prompt, string? provider, string? model);
public Task<Prompt> Send(Prompt prompt, string? provider, string? model, Uri? baseAddress);
}
2 changes: 1 addition & 1 deletion src/Cellm/Models/Llamafile/LLamafileProcessManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public LLamafileProcessManager()
Marshal.FreeHGlobal(extendedInfoPtr);
}

public void AssignProcessToCellm(Process process)
public void AssignProcessToExcel(Process process)
{
AssignProcessToJobObject(_jobObject, process.Handle);
}
Expand Down
Loading