Skip to content

Commit 8151fbe

Browse files
MackinnonBuckjeffhandley
authored andcommitted
Preserve function content in SummarizingChatReducer (dotnet#6908)
1 parent b819cfd commit 8151fbe

File tree

2 files changed

+229
-34
lines changed

2 files changed

+229
-34
lines changed

src/Libraries/Microsoft.Extensions.AI/ChatReduction/SummarizingChatReducer.cs

Lines changed: 84 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System.Collections.Generic;
5+
using System.Diagnostics;
56
using System.Diagnostics.CodeAnalysis;
67
using System.Linq;
78
using System.Threading;
@@ -73,18 +74,24 @@ public async Task<IEnumerable<ChatMessage>> ReduceAsync(IEnumerable<ChatMessage>
7374
{
7475
_ = Throw.IfNull(messages);
7576

76-
var summarizedConversion = SummarizedConversation.FromChatMessages(messages);
77-
if (summarizedConversion.ShouldResummarize(_targetCount, _thresholdCount))
77+
var summarizedConversation = SummarizedConversation.FromChatMessages(messages);
78+
var indexOfFirstMessageToKeep = summarizedConversation.FindIndexOfFirstMessageToKeep(_targetCount, _thresholdCount);
79+
if (indexOfFirstMessageToKeep > 0)
7880
{
79-
summarizedConversion = await summarizedConversion.ResummarizeAsync(
80-
_chatClient, _targetCount, SummarizationPrompt, cancellationToken);
81+
summarizedConversation = await summarizedConversation.ResummarizeAsync(
82+
_chatClient,
83+
indexOfFirstMessageToKeep,
84+
SummarizationPrompt,
85+
cancellationToken);
8186
}
8287

83-
return summarizedConversion.ToChatMessages();
88+
return summarizedConversation.ToChatMessages();
8489
}
8590

91+
/// <summary>Represents a conversation with an optional summary.</summary>
8692
private readonly struct SummarizedConversation(string? summary, ChatMessage? systemMessage, IList<ChatMessage> unsummarizedMessages)
8793
{
94+
/// <summary>Creates a <see cref="SummarizedConversation"/> from a list of chat messages.</summary>
8895
public static SummarizedConversation FromChatMessages(IEnumerable<ChatMessage> messages)
8996
{
9097
string? summary = null;
@@ -102,7 +109,7 @@ public static SummarizedConversation FromChatMessages(IEnumerable<ChatMessage> m
102109
unsummarizedMessages.Clear();
103110
summary = summaryValue;
104111
}
105-
else if (!message.Contents.Any(m => m is FunctionCallContent or FunctionResultContent))
112+
else
106113
{
107114
unsummarizedMessages.Add(message);
108115
}
@@ -111,31 +118,68 @@ public static SummarizedConversation FromChatMessages(IEnumerable<ChatMessage> m
111118
return new(summary, systemMessage, unsummarizedMessages);
112119
}
113120

114-
public bool ShouldResummarize(int targetCount, int thresholdCount)
115-
=> unsummarizedMessages.Count > targetCount + thresholdCount;
116-
117-
public async Task<SummarizedConversation> ResummarizeAsync(
118-
IChatClient chatClient, int targetCount, string summarizationPrompt, CancellationToken cancellationToken)
121+
/// <summary>Performs summarization by calling the chat client and updating the conversation state.</summary>
122+
public async ValueTask<SummarizedConversation> ResummarizeAsync(
123+
IChatClient chatClient, int indexOfFirstMessageToKeep, string summarizationPrompt, CancellationToken cancellationToken)
119124
{
120-
var messagesToResummarize = unsummarizedMessages.Count - targetCount;
121-
if (messagesToResummarize <= 0)
122-
{
123-
// We're at or below the target count - no need to resummarize.
124-
return this;
125-
}
125+
Debug.Assert(indexOfFirstMessageToKeep > 0, "Expected positive index for first message to keep.");
126126

127-
var summarizerChatMessages = ToSummarizerChatMessages(messagesToResummarize, summarizationPrompt);
127+
// Generate the summary by sending unsummarized messages to the chat client
128+
var summarizerChatMessages = ToSummarizerChatMessages(indexOfFirstMessageToKeep, summarizationPrompt);
128129
var response = await chatClient.GetResponseAsync(summarizerChatMessages, cancellationToken: cancellationToken);
129130
var newSummary = response.Text;
130131

131-
var lastSummarizedMessage = unsummarizedMessages[messagesToResummarize - 1];
132+
// Attach the summary metadata to the last message being summarized
133+
// This is what allows us to build on previously-generated summaries
134+
var lastSummarizedMessage = unsummarizedMessages[indexOfFirstMessageToKeep - 1];
132135
var additionalProperties = lastSummarizedMessage.AdditionalProperties ??= [];
133136
additionalProperties[SummaryKey] = newSummary;
134137

135-
var newUnsummarizedMessages = unsummarizedMessages.Skip(messagesToResummarize).ToList();
138+
// Compute the new list of unsummarized messages
139+
var newUnsummarizedMessages = unsummarizedMessages.Skip(indexOfFirstMessageToKeep).ToList();
136140
return new SummarizedConversation(newSummary, systemMessage, newUnsummarizedMessages);
137141
}
138142

143+
/// <summary>Determines the index of the first message to keep (not summarize) based on target and threshold counts.</summary>
144+
public int FindIndexOfFirstMessageToKeep(int targetCount, int thresholdCount)
145+
{
146+
var earliestAllowedIndex = unsummarizedMessages.Count - thresholdCount - targetCount;
147+
if (earliestAllowedIndex <= 0)
148+
{
149+
// Not enough messages to warrant summarization
150+
return 0;
151+
}
152+
153+
// Start at the ideal cut point (keeping exactly targetCount messages)
154+
var indexOfFirstMessageToKeep = unsummarizedMessages.Count - targetCount;
155+
156+
// Move backward to skip over function call/result content at the boundary
157+
// We want to keep complete function call sequences together with their responses
158+
while (indexOfFirstMessageToKeep > 0)
159+
{
160+
if (!unsummarizedMessages[indexOfFirstMessageToKeep - 1].Contents.Any(IsToolRelatedContent))
161+
{
162+
break;
163+
}
164+
165+
indexOfFirstMessageToKeep--;
166+
}
167+
168+
// Search backward within the threshold window to find a User message
169+
// If found, cut right before it to avoid orphaning user questions from responses
170+
for (var i = indexOfFirstMessageToKeep; i >= earliestAllowedIndex; i--)
171+
{
172+
if (unsummarizedMessages[i].Role == ChatRole.User)
173+
{
174+
return i;
175+
}
176+
}
177+
178+
// No User message found within threshold - use the adjusted cut point
179+
return indexOfFirstMessageToKeep;
180+
}
181+
182+
/// <summary>Converts the summarized conversation back into a collection of chat messages.</summary>
139183
public IEnumerable<ChatMessage> ToChatMessages()
140184
{
141185
if (systemMessage is not null)
@@ -154,16 +198,33 @@ public IEnumerable<ChatMessage> ToChatMessages()
154198
}
155199
}
156200

157-
private IEnumerable<ChatMessage> ToSummarizerChatMessages(int messagesToResummarize, string summarizationPrompt)
201+
/// <summary>Returns whether the given <see cref="AIContent"/> relates to tool calling capabilities.</summary>
202+
/// <remarks>
203+
/// This method returns <see langword="true"/> for content types whose meaning depends on other related <see cref="AIContent"/>
204+
/// instances in the conversation, such as function calls that require corresponding results, or other tool interactions that span
205+
/// multiple messages. Such content should be kept together during summarization.
206+
/// </remarks>
207+
private static bool IsToolRelatedContent(AIContent content) => content
208+
is FunctionCallContent
209+
or FunctionResultContent
210+
or UserInputRequestContent
211+
or UserInputResponseContent;
212+
213+
/// <summary>Builds the list of messages to send to the chat client for summarization.</summary>
214+
private IEnumerable<ChatMessage> ToSummarizerChatMessages(int indexOfFirstMessageToKeep, string summarizationPrompt)
158215
{
159216
if (summary is not null)
160217
{
161218
yield return new ChatMessage(ChatRole.Assistant, summary);
162219
}
163220

164-
for (var i = 0; i < messagesToResummarize; i++)
221+
for (var i = 0; i < indexOfFirstMessageToKeep; i++)
165222
{
166-
yield return unsummarizedMessages[i];
223+
var message = unsummarizedMessages[i];
224+
if (!message.Contents.Any(IsToolRelatedContent))
225+
{
226+
yield return message;
227+
}
167228
}
168229

169230
yield return new ChatMessage(ChatRole.System, summarizationPrompt);

test/Libraries/Microsoft.Extensions.AI.Tests/ChatReduction/SummarizingChatReducerTests.cs

Lines changed: 145 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,27 +84,145 @@ public async Task ReduceAsync_PreservesSystemMessage()
8484
}
8585

8686
[Fact]
87-
public async Task ReduceAsync_IgnoresFunctionCallsAndResults()
87+
public async Task ReduceAsync_PreservesCompleteToolCallSequence()
8888
{
8989
using var chatClient = new TestChatClient();
90-
var reducer = new SummarizingChatReducer(chatClient, targetCount: 3, threshold: 0);
90+
91+
// Target 2 messages, but this would split a function call sequence
92+
var reducer = new SummarizingChatReducer(chatClient, targetCount: 2, threshold: 0);
9193

9294
List<ChatMessage> messages =
9395
[
96+
new ChatMessage(ChatRole.User, "What's the time?"),
97+
new ChatMessage(ChatRole.Assistant, "Let me check"),
9498
new ChatMessage(ChatRole.User, "What's the weather?"),
95-
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("call1", "get_weather", new Dictionary<string, object?> { ["location"] = "Seattle" })]),
96-
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("call1", "Sunny, 72°F")]),
97-
new ChatMessage(ChatRole.Assistant, "The weather in Seattle is sunny and 72°F."),
98-
new ChatMessage(ChatRole.User, "Thanks!"),
99+
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("call1", "get_weather"), new TestUserInputRequestContent("uir1")]),
100+
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("call1", "Sunny")]),
101+
new ChatMessage(ChatRole.User, [new TestUserInputResponseContent("uir1")]),
102+
new ChatMessage(ChatRole.Assistant, "It's sunny"),
99103
];
100104

