Skip to content

Commit 4326eb0

Browse files
committed
Address PR feedback
1 parent 6742995 commit 4326eb0

File tree

6 files changed

+124
-42
lines changed

6 files changed

+124
-42
lines changed

src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantChatClient.cs

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,21 +42,22 @@ internal sealed partial class OpenAIAssistantChatClient : IChatClient
4242
private readonly string _assistantId;
4343

4444
/// <summary>The thread ID to use if none is supplied in <see cref="ChatOptions.ConversationId"/>.</summary>
45-
private readonly string? _threadId;
45+
private readonly string? _defaultThreadId;
4646

4747
/// <summary>Initializes a new instance of the <see cref="OpenAIAssistantChatClient"/> class for the specified <see cref="AssistantClient"/>.</summary>
48-
public OpenAIAssistantChatClient(AssistantClient client, string assistantId, string? threadId)
48+
public OpenAIAssistantChatClient(AssistantClient assistantClient, string assistantId, string? defaultThreadId)
4949
{
50-
_client = Throw.IfNull(client);
50+
_client = Throw.IfNull(assistantClient);
5151
_assistantId = Throw.IfNullOrWhitespace(assistantId);
52-
_threadId = threadId;
52+
53+
_defaultThreadId = defaultThreadId;
5354

5455
// https://github.com/openai/openai-dotnet/issues/215
5556
// The endpoint isn't currently exposed, so use reflection to get at it, temporarily. Once packages
5657
// implement the abstractions directly rather than providing adapters on top of the public APIs,
5758
// the package can provide such implementations separate from what's exposed in the public API.
5859
Uri providerUrl = typeof(AssistantClient).GetField("_endpoint", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)
59-
?.GetValue(client) as Uri ?? OpenAIResponseChatClient.DefaultOpenAIEndpoint;
60+
?.GetValue(assistantClient) as Uri ?? OpenAIResponseChatClient.DefaultOpenAIEndpoint;
6061

6162
_metadata = new("openai", providerUrl);
6263
}
@@ -85,7 +86,7 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
8586
(RunCreationOptions runOptions, List<FunctionResultContent>? toolResults) = CreateRunOptions(messages, options);
8687

8788
// Get the thread ID.
88-
string? threadId = options?.ConversationId ?? _threadId;
89+
string? threadId = options?.ConversationId ?? _defaultThreadId;
8990
if (threadId is null && toolResults is not null)
9091
{
9192
Throw.ArgumentException(nameof(messages), "No thread ID was provided, but chat messages includes tool results.");
@@ -327,17 +328,22 @@ void IDisposable.Dispose()
327328
}
328329
}
329330

330-
// Process ChatMessages. System messages are turned into additional instructions.
331-
// All other messages are added 1:1, treating assistant messages as agent messages
332-
// and everything else as user messages.
331+
// Process ChatMessages.
333332
StringBuilder? instructions = null;
334333
List<FunctionResultContent>? functionResults = null;
335334
foreach (var chatMessage in messages)
336335
{
337336
List<MessageContent> messageContents = [];
338337

338+
// Assistants doesn't support system/developer messages directly. It does support transient per-request instructions,
339+
// so we can use the system/developer messages to build up a set of instructions that will be passed to the assistant
340+
// as part of this request. However, in doing so, on a subsequent request that information will be lost, as there's no
341+
// way to store per-thread instructions in the OpenAI Assistants API. We don't want to convert these to user messages,
342+
// however, as that would then expose the system/developer messages in a way that might make the model more likely
343+
// to include that information in its responses. System messages should ideally be instead done as instructions to
344+
// the assistant when the assistant is created.
339345
if (chatMessage.Role == ChatRole.System ||
340-
chatMessage.Role == OpenAIChatClient.ChatRoleDeveloper)
346+
chatMessage.Role == OpenAIResponseChatClient.ChatRoleDeveloper)
341347
{
342348
instructions ??= new();
343349
foreach (var textContent in chatMessage.Contents.OfType<TextContent>())

src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,6 @@ void IDisposable.Dispose()
110110
// Nothing to dispose. Implementation required for the IChatClient interface.
111111
}
112112

