Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 7 additions & 20 deletions src/Cellm/AddIn/Functions.cs → src/Cellm/AddIn/CellmFunctions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

namespace Cellm.AddIn;

public static class Functions
public static class CellmFunctions
{
/// <summary>
/// Sends a prompt to the default model configured in CellmConfiguration.
Expand Down Expand Up @@ -73,7 +73,7 @@ public static object PromptWith(
{
try
{
var arguments = ServiceLocator.Get<PromptWithArgumentParser>()
var arguments = ServiceLocator.Get<PromptArgumentParser>()
.AddProvider(providerAndModel)
.AddModel(providerAndModel)
.AddInstructionsOrContext(instructionsOrContext)
Expand All @@ -88,8 +88,8 @@ public static object PromptWith(

var prompt = new PromptBuilder()
.SetModel(arguments.Model)
.SetSystemMessage(SystemMessages.SystemMessage)
.SetTemperature(arguments.Temperature)
.AddSystemMessage(SystemMessages.SystemMessage)
.AddUserMessage(userMessage)
.Build();

Expand All @@ -102,6 +102,7 @@ public static object PromptWith(
catch (CellmException ex)
{
SentrySdk.CaptureException(ex);
Debug.WriteLine(ex);
return ex.Message;
}
}
Expand All @@ -117,22 +118,8 @@ public static object PromptWith(

private static async Task<string> CallModelAsync(Prompt prompt, string? provider = null, Uri? baseAddress = null)
{
try
{
var client = ServiceLocator.Get<Client>();
var response = await client.Send(prompt, provider, baseAddress);
var content = response.Messages.Last().Content;
return content;
}
catch (CellmException ex)
{
Debug.WriteLine(ex);
throw;
}
catch (Exception ex)
{
Debug.WriteLine(ex);
throw new CellmException("An unexpected error occurred", ex);
}
var client = ServiceLocator.Get<Client>();
var response = await client.Send(prompt, provider, baseAddress);
return response.Messages.Last().Content;
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Text;
using Cellm.AddIn.Exceptions;
using Cellm.Prompts;
using ExcelDna.Integration;
using Microsoft.Extensions.Configuration;
using Microsoft.Office.Interop.Excel;
Expand All @@ -8,7 +9,7 @@ namespace Cellm.AddIn;

public record Arguments(string Provider, string Model, string Context, string Instructions, double Temperature);

public class PromptWithArgumentParser
public class PromptArgumentParser
{
private string? _provider;
private string? _model;
Expand All @@ -18,12 +19,12 @@ public class PromptWithArgumentParser

private readonly IConfiguration _configuration;

public PromptWithArgumentParser(IConfiguration configuration)
public PromptArgumentParser(IConfiguration configuration)
{
_configuration = configuration;
}

public PromptWithArgumentParser AddProvider(object providerAndModel)
public PromptArgumentParser AddProvider(object providerAndModel)
{
_provider = providerAndModel switch
{
Expand All @@ -35,7 +36,7 @@ public PromptWithArgumentParser AddProvider(object providerAndModel)
return this;
}

public PromptWithArgumentParser AddModel(object providerAndModel)
public PromptArgumentParser AddModel(object providerAndModel)
{
_model = providerAndModel switch
{
Expand All @@ -47,21 +48,21 @@ public PromptWithArgumentParser AddModel(object providerAndModel)
return this;
}

public PromptWithArgumentParser AddInstructionsOrContext(object instructionsOrContext)
public PromptArgumentParser AddInstructionsOrContext(object instructionsOrContext)
{
_instructionsOrContext = instructionsOrContext;

return this;
}

public PromptWithArgumentParser AddInstructionsOrTemperature(object instructionsOrTemperature)
public PromptArgumentParser AddInstructionsOrTemperature(object instructionsOrTemperature)
{
_instructionsOrTemperature = instructionsOrTemperature;

return this;
}

public PromptWithArgumentParser AddTemperature(object temperature)
public PromptArgumentParser AddTemperature(object temperature)
{
_temperature = temperature;

Expand Down Expand Up @@ -92,25 +93,25 @@ public Arguments Parse()
// "=PROMPT("Extract keywords", 0.7)
(string instructions, double temperature, ExcelMissing) => new Arguments(provider, model, string.Empty, RenderInstructions(instructions), ParseTemperature(temperature)),
// "=PROMPT(A1:B2)
(ExcelReference context, ExcelMissing, ExcelMissing) => new Arguments(provider, model, RenderContext(ParseCells(context)), RenderInstructions(SystemMessages.InlineInstructions), ParseTemperature(defaultTemperature)),
(ExcelReference context, ExcelMissing, ExcelMissing) => new Arguments(provider, model, RenderCells(ParseCells(context)), RenderInstructions(SystemMessages.InlineInstructions), ParseTemperature(defaultTemperature)),
// "=PROMPT(A1:B2, 0.7)
(ExcelReference context, double temperature, ExcelMissing) => new Arguments(provider, model, RenderContext(ParseCells(context)), RenderInstructions(SystemMessages.InlineInstructions), ParseTemperature(defaultTemperature)),
(ExcelReference context, double temperature, ExcelMissing) => new Arguments(provider, model, RenderCells(ParseCells(context)), RenderInstructions(SystemMessages.InlineInstructions), ParseTemperature(defaultTemperature)),
// "=PROMPT(A1:B2, "Extract keywords")
(ExcelReference context, string instructions, ExcelMissing) => new Arguments(provider, model, RenderContext(ParseCells(context)), RenderInstructions(instructions), ParseTemperature(defaultTemperature)),
(ExcelReference context, string instructions, ExcelMissing) => new Arguments(provider, model, RenderCells(ParseCells(context)), RenderInstructions(instructions), ParseTemperature(defaultTemperature)),
// "=PROMPT(A1:B2, "Extract keywords", 0.7)
(ExcelReference context, string instructions, double temperature) => new Arguments(provider, model, RenderContext(ParseCells(context)), RenderInstructions(instructions), ParseTemperature(temperature)),
(ExcelReference context, string instructions, double temperature) => new Arguments(provider, model, RenderCells(ParseCells(context)), RenderInstructions(instructions), ParseTemperature(temperature)),
// "=PROMPT(A1:B2, C1:D2)
(ExcelReference context, ExcelReference instructions, ExcelMissing) => new Arguments(provider, model, RenderContext(ParseCells(context)), RenderInstructions(ParseCells(instructions)), ParseTemperature(defaultTemperature)),
(ExcelReference context, ExcelReference instructions, ExcelMissing) => new Arguments(provider, model, RenderCells(ParseCells(context)), RenderInstructions(ParseCells(instructions)), ParseTemperature(defaultTemperature)),
// "=PROMPT(A1:B2, C1:D2, 0.7)
(ExcelReference context, ExcelReference instructions, double temperature) => new Arguments(provider, model, RenderContext(ParseCells(context)), RenderInstructions(ParseCells(instructions)), ParseTemperature(temperature)),
(ExcelReference context, ExcelReference instructions, double temperature) => new Arguments(provider, model, RenderCells(ParseCells(context)), RenderInstructions(ParseCells(instructions)), ParseTemperature(temperature)),
// Anything else
_ => throw new ArgumentException($"Invalid arguments ({_instructionsOrContext?.GetType().Name}, {_instructionsOrTemperature?.GetType().Name}, {_temperature?.GetType().Name})")
};
}

private static string GetProvider(string providerAndModel)
{
var index = providerAndModel.IndexOf("/");
var index = providerAndModel.IndexOf('/');

if (index < 0)
{
Expand All @@ -122,7 +123,7 @@ private static string GetProvider(string providerAndModel)

private static string GetModel(string providerAndModel)
{
var index = providerAndModel.IndexOf("/");
var index = providerAndModel.IndexOf('/');

if (index < 0)
{
Expand Down Expand Up @@ -203,7 +204,7 @@ private static string GetRowName(int rowNumber)
return (rowNumber + 1).ToString();
}

private static string RenderContext(string context)
private static string RenderCells(string context)
{
return new StringBuilder()
.AppendLine("<context>")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
namespace Cellm.AddIn;
namespace Cellm.Prompts;

internal static class SystemMessages
{
Expand Down
2 changes: 1 addition & 1 deletion src/Cellm/Services/ServiceLocator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ private static IServiceCollection ConfigureServices(IServiceCollection services)
.AddSingleton(configuration)
.AddMemoryCache()
.AddMediatR(cfg => cfg.RegisterServicesFromAssembly(Assembly.GetExecutingAssembly()))
.AddTransient<PromptWithArgumentParser>()
.AddTransient<PromptArgumentParser>()
.AddSingleton<Client>()
.AddSingleton<Serde>();

Expand Down