105+
chatClient.GetResponseAsyncCallback = (msgs, _, _) =>
106+
{
107+
Assert.DoesNotContain(msgs, m => m.Contents.Any(c => c is FunctionCallContent or FunctionResultContent or TestUserInputRequestContent or TestUserInputResponseContent));
108+
return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "Asked about time")));
109+
};
110+
101111
var result = await reducer.ReduceAsync(messages, CancellationToken.None);
112+
var resultList = result.ToList();
102113

103-
// Function calls/results should be ignored, which means there aren't enough messages to generate a summary.
114+
// Should have: summary + function call + function result + user input response + last reply
115+
Assert.Equal(5, resultList.Count);
116+
117+
// Verify the complete sequence is preserved
118+
Assert.Collection(resultList,
119+
m => Assert.Contains("Asked about time", m.Text),
120+
m =>
121+
{
122+
Assert.Contains(m.Contents, c => c is FunctionCallContent);
123+
Assert.Contains(m.Contents, c => c is TestUserInputRequestContent);
124+
},
125+
m => Assert.Contains(m.Contents, c => c is FunctionResultContent),
126+
m => Assert.Contains(m.Contents, c => c is TestUserInputResponseContent),
127+
m => Assert.Contains("sunny", m.Text));
128+
}
129+
130+
[Fact]
131+
public async Task ReduceAsync_PreservesUserMessageWhenWithinThreshold()
132+
{
133+
using var chatClient = new TestChatClient();
134+
135+
// Target 3 messages with threshold of 2
136+
// This allows us to keep anywhere from 3 to 5 messages
137+
var reducer = new SummarizingChatReducer(chatClient, targetCount: 3, threshold: 2);
138+
139+
List<ChatMessage> messages =
140+
[
141+
new ChatMessage(ChatRole.User, "First question"),
142+
new ChatMessage(ChatRole.Assistant, "First answer"),
143+
new ChatMessage(ChatRole.User, "Second question"),
144+
new ChatMessage(ChatRole.Assistant, "Second answer"),
145+
new ChatMessage(ChatRole.User, "Third question"),
146+
new ChatMessage(ChatRole.Assistant, "Third answer"),
147+
];
148+
149+
chatClient.GetResponseAsyncCallback = (msgs, _, _) =>
150+
{
151+
var msgList = msgs.ToList();
152+
153+
// Should summarize messages 0-1 (First question and answer)
154+
// The reducer should find the User message at index 2 within the threshold
155+
Assert.Equal(3, msgList.Count); // 2 messages to summarize + system prompt
156+
return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "Summary of first exchange")));
157+
};
158+
159+
var result = await reducer.ReduceAsync(messages, CancellationToken.None);
104160
var resultList = result.ToList();
105-
Assert.Equal(3, resultList.Count); // Function calls get removed in the summarized chat.
106-
Assert.DoesNotContain(resultList, m => m.Contents.Any(c => c is FunctionCallContent));
107-
Assert.DoesNotContain(resultList, m => m.Contents.Any(c => c is FunctionResultContent));
161+
162+
// Should have: summary + 4 kept messages (from "Second question" onward)
163+
Assert.Equal(5, resultList.Count);
164+
165+
// Verify the summary is first
166+
Assert.Contains("Summary", resultList[0].Text);
167+
168+
// Verify we kept the User message at index 2 and everything after
169+
Assert.Collection(resultList.Skip(1),
170+
m => Assert.Contains("Second question", m.Text),
171+
m => Assert.Contains("Second answer", m.Text),
172+
m => Assert.Contains("Third question", m.Text),
173+
m => Assert.Contains("Third answer", m.Text));
174+
}
175+
176+
[Fact]
177+
public async Task ReduceAsync_ExcludesToolCallsFromSummarizedPortion()
178+
{
179+
using var chatClient = new TestChatClient();
180+
181+
// Target 3 messages - this will cause function calls in older messages to be summarized (excluded)
182+
// while function calls in recent messages are kept
183+
var reducer = new SummarizingChatReducer(chatClient, targetCount: 3, threshold: 0);
184+
185+
List<ChatMessage> messages =
186+
[
187+
new ChatMessage(ChatRole.User, "What's the weather in Seattle?"),
188+
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("call1", "get_weather", new Dictionary<string, object?> { ["location"] = "Seattle" }), new TestUserInputRequestContent("uir2")]),
189+
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("call1", "Sunny, 72°F")]),
190+
new ChatMessage(ChatRole.User, [new TestUserInputResponseContent("uir2")]),
191+
new ChatMessage(ChatRole.Assistant, "It's sunny and 72°F in Seattle."),
192+
new ChatMessage(ChatRole.User, "What about New York?"),
193+
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("call2", "get_weather", new Dictionary<string, object?> { ["location"] = "New York" })]),
194+
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("call2", "Rainy, 65°F")]),
195+
new ChatMessage(ChatRole.Assistant, "It's rainy and 65°F in New York."),
196+
];
197+
198+
chatClient.GetResponseAsyncCallback = (msgs, _, _) =>
199+
{
200+
var msgList = msgs.ToList();
201+
202+
Assert.Equal(4, msgList.Count); // 3 non-function messages + system prompt
203+
Assert.DoesNotContain(msgList, m => m.Contents.Any(c => c is FunctionCallContent or FunctionResultContent or TestUserInputRequestContent or TestUserInputResponseContent));
204+
Assert.Contains(msgList, m => m.Text.Contains("What's the weather in Seattle?"));
205+
Assert.Contains(msgList, m => m.Text.Contains("sunny and 72°F in Seattle"));
206+
Assert.Contains(msgList, m => m.Text.Contains("What about New York?"));
207+
Assert.Contains(msgList, m => m.Role == ChatRole.System);
208+
209+
return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "User asked about weather in Seattle and New York.")));
210+
};
211+
212+
var result = await reducer.ReduceAsync(messages, CancellationToken.None);
213+
var resultList = result.ToList();
214+
215+
// Should have: summary + 3 kept messages (the last 3 messages with function calls)
216+
Assert.Equal(4, resultList.Count);
217+
218+
Assert.Contains("User asked about weather", resultList[0].Text);
219+
Assert.Contains(resultList, m => m.Contents.Any(c => c is FunctionCallContent fc && fc.CallId == "call2"));
220+
Assert.Contains(resultList, m => m.Contents.Any(c => c is FunctionResultContent fr && fr.CallId == "call2"));
221+
Assert.DoesNotContain(resultList, m => m.Contents.Any(c => c is FunctionCallContent fc && fc.CallId == "call1"));
222+
Assert.DoesNotContain(resultList, m => m.Contents.Any(c => c is FunctionResultContent fr && fr.CallId == "call1"));
223+
Assert.DoesNotContain(resultList, m => m.Contents.Any(c => c is TestUserInputRequestContent));
224+
Assert.DoesNotContain(resultList, m => m.Contents.Any(c => c is TestUserInputResponseContent));
225+
Assert.DoesNotContain(resultList, m => m.Text.Contains("sunny and 72°F in Seattle"));
108226
}
109227

