Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
Expand All @@ -13,7 +12,6 @@ namespace Microsoft.Extensions.AI.Evaluation.Reporting;

internal sealed class ResponseCachingChatClient : DistributedCachingChatClient
{
private readonly IReadOnlyList<string> _cachingKeys;
private readonly ChatDetails _chatDetails;
private readonly ConcurrentDictionary<string, Stopwatch> _stopWatches;

Expand All @@ -24,7 +22,7 @@ internal ResponseCachingChatClient(
ChatDetails chatDetails)
: base(originalChatClient, cache)
{
_cachingKeys = [.. cachingKeys];
CacheKeyAdditionalValues = [.. cachingKeys];
_chatDetails = chatDetails;
_stopWatches = new ConcurrentDictionary<string, Stopwatch>();
}
Expand Down Expand Up @@ -124,7 +122,4 @@ protected override async Task WriteCacheStreamingAsync(
cacheHit: false));
}
}

protected override string GetCacheKey(IEnumerable<ChatMessage> messages, ChatOptions? options, params ReadOnlySpan<object?> additionalValues)
=> base.GetCacheKey(messages, options, [.. additionalValues, .. _cachingKeys]);
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Buffers;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -34,9 +36,16 @@ namespace Microsoft.Extensions.AI;
/// </remarks>
public class DistributedCachingChatClient : CachingChatClient
{
/// <summary>Boxed cache version.</summary>
/// <remarks>Bump the cache version to invalidate existing caches if the serialization format changes in a breaking way.</remarks>
private static readonly object _cacheVersion = 2;

/// <summary>The <see cref="IDistributedCache"/> instance that will be used as the backing store for the cache.</summary>
private readonly IDistributedCache _storage;

/// <summary>Additional values used to inform the cache key employed for storing state.</summary>
private object[]? _cacheKeyAdditionalValues;

/// <summary>The <see cref="JsonSerializerOptions"/> to use when serializing cache data.</summary>
private JsonSerializerOptions _jsonSerializerOptions = AIJsonUtilities.DefaultOptions;

Expand All @@ -56,6 +65,14 @@ public JsonSerializerOptions JsonSerializerOptions
set => _jsonSerializerOptions = Throw.IfNull(value);
}

/// <summary>Gets or sets additional values used to inform the cache key employed for storing state.</summary>
/// <remarks>Any values set in this list will augment the other values used to inform the cache key.</remarks>
public IReadOnlyList<object>? CacheKeyAdditionalValues
{
get => _cacheKeyAdditionalValues;
set => _cacheKeyAdditionalValues = value?.ToArray();
}

/// <inheritdoc />
protected override async Task<ChatResponse?> ReadCacheAsync(string key, CancellationToken cancellationToken)
{
Expand Down Expand Up @@ -122,9 +139,26 @@ protected override async Task WriteCacheStreamingAsync(string key, IReadOnlyList
/// </remarks>
protected override string GetCacheKey(IEnumerable<ChatMessage> messages, ChatOptions? options, params ReadOnlySpan<object?> additionalValues)
{
// Bump the cache version to invalidate existing caches if the serialization format changes in a breaking way.
const int CacheVersion = 2;
const int FixedValuesCount = 3;

object[] clientValues = _cacheKeyAdditionalValues ?? Array.Empty<object>();
int length = FixedValuesCount + additionalValues.Length + clientValues.Length;

return AIJsonUtilities.HashDataToString([CacheVersion, messages, options, .. additionalValues], _jsonSerializerOptions);
object?[] arr = ArrayPool<object?>.Shared.Rent(length);
try
{
arr[0] = _cacheVersion;
arr[1] = messages;
arr[2] = options;
additionalValues.CopyTo(arr.AsSpan(FixedValuesCount));
clientValues.CopyTo(arr, FixedValuesCount + additionalValues.Length);

return AIJsonUtilities.HashDataToString(arr.AsSpan(0, length), _jsonSerializerOptions);
}
finally
{
Array.Clear(arr, 0, length);
ArrayPool<object?>.Shared.Return(arr);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ public override async Task<GeneratedEmbeddings<TEmbedding>> GenerateAsync(
Throw.InvalidOperationException($"Expected exactly one embedding to be generated, but received {generated.Count}.");
}

if (generated[0] is null)
{
Throw.InvalidOperationException("Generator produced null embedding.");
}

await WriteCacheAsync(cacheKey, generated[0], cancellationToken);
return generated;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Buffers;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using System.Threading;
Expand All @@ -24,7 +27,17 @@ namespace Microsoft.Extensions.AI;
public class DistributedCachingEmbeddingGenerator<TInput, TEmbedding> : CachingEmbeddingGenerator<TInput, TEmbedding>
where TEmbedding : Embedding
{
/// <summary>Boxed cache version.</summary>
/// <remarks>Bump the cache version to invalidate existing caches if the serialization format changes in a breaking way.</remarks>
private static readonly object _cacheVersion = 2;

/// <summary>The <see cref="IDistributedCache"/> instance that will be used as the backing store for the cache.</summary>
private readonly IDistributedCache _storage;

/// <summary>Additional values used to inform the cache key employed for storing state.</summary>
private object[]? _cacheKeyAdditionalValues;

/// <summary>Additional cache key values used to inform the key employed for storing state.</summary>
private JsonSerializerOptions _jsonSerializerOptions;

/// <summary>Initializes a new instance of the <see cref="DistributedCachingEmbeddingGenerator{TInput, TEmbedding}"/> class.</summary>
Expand All @@ -51,6 +64,14 @@ public JsonSerializerOptions JsonSerializerOptions
}
}

/// <summary>Gets or sets additional values used to inform the cache key employed for storing state.</summary>
/// <remarks>Any values set in this list will augment the other values used to inform the cache key.</remarks>
public IReadOnlyList<object>? CacheKeyAdditionalValues
{
get => _cacheKeyAdditionalValues;
set => _cacheKeyAdditionalValues = value?.ToArray();
}

/// <inheritdoc />
protected override async Task<TEmbedding?> ReadCacheAsync(string key, CancellationToken cancellationToken)
{
Expand Down Expand Up @@ -87,6 +108,26 @@ protected override async Task WriteCacheAsync(string key, TEmbedding value, Canc
/// The generated cache key is not guaranteed to be stable across releases of the library.
/// </para>
/// </remarks>
protected override string GetCacheKey(params ReadOnlySpan<object?> values) =>
AIJsonUtilities.HashDataToString(values, _jsonSerializerOptions);
protected override string GetCacheKey(params ReadOnlySpan<object?> values)
{
const int FixedValuesCount = 1;

object[] clientValues = _cacheKeyAdditionalValues ?? Array.Empty<object>();
int length = FixedValuesCount + clientValues.Length + values.Length;

object?[] arr = ArrayPool<object?>.Shared.Rent(length);
try
{
arr[0] = _cacheVersion;
values.CopyTo(arr.AsSpan(FixedValuesCount));
clientValues.CopyTo(arr, FixedValuesCount + values.Length);

return AIJsonUtilities.HashDataToString(arr.AsSpan(0, length), _jsonSerializerOptions);
}
finally
{
Array.Clear(arr, 0, length);
ArrayPool<object?>.Shared.Return(arr);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,10 @@
}
],
"Properties": [
{
"Member": "System.Collections.Generic.IReadOnlyList<object>? Microsoft.Extensions.AI.DistributedCachingChatClient.CacheKeyAdditionalValues { get; set; }",
"Stage": "Stable"
},
{
"Member": "System.Text.Json.JsonSerializerOptions Microsoft.Extensions.AI.DistributedCachingChatClient.JsonSerializerOptions { get; set; }",
"Stage": "Stable"
Expand Down Expand Up @@ -351,6 +355,10 @@
}
],
"Properties": [
{
"Member": "System.Collections.Generic.IReadOnlyList<object>? Microsoft.Extensions.AI.DistributedCachingEmbeddingGenerator<TInput, TEmbedding>.CacheKeyAdditionalValues { get; set; }",
"Stage": "Stable"
},
{
"Member": "System.Text.Json.JsonSerializerOptions Microsoft.Extensions.AI.DistributedCachingEmbeddingGenerator<TInput, TEmbedding>.JsonSerializerOptions { get; set; }",
"Stage": "Stable"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,52 @@ public async Task CacheKeyVariesByChatOptionsAsync()
Assert.Equal("value 2", result4.Text);
}

[Fact]
public async Task CacheKeyVariesByAdditionalKeyValuesAsync()
{
// Arrange
var innerCallCount = 0;
var completionTcs = new TaskCompletionSource<bool>();
using var testClient = new TestChatClient
{
GetResponseAsyncCallback = async (_, options, _) =>
{
innerCallCount++;
await Task.Yield();
return new(new ChatMessage(ChatRole.Assistant, innerCallCount.ToString()));
}
};
using var outer = new DistributedCachingChatClient(testClient, _storage)
{
JsonSerializerOptions = TestJsonSerializerContext.Default.Options
};

var result1 = await outer.GetResponseAsync([]);
var result2 = await outer.GetResponseAsync([]);

Assert.Equal(1, innerCallCount);
Assert.Equal("1", result1.Text);
Assert.Equal("1", result2.Text);

// Change key
outer.CacheKeyAdditionalValues = ["extraKey"];

var result3 = await outer.GetResponseAsync([]);
var result4 = await outer.GetResponseAsync([]);

Assert.Equal(2, innerCallCount);
Assert.Equal("2", result3.Text);
Assert.Equal("2", result4.Text);

// Remove key
outer.CacheKeyAdditionalValues = [];

var result5 = await outer.GetResponseAsync([]);

Assert.Equal(2, innerCallCount);
Assert.Equal("1", result5.Text);
}

[Fact]
public async Task SubclassCanOverrideCacheKeyToVaryByChatOptionsAsync()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,24 @@ public class DistributedCachingEmbeddingGeneratorTest
AdditionalProperties = new() { ["a"] = "b" },
};

[Fact]
public void Properties_Roundtrip()
{
using var innerGenerator = new TestEmbeddingGenerator();
using DistributedCachingEmbeddingGenerator<string, Embedding<float>> generator = new(innerGenerator, _storage);

Assert.Same(AIJsonUtilities.DefaultOptions, generator.JsonSerializerOptions);
var jso = new JsonSerializerOptions();
generator.JsonSerializerOptions = jso;
Assert.Same(jso, generator.JsonSerializerOptions);

Assert.Null(generator.CacheKeyAdditionalValues);
var additionalValues = new[] { "value1", "value2" };
generator.CacheKeyAdditionalValues = additionalValues;
Assert.NotSame(additionalValues, generator.CacheKeyAdditionalValues);
Assert.Equal(additionalValues, generator.CacheKeyAdditionalValues);
}

[Fact]
public async Task CachesSuccessResultsAsync()
{
Expand Down Expand Up @@ -271,6 +289,49 @@ public async Task CacheKeyVariesByEmbeddingOptionsAsync()
AssertEmbeddingsEqual(new("value 2".Select(c => (float)c).ToArray()), result4);
}

[Fact]
public async Task CacheKeyVariesByAdditionalKeyValuesAsync()
{
// Arrange
var innerCallCount = 0;
var completionTcs = new TaskCompletionSource<bool>();
using var innerGenerator = new TestEmbeddingGenerator
{
GenerateAsyncCallback = async (value, options, cancellationToken) =>
{
innerCallCount++;
await Task.Yield();
return new(new Embedding<float>[] { new Embedding<float>(new float[] { innerCallCount }) });
}
};
using var outer = new DistributedCachingEmbeddingGenerator<string, Embedding<float>>(innerGenerator, _storage)
{
JsonSerializerOptions = TestJsonSerializerContext.Default.Options,
};

var result1 = await outer.GenerateAsync("abc");
var result2 = await outer.GenerateAsync("abc");
AssertEmbeddingsEqual(result1, result2);

var result3 = await outer.GenerateAsync("abc");
AssertEmbeddingsEqual(result1, result3);

// Change key
outer.CacheKeyAdditionalValues = ["extraKey"];

var result4 = await outer.GenerateAsync("abc");
Assert.NotEqual(result1.Vector.ToArray(), result4.Vector.ToArray());

var result5 = await outer.GenerateAsync("abc");
AssertEmbeddingsEqual(result4, result5);

// Remove key
outer.CacheKeyAdditionalValues = [];

var result6 = await outer.GenerateAsync("abc");
AssertEmbeddingsEqual(result1, result6);
}

[Fact]
public async Task SubclassCanOverrideCacheKeyToVaryByOptionsAsync()
{
Expand Down
Loading