Skip to content

GH-981: Use OpenAIAsyncClient for streaming in AzureOpenAiChatModel #1447

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

import com.azure.ai.openai.OpenAIAsyncClient;
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.ai.openai.models.ChatChoice;
import com.azure.ai.openai.models.ChatCompletions;
import com.azure.ai.openai.models.ChatCompletionsFunctionToolCall;
Expand Down Expand Up @@ -108,34 +110,40 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha
*/
private final OpenAIClient openAIClient;

/**
* The {@link OpenAIAsyncClient} used for streaming async operations.
*/
private final OpenAIAsyncClient openAIAsyncClient;

/**
* The configuration information for a chat completions request.
*/
private AzureOpenAiChatOptions defaultOptions;

public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient) {
public AzureOpenAiChatModel(OpenAIClientBuilder microsoftOpenAiClient) {
this(microsoftOpenAiClient,
AzureOpenAiChatOptions.builder()
.withDeploymentName(DEFAULT_DEPLOYMENT_NAME)
.withTemperature(DEFAULT_TEMPERATURE)
.build());
}

public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options) {
this(microsoftOpenAiClient, options, null);
public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options) {
this(openAIClientBuilder, options, null);
}

public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options,
public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options,
FunctionCallbackContext functionCallbackContext) {
this(microsoftOpenAiClient, options, functionCallbackContext, List.of());
this(openAIClientBuilder, options, functionCallbackContext, List.of());
}

