Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Cellm/AddIn/CellmAddIn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ private static ServiceCollection ConfigureServices(ServiceCollection services)
cfg.AddBehavior<SentryBehavior<ProviderRequest, ProviderResponse>>(ServiceLifetime.Singleton);
cfg.AddBehavior<ToolBehavior<ProviderRequest, ProviderResponse>>(ServiceLifetime.Singleton);
cfg.AddBehavior<CacheBehavior<ProviderRequest, ProviderResponse>>(ServiceLifetime.Singleton);
cfg.AddBehavior<TokenUsageBehavior<ProviderRequest, ProviderResponse>>(ServiceLifetime.Singleton);
})
.AddSingleton<IProviderBehavior, GeminiTemperatureBehavior>();

Expand Down
2 changes: 1 addition & 1 deletion src/Cellm/AddIn/UserInterface/Ribbon/RibbonMain.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace Cellm.AddIn.UserInterface.Ribbon;
[ComVisible(true)]
public partial class RibbonMain : ExcelRibbon
{
private IRibbonUI? _ribbonUi;
internal static IRibbonUI? _ribbonUi;

private static readonly string _appSettingsPath = Path.Combine(CellmAddIn.ConfigurationPath, "appsettings.json");
private static readonly string _appsettingsLocalPath = Path.Combine(CellmAddIn.ConfigurationPath, "appsettings.Local.json");
Expand Down
52 changes: 50 additions & 2 deletions src/Cellm/AddIn/UserInterface/Ribbon/RibbonModelGroup.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using System.Text;
using Cellm.AddIn.UserInterface.Forms;
using Cellm.Models.Prompts;
using Cellm.Models.Providers;
using Cellm.Models.Providers.Anthropic;
using Cellm.Models.Providers.Aws;
Expand All @@ -14,6 +13,7 @@
using Cellm.Models.Providers.OpenAiCompatible;
using Cellm.Users;
using ExcelDna.Integration.CustomUI;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Caching.Hybrid;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
Expand All @@ -24,7 +24,7 @@ namespace Cellm.AddIn.UserInterface.Ribbon;

public partial class RibbonMain
{
private enum ModelGroupControlIds
public enum ModelGroupControlIds
{
VerticalContainer,
HorizontalContainer,
Expand All @@ -38,6 +38,14 @@ private enum ModelGroupControlIds
ModelComboBox,
TemperatureComboBox,

StatisticsContainer,
StatisticsTokensContainer,
StatisticsSpeedContainer,
TokensLabel,
TokenStatistics,
SpeedLabel,
SpeedStatistics,

CacheToggleButton,

ProviderSettingsButton
Expand Down Expand Up @@ -157,6 +165,14 @@ public string ModelGroup()
getItemCount="{nameof(GetTemperatureItemCount)}"
getItemLabel="{nameof(GetTemperatureItemLabel)}" />
</box>
<box id="{nameof(ModelGroupControlIds.StatisticsTokensContainer)}" boxStyle="horizontal">
<labelControl id="{nameof(ModelGroupControlIds.TokensLabel)}" label="Tokens:" />
<labelControl id="{nameof(ModelGroupControlIds.TokenStatistics)}" getLabel="{nameof(GetTokenStatistics)}" supertip="Total input and output token usage this session" />
</box>
<box id="{nameof(ModelGroupControlIds.StatisticsSpeedContainer)}" boxStyle="horizontal">
<labelControl id="{nameof(ModelGroupControlIds.SpeedLabel)}" label="Speed:" />
<labelControl id="{nameof(ModelGroupControlIds.SpeedStatistics)}" getLabel="{nameof(GetSpeedStatistics)}" supertip="Average Tokens Per Second (TPS) per request and average Requests Per Second" />
</box>
</box>
<separator id="cacheSeparator" />
<toggleButton id="{nameof(ModelGroupControlIds.CacheToggleButton)}"
Expand Down Expand Up @@ -845,4 +861,36 @@ public bool GetCachePressed(IRibbonControl control)
{
return bool.Parse(GetValue($"{nameof(CellmAddInConfiguration)}:{nameof(CellmAddInConfiguration.EnableCache)}"));
}

public string GetTokenStatistics(IRibbonControl control)
{
return $"{FormatCount(TokenUsageNotificationHandler.GetTotalInputTokens())} in / {FormatCount(TokenUsageNotificationHandler.GetTotalOutputTokens())} out";
}

public string GetSpeedStatistics(IRibbonControl control)
{
return $"{TokenUsageNotificationHandler.GetTokensPerSecond():F0} TPS / {TokenUsageNotificationHandler.GetRequestsPerSecond():F1} RPS";
}

public static string FormatCount(long number)
{
if (number == 0) return "0";

string[] suffixes = { "", "K", "M", "B", "T", "P", "E" }; // Kilo, Mega, Giga, Tera, Peta, Exa

// The log base 1000 of the number gives us the magnitude
var magnitude = (int)Math.Log(Math.Abs(number), 1000);

// Don't go beyond the available suffixes
if (magnitude >= suffixes.Length)
{
magnitude = suffixes.Length - 1;
}

// Scale the number down to the 1-999 range
var scaledNumber = number / Math.Pow(1000, magnitude);

// Format the number with one optional decimal place and append the correct suffix
return $"{scaledNumber:0.#}{suffixes[magnitude]}";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
using System.Collections.Concurrent;
using Cellm.Models.Behaviors;
using ExcelDna.Integration;
using MediatR;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Logging;

namespace Cellm.AddIn.UserInterface.Ribbon;

internal class TokenUsageNotificationHandler(ILogger<TokenUsageNotificationHandler> logger) : INotificationHandler<TokenUsageNotification>
{
private static readonly ConcurrentDictionary<string, long> _tokenUsage = new()
{
[nameof(UsageDetails.InputTokenCount)] = 0,
[nameof(UsageDetails.OutputTokenCount)] = 0
};

private static readonly ConcurrentDictionary<DateTime, (long, double)> _tokensPerSecond = new();
private readonly int _maxTokensPerSecondMeasurements = 100;

public Task Handle(TokenUsageNotification notification, CancellationToken cancellationToken)
{
if (notification is null)
{
logger.LogWarning("Received null usage notification");
return Task.CompletedTask;
}

_tokenUsage[nameof(UsageDetails.InputTokenCount)] += notification.Usage.InputTokenCount ?? 0;
_tokenUsage[nameof(UsageDetails.OutputTokenCount)] += notification.Usage.OutputTokenCount ?? 0;

_tokensPerSecond[DateTime.UtcNow] = (notification.Usage.OutputTokenCount ?? 0, notification.ElapsedTime.TotalSeconds);

// Remove measurements until we are at allowed max
while (_tokensPerSecond.Count > _maxTokensPerSecondMeasurements)
{
var oldestMeasurement = _tokensPerSecond.Keys.Min();
_tokensPerSecond.TryRemove(oldestMeasurement, out _);
}

// Remove measurements older than 30 seconds, as they mess up the statistics
var cutoffTime = DateTime.UtcNow.AddSeconds(-30);
var keysToRemove = _tokensPerSecond.Keys.Where(k => k < cutoffTime).ToList();

foreach (var key in keysToRemove)
{
_tokensPerSecond.TryRemove(key, out _);
}

ExcelAsyncUtil.QueueAsMacro(() =>
{
RibbonMain._ribbonUi?.InvalidateControl(nameof(RibbonMain.ModelGroupControlIds.TokenStatistics));
});

// Update speed statistics iff we have two measurements or more
if (_tokensPerSecond.Count >= 2)
{
ExcelAsyncUtil.QueueAsMacro(() =>
{
RibbonMain._ribbonUi?.InvalidateControl(nameof(RibbonMain.ModelGroupControlIds.SpeedStatistics));
});
}

return Task.CompletedTask;
}

public static long GetTotalInputTokens() => _tokenUsage[nameof(UsageDetails.InputTokenCount)];

public static long GetTotalOutputTokens() => _tokenUsage[nameof(UsageDetails.OutputTokenCount)];

public static double GetTokensPerSecond()
{
if (_tokensPerSecond.IsEmpty)
{
return 0;
}

return _tokensPerSecond.Sum(kvp => kvp.Value.Item1) / (_tokensPerSecond.Average(kvp => kvp.Value.Item2) + 0.01);
}

public static double GetRequestsPerSecond()
{
if (_tokensPerSecond.Count < 2)
{
return 1;
}

var durationInSeconds = _tokensPerSecond.Keys.Max().Subtract(_tokensPerSecond.Keys.Min()).TotalSeconds;

if (durationInSeconds < 0.01)
{
return 1;
}

return _tokensPerSecond.Count / durationInSeconds;
}
}
5 changes: 2 additions & 3 deletions src/Cellm/Models/Behaviors/CacheBehavior.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
using System.Text;
using System.Text.Json;
using Cellm.AddIn;
using Cellm.Models.Prompts;
using Cellm.Models.Providers;
using MediatR;
using Microsoft.Extensions.Caching.Hybrid;
Expand All @@ -15,8 +14,8 @@ internal class CacheBehavior<TRequest, TResponse>(
HybridCache cache,
IOptionsMonitor<CellmAddInConfiguration> providerConfiguration,
ILogger<CacheBehavior<TRequest, TResponse>> logger) : IPipelineBehavior<TRequest, TResponse>
where TRequest : IPrompt
where TResponse : IPrompt
where TRequest : IGetPrompt
where TResponse : IGetPrompt
{
private readonly HybridCacheEntryOptions _cacheEntryOptions = new()
{
Expand Down
8 changes: 8 additions & 0 deletions src/Cellm/Models/Behaviors/IChatResponse.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
using Microsoft.Extensions.AI;

namespace Cellm.Models.Behaviors;

internal interface IChatResponse
{
public ChatResponse ChatResponse { get; }
}
8 changes: 8 additions & 0 deletions src/Cellm/Models/Behaviors/IGetPrompt.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
using Cellm.Models.Prompts;

namespace Cellm.Models.Behaviors;

internal interface IGetPrompt
{
public Prompt Prompt { get; }
}
8 changes: 8 additions & 0 deletions src/Cellm/Models/Behaviors/IGetProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
using Cellm.Models.Providers;

namespace Cellm.Models.Behaviors;

internal interface IGetProvider
{
public Provider Provider { get; }
}
8 changes: 3 additions & 5 deletions src/Cellm/Models/Behaviors/ProviderBehavior.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
using Cellm.Models.Prompts;
using Cellm.Models.Providers;
using Cellm.Models.Providers.Behaviors;
using Cellm.Models.Providers.Behaviors;
using MediatR;

namespace Cellm.Models.Behaviors;

internal class ProviderBehavior<TRequest, TResponse>(IEnumerable<IProviderBehavior> providerBehaviors) : IPipelineBehavior<TRequest, TResponse>
where TRequest : IPrompt, IProvider
where TResponse : IPrompt, IProvider
where TRequest : IGetPrompt, IGetProvider
where TResponse : IGetPrompt, IGetProvider
{
public async Task<TResponse> Handle(TRequest request, RequestHandlerDelegate<TResponse> next, CancellationToken cancellationToken)
{
Expand Down
3 changes: 1 addition & 2 deletions src/Cellm/Models/Behaviors/SentryBehavior.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using Cellm.AddIn;
using Cellm.Models.Prompts;
using Cellm.Users;
using MediatR;
using Microsoft.Extensions.AI;
Expand All @@ -13,7 +12,7 @@ internal class SentryBehavior<TRequest, TResponse>(
Account account,
ILogger<SentryBehavior<TRequest, TResponse>> logger)
: IPipelineBehavior<TRequest, TResponse>
where TRequest : IPrompt
where TRequest : IGetPrompt
{
public async Task<TResponse> Handle(TRequest request, RequestHandlerDelegate<TResponse> next, CancellationToken cancellationToken)
{
Expand Down
53 changes: 53 additions & 0 deletions src/Cellm/Models/Behaviors/TokenUsageBehavior.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
using System.Diagnostics;
using MediatR;
using Microsoft.Extensions.Logging;

namespace Cellm.Models.Behaviors;

internal class TokenUsageBehavior<TRequest, TResponse>(
IPublisher publisher,
ILogger<TokenUsageBehavior<TRequest, TResponse>> logger) : IPipelineBehavior<TRequest, TResponse>
where TRequest : IRequest<TResponse>, IGetProvider
where TResponse : IChatResponse
{
public async Task<TResponse> Handle(TRequest request, RequestHandlerDelegate<TResponse> next, CancellationToken cancellationToken)
{
var stopwatch = Stopwatch.StartNew();

// Let the rest of the pipeline (including the actual handler) run
var response = await next().ConfigureAwait(false);

stopwatch.Stop();
var elapsedTime = stopwatch.Elapsed;

if (response.ChatResponse?.Usage is null)
{
logger.LogDebug(
"{provider} completed request in {ElapsedMilliseconds:F2}ms. No token usage details found in response.",
request.Provider,
elapsedTime.TotalMilliseconds
);

return response;
}

var requestType = typeof(TRequest).Name;

logger.LogInformation(
"{provider} completed request in {ElapsedMilliseconds:F2}ms",
requestType,
elapsedTime.TotalMilliseconds
);

var notification = new TokenUsageNotification(
Usage: response.ChatResponse.Usage,
Provider: request.Provider,
Model: response.ChatResponse.ModelId,
ElapsedTime: elapsedTime
);

await publisher.Publish(notification, cancellationToken).ConfigureAwait(false);

return response;
}
}
12 changes: 12 additions & 0 deletions src/Cellm/Models/Behaviors/TokenUsageNotification.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
using Cellm.Models.Providers;
using MediatR;
using Microsoft.Extensions.AI;

namespace Cellm.Models.Behaviors;

public record TokenUsageNotification(
UsageDetails Usage,
Provider Provider,
string? Model,
TimeSpan ElapsedTime
) : INotification;
3 changes: 1 addition & 2 deletions src/Cellm/Models/Behaviors/ToolBehavior.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using System.Collections.Concurrent;
using System.Runtime.CompilerServices;
using Cellm.AddIn;
using Cellm.Models.Prompts;
using Cellm.Tools.ModelContextProtocol;
using Cellm.Users;
using MediatR;
Expand All @@ -20,7 +19,7 @@ internal class ToolBehavior<TRequest, TResponse>(
ILogger<ToolBehavior<TRequest, TResponse>> logger,
ILoggerFactory loggerFactory)
: IPipelineBehavior<TRequest, TResponse>
where TRequest : IPrompt
where TRequest : IGetPrompt
{
// TODO: Cannot use HybridCache because McpClientTool instances can be serialized
private readonly ConcurrentDictionary<string, IList<McpClientTool>> _cache = new();
Expand Down
6 changes: 0 additions & 6 deletions src/Cellm/Models/Prompts/IPrompt.cs

This file was deleted.

6 changes: 0 additions & 6 deletions src/Cellm/Models/Providers/Behaviors/IProvider.cs

This file was deleted.

5 changes: 3 additions & 2 deletions src/Cellm/Models/Providers/ProviderRequest.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Cellm.Models.Prompts;
using Cellm.Models.Behaviors;
using Cellm.Models.Prompts;
using MediatR;

namespace Cellm.Models.Providers;

internal record ProviderRequest(Prompt Prompt, Provider Provider) : IPrompt, IRequest<ProviderResponse>;
internal record ProviderRequest(Prompt Prompt, Provider Provider) : IGetPrompt, IGetProvider, IRequest<ProviderResponse>;
Loading
Loading