Skip to content

Commit

Permalink
.Net: Azure OpenAI Temporary Fix [401 Retry Bug] (#9465)
Browse files Browse the repository at this point in the history
### Motivation and Context

- Fix #8929
  • Loading branch information
RogerBarreto authored Oct 30, 2024
1 parent 5ccfeaf commit 846c565
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.ClientModel.Primitives;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Azure.AI.OpenAI;
using Azure.Core;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.AzureOpenAI;
using SemanticKernel.Connectors.AzureOpenAI.Core;

namespace SemanticKernel.Connectors.AzureOpenAI.UnitTests.Core;

public sealed class ClientCoreTests : IDisposable
{
private readonly MultipleHttpMessageHandlerStub _multiHttpMessageHandlerStub;
private readonly HttpClient _httpClient;

public ClientCoreTests()
{
this._multiHttpMessageHandlerStub = new MultipleHttpMessageHandlerStub();
this._httpClient = new HttpClient(this._multiHttpMessageHandlerStub);
}

public void Dispose()
{
this._httpClient.Dispose();
this._multiHttpMessageHandlerStub.Dispose();
}

[Fact]
public async Task AuthenticationHeaderShouldBeProvidedOnlyOnce()
{
// Arrange
using var firstResponse = new HttpResponseMessage(System.Net.HttpStatusCode.TooManyRequests);
using var secondResponse = new HttpResponseMessage(System.Net.HttpStatusCode.TooManyRequests);
using var thirdResponse = new HttpResponseMessage(System.Net.HttpStatusCode.TooManyRequests);

this._multiHttpMessageHandlerStub.ResponsesToReturn.AddRange([firstResponse, secondResponse, thirdResponse]);
var options = new AzureOpenAIClientOptions()
{
Transport = new HttpClientPipelineTransport(this._httpClient),
RetryPolicy = new ClientRetryPolicy(2),
NetworkTimeout = TimeSpan.FromSeconds(10),
};

// Bug fix workaround
options.AddPolicy(new SingleAuthorizationHeaderPolicy(), PipelinePosition.PerTry);

var azureClient = new AzureOpenAIClient(
endpoint: new Uri("http://any"),
credential: new TestJWTBearerTokenCredential(),
options: options);

var clientCore = new AzureClientCore("deployment-name", azureClient);

ChatHistory chatHistory = [];
chatHistory.AddUserMessage("User test");

// Act
var exception = await Record.ExceptionAsync(() => clientCore.GetChatMessageContentsAsync("model-id", chatHistory, null, null, CancellationToken.None));

// Assert
Assert.NotNull(exception);
Assert.Equal(3, this._multiHttpMessageHandlerStub.RequestHeaders.Count);

foreach (var requestHeaders in this._multiHttpMessageHandlerStub.RequestHeaders)
{
this._multiHttpMessageHandlerStub.RequestHeaders[2]!.TryGetValues("Authorization", out var authHeaders);
Assert.NotNull(authHeaders);
Assert.Single(authHeaders);
}
}

private sealed class TestJWTBearerTokenCredential : TokenCredential
{
public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken)
{
return new AccessToken("JWT", DateTimeOffset.Now.AddHours(1));
}

public override ValueTask<AccessToken> GetTokenAsync(TokenRequestContext requestContext, CancellationToken cancellationToken)
{
return ValueTask.FromResult(new AccessToken("JWT", DateTimeOffset.Now.AddHours(1)));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
using Microsoft.SemanticKernel.Connectors.OpenAI;
using Microsoft.SemanticKernel.Http;
using OpenAI;
using SemanticKernel.Connectors.AzureOpenAI.Core;

namespace Microsoft.SemanticKernel.Connectors.AzureOpenAI;

Expand Down Expand Up @@ -146,6 +147,7 @@ internal static AzureOpenAIClientOptions GetAzureOpenAIClientOptions(HttpClient?

options.UserAgentApplicationId = HttpHeaderConstant.Values.UserAgent;
options.AddPolicy(CreateRequestHeaderPolicy(HttpHeaderConstant.Names.SemanticKernelVersion, HttpHeaderConstant.Values.GetAssemblyVersion(typeof(AzureClientCore))), PipelinePosition.PerCall);
options.AddPolicy(new SingleAuthorizationHeaderPolicy(), PipelinePosition.PerTry);

if (httpClient is not null)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (c) Microsoft. All rights reserved.

using System.ClientModel.Primitives;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;

namespace SemanticKernel.Connectors.AzureOpenAI.Core;

/// <summary>
/// This class is used to remove duplicate Authorization headers from the request Azure OpenAI Bug.
/// https://github.com/Azure/azure-sdk-for-net/issues/46109 (Remove when beta.2 is released)
/// </summary>
internal sealed class SingleAuthorizationHeaderPolicy : PipelinePolicy
{
private const string AuthorizationHeaderName = "Authorization";

public override void Process(PipelineMessage message, IReadOnlyList<PipelinePolicy> pipeline, int currentIndex)
{
RemoveDuplicateHeaderValues(message.Request.Headers);

ProcessNext(message, pipeline, currentIndex);
}

public override async ValueTask ProcessAsync(PipelineMessage message, IReadOnlyList<PipelinePolicy> pipeline, int currentIndex)
{
RemoveDuplicateHeaderValues(message.Request.Headers);

await ProcessNextAsync(message, pipeline, currentIndex).ConfigureAwait(false);
}

private static void RemoveDuplicateHeaderValues(PipelineRequestHeaders headers)
{
if (headers.TryGetValues(AuthorizationHeaderName, out var headerValues)
&& headerValues is not null
#if NET
&& headerValues.TryGetNonEnumeratedCount(out var count) && count > 1
#else
&& headerValues.Count() > 1
#endif
)
{
headers.Set(AuthorizationHeaderName, headerValues.First());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

<ItemGroup>
<InternalsVisibleTo Include="SemanticKernel.Connectors.OpenAI.UnitTests" />
<InternalsVisibleTo Include="SemanticKernel.Connectors.AzureOpenAI.UnitTests" />
<InternalsVisibleTo Include="Microsoft.SemanticKernel.Connectors.AzureOpenAI" />
</ItemGroup>

Expand Down

0 comments on commit 846c565

Please sign in to comment.