Skip to content

Commit 7cac12b

Browse files
authored
Fix a few issues in IChatClient implementations (#5549)
* Fix a few issues in IChatClient implementations - Avoid null arg exception when constructing system message with null text - Avoid empty exception when constructing user message with no parts - Use all parts rather than just first text part for system message - Handle assistant messages with both content and tools - Avoid unnecessarily trying to weed out duplicate call ids * Address PR feedback - Normalize null to string.Empty in TextContent - Ensure GetContentParts always produces at least one part, even if empty text content
1 parent 2dd959f commit 7cac12b

File tree

7 files changed

+510
-64
lines changed

7 files changed

+510
-64
lines changed
Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,36 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4+
using System.Diagnostics.CodeAnalysis;
5+
46
namespace Microsoft.Extensions.AI;
57

68
/// <summary>
79
/// Represents text content in a chat.
810
/// </summary>
911
public sealed class TextContent : AIContent
1012
{
13+
private string? _text;
14+
1115
/// <summary>
1216
/// Initializes a new instance of the <see cref="TextContent"/> class.
1317
/// </summary>
1418
/// <param name="text">The text content.</param>
1519
public TextContent(string? text)
1620
{
17-
Text = text;
21+
_text = text;
1822
}
1923

2024
/// <summary>
2125
/// Gets or sets the text content.
2226
/// </summary>
23-
public string? Text { get; set; }
27+
[AllowNull]
28+
public string Text
29+
{
30+
get => _text ?? string.Empty;
31+
set => _text = value;
32+
}
2433

2534
/// <inheritdoc/>
26-
public override string ToString() => Text ?? string.Empty;
35+
public override string ToString() => Text;
2736
}

src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
using System;
55
using System.Collections.Generic;
6-
using System.Linq;
76
using System.Reflection;
87
using System.Runtime.CompilerServices;
98
using System.Text;
@@ -410,13 +409,13 @@ private sealed class AzureAIChatToolJson
410409
private IEnumerable<ChatRequestMessage> ToAzureAIInferenceChatMessages(IEnumerable<ChatMessage> inputs)
411410
{
412411
// Maps all of the M.E.AI types to the corresponding AzureAI types.
413-
// Unrecognized content is ignored.
412+
// Unrecognized or non-processable content is ignored.
414413

415414
foreach (ChatMessage input in inputs)
416415
{
417416
if (input.Role == ChatRole.System)
418417
{
419-
yield return new ChatRequestSystemMessage(input.Text);
418+
yield return new ChatRequestSystemMessage(input.Text ?? string.Empty);
420419
}
421420
else if (input.Role == ChatRole.Tool)
422421
{
@@ -444,52 +443,64 @@ private IEnumerable<ChatRequestMessage> ToAzureAIInferenceChatMessages(IEnumerab
444443
}
445444
else if (input.Role == ChatRole.User)
446445
{
447-
yield return new ChatRequestUserMessage(input.Contents.Select(static (AIContent item) => item switch
448-
{
449-
TextContent textContent => new ChatMessageTextContentItem(textContent.Text),
450-
ImageContent imageContent => imageContent.Data is { IsEmpty: false } data ? new ChatMessageImageContentItem(BinaryData.FromBytes(data), imageContent.MediaType) :
451-
imageContent.Uri is string uri ? new ChatMessageImageContentItem(new Uri(uri)) :
452-
(ChatMessageContentItem?)null,
453-
_ => null,
454-
}).Where(c => c is not null));
446+
yield return new ChatRequestUserMessage(GetContentParts(input.Contents));
455447
}
456448
else if (input.Role == ChatRole.Assistant)
457449
{
458-
Dictionary<string, ChatCompletionsToolCall>? toolCalls = null;
450+
// TODO: ChatRequestAssistantMessage only enables text content currently.
451+
// Update it with other content types when it supports that.
452+
ChatRequestAssistantMessage message = new()
453+
{
454+
Content = input.Text
455+
};
459456

460457
foreach (var content in input.Contents)
461458
{
462-
if (content is FunctionCallContent callRequest && callRequest.CallId is not null && toolCalls?.ContainsKey(callRequest.CallId) is not true)
459+
if (content is FunctionCallContent { CallId: not null } callRequest)
463460
{
464461
JsonSerializerOptions serializerOptions = ToolCallJsonSerializerOptions ?? JsonContext.Default.Options;
465-
string jsonArguments = JsonSerializer.Serialize(callRequest.Arguments, serializerOptions.GetTypeInfo(typeof(IDictionary<string, object>)));
466-
(toolCalls ??= []).Add(
462+
message.ToolCalls.Add(new ChatCompletionsFunctionToolCall(
467463
callRequest.CallId,
468-
new ChatCompletionsFunctionToolCall(
469-
callRequest.CallId,
470-
callRequest.Name,
471-
jsonArguments));
464+
callRequest.Name,
465+
JsonSerializer.Serialize(callRequest.Arguments, serializerOptions.GetTypeInfo(typeof(IDictionary<string, object>)))));
472466
}
473467
}
474468

475-
ChatRequestAssistantMessage message = new();
476-
if (toolCalls is not null)
477-
{
478-
foreach (var entry in toolCalls)
479-
{
480-
message.ToolCalls.Add(entry.Value);
481-
}
482-
}
483-
else
484-
{
485-
message.Content = input.Text;
486-
}
487-
488469
yield return message;
489470
}
490471
}
491472
}
492473

474+
/// <summary>Converts a list of <see cref="AIContent"/> to a list of <see cref="ChatMessageContentItem"/>.</summary>
475+
private static List<ChatMessageContentItem> GetContentParts(IList<AIContent> contents)
476+
{
477+
List<ChatMessageContentItem> parts = [];
478+
foreach (var content in contents)
479+
{
480+
switch (content)
481+
{
482+
case TextContent textContent:
483+
(parts ??= []).Add(new ChatMessageTextContentItem(textContent.Text));
484+
break;
485+
486+
case ImageContent imageContent when imageContent.Data is { IsEmpty: false } data:
487+
(parts ??= []).Add(new ChatMessageImageContentItem(BinaryData.FromBytes(data), imageContent.MediaType));
488+
break;
489+
490+
case ImageContent imageContent when imageContent.Uri is string uri:
491+
(parts ??= []).Add(new ChatMessageImageContentItem(new Uri(uri)));
492+
break;
493+
}
494+
}
495+
496+
if (parts.Count == 0)
497+
{
498+
parts.Add(new ChatMessageTextContentItem(string.Empty));
499+
}
500+
501+
return parts;
502+
}
503+
493504
private static FunctionCallContent ParseCallContentFromJsonString(string json, string callId, string name) =>
494505
FunctionCallContent.CreateFromParsedArguments(json, callId, name,
495506
argumentParser: static json => JsonSerializer.Deserialize(json, JsonContext.Default.IDictionaryStringObject)!);

src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
using System;
55
using System.Collections.Generic;
6-
using System.Linq;
76
using System.Reflection;
87
using System.Runtime.CompilerServices;
98
using System.Text;
@@ -569,13 +568,16 @@ private sealed class OpenAIChatToolJson
569568
private IEnumerable<OpenAI.Chat.ChatMessage> ToOpenAIChatMessages(IEnumerable<ChatMessage> inputs)
570569
{
571570
// Maps all of the M.E.AI types to the corresponding OpenAI types.
572-
// Unrecognized content is ignored.
571+
// Unrecognized or non-processable content is ignored.
573572

574573
foreach (ChatMessage input in inputs)
575574
{
576-
if (input.Role == ChatRole.System)
575+
if (input.Role == ChatRole.System || input.Role == ChatRole.User)
577576
{
578-
yield return new SystemChatMessage(input.Text) { ParticipantName = input.AuthorName };
577+
var parts = GetContentParts(input.Contents);
578+
yield return input.Role == ChatRole.System ?
579+
new SystemChatMessage(parts) { ParticipantName = input.AuthorName } :
580+
new UserChatMessage(parts) { ParticipantName = input.AuthorName };
579581
}
580582
else if (input.Role == ChatRole.Tool)
581583
{
@@ -601,39 +603,25 @@ private sealed class OpenAIChatToolJson
601603
}
602604
}
603605
}
604-
else if (input.Role == ChatRole.User)
605-
{
606-
yield return new UserChatMessage(input.Contents.Select(static (AIContent item) => item switch
607-
{
608-
TextContent textContent => ChatMessageContentPart.CreateTextPart(textContent.Text),
609-
ImageContent imageContent => imageContent.Data is { IsEmpty: false } data ? ChatMessageContentPart.CreateImagePart(BinaryData.FromBytes(data), imageContent.MediaType) :
610-
imageContent.Uri is string uri ? ChatMessageContentPart.CreateImagePart(new Uri(uri)) :
611-
null,
612-
_ => null,
613-
}).Where(c => c is not null))
614-
{ ParticipantName = input.AuthorName };
615-
}
616606
else if (input.Role == ChatRole.Assistant)
617607
{
618-
Dictionary<string, ChatToolCall>? toolCalls = null;
608+
AssistantChatMessage message = new(GetContentParts(input.Contents))
609+
{
610+
ParticipantName = input.AuthorName
611+
};
619612

620613
foreach (var content in input.Contents)
621614
{
622-
if (content is FunctionCallContent callRequest && callRequest.CallId is not null && toolCalls?.ContainsKey(callRequest.CallId) is not true)
615+
if (content is FunctionCallContent { CallId: not null } callRequest)
623616
{
624-
(toolCalls ??= []).Add(
625-
callRequest.CallId,
617+
message.ToolCalls.Add(
626618
ChatToolCall.CreateFunctionToolCall(
627619
callRequest.CallId,
628620
callRequest.Name,
629621
BinaryData.FromObjectAsJson(callRequest.Arguments, ToolCallJsonSerializerOptions)));
630622
}
631623
}
632624

633-
AssistantChatMessage message = toolCalls is not null ?
634-
new(toolCalls.Values) { ParticipantName = input.AuthorName } :
635-
new(input.Text) { ParticipantName = input.AuthorName };
636-
637625
if (input.AdditionalProperties?.TryGetValue(nameof(message.Refusal), out string? refusal) is true)
638626
{
639627
message.Refusal = refusal;
@@ -644,6 +632,36 @@ private sealed class OpenAIChatToolJson
644632
}
645633
}
646634

635+
/// <summary>Converts a list of <see cref="AIContent"/> to a list of <see cref="ChatMessageContentPart"/>.</summary>
636+
private static List<ChatMessageContentPart> GetContentParts(IList<AIContent> contents)
637+
{
638+
List<ChatMessageContentPart> parts = [];
639+
foreach (var content in contents)
640+
{
641+
switch (content)
642+
{
643+
case TextContent textContent:
644+
(parts ??= []).Add(ChatMessageContentPart.CreateTextPart(textContent.Text));
645+
break;
646+
647+
case ImageContent imageContent when imageContent.Data is { IsEmpty: false } data:
648+
(parts ??= []).Add(ChatMessageContentPart.CreateImagePart(BinaryData.FromBytes(data), imageContent.MediaType));
649+
break;
650+
651+
case ImageContent imageContent when imageContent.Uri is string uri:
652+
(parts ??= []).Add(ChatMessageContentPart.CreateImagePart(new Uri(uri)));
653+
break;
654+
}
655+
}
656+
657+
if (parts.Count == 0)
658+
{
659+
parts.Add(ChatMessageContentPart.CreateTextPart(string.Empty));
660+
}
661+
662+
return parts;
663+
}
664+
647665
private static FunctionCallContent ParseCallContentFromJsonString(string json, string callId, string name) =>
648666
FunctionCallContent.CreateFromParsedArguments(json, callId, name,
649667
argumentParser: static json => JsonSerializer.Deserialize(json, JsonContext.Default.IDictionaryStringObject)!);

test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/TextContentTests.cs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ public void Constructor_String_PropsDefault(string? text)
1616
TextContent c = new(text);
1717
Assert.Null(c.RawRepresentation);
1818
Assert.Null(c.AdditionalProperties);
19-
Assert.Equal(text, c.Text);
19+
Assert.Equal(text ?? string.Empty, c.Text);
2020
}
2121

2222
[Fact]
@@ -34,13 +34,17 @@ public void Constructor_PropsRoundtrip()
3434
c.AdditionalProperties = props;
3535
Assert.Same(props, c.AdditionalProperties);
3636

37-
Assert.Null(c.Text);
37+
Assert.Equal(string.Empty, c.Text);
3838
c.Text = "text";
3939
Assert.Equal("text", c.Text);
4040
Assert.Equal("text", c.ToString());
4141

4242
c.Text = null;
43-
Assert.Null(c.Text);
43+
Assert.Equal(string.Empty, c.Text);
44+
Assert.Equal(string.Empty, c.ToString());
45+
46+
c.Text = string.Empty;
47+
Assert.Equal(string.Empty, c.Text);
4448
Assert.Equal(string.Empty, c.ToString());
4549
}
4650
}

