Skip to content

Commit f78d287

Browse files
authored
Add DistributedCachingChatClient/EmbeddingGenerator.AdditionalCacheKeyValues (#6558)
* Add DistributedCachingChatClient.AdditionalCacheKeyValues GetCacheKey already enabled augmenting the key list, but doing so required deriving a custom caching client. It appears to be a reasonably common thing to want to configure, so exposing it as a property as well. While doing this, I also removed additional per-call allocation coming from GetCacheKey. * Address feedback
1 parent fd8a1f0 commit f78d287

File tree

7 files changed

+201
-11
lines changed

7 files changed

+201
-11
lines changed

src/Libraries/Microsoft.Extensions.AI.Evaluation.Reporting/CSharp/ResponseCachingChatClient.cs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4-
using System;
54
using System.Collections.Concurrent;
65
using System.Collections.Generic;
76
using System.Diagnostics;
@@ -13,7 +12,6 @@ namespace Microsoft.Extensions.AI.Evaluation.Reporting;
1312

1413
internal sealed class ResponseCachingChatClient : DistributedCachingChatClient
1514
{
16-
private readonly IReadOnlyList<string> _cachingKeys;
1715
private readonly ChatDetails _chatDetails;
1816
private readonly ConcurrentDictionary<string, Stopwatch> _stopWatches;
1917

@@ -24,7 +22,7 @@ internal ResponseCachingChatClient(
2422
ChatDetails chatDetails)
2523
: base(originalChatClient, cache)
2624
{
27-
_cachingKeys = [.. cachingKeys];
25+
CacheKeyAdditionalValues = [.. cachingKeys];
2826
_chatDetails = chatDetails;
2927
_stopWatches = new ConcurrentDictionary<string, Stopwatch>();
3028
}
@@ -124,7 +122,4 @@ protected override async Task WriteCacheStreamingAsync(
124122
cacheHit: false));
125123
}
126124
}
127-
128-
protected override string GetCacheKey(IEnumerable<ChatMessage> messages, ChatOptions? options, params ReadOnlySpan<object?> additionalValues)
129-
=> base.GetCacheKey(messages, options, [.. additionalValues, .. _cachingKeys]);
130125
}

src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System;
5+
using System.Buffers;
56
using System.Collections.Generic;
7+
using System.Linq;
68
using System.Text.Json;
79
using System.Threading;
810
using System.Threading.Tasks;
@@ -34,9 +36,16 @@ namespace Microsoft.Extensions.AI;
3436
/// </remarks>
3537
public class DistributedCachingChatClient : CachingChatClient
3638
{
39+
/// <summary>Boxed cache version.</summary>
40+
/// <remarks>Bump the cache version to invalidate existing caches if the serialization format changes in a breaking way.</remarks>
41+
private static readonly object _cacheVersion = 2;
42+
3743
/// <summary>The <see cref="IDistributedCache"/> instance that will be used as the backing store for the cache.</summary>
3844
private readonly IDistributedCache _storage;
3945

46+
/// <summary>Additional values used to inform the cache key employed for storing state.</summary>
47+
private object[]? _cacheKeyAdditionalValues;
48+
4049
/// <summary>The <see cref="JsonSerializerOptions"/> to use when serializing cache data.</summary>
4150
private JsonSerializerOptions _jsonSerializerOptions = AIJsonUtilities.DefaultOptions;
4251

@@ -56,6 +65,14 @@ public JsonSerializerOptions JsonSerializerOptions
5665
set => _jsonSerializerOptions = Throw.IfNull(value);
5766
}
5867

