Skip to content

Commit 596761a

Browse files
.Net: [fix] StepwisePlanner ChatHistory token calculation refactoring (#2788)
Refactor chat history token calculation and clipping. Leaving the existing `ChatHistory` as-is to ensure a complete transcript is available. Extended class `ChatHistory` to support insert. Fixes #2773 ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone 😄
1 parent 92c4977 commit 596761a

File tree

4 files changed

+81
-34
lines changed

4 files changed

+81
-34
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
3+
using System.Linq;
4+
using Microsoft.SemanticKernel.AI.ChatCompletion;
5+
using static Microsoft.SemanticKernel.Text.TextChunker;
6+
7+
namespace Microsoft.SemanticKernel.Planning.Stepwise;
8+
9+
/// <summary>
10+
/// Extension methods for <see cref="ChatHistory"/> class.
11+
/// </summary>
12+
public static class ChatHistoryExtensions
13+
{
14+
/// <summary>
15+
/// Returns the number of tokens in the chat history.
16+
/// </summary>
17+
// <param name="chatHistory">The chat history.</param>
18+
// <param name="additionalMessage">An additional message to include in the token count.</param>
19+
// <param name="skipStart">The index to start skipping messages.</param>
20+
// <param name="skipCount">The number of messages to skip.</param>
21+
// <param name="tokenCounter">The token counter to use.</param>
22+
internal static int GetTokenCount(this ChatHistory chatHistory, string? additionalMessage = null, int skipStart = 0, int skipCount = 0, TokenCounter? tokenCounter = null)
23+
{
24+
tokenCounter ??= DefaultTokenCounter;
25+
26+
var messages = string.Join("\n", chatHistory.Where((m, i) => i < skipStart || i >= skipStart + skipCount).Select(m => m.Content));
27+
28+
if (!string.IsNullOrEmpty(additionalMessage))
29+
{
30+
messages = $"{messages}\n{additionalMessage}";
31+
}
32+
33+
var tokenCount = tokenCounter(messages);
34+
return tokenCount;
35+
}
36+
37+
private static int DefaultTokenCounter(string input)
38+
{
39+
return input.Length / 4;
40+
}
41+
}

dotnet/src/Extensions/Planning.StepwisePlanner/StepwisePlanner.cs

+24-32
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ public async Task<SKContext> ExecutePlanAsync(
128128
throw new SKException("ChatHistory is null.");
129129
}
130130

131-
var startingMessageCount = chatHistory.Messages.Count;
131+
var startingMessageCount = chatHistory.Count;
132132

133133
var stepsTaken = new List<SystemStep>();
134134
SystemStep? lastStep = null;
@@ -205,9 +205,9 @@ public async Task<SKContext> ExecutePlanAsync(
205205

206206
lastStep.OriginalResponse += step.OriginalResponse;
207207
step = lastStep;
208-
if (chatHistory.Messages.Count > startingMessageCount)
208+
if (chatHistory.Count > startingMessageCount)
209209
{
210-
chatHistory.Messages.RemoveAt(chatHistory.Messages.Count - 1);
210+
chatHistory.RemoveAt(chatHistory.Count - 1);
211211
}
212212
}
213213
else
@@ -382,34 +382,28 @@ private Task<string> GetSystemMessage(SKContext context)
382382

383383
private Task<string> GetNextStepCompletion(List<SystemStep> stepsTaken, ChatHistory chatHistory, IAIService aiService, int startingMessageCount, CancellationToken token)
384384
{
385-
var tokenCount = this.GetChatHistoryTokens(chatHistory);
385+
var skipStart = startingMessageCount;
386+
var skipCount = 0;
386387

387-
var preserveFirstNSteps = 0;
388-
var removalIndex = (startingMessageCount) + preserveFirstNSteps;
389-
var messagesRemoved = 0;
390388
string? originalThought = null;
391-
while (tokenCount >= this.Config.MaxTokens && chatHistory.Messages.Count > removalIndex)
392-
{
393-
// something needs to be removed.
394-
if (string.IsNullOrEmpty(originalThought))
395-
{
396-
originalThought = stepsTaken[0].Thought;
397-
}
398389

399-
// Update message history
400-
chatHistory.AddAssistantMessage($"{Thought} {originalThought}");
401-
preserveFirstNSteps++;
402-
chatHistory.AddAssistantMessage("... I've removed some of my previous work to make room for the new stuff ...");
403-
preserveFirstNSteps++;
390+
var tokenCount = chatHistory.GetTokenCount();
391+
while (tokenCount >= this.Config.MaxTokens && chatHistory.Count > skipStart)
392+
{
393+
originalThought = $"{Thought} {stepsTaken.FirstOrDefault()?.Thought}";
394+
tokenCount = chatHistory.GetTokenCount($"{originalThought}\n{TrimMessage}", skipStart, ++skipCount);
395+
}
404396

405-
removalIndex = (startingMessageCount) + preserveFirstNSteps;
397+
var reducedChatHistory = new ChatHistory();
398+
reducedChatHistory.AddRange(chatHistory.Where((m, i) => i < skipStart || i >= skipStart + skipCount));
406399

407-
chatHistory.Messages.RemoveAt(removalIndex);
408-
tokenCount = this.GetChatHistoryTokens(chatHistory);
409-
messagesRemoved++;
400+
if (skipCount > 0 && originalThought is not null)
401+
{
402+
reducedChatHistory.InsertMessage(skipStart, AuthorRole.Assistant, TrimMessage);
403+
reducedChatHistory.InsertMessage(skipStart, AuthorRole.Assistant, originalThought);
410404
}
411405

412-
return this.GetCompletionAsync(aiService, chatHistory, stepsTaken.Count == 0, token);
406+
return this.GetCompletionAsync(aiService, reducedChatHistory, stepsTaken.Count == 0, token);
413407
}
414408

415409
private async Task<string> GetCompletionAsync(IAIService aiService, ChatHistory chatHistory, bool addThought, CancellationToken token)
@@ -421,7 +415,7 @@ private async Task<string> GetCompletionAsync(IAIService aiService, ChatHistory
421415
}
422416
else if (aiService is ITextCompletion textCompletion)
423417
{
424-
var thoughtProcess = string.Join("\n", chatHistory.Messages.Select(m => m.Content));
418+
var thoughtProcess = string.Join("\n", chatHistory.Select(m => m.Content));
425419

426420
// Add Thought to the thought process at the start of the first iteration
427421
if (addThought)
@@ -444,13 +438,6 @@ private async Task<string> GetCompletionAsync(IAIService aiService, ChatHistory
444438
throw new SKException("No AIService available for getting completions.");
445439
}
446440

447-
private int GetChatHistoryTokens(ChatHistory chatHistory)
448-
{
449-
var messages = string.Join("\n", chatHistory.Messages);
450-
var tokenCount = messages.Length / 4;
451-
return tokenCount;
452-
}
453-
454441
/// <summary>
455442
/// Parse LLM response into a SystemStep during execution
456443
/// </summary>
@@ -749,6 +736,11 @@ private static string ToFullyQualifiedName(FunctionView function)
749736
/// </summary>
750737
private const string Observation = "[OBSERVATION]";
751738

739+
/// <summary>
740+
/// The chat message to include when trimming thought process history
741+
/// </summary>
742+
private const string TrimMessage = "... I've removed some of my previous work to make room for the new stuff ...";
743+
752744
/// <summary>
753745
/// The regex for parsing the thought response
754746
/// </summary>

dotnet/src/Extensions/Planning.StepwisePlanner/StepwisePlannerConfig.cs

+5-2
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,12 @@ public sealed class StepwisePlannerConfig
5151
#region Execution configuration
5252

5353
/// <summary>
54-
/// The maximum number of tokens to allow in a plan.
54+
/// The maximum number of tokens to allow in a request and for completion.
5555
/// </summary>
56-
public int MaxTokens { get; set; } = 1024;
56+
/// <remarks>
57+
/// Default value is 2000.
58+
/// </remarks>
59+
public int MaxTokens { get; set; } = 2000;
5760

5861
/// <summary>
5962
/// The maximum number of iterations to allow in a plan.

dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs

+11
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,17 @@ public void AddMessage(AuthorRole authorRole, string content)
3333
this.Add(new ChatMessage(authorRole, content));
3434
}
3535

36+
/// <summary>
37+
/// Insert a message into the chat history
38+
/// </summary>
39+
/// <param name="index">Index of the message to insert</param>
40+
/// <param name="authorRole">Role of the message author</param>
41+
/// <param name="content">Message content</param>
42+
public void InsertMessage(int index, AuthorRole authorRole, string content)
43+
{
44+
this.Insert(index, new ChatMessage(authorRole, content));
45+
}
46+
3647
/// <summary>
3748
/// Add a user message to the chat history
3849
/// </summary>

0 commit comments

Comments
 (0)