test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,89 @@ public async Task MultipleMessages_NonStreaming()
321321
Assert.Equal(57, response.Usage.TotalTokenCount);
322322
}
323323

324+
[Fact]
325+
public async Task NullAssistantText_ContentSkipped_NonStreaming()
326+
{
327+
const string Input = """
328+
{
329+
"messages": [
330+
{
331+
"role": "assistant"
332+
},
333+
{
334+
"content": [
335+
{
336+
"text": "hello!",
337+
"type": "text"
338+
}
339+
],
340+
"role": "user"
341+
}
342+
],
343+
"model": "gpt-4o-mini"
344+
}
345+
""";
346+
347+
const string Output = """
348+
{
349+
"id": "chatcmpl-ADyV17bXeSm5rzUx3n46O7m3M0o3P",
350+
"object": "chat.completion",
351+
"created": 1727894187,
352+
"model": "gpt-4o-mini-2024-07-18",
353+
"choices": [
354+
{
355+
"index": 0,
356+
"message": {
357+
"role": "assistant",
358+
"content": "Hello.",
359+
"refusal": null
360+
},
361+
"logprobs": null,
362+
"finish_reason": "stop"
363+
}
364+
],
365+
"usage": {
366+
"prompt_tokens": 42,
367+
"completion_tokens": 15,
368+
"total_tokens": 57,
369+
"prompt_tokens_details": {
370+
"cached_tokens": 0
371+
},
372+
"completion_tokens_details": {
373+
"reasoning_tokens": 0
374+
}
375+
},
376+
"system_fingerprint": "fp_f85bea6784"
377+
}
378+
""";
379+
380+
using VerbatimHttpHandler handler = new(Input, Output);
381+
using HttpClient httpClient = new(handler);
382+
using IChatClient client = CreateChatClient(httpClient, "gpt-4o-mini");
383+
384+
List<ChatMessage> messages =
385+
[
386+
new(ChatRole.Assistant, (string?)null),
387+
new(ChatRole.User, "hello!"),
388+
];
389+
390+
var response = await client.CompleteAsync(messages);
391+
Assert.NotNull(response);
392+
393+
Assert.Equal("chatcmpl-ADyV17bXeSm5rzUx3n46O7m3M0o3P", response.CompletionId);
394+
Assert.Equal("Hello.", response.Message.Text);
395+
Assert.Single(response.Message.Contents);
396+
Assert.Equal(ChatRole.Assistant, response.Message.Role);
397+
Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId);
398+
Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_894_187), response.CreatedAt);
399+
Assert.Equal(ChatFinishReason.Stop, response.FinishReason);
400+
401+
Assert.NotNull(response.Usage);
402+
Assert.Equal(42, response.Usage.InputTokenCount);
403+
Assert.Equal(15, response.Usage.OutputTokenCount);
404+
Assert.Equal(57, response.Usage.TotalTokenCount);
405+
}
406+
324407
[Fact]
325408
public async Task FunctionCallContent_NonStreaming()
326409
{

test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ReducingChatClientTests.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,9 @@ private int CountTokens(ChatMessage message)
190190
int sum = 0;
191191
foreach (AIContent content in message.Contents)
192192
{
193-
if ((content as TextContent)?.Text is string text)
193+
if (content is TextContent text)
194194
{
195-
sum += _tokenizer.CountTokens(text);
195+
sum += _tokenizer.CountTokens(text.Text);
196196
}
197197
}
198198

0 commit comments

Comments
 (0)