Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ToChatCompletion / ToStreamingChatCompletionUpdates in CachingChatClient #5616

Merged
merged 3 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Use ToChatCompletion / ToStreamingChatCompletionUpdates in CachingCha…
…tClient

Adds a ToStreamingChatCompletionUpdates method that's the counterpart to the recently added ToChatCompletion.

Then uses both from CachingChatClient instead of its now bespoke coalescing implementation. When coalescing is enabled (the default), CachingChatClient caches everything as a ChatCompletion, rather than distinguishing streaming and non-streaming.
  • Loading branch information
stephentoub committed Nov 11, 2024
commit 65631ef7b51562d76038fef676dd1ac20265cc16
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,53 @@ public ChatMessage Message
/// <inheritdoc />
public override string ToString() =>
Choices is { Count: > 0 } choices ? string.Join(Environment.NewLine, choices) : string.Empty;

/// <summary>Creates an array of <see cref="StreamingChatCompletionUpdate" /> instances that represent this <see cref="ChatCompletion" />.</summary>
/// <returns>An array of <see cref="StreamingChatCompletionUpdate" /> instances that may be used to represent this <see cref="ChatCompletion" />.</returns>
public StreamingChatCompletionUpdate[] ToStreamingChatCompletionUpdates()
{
StreamingChatCompletionUpdate? extra = null;
if (AdditionalProperties is not null || Usage is not null)
{
extra = new StreamingChatCompletionUpdate
{
AdditionalProperties = AdditionalProperties
};

if (Usage is { } usage)
{
extra.Contents.Add(new UsageContent(usage));
}
}

int choicesCount = Choices.Count;
var updates = new StreamingChatCompletionUpdate[choicesCount + Convert.ToInt32(extra is not null)];
stephentoub marked this conversation as resolved.
Show resolved Hide resolved

for (int choiceIndex = 0; choiceIndex < choicesCount; choiceIndex++)
{
ChatMessage choice = Choices[choiceIndex];
updates[choiceIndex] = new StreamingChatCompletionUpdate
{
ChoiceIndex = choiceIndex,

AdditionalProperties = choice.AdditionalProperties,
AuthorName = choice.AuthorName,
Contents = choice.Contents,
RawRepresentation = choice.RawRepresentation,
Role = choice.Role,

CompletionId = CompletionId,
CreatedAt = CreatedAt,
FinishReason = FinishReason,
ModelId = ModelId
};
}

if (extra is not null)
{
updates[choicesCount] = extra;
}

return updates;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,35 @@

namespace Microsoft.Extensions.AI;

// Conceptually this combines the roles of ChatCompletion and ChatMessage in streaming output.
// For ease of consumption, it also flattens the nested structure you see on streaming chunks in
// the OpenAI/Gemini APIs, so instead of a dictionary of choices, each update represents a single
// choice (and hence has its own role, choice ID, etc.).

/// <summary>
/// Represents a single response chunk from an <see cref="IChatClient"/>.
/// Represents a single streaming response chunk from an <see cref="IChatClient"/>.
/// </summary>
/// <remarks>
/// <para>
/// Conceptually, this combines the roles of <see cref="ChatCompletion"/> and <see cref="ChatMessage"/>
/// in streaming output. For ease of consumption, it also flattens the nested structure you see on
/// streaming chunks in some AI service, so instead of a dictionary of choices, each update represents a
/// single choice (and hence has its own role, choice ID, etc.).
/// </para>
/// <para>
/// <see cref="StreamingChatCompletionUpdate"/> is so named because it represents streaming updates
/// to a single chat completion. As such, it is considered erroneous for multiple updates that are part
/// of the same completion to contain competing values. For example, some updates that are part of
/// the same completion may have a <see langword="null"/> <see cref="StreamingChatCompletionUpdate.Role"/>
/// value, and others may have a non-<see langword="null"/> value, but all of those with a non-<see langword="null"/>
/// value must have the same value (e.g. <see cref="ChatRole.Assistant"/>. It should never be the case, for example,
/// that one <see cref="StreamingChatCompletionUpdate"/> in a completion has a role of <see cref="ChatRole.Assistant"/>
/// while another has a role of "AI".
/// </para>
/// <para>
/// The relationship between <see cref="ChatCompletion"/> and <see cref="StreamingChatCompletionUpdate"/> is
/// codified in the <see cref="StreamingChatCompletionUpdateExtensions.ToChatCompletionAsync"/> and
/// <see cref="ChatCompletion.ToStreamingChatCompletionUpdates"/>, which enable bidirectional conversions
/// between the two. Note, however, that the conversion may be slightly lossy, for example if multiple updates
/// all have different <see cref="StreamingChatCompletionUpdate.RawRepresentation"/> objects whereas there's
/// only one slot for such an object available in <see cref="ChatCompletion.RawRepresentation"/>.
/// </para>
/// </remarks>
public class StreamingChatCompletionUpdate
{
/// <summary>The completion update content items.</summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Linq;
#if NET
using System.Runtime.InteropServices;
#endif
Expand Down Expand Up @@ -133,7 +134,22 @@ private static void ProcessUpdate(StreamingChatCompletionUpdate update, Dictiona
/// <param name="coalesceContent">The corresponding option value provided to <see cref="ToChatCompletion"/> or <see cref="ToChatCompletionAsync"/>.</param>
private static void AddMessagesToCompletion(Dictionary<int, ChatMessage> messages, ChatCompletion completion, bool coalesceContent)
{
foreach (var entry in messages)
if (messages.Count <= 1)
{
foreach (var entry in messages)
{
AddMessage(completion, coalesceContent, entry);
}
}
else
{
foreach (var entry in messages.OrderBy(entry => entry.Key))
{
AddMessage(completion, coalesceContent, entry);
}
}

static void AddMessage(ChatCompletion completion, bool coalesceContent, KeyValuePair<int, ChatMessage> entry)
{
if (entry.Value.Role == default)
{
Expand All @@ -154,6 +170,8 @@ private static void AddMessagesToCompletion(Dictionary<int, ChatMessage> message
if (content is UsageContent c)
{
completion.Usage = c.Details;
entry.Value.Contents = entry.Value.Contents.ToList();
_ = entry.Value.Contents.Remove(c);
break;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;
Expand Down Expand Up @@ -48,13 +47,12 @@ public override async Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chat
// concurrent callers might trigger duplicate requests, but that's acceptable.
var cacheKey = GetCacheKey(false, chatMessages, options);

if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is ChatCompletion existing)
if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is not { } result)
{
return existing;
result = await base.CompleteAsync(chatMessages, options, cancellationToken).ConfigureAwait(false);
await WriteCacheAsync(cacheKey, result, cancellationToken).ConfigureAwait(false);
}

var result = await base.CompleteAsync(chatMessages, options, cancellationToken).ConfigureAwait(false);
await WriteCacheAsync(cacheKey, result, cancellationToken).ConfigureAwait(false);
return result;
}

Expand All @@ -64,127 +62,59 @@ public override async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteSt
{
_ = Throw.IfNull(chatMessages);

var cacheKey = GetCacheKey(true, chatMessages, options);
if (await ReadCacheStreamingAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } existingChunks)
if (CoalesceStreamingUpdates)
{
// Yield all of the cached items.
foreach (var chunk in existingChunks)
// When coalescing updates, we cache non-streaming results coalesced from streaming ones. That means
// we make a streaming request, yielding those results, but then convert those into a non-streaming
// result and cache it. When we get a cache hit, we yield the non-streaming result as a streaming one.

var cacheKey = GetCacheKey(true, chatMessages, options);
if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } chatCompletion)
{
yield return chunk;
// Yield all of the cached items.
foreach (var chunk in chatCompletion.ToStreamingChatCompletionUpdates())
{
yield return chunk;
}
}
else
{
// Yield and store all of the items.
List<StreamingChatCompletionUpdate> capturedItems = [];
await foreach (var chunk in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false))
{
capturedItems.Add(chunk);
yield return chunk;
}

// Write the captured items to the cache as a non-streaming result.
await WriteCacheAsync(cacheKey, capturedItems.ToChatCompletion(), cancellationToken).ConfigureAwait(false);
}
}
else
{
// Yield and store all of the items.
List<StreamingChatCompletionUpdate> capturedItems = [];
await foreach (var chunk in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false))
var cacheKey = GetCacheKey(true, chatMessages, options);
if (await ReadCacheStreamingAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } existingChunks)
{
capturedItems.Add(chunk);
yield return chunk;
// Yield all of the cached items.
foreach (var chunk in existingChunks)
{
yield return chunk;
}
}

// If the caching client is configured to coalesce streaming updates, do so now within the capturedItems list.
if (CoalesceStreamingUpdates)
else
{
StringBuilder coalescedText = new();

// Iterate through all of the items in the list looking for contiguous items that can be coalesced.
for (int startInclusive = 0; startInclusive < capturedItems.Count; startInclusive++)
// Yield and store all of the items.
List<StreamingChatCompletionUpdate> capturedItems = [];
await foreach (var chunk in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false))
{
// If an item isn't generally coalescable, skip it.
StreamingChatCompletionUpdate update = capturedItems[startInclusive];
if (update.ChoiceIndex != 0 ||
update.Contents.Count != 1 ||
update.Contents[0] is not TextContent textContent)
{
continue;
}

// We found a coalescable item. Look for more contiguous items that are also coalescable with it.
int endExclusive = startInclusive + 1;
for (; endExclusive < capturedItems.Count; endExclusive++)
{
StreamingChatCompletionUpdate next = capturedItems[endExclusive];
if (next.ChoiceIndex != 0 ||
next.Contents.Count != 1 ||
next.Contents[0] is not TextContent ||

// changing role or author would be really strange, but check anyway
(update.Role is not null && next.Role is not null && update.Role != next.Role) ||
(update.AuthorName is not null && next.AuthorName is not null && update.AuthorName != next.AuthorName))
{
break;
}
}

// If we couldn't find anything to coalesce, there's nothing to do.
if (endExclusive - startInclusive <= 1)
{
continue;
}

// We found a coalescable run of items. Create a new node to represent the run. We create a new one
// rather than reappropriating one of the existing ones so as not to mutate an item already yielded.
_ = coalescedText.Clear().Append(capturedItems[startInclusive].Text);

TextContent coalescedContent = new(null) // will patch the text after examining all items in the run
{
AdditionalProperties = textContent.AdditionalProperties?.Clone(),
};

StreamingChatCompletionUpdate coalesced = new()
{
AdditionalProperties = update.AdditionalProperties?.Clone(),
AuthorName = update.AuthorName,
CompletionId = update.CompletionId,
Contents = [coalescedContent],
CreatedAt = update.CreatedAt,
FinishReason = update.FinishReason,
ModelId = update.ModelId,
Role = update.Role,

// Explicitly don't include RawRepresentation. It's not applicable if one update ends up being used
// to represent multiple, and it won't be serialized anyway.
};

// Replace the starting node with the coalesced node.
capturedItems[startInclusive] = coalesced;

// Now iterate through all the rest of the updates in the run, updating the coalesced node with relevant properties,
// and nulling out the nodes along the way. We do this rather than removing the entry in order to avoid an O(N^2) operation.
// We'll remove all the null entries at the end of the loop, using RemoveAll to do so, which can remove all of
// the nulls in a single O(N) pass.
for (int i = startInclusive + 1; i < endExclusive; i++)
{
// Grab the next item.
StreamingChatCompletionUpdate next = capturedItems[i];
capturedItems[i] = null!;

var nextContent = (TextContent)next.Contents[0];
_ = coalescedText.Append(nextContent.Text);

coalesced.AuthorName ??= next.AuthorName;
coalesced.CompletionId ??= next.CompletionId;
coalesced.CreatedAt ??= next.CreatedAt;
coalesced.FinishReason ??= next.FinishReason;
coalesced.ModelId ??= next.ModelId;
coalesced.Role ??= next.Role;
}

// Complete the coalescing by patching the text of the coalesced node.
coalesced.Text = coalescedText.ToString();

// Jump to the last update in the run, so that when we loop around and bump ahead,
// we're at the next update just after the run.
startInclusive = endExclusive - 1;
capturedItems.Add(chunk);
yield return chunk;
}

// Remove all of the null slots left over from the coalescing process.
_ = capturedItems.RemoveAll(u => u is null);
// Write the captured items to the cache.
await WriteCacheStreamingAsync(cacheKey, capturedItems, cancellationToken).ConfigureAwait(false);
}

// Write the captured items to the cache.
await WriteCacheStreamingAsync(cacheKey, capturedItems, cancellationToken).ConfigureAwait(false);
}
}

Expand Down
Loading
Loading