113-
internal static ChatRole ChatRoleDeveloper { get; } = new ChatRole("developer");
114-
115113
/// <summary>Converts an Extensions chat message enumerable to an OpenAI chat message enumerable.</summary>
116114
private static IEnumerable<OpenAI.Chat.ChatMessage> ToOpenAIChatMessages(IEnumerable<ChatMessage> inputs, JsonSerializerOptions options)
117115
{
@@ -122,12 +120,12 @@ void IDisposable.Dispose()
122120
{
123121
if (input.Role == ChatRole.System ||
124122
input.Role == ChatRole.User ||
125-
input.Role == ChatRoleDeveloper)
123+
input.Role == OpenAIResponseChatClient.ChatRoleDeveloper)
126124
{
127125
var parts = ToOpenAIChatContent(input.Contents);
128126
yield return
129127
input.Role == ChatRole.System ? new SystemChatMessage(parts) { ParticipantName = input.AuthorName } :
130-
input.Role == ChatRoleDeveloper ? new DeveloperChatMessage(parts) { ParticipantName = input.AuthorName } :
128+
input.Role == OpenAIResponseChatClient.ChatRoleDeveloper ? new DeveloperChatMessage(parts) { ParticipantName = input.AuthorName } :
131129
new UserChatMessage(parts) { ParticipantName = input.AuthorName };
132130
}
133131
else if (input.Role == ChatRole.Tool)
@@ -619,7 +617,7 @@ private static ChatRole FromOpenAIChatRole(ChatMessageRole role) =>
619617
ChatMessageRole.User => ChatRole.User,
620618
ChatMessageRole.Assistant => ChatRole.Assistant,
621619
ChatMessageRole.Tool => ChatRole.Tool,
622-
ChatMessageRole.Developer => ChatRoleDeveloper,
620+
ChatMessageRole.Developer => OpenAIResponseChatClient.ChatRoleDeveloper,
623621
_ => new ChatRole(role.ToString()),
624622
};
625623

