Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ To get started, we recommend using Ollama with the Gemma 2 2B model:

1. Rename `appsettings.Ollama.json` to `appsettings.Local.json`,
2. Build and install Cellm.
3. Run the following command in the docker directory:
3. Run the following command in the `docker/` directory:
```cmd
docker compose -f docker-compose.Ollama.yml up --detach
docker compose -f docker-compose.Ollama.yml exec backend ollama pull gemma2:2b
Expand Down
2 changes: 2 additions & 0 deletions src/Cellm/Models/Client.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using Cellm.AddIn.Exceptions;
using Cellm.Models.Anthropic;
using Cellm.Models.Llamafile;
using Cellm.Models.Ollama;
using Cellm.Models.OpenAi;
using Cellm.Prompts;
using MediatR;
Expand Down Expand Up @@ -37,6 +38,7 @@ public async Task<Prompt> Send(Prompt prompt, string? provider, Uri? baseAddress
{
Providers.Anthropic => await _sender.Send(new AnthropicRequest(prompt, provider, baseAddress)),
Providers.Llamafile => await _sender.Send(new LlamafileRequest(prompt)),
Providers.Ollama => await _sender.Send(new OllamaRequest(prompt, provider, baseAddress)),
Providers.OpenAi => await _sender.Send(new OpenAiRequest(prompt, provider, baseAddress)),
_ => throw new InvalidOperationException($"Provider {parsedProvider} is defined but not implemented")
};
Expand Down
143 changes: 16 additions & 127 deletions src/Cellm/Models/Llamafile/LlamafileRequestHandler.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using System.Diagnostics;
using System.Net.NetworkInformation;
using Cellm.AddIn;
using Cellm.AddIn.Exceptions;
using Cellm.Models.Local;
using Cellm.Models.OpenAi;
using MediatR;
using Microsoft.Extensions.Options;
Expand All @@ -14,41 +14,42 @@ private record Llamafile(string ModelPath, Uri BaseAddress, Process Process);

private readonly AsyncLazy<string> _llamafileExePath;
private readonly Dictionary<string, AsyncLazy<Llamafile>> _llamafiles;
private readonly LLamafileProcessManager _llamafileProcessManager;
private readonly ProcessManager _processManager;

private readonly CellmConfiguration _cellmConfiguration;
private readonly LlamafileConfiguration _llamafileConfiguration;
private readonly OpenAiConfiguration _openAiConfiguration;

private readonly ISender _sender;
private readonly HttpClient _httpClient;
private readonly LocalUtilities _localUtilities;

public LlamafileRequestHandler(IOptions<CellmConfiguration> cellmConfiguration,
IOptions<LlamafileConfiguration> llamafileConfiguration,
IOptions<OpenAiConfiguration> openAiConfiguration,
ISender sender,
HttpClient httpClient,
LLamafileProcessManager llamafileProcessManager)
LocalUtilities localUtilities,
ProcessManager processManager)
{
_cellmConfiguration = cellmConfiguration.Value;
_llamafileConfiguration = llamafileConfiguration.Value;
_openAiConfiguration = openAiConfiguration.Value;
_sender = sender;
_httpClient = httpClient;
_llamafileProcessManager = llamafileProcessManager;
_localUtilities = localUtilities;
_processManager = processManager;

_llamafileExePath = new AsyncLazy<string>(async () =>
{
return await DownloadFile(_llamafileConfiguration.LlamafileUrl, $"{nameof(Llamafile)}.exe");
var llamafileName = Path.GetFileName(_llamafileConfiguration.LlamafileUrl.Segments.Last());
return await _localUtilities.DownloadFile(_llamafileConfiguration.LlamafileUrl, $"{llamafileName}.exe");
});

_llamafiles = _llamafileConfiguration.Models.ToDictionary(x => x.Key, x => new AsyncLazy<Llamafile>(async () =>
{
// Download model
var modelPath = await DownloadFile(x.Value, CreateFilePath(CreateModelFileName(x.Key)));
var modelPath = await _localUtilities.DownloadFile(x.Value, _localUtilities.CreateCellmFilePath(CreateModelFileName(x.Key)));

// Run Llamafile
var baseAddress = CreateBaseAddress();
// Start server
var baseAddress = new UriBuilder("http", "localhost", _localUtilities.FindPort()).Uri;
var process = await StartProcess(modelPath, baseAddress);

return new Llamafile(modelPath, baseAddress, process);
Expand Down Expand Up @@ -101,130 +102,18 @@ private async Task<Process> StartProcess(string modelPath, Uri baseAddress)
process.BeginErrorReadLine();
}

await WaitForLlamafile(baseAddress, process);
var address = new Uri(baseAddress, "health");
await _localUtilities.WaitForServer(address, process);

// Kill the process when Excel exits or dies
_llamafileProcessManager.AssignProcessToExcel(process);
// Kill Llamafile when Excel exits or dies
_processManager.AssignProcessToExcel(process);

return process;
}

private async Task<string> DownloadFile(Uri uri, string filePath)
{
if (File.Exists(filePath))
{
return filePath;
}

var filePathPart = $"{filePath}.part";

if (File.Exists(filePathPart))
{
File.Delete(filePathPart);
}

var response = await _httpClient.GetAsync(uri, HttpCompletionOption.ResponseHeadersRead);
response.EnsureSuccessStatusCode();

using (var fileStream = File.Create(filePathPart))
using (var httpStream = await response.Content.ReadAsStreamAsync())
{

await httpStream.CopyToAsync(fileStream);
}

File.Move(filePathPart, filePath);

return filePath;
}

private async Task WaitForLlamafile(Uri baseAddress, Process process)
{
var startTime = DateTime.UtcNow;

// Wait max 30 seconds to load model
while ((DateTime.UtcNow - startTime).TotalSeconds < 30)
{
if (process.HasExited)
{
throw new CellmException($"Failed to run Llamafile, process exited. Exit code: {process.ExitCode}");
}

try
{
var cancellationTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(1));
var response = await _httpClient.GetAsync(new Uri(baseAddress, "health"), cancellationTokenSource.Token);
if (response.StatusCode == System.Net.HttpStatusCode.OK)
{
// Server is ready
return;
}
}
catch (HttpRequestException)
{
}
catch (TaskCanceledException)
{
}

// Wait before next attempt
await Task.Delay(500);
}

process.Kill();

throw new CellmException("Failed to run Llamafile, timeout waiting for Llamafile server to start");
}

string CreateFilePath(string fileName)
{
var filePath = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), nameof(Cellm), fileName);
Directory.CreateDirectory(Path.GetDirectoryName(filePath) ?? throw new CellmException("Failed to create Llamafile folder"));
return filePath;
}

private static string CreateModelFileName(string modelName)
{
return $"Llamafile-model-{modelName}";
}

private Uri CreateBaseAddress()
{
var uriBuilder = new UriBuilder(_llamafileConfiguration.BaseAddress)
{
Port = GetFirstUnusedPort()
};

return uriBuilder.Uri;
}

private static int GetFirstUnusedPort(ushort min = 49152, ushort max = 65535)
{
if (max < min)
{
throw new ArgumentException("Max port must be larger than min port.");
}

var ipProperties = IPGlobalProperties.GetIPGlobalProperties();

var activePorts = ipProperties.GetActiveTcpConnections()
.Where(connection => connection.State != TcpState.Closed)
.Select(connection => connection.LocalEndPoint)
.Concat(ipProperties.GetActiveTcpListeners())
.Concat(ipProperties.GetActiveUdpListeners())
.Select(endpoint => endpoint.Port)
.ToArray();

var firstInactivePort = Enumerable.Range(min, max)
.Where(port => !activePorts.Contains(port))
.FirstOrDefault();

if (firstInactivePort == default)
{
throw new CellmException($"All local TCP ports between {min} and {max} are currently in use.");
}

return firstInactivePort;
}
}

150 changes: 150 additions & 0 deletions src/Cellm/Models/Local/LocalUtilities.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
using System.Diagnostics;
using System.IO.Compression;
using System.Net.NetworkInformation;
using Cellm.AddIn.Exceptions;
using Microsoft.Office.Interop.Excel;

namespace Cellm.Models.Local;

internal class LocalUtilities(HttpClient httpClient)
{
public async Task<string> DownloadFile(Uri uri, string filePath)
{
if (File.Exists(filePath))
{
return filePath;
}

var filePathPart = $"{filePath}.part";

if (File.Exists(filePathPart))
{
File.Delete(filePathPart);
}

var response = await httpClient.GetAsync(uri, HttpCompletionOption.ResponseHeadersRead);
response.EnsureSuccessStatusCode();

using (var fileStream = File.Create(filePathPart))
using (var httpStream = await response.Content.ReadAsStreamAsync())
{

await httpStream.CopyToAsync(fileStream);
}

File.Move(filePathPart, filePath);

return filePath;
}

public async Task WaitForServer(Uri endpoint, Process process)
{
var startTime = DateTime.UtcNow;

// Wait max 30 seconds to load model
while ((DateTime.UtcNow - startTime).TotalSeconds < 30)
{
if (process.HasExited)
{
throw new CellmException($"Failed to run Llamafile, process exited. Exit code: {process.ExitCode}");
}

try
{
var cancellationTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(1));
var response = await httpClient.GetAsync(endpoint, cancellationTokenSource.Token);
if (response.StatusCode == System.Net.HttpStatusCode.OK)
{
// Server is ready
return;
}
}
catch (HttpRequestException)
{
}
catch (TaskCanceledException)
{
}

// Wait before next attempt
await Task.Delay(500);
}

process.Kill();

throw new CellmException("Failed to run Llamafile, timeout waiting for Llamafile server to start");
}

public string CreateCellmDirectory(params string[] subFolders)
{
var folderPath = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), nameof(Cellm));

if (subFolders.Length > 0)
{
folderPath = Path.Combine(subFolders.Prepend(folderPath).ToArray());
}

Directory.CreateDirectory(folderPath);
return folderPath;
}

public string CreateCellmFilePath(string fileName)
{
return Path.Combine(CreateCellmDirectory(), fileName);
}

public int FindPort(ushort min = 49152, ushort max = 65535)
{
if (max < min)
{
throw new ArgumentException("Max port must be larger than min port.");
}

var ipProperties = IPGlobalProperties.GetIPGlobalProperties();

var activePorts = ipProperties.GetActiveTcpConnections()
.Where(connection => connection.State != TcpState.Closed)
.Select(connection => connection.LocalEndPoint)
.Concat(ipProperties.GetActiveTcpListeners())
.Concat(ipProperties.GetActiveUdpListeners())
.Select(endpoint => endpoint.Port)
.ToArray();

var firstInactivePort = Enumerable.Range(min, max)
.Where(port => !activePorts.Contains(port))
.FirstOrDefault();

if (firstInactivePort == default)
{
throw new CellmException($"All local TCP ports between {min} and {max} are currently in use.");
}

return firstInactivePort;
}

public string ExtractFile(string zipFilePath, string targetDirectory)
{
using (ZipArchive archive = ZipFile.OpenRead(zipFilePath))
{
foreach (ZipArchiveEntry entry in archive.Entries)
{
string destinationPath = Path.Combine(targetDirectory, entry.FullName);

if (!File.Exists(destinationPath))
{
ZipFile.ExtractToDirectory(zipFilePath, targetDirectory);
return targetDirectory;
}

var fileInfo = new FileInfo(destinationPath);
if (fileInfo.Length != entry.Length)
{
ZipFile.ExtractToDirectory(zipFilePath, targetDirectory);
return targetDirectory;
}
}
}

return targetDirectory;
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using System.Diagnostics;
using System.Runtime.InteropServices;

public class LLamafileProcessManager
public class ProcessManager
{
[DllImport("kernel32.dll", CharSet = CharSet.Unicode)]
static extern IntPtr CreateJobObject(IntPtr a, string lpName);
Expand Down Expand Up @@ -61,7 +61,7 @@ enum JobObjectInfoType

private IntPtr _jobObject;

public LLamafileProcessManager()
public ProcessManager()
{
_jobObject = CreateJobObject(IntPtr.Zero, string.Empty);

Expand Down
Loading