68+
/// <summary>Gets or sets additional values used to inform the cache key employed for storing state.</summary>
69+
/// <remarks>Any values set in this list will augment the other values used to inform the cache key.</remarks>
70+
public IReadOnlyList<object>? CacheKeyAdditionalValues
71+
{
72+
get => _cacheKeyAdditionalValues;
73+
set => _cacheKeyAdditionalValues = value?.ToArray();
74+
}
75+
5976
/// <inheritdoc />
6077
protected override async Task<ChatResponse?> ReadCacheAsync(string key, CancellationToken cancellationToken)
6178
{
@@ -122,9 +139,26 @@ protected override async Task WriteCacheStreamingAsync(string key, IReadOnlyList
122139
/// </remarks>
123140
protected override string GetCacheKey(IEnumerable<ChatMessage> messages, ChatOptions? options, params ReadOnlySpan<object?> additionalValues)
124141
{
125-
// Bump the cache version to invalidate existing caches if the serialization format changes in a breaking way.
126-
const int CacheVersion = 2;
142+
const int FixedValuesCount = 3;
143+
144+
object[] clientValues = _cacheKeyAdditionalValues ?? Array.Empty<object>();
145+
int length = FixedValuesCount + additionalValues.Length + clientValues.Length;
127146

128-
return AIJsonUtilities.HashDataToString([CacheVersion, messages, options, .. additionalValues], _jsonSerializerOptions);
147+
object?[] arr = ArrayPool<object?>.Shared.Rent(length);
148+
try
149+
{
150+
arr[0] = _cacheVersion;
151+
arr[1] = messages;
152+
arr[2] = options;
153+
additionalValues.CopyTo(arr.AsSpan(FixedValuesCount));
154+
clientValues.CopyTo(arr, FixedValuesCount + additionalValues.Length);
155+
156+
return AIJsonUtilities.HashDataToString(arr.AsSpan(0, length), _jsonSerializerOptions);
157+
}
158+
finally
159+
{
160+
Array.Clear(arr, 0, length);
161+
ArrayPool<object?>.Shared.Return(arr);
162+
}
129163
}
130164
}

src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ public override async Task<GeneratedEmbeddings<TEmbedding>> GenerateAsync(
5454
Throw.InvalidOperationException($"Expected exactly one embedding to be generated, but received {generated.Count}.");
5555
}
5656

57+
if (generated[0] is null)
58+
{
59+
Throw.InvalidOperationException("Generator produced null embedding.");
60+
}
61+
5762
await WriteCacheAsync(cacheKey, generated[0], cancellationToken);
5863
return generated;
5964
}

src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System;
5+
using System.Buffers;
6+
using System.Collections.Generic;
7+
using System.Linq;
58
using System.Text.Json;
69
using System.Text.Json.Serialization.Metadata;
710
using System.Threading;
@@ -24,7 +27,17 @@ namespace Microsoft.Extensions.AI;
2427
public class DistributedCachingEmbeddingGenerator<TInput, TEmbedding> : CachingEmbeddingGenerator<TInput, TEmbedding>
2528
where TEmbedding : Embedding
2629
{
30+
/// <summary>Boxed cache version.</summary>
31+
/// <remarks>Bump the cache version to invalidate existing caches if the serialization format changes in a breaking way.</remarks>
32+
private static readonly object _cacheVersion = 2;
33+
34+
/// <summary>The <see cref="IDistributedCache"/> instance that will be used as the backing store for the cache.</summary>
2735
private readonly IDistributedCache _storage;
36+
37+
/// <summary>Additional values used to inform the cache key employed for storing state.</summary>
38+
private object[]? _cacheKeyAdditionalValues;
39+
40+
/// <summary>Additional cache key values used to inform the key employed for storing state.</summary>
2841
private JsonSerializerOptions _jsonSerializerOptions;
2942

3043
/// <summary>Initializes a new instance of the <see cref="DistributedCachingEmbeddingGenerator{TInput, TEmbedding}"/> class.</summary>
@@ -51,6 +64,14 @@ public JsonSerializerOptions JsonSerializerOptions
5164
}
5265
}
5366