src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIClientExtensions.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public static IChatClient AsIChatClient(this OpenAIResponseClient responseClient
2727
new OpenAIResponseChatClient(responseClient);
2828

2929
/// <summary>Gets an <see cref="IChatClient"/> for use with this <see cref="AssistantClient"/>.</summary>
30-
/// <param name="client">The <see cref="AssistantClient"/> instance to be accessed as an <see cref="IChatClient"/>.</param>
30+
/// <param name="assistantClient">The <see cref="AssistantClient"/> instance to be accessed as an <see cref="IChatClient"/>.</param>
3131
/// <param name="assistantId">The unique identifier of the assistant with which to interact.</param>
3232
/// <param name="threadId">
3333
/// An optional existing thread identifier for the chat session. This serves as a default, and may be overridden per call to
@@ -36,8 +36,8 @@ public static IChatClient AsIChatClient(this OpenAIResponseClient responseClient
3636
/// </param>
3737
/// <returns>An <see cref="IChatClient"/> instance configured to interact with the specified agent and thread.</returns>
3838
[Experimental("OPENAI001")]
39-
public static IChatClient AsIChatClient(this AssistantClient client, string assistantId, string? threadId = null) =>
40-
new OpenAIAssistantChatClient(client, assistantId, threadId);
39+
public static IChatClient AsIChatClient(this AssistantClient assistantClient, string assistantId, string? threadId = null) =>
40+
new OpenAIAssistantChatClient(assistantClient, assistantId, threadId);
4141

4242
/// <summary>Gets an <see cref="ISpeechToTextClient"/> for use with this <see cref="AudioClient"/>.</summary>
4343
/// <param name="audioClient">The client.</param>

src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIResponseChatClient.cs

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ internal sealed partial class OpenAIResponseChatClient : IChatClient
2929
/// <summary>Gets the default OpenAI endpoint.</summary>
3030
internal static Uri DefaultOpenAIEndpoint { get; } = new("https://api.openai.com/v1");
3131

32+
/// <summary>Gets a <see cref="ChatRole"/> for "developer".</summary>
33+
internal static ChatRole ChatRoleDeveloper { get; } = new ChatRole("developer");
34+
3235
/// <summary>Metadata about the client.</summary>
3336
private readonly ChatClientMetadata _metadata;
3437

@@ -85,7 +88,7 @@ public async Task<ChatResponse> GetResponseAsync(
8588
// Convert and return the results.
8689
ChatResponse response = new()
8790
{
88-
ConversationId = openAIResponse.Id,
91+
ConversationId = openAIOptions.StoredOutputEnabled is false ? null : openAIResponse.Id,
8992
CreatedAt = openAIResponse.CreatedAt,
9093
FinishReason = ToFinishReason(openAIResponse.IncompleteStatusDetails?.Reason),
9194
Messages = [new(ChatRole.Assistant, [])],
@@ -164,6 +167,7 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
164167
// Make the call to the OpenAIResponseClient and process the streaming results.
165168
DateTimeOffset? createdAt = null;
166169
string? responseId = null;
170+
string? conversationId = null;
167171
string? modelId = null;
168172
string? lastMessageId = null;
169173
ChatRole? lastRole = null;
@@ -176,18 +180,19 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
176180
case StreamingResponseCreatedUpdate createdUpdate:
177181
createdAt = createdUpdate.Response.CreatedAt;
178182
responseId = createdUpdate.Response.Id;
183+
conversationId = openAIOptions.StoredOutputEnabled is false ? null : responseId;
179184
modelId = createdUpdate.Response.Model;
180185
goto default;
181186

182187
case StreamingResponseCompletedUpdate completedUpdate:
183188
yield return new()
184189
{
190+
Contents = ToUsageDetails(completedUpdate.Response) is { } usage ? [new UsageContent(usage)] : [],
191+
ConversationId = conversationId,
192+
CreatedAt = createdAt,
185193
FinishReason =
186194
ToFinishReason(completedUpdate.Response?.IncompleteStatusDetails?.Reason) ??
187195
(functionCallInfos is not null ? ChatFinishReason.ToolCalls : ChatFinishReason.Stop),
188-
Contents = ToUsageDetails(completedUpdate.Response) is { } usage ? [new UsageContent(usage)] : [],
189-
ConversationId = responseId,
190-
CreatedAt = createdAt,
191196
MessageId = lastMessageId,
192197
ModelId = modelId,
193198
RawRepresentation = streamingUpdate,
@@ -220,7 +225,7 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
220225
lastRole = ToChatRole(messageItem?.Role);
221226
yield return new ChatResponseUpdate(lastRole, outputTextDeltaUpdate.Delta)
222227
{
223-
ConversationId = responseId,
228+
ConversationId = conversationId,
224229
CreatedAt = createdAt,
225230
MessageId = lastMessageId,
226231
ModelId = modelId,
@@ -255,7 +260,7 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
255260
lastRole = ChatRole.Assistant;
256261
yield return new ChatResponseUpdate(lastRole, [fci])
257262
{
258-
ConversationId = responseId,
263+
ConversationId = conversationId,
259264
CreatedAt = createdAt,
260265
MessageId = lastMessageId,
261266
ModelId = modelId,
@@ -272,7 +277,6 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
272277
case StreamingResponseErrorUpdate errorUpdate:
273278
yield return new ChatResponseUpdate
274279
{
275-
ConversationId = responseId,
276280
Contents =
277281
[
278282
new ErrorContent(errorUpdate.Message)
@@ -281,6 +285,7 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
281285
Details = errorUpdate.Param,
282286
}
283287
],
288+
ConversationId = conversationId,
284289
CreatedAt = createdAt,
285290
MessageId = lastMessageId,
286291
ModelId = modelId,
@@ -293,21 +298,21 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
293298
case StreamingResponseRefusalDoneUpdate refusalDone:
294299
yield return new ChatResponseUpdate
295300
{
301+
Contents = [new ErrorContent(refusalDone.Refusal) { ErrorCode = nameof(ResponseContentPart.Refusal) }],
302+
ConversationId = conversationId,
296303
CreatedAt = createdAt,
297304
MessageId = lastMessageId,
298305
ModelId = modelId,
299306
RawRepresentation = streamingUpdate,
300307
ResponseId = responseId,
301308
Role = lastRole,
302-
ConversationId = responseId,
303-
Contents = [new ErrorContent(refusalDone.Refusal) { ErrorCode = nameof(ResponseContentPart.Refusal) }],
304309
};
305310
break;
306311

307312
default:
308313
yield return new ChatResponseUpdate
309314
{
310-
ConversationId = responseId,
315+
ConversationId = conversationId,
311316
CreatedAt = createdAt,
312317
MessageId = lastMessageId,
313318
ModelId = modelId,
@@ -331,7 +336,7 @@ private static ChatRole ToChatRole(MessageRole? role) =>
331336
role switch
332337
{
333338
MessageRole.System => ChatRole.System,
334-
MessageRole.Developer => OpenAIChatClient.ChatRoleDeveloper,
339+
MessageRole.Developer => ChatRoleDeveloper,
335340
MessageRole.User => ChatRole.User,
336341
_ => ChatRole.Assistant,
337342
};

test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIAssistantChatClientIntegrationTests.cs

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
using System;
1111
using System.Linq;
12+
using System.Net;
1213
using System.Net.Http;
1314
using System.Text.RegularExpressions;
1415
using System.Threading.Tasks;
@@ -45,26 +46,21 @@ public class OpenAIAssistantChatClientIntegrationTests : ChatClientIntegrationTe
4546
public override Task MultiModal_DescribeImage() => Task.CompletedTask;
4647
public override Task MultiModal_DescribePdf() => Task.CompletedTask;
4748

48-
// [Fact]
49+
// [Fact] // uncomment and run to clear out _all_ threads in your OpenAI account
4950
public async Task DeleteAllThreads()
5051
{
5152
using HttpClient client = new(new HttpClientHandler
5253
{
53-
AutomaticDecompression = System.Net.DecompressionMethods.GZip,
54+
AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate,
5455
});
5556

56-
client.DefaultRequestHeaders.Add("accept", "*/*");
57-
client.DefaultRequestHeaders.Add("accept-encoding", "gzip");
58-
client.DefaultRequestHeaders.Add("accept-language", "en-US,en;q=0.9");
59-
client.DefaultRequestHeaders.Add("openai-beta", "assistants=v2");
60-
client.DefaultRequestHeaders.Add("origin", "https://platform.openai.com");
61-
client.DefaultRequestHeaders.Add("user-agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/137.0.0.0 Safari/537.36 Edg/137.0.0.0");
62-
6357
// These values need to be filled in. The bearer token needs to be sniffed from a browser
64-
// session interacting with the dashboard (e.g. F12 networking tools).
65-
client.DefaultRequestHeaders.Add("authorization", $"Bearer TODO-SESSION-TOKEN");
66-
client.DefaultRequestHeaders.Add("openai-organization", "TODO");
67-
client.DefaultRequestHeaders.Add("openai-project", "TODO");
58+
// session interacting with the dashboard (e.g. use F12 networking tools to look at request headers
59+
// made to "https://api.openai.com/v1/threads?limit=10" after clicking on Assistants | Threads in the
60+
// OpenAI portal dashboard).
61+
client.DefaultRequestHeaders.Add("authorization", $"Bearer sess-ENTERYOURSESSIONTOKEN");
62+
client.DefaultRequestHeaders.Add("openai-organization", "org-ENTERYOURORGID");
63+
client.DefaultRequestHeaders.Add("openai-project", "proj_ENTERYOURPROJECTID");
6864

6965
AssistantClient ac = new AssistantClient(Environment.GetEnvironmentVariable("AI:OpenAI:ApiKey")!);
7066
while (true)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.ClientModel;
6+
using Azure.AI.OpenAI;
7+
using Microsoft.Extensions.Caching.Distributed;
8+
using Microsoft.Extensions.Caching.Memory;
9+
using OpenAI;
10+
using OpenAI.Assistants;
11+
using Xunit;
12+
13+
#pragma warning disable S103 // Lines should not be too long
14+
#pragma warning disable OPENAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
15+
16+
namespace Microsoft.Extensions.AI;
17+
18+
public class OpenAIAssistantChatClientTests
19+
{
20+
[Fact]
21+
public void AsIChatClient_InvalidArgs_Throws()
22+
{
23+
Assert.Throws<ArgumentNullException>("assistantClient", () => ((AssistantClient)null!).AsIChatClient("assistantId"));
24+
Assert.Throws<ArgumentNullException>("assistantId", () => new AssistantClient("ignored").AsIChatClient(null!));
25+
}
26+
27+
[Theory]
28+
[InlineData(false)]
29+
[InlineData(true)]
30+
public void AsIChatClient_OpenAIClient_ProducesExpectedMetadata(bool useAzureOpenAI)
31+
{
32+
Uri endpoint = new("http://localhost/some/endpoint");
33+
34+
var client = useAzureOpenAI ?
35+
new AzureOpenAIClient(endpoint, new ApiKeyCredential("key")) :
36+
new OpenAIClient(new ApiKeyCredential("key"), new OpenAIClientOptions { Endpoint = endpoint });
37+
38+
IChatClient[] clients =
39+
[
40+
client.GetAssistantClient().AsIChatClient("assistantId"),
41+
client.GetAssistantClient().AsIChatClient("assistantId", "threadId"),
42+
];
43+
44+
foreach (var chatClient in clients)
45+
{
46+
var metadata = chatClient.GetService<ChatClientMetadata>();
47+
Assert.Equal("openai", metadata?.ProviderName);
48+
Assert.Equal(endpoint, metadata?.ProviderUri);
49+
}
50+
}
51+
52+
[Fact]
53+
public void GetService_AssistantClient_SuccessfullyReturnsUnderlyingClient()
54+
{
55+
AssistantClient assistantClient = new OpenAIClient("key").GetAssistantClient();
56+
IChatClient chatClient = assistantClient.AsIChatClient("assistantId");
57+
58+
Assert.Same(assistantClient, chatClient.GetService<AssistantClient>());
59+
60+
Assert.Null(chatClient.GetService<OpenAIClient>());
61+
62+
using IChatClient pipeline = chatClient
63+
.AsBuilder()
64+
.UseFunctionInvocation()
65+
.UseOpenTelemetry()
66+
.UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions())))
67+
.Build();
68+
69+
Assert.NotNull(pipeline.GetService<FunctionInvokingChatClient>());
70+
Assert.NotNull(pipeline.GetService<DistributedCachingChatClient>());
71+
Assert.NotNull(pipeline.GetService<CachingChatClient>());
72+
Assert.NotNull(pipeline.GetService<OpenTelemetryChatClient>());
73+
74+
Assert.Same(assistantClient, pipeline.GetService<AssistantClient>());
75+
Assert.IsType<FunctionInvokingChatClient>(pipeline.GetService<IChatClient>());
76+
}
77+
}

0 commit comments

Comments
 (0)