public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options,
public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options,
FunctionCallbackContext functionCallbackContext, List<FunctionCallback> toolFunctionCallbacks) {
super(functionCallbackContext, options, toolFunctionCallbacks);
Assert.notNull(microsoftOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null");
Assert.notNull(openAIClientBuilder, "com.azure.ai.openai.OpenAIClient must not be null");
Assert.notNull(options, "AzureOpenAiChatOptions must not be null");
this.openAIClient = microsoftOpenAiClient;
this.openAIClient = openAIClientBuilder.buildClient();
this.openAIAsyncClient = openAIClientBuilder.buildAsyncClient();
this.defaultOptions = options;
}

Expand Down Expand Up @@ -170,11 +178,11 @@ public Flux<ChatResponse> stream(Prompt prompt) {
ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt);
options.setStream(true);

IterableStream<ChatCompletions> chatCompletionsStream = this.openAIClient
Flux<ChatCompletions> chatCompletionsStream = this.openAIAsyncClient
.getChatCompletionsStream(options.getModel(), options);

final var isFunctionCall = new AtomicBoolean(false);
final Flux<ChatCompletions> accessibleChatCompletionsFlux = Flux.fromIterable(chatCompletionsStream)
final Flux<ChatCompletions> accessibleChatCompletionsFlux = chatCompletionsStream
// Note: the first chat completions can be ignored when using Azure OpenAI
// service which is a known service bug.
.filter(chatCompletions -> !CollectionUtils.isEmpty(chatCompletions.getChoices()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ public class AzureOpenAiImageModel implements ImageModel {

private final Logger logger = LoggerFactory.getLogger(getClass());

@Autowired
private final OpenAIClient openAIClient;

private final AzureOpenAiImageOptions defaultOptions;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.ai.azure.openai;

import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.ai.openai.models.AzureChatEnhancementConfiguration;
import com.azure.ai.openai.models.AzureChatOCREnhancementConfiguration;
import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat;
Expand Down Expand Up @@ -44,7 +45,7 @@ public class AzureChatCompletionsOptionsTests {
@Test
public void createRequestWithChatOptions() {

OpenAIClient mockClient = Mockito.mock(OpenAIClient.class);
OpenAIClientBuilder mockClient = Mockito.mock(OpenAIClientBuilder.class);

AzureChatEnhancementConfiguration mockAzureChatEnhancementConfiguration = Mockito
.mock(AzureChatEnhancementConfiguration.class);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.azure.openai;

import static com.azure.core.http.policy.HttpLogDetailLevel.BODY_AND_HEADERS;
import static org.assertj.core.api.Assertions.assertThat;

import java.util.Arrays;
import java.util.stream.Collectors;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;

import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.ai.openai.OpenAIServiceVersion;
import com.azure.core.credential.AzureKeyCredential;
import com.azure.core.http.policy.HttpLogOptions;

/**
* @author Soby Chacko
*/
@SpringBootTest(classes = AzureOpenAiChatClientTest.TestConfiguration.class)
@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_API_KEY", matches = ".+")
@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_ENDPOINT", matches = ".+")
public class AzureOpenAiChatClientTest {

@Autowired
private ChatClient chatClient;

@Test
void streamingAndImperativeResponsesContainIdenticalRelevantResults() {
String prompt = "Name all states in the USA and their capitals, add a space followed by a hyphen, then another space between the two. "
+ "List them with a numerical index. Do not use any abbreviations in state or capitals.";

// Imperative call
String rawDataFromImperativeCall = chatClient.prompt(prompt).call().content();
String imperativeStatesData = extractStatesData(rawDataFromImperativeCall);
String formattedImperativeResponse = formatResponse(imperativeStatesData);

// Streaming call
String stitchedResponseFromStream = chatClient.prompt(prompt)
.stream()
.content()
.collectList()
.block()
.stream()
.collect(Collectors.joining());
String streamingStatesData = extractStatesData(stitchedResponseFromStream);
String formattedStreamingResponse = formatResponse(streamingStatesData);

// Assertions
assertThat(formattedStreamingResponse).isEqualTo(formattedImperativeResponse);
assertThat(formattedStreamingResponse).contains("1. Alabama - Montgomery");
assertThat(formattedStreamingResponse).contains("50. Wyoming - Cheyenne");
assertThat(formattedStreamingResponse.lines().count()).isEqualTo(50);
}

private String extractStatesData(String rawData) {
int firstStateIndex = rawData.indexOf("1. Alabama - Montgomery");
String lastAlphabeticalState = "50. Wyoming - Cheyenne";
int lastStateIndex = rawData.indexOf(lastAlphabeticalState) + lastAlphabeticalState.length();
return rawData.substring(firstStateIndex, lastStateIndex);
}

private String formatResponse(String response) {
return String.join("\n", Arrays.stream(response.split("\n")).map(String::strip).toArray(String[]::new));
}

@SpringBootConfiguration
public static class TestConfiguration {

@Bean
public OpenAIClientBuilder openAIClient() {
return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY")))
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.serviceVersion(OpenAIServiceVersion.V2024_02_15_PREVIEW)
.httpLogOptions(new HttpLogOptions().setLogLevel(BODY_AND_HEADERS));
}

@Bean
public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder) {
return new AzureOpenAiChatModel(openAIClientBuilder,
AzureOpenAiChatOptions.builder().withDeploymentName("gpt-4o").withMaxTokens(1000).build());

}

@Bean
public ChatClient chatClient(AzureOpenAiChatModel azureOpenAiChatModel) {
return ChatClient.builder(azureOpenAiChatModel).build();
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
*/
package org.springframework.ai.azure.openai;

import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.ai.openai.OpenAIServiceVersion;
import com.azure.core.credential.AzureKeyCredential;
Expand Down Expand Up @@ -262,17 +261,16 @@ record ActorsFilmsRecord(String actor, List<String> movies) {
public static class TestConfiguration {

@Bean
public OpenAIClient openAIClient() {
public OpenAIClientBuilder openAIClientBuilder() {
return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY")))
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.serviceVersion(OpenAIServiceVersion.V2024_02_15_PREVIEW)
.httpLogOptions(new HttpLogOptions().setLogLevel(BODY_AND_HEADERS))
.buildClient();
.httpLogOptions(new HttpLogOptions().setLogLevel(BODY_AND_HEADERS));
}

@Bean
public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient openAIClient) {
return new AzureOpenAiChatModel(openAIClient,
public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder) {
return new AzureOpenAiChatModel(openAIClientBuilder,
AzureOpenAiChatOptions.builder().withDeploymentName("gpt-4o").withMaxTokens(1000).build());

}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.azure.openai;

import com.azure.ai.openai.OpenAIClient;
Expand Down Expand Up @@ -40,8 +41,9 @@
* {@link Dispatcher} to integrate with Spring {@link MockMvc}.
*
* @author John Blum
* @author Soby Chacko
* @see org.springframework.boot.SpringBootConfiguration
* @see org.springframework.ai.test.config.MockAiTestConfiguration
* @see org.springframework.ai.azure.openai.MockAiTestConfiguration
* @since 0.7.0
*/
@SpringBootConfiguration
Expand All @@ -51,15 +53,13 @@
public class MockAzureOpenAiTestConfiguration {

@Bean
OpenAIClient microsoftAzureOpenAiClient(MockWebServer webServer) {

OpenAIClientBuilder microsoftAzureOpenAiClient(MockWebServer webServer) {
HttpUrl baseUrl = webServer.url(MockAiTestConfiguration.SPRING_AI_API_PATH);

return new OpenAIClientBuilder().endpoint(baseUrl.toString()).buildClient();
return new OpenAIClientBuilder().endpoint(baseUrl.toString());
}

@Bean
AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient microsoftAzureOpenAiClient) {
AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder microsoftAzureOpenAiClient) {
return new AzureOpenAiChatModel(microsoftAzureOpenAiClient);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,14 +183,13 @@ void functionCallSequentialAndStreamTest() {
public static class TestConfiguration {

@Bean
public OpenAIClient openAIClient() {
public OpenAIClientBuilder openAIClient() {
return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY")))
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.buildClient();
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"));
}

@Bean
public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient openAIClient, String selectedModel) {
public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClient, String selectedModel) {
return new AzureOpenAiChatModel(openAIClient,
AzureOpenAiChatOptions.builder().withDeploymentName(selectedModel).withMaxTokens(500).build());
}
Expand Down
Loading