110228
[Theory]
@@ -121,7 +239,7 @@ public async Task ReduceAsync_RespectsTargetAndThresholdCounts(int targetCount,
121239
var messages = new List<ChatMessage>();
122240
for (int i = 0; i < messageCount; i++)
123241
{
124-
messages.Add(new ChatMessage(i % 2 == 0 ? ChatRole.User : ChatRole.Assistant, $"Message {i}"));
242+
messages.Add(new ChatMessage(ChatRole.Assistant, $"Message {i}"));
125243
}
126244

127245
var summarizationCalled = false;
@@ -266,4 +384,20 @@ need frequent exercise. The user then asked about whether they're good around ki
266384
m => Assert.StartsWith("Golden retrievers get along", m.Text, StringComparison.Ordinal),
267385
m => Assert.StartsWith("Do they make good lap dogs", m.Text, StringComparison.Ordinal));
268386
}
387+
388+
private sealed class TestUserInputRequestContent : UserInputRequestContent
389+
{
390+
public TestUserInputRequestContent(string id)
391+
: base(id)
392+
{
393+
}
394+
}
395+
396+
private sealed class TestUserInputResponseContent : UserInputResponseContent
397+
{
398+
public TestUserInputResponseContent(string id)
399+
: base(id)
400+
{
401+
}
402+
}
269403
}

0 commit comments

Comments
 (0)