67+
/// <summary>Gets or sets additional values used to inform the cache key employed for storing state.</summary>
68+
/// <remarks>Any values set in this list will augment the other values used to inform the cache key.</remarks>
69+
public IReadOnlyList<object>? CacheKeyAdditionalValues
70+
{
71+
get => _cacheKeyAdditionalValues;
72+
set => _cacheKeyAdditionalValues = value?.ToArray();
73+
}
74+
5475
/// <inheritdoc />
5576
protected override async Task<TEmbedding?> ReadCacheAsync(string key, CancellationToken cancellationToken)
5677
{
@@ -87,6 +108,26 @@ protected override async Task WriteCacheAsync(string key, TEmbedding value, Canc
87108
/// The generated cache key is not guaranteed to be stable across releases of the library.
88109
/// </para>
89110
/// </remarks>
90-
protected override string GetCacheKey(params ReadOnlySpan<object?> values) =>
91-
AIJsonUtilities.HashDataToString(values, _jsonSerializerOptions);
111+
protected override string GetCacheKey(params ReadOnlySpan<object?> values)
112+
{
113+
const int FixedValuesCount = 1;
114+
115+
object[] clientValues = _cacheKeyAdditionalValues ?? Array.Empty<object>();
116+
int length = FixedValuesCount + clientValues.Length + values.Length;
117+
118+
object?[] arr = ArrayPool<object?>.Shared.Rent(length);
119+
try
120+
{
121+
arr[0] = _cacheVersion;
122+
values.CopyTo(arr.AsSpan(FixedValuesCount));
123+
clientValues.CopyTo(arr, FixedValuesCount + values.Length);
124+
125+
return AIJsonUtilities.HashDataToString(arr.AsSpan(0, length), _jsonSerializerOptions);
126+
}
127+
finally
128+
{
129+
Array.Clear(arr, 0, length);
130+
ArrayPool<object?>.Shared.Return(arr);
131+
}
132+
}
92133
}

src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.json

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,10 @@
310310
}
311311
],
312312
"Properties": [
313+
{
314+
"Member": "System.Collections.Generic.IReadOnlyList<object>? Microsoft.Extensions.AI.DistributedCachingChatClient.CacheKeyAdditionalValues { get; set; }",
315+
"Stage": "Stable"
316+
},
313317
{
314318
"Member": "System.Text.Json.JsonSerializerOptions Microsoft.Extensions.AI.DistributedCachingChatClient.JsonSerializerOptions { get; set; }",
315319
"Stage": "Stable"
@@ -351,6 +355,10 @@
351355
}
352356
],
353357
"Properties": [
358+
{
359+
"Member": "System.Collections.Generic.IReadOnlyList<object>? Microsoft.Extensions.AI.DistributedCachingEmbeddingGenerator<TInput, TEmbedding>.CacheKeyAdditionalValues { get; set; }",
360+
"Stage": "Stable"
361+
},
354362
{
355363
"Member": "System.Text.Json.JsonSerializerOptions Microsoft.Extensions.AI.DistributedCachingEmbeddingGenerator<TInput, TEmbedding>.JsonSerializerOptions { get; set; }",
356364
"Stage": "Stable"

test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,52 @@ public async Task CacheKeyVariesByChatOptionsAsync()
595595
Assert.Equal("value 2", result4.Text);
596596
}
597597

598+
[Fact]
599+
public async Task CacheKeyVariesByAdditionalKeyValuesAsync()
600+
{
601+
// Arrange
602+
var innerCallCount = 0;
603+
var completionTcs = new TaskCompletionSource<bool>();
604+
using var testClient = new TestChatClient
605+
{
606+
GetResponseAsyncCallback = async (_, options, _) =>
607+
{
608+
innerCallCount++;
609+
await Task.Yield();
610+
return new(new ChatMessage(ChatRole.Assistant, innerCallCount.ToString()));
611+
}
612+
};
613+
using var outer = new DistributedCachingChatClient(testClient, _storage)
614+
{
615+
JsonSerializerOptions = TestJsonSerializerContext.Default.Options
616+
};
617+
618+
var result1 = await outer.GetResponseAsync([]);
619+
var result2 = await outer.GetResponseAsync([]);
620+
621+
Assert.Equal(1, innerCallCount);
622+
Assert.Equal("1", result1.Text);
623+
Assert.Equal("1", result2.Text);
624+
625+
// Change key
626+
outer.CacheKeyAdditionalValues = ["extraKey"];
627+
628+
var result3 = await outer.GetResponseAsync([]);
629+
var result4 = await outer.GetResponseAsync([]);
630+
631+
Assert.Equal(2, innerCallCount);
632+
Assert.Equal("2", result3.Text);
633+
Assert.Equal("2", result4.Text);
634+
635+
// Remove key
636+
outer.CacheKeyAdditionalValues = [];
637+
638+
var result5 = await outer.GetResponseAsync([]);
639+
640+
Assert.Equal(2, innerCallCount);
641+
Assert.Equal("1", result5.Text);
642+
}
643+
598644
[Fact]
599645
public async Task SubclassCanOverrideCacheKeyToVaryByChatOptionsAsync()
600646
{

test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,24 @@ public class DistributedCachingEmbeddingGeneratorTest
2121
AdditionalProperties = new() { ["a"] = "b" },
2222
};
2323

24+
[Fact]
25+
public void Properties_Roundtrip()
26+
{
27+
using var innerGenerator = new TestEmbeddingGenerator();
28+
using DistributedCachingEmbeddingGenerator<string, Embedding<float>> generator = new(innerGenerator, _storage);
29+
30+
Assert.Same(AIJsonUtilities.DefaultOptions, generator.JsonSerializerOptions);
31+
var jso = new JsonSerializerOptions();
32+
generator.JsonSerializerOptions = jso;
33+
Assert.Same(jso, generator.JsonSerializerOptions);
34+
35+
Assert.Null(generator.CacheKeyAdditionalValues);
36+
var additionalValues = new[] { "value1", "value2" };
37+
generator.CacheKeyAdditionalValues = additionalValues;
38+
Assert.NotSame(additionalValues, generator.CacheKeyAdditionalValues);
39+
Assert.Equal(additionalValues, generator.CacheKeyAdditionalValues);
40+
}
41+
2442
[Fact]
2543
public async Task CachesSuccessResultsAsync()
2644
{
@@ -271,6 +289,49 @@ public async Task CacheKeyVariesByEmbeddingOptionsAsync()
271289
AssertEmbeddingsEqual(new("value 2".Select(c => (float)c).ToArray()), result4);
272290
}
273291

292+
[Fact]
293+
public async Task CacheKeyVariesByAdditionalKeyValuesAsync()
294+
{
295+
// Arrange
296+
var innerCallCount = 0;
297+
var completionTcs = new TaskCompletionSource<bool>();
298+
using var innerGenerator = new TestEmbeddingGenerator
299+
{
300+
GenerateAsyncCallback = async (value, options, cancellationToken) =>
301+
{
302+
innerCallCount++;
303+
await Task.Yield();
304+
return new(new Embedding<float>[] { new Embedding<float>(new float[] { innerCallCount }) });
305+
}
306+
};
307+
using var outer = new DistributedCachingEmbeddingGenerator<string, Embedding<float>>(innerGenerator, _storage)
308+
{
309+
JsonSerializerOptions = TestJsonSerializerContext.Default.Options,
310+
};
311+
312+
var result1 = await outer.GenerateAsync("abc");
313+
var result2 = await outer.GenerateAsync("abc");
314+
AssertEmbeddingsEqual(result1, result2);
315+
316+
var result3 = await outer.GenerateAsync("abc");
317+
AssertEmbeddingsEqual(result1, result3);
318+
319+
// Change key
320+
outer.CacheKeyAdditionalValues = ["extraKey"];
321+
322+
var result4 = await outer.GenerateAsync("abc");
323+
Assert.NotEqual(result1.Vector.ToArray(), result4.Vector.ToArray());
324+
325+
var result5 = await outer.GenerateAsync("abc");
326+
AssertEmbeddingsEqual(result4, result5);
327+
328+
// Remove key
329+
outer.CacheKeyAdditionalValues = [];
330+
331+
var result6 = await outer.GenerateAsync("abc");
332+
AssertEmbeddingsEqual(result1, result6);
333+
}
334+
274335
[Fact]
275336
public async Task SubclassCanOverrideCacheKeyToVaryByOptionsAsync()
276337
{

0 commit comments

Comments
 (0)