Skip to content

Commit

Permalink
Migration of the AzureOpenAiChatModel to use the Azure OpenAI SDK (la…
Browse files Browse the repository at this point in the history
…ngchain4j#328)

This PR fixes langchain4j#325

- [x] Migrate AzureOpenAiChatModel
- [x] Migrate AzureOpenAiEmbeddingModel
- [x] Migrate AzureOpenAiLanguageModel
- [x] Migrate AzureOpenAiStreamingChatModel
- [x] Migrate AzureOpenAiStreamingLanguageModel
- [x] Add a full suite of tests
  • Loading branch information
jdubois authored Dec 8, 2023
1 parent 0d2f743 commit 09ab6a1
Show file tree
Hide file tree
Showing 25 changed files with 2,934 additions and 422 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,7 @@ build/
.vscode/

### Mac OS ###
.DS_Store
.DS_Store

### .env files contain local environment variables ###
.env
26 changes: 26 additions & 0 deletions langchain4j-azure-open-ai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,32 @@
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>com.azure</groupId>
<artifactId>azure-ai-openai</artifactId>
</dependency>

<dependency>
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-api</artifactId>
<version>2.22.0</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-core</artifactId>
<version>2.22.0</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-slf4j2-impl</artifactId>
<version>2.20.0</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,36 +1,37 @@
package dev.langchain4j.model.azure;

import dev.ai4j.openai4j.OpenAiClient;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.models.ChatCompletions;
import com.azure.ai.openai.models.ChatCompletionsOptions;
import com.azure.ai.openai.models.FunctionCallConfig;
import com.azure.core.http.ProxyOptions;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.openai.OpenAiTokenizer;
import dev.langchain4j.model.output.Response;

import java.net.Proxy;
import java.time.Duration;
import java.util.List;

import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.*;
import static java.time.Duration.ofSeconds;
import static dev.langchain4j.model.azure.AzureOpenAiModelName.GPT_3_5_TURBO;
import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.*;
import static java.util.Collections.singletonList;

/**
* Represents an OpenAI language model, hosted on Azure, that has a chat completion interface, such as gpt-3.5-turbo.
* <p>
* Mandatory parameters for initialization are: baseUrl, apiVersion and apiKey.
* Mandatory parameters for initialization are: endpoint, serviceVersion, apiKey and deploymentName.
* You can also provide your own OpenAIClient instance, if you need more flexibility.
* <p>
* There are two primary authentication methods to access Azure OpenAI:
* <p>
* 1. API Key Authentication: For this type of authentication, HTTP requests must include the
* API Key in the "api-key" HTTP header.
* API Key in the "api-key" HTTP header as follows: `api-key: OPENAI_API_KEY`
* <p>
* 2. Azure Active Directory Authentication: For this type of authentication, HTTP requests must include the
* authentication/access token in the "Authorization" HTTP header.
Expand All @@ -42,18 +43,32 @@
*/
public class AzureOpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {

private final OpenAiClient client;
private OpenAIClient client;
private final String deploymentName;
private final Tokenizer tokenizer;
private final Double temperature;
private final Double topP;
private final Integer maxTokens;
private final Double presencePenalty;
private final Double frequencyPenalty;
private final Integer maxRetries;
private final Tokenizer tokenizer;

public AzureOpenAiChatModel(String baseUrl,
String apiVersion,
public AzureOpenAiChatModel(OpenAIClient client,
String deploymentName,
Tokenizer tokenizer,
Double temperature,
Double topP,
Integer maxTokens,
Double presencePenalty,
Double frequencyPenalty) {

this(deploymentName, tokenizer, temperature, topP, maxTokens, presencePenalty, frequencyPenalty);
this.client = client;
}

public AzureOpenAiChatModel(String endpoint,
String serviceVersion,
String apiKey,
String deploymentName,
Tokenizer tokenizer,
Double temperature,
Double topP,
Expand All @@ -62,31 +77,30 @@ public AzureOpenAiChatModel(String baseUrl,
Double frequencyPenalty,
Duration timeout,
Integer maxRetries,
Proxy proxy,
Boolean logRequests,
Boolean logResponses) {

timeout = getOrDefault(timeout, ofSeconds(60));

this.client = OpenAiClient.builder()
.baseUrl(ensureNotBlank(baseUrl, "baseUrl"))
.azureApiKey(apiKey)
.apiVersion(apiVersion)
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
.writeTimeout(timeout)
.proxy(proxy)
.logRequests(logRequests)
.logResponses(logResponses)
.build();
ProxyOptions proxyOptions,
boolean logRequestsAndResponses) {

this(deploymentName, tokenizer, temperature, topP, maxTokens, presencePenalty, frequencyPenalty);
this.client = setupOpenAIClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, proxyOptions, logRequestsAndResponses);
}



private AzureOpenAiChatModel(String deploymentName,
Tokenizer tokenizer,
Double temperature,
Double topP,
Integer maxTokens,
Double presencePenalty,
Double frequencyPenalty) {

this.deploymentName = getOrDefault(deploymentName, "gpt-35-turbo-0613");
this.tokenizer = getOrDefault(tokenizer, new OpenAiTokenizer(GPT_3_5_TURBO));
this.temperature = getOrDefault(temperature, 0.7);
this.topP = topP;
this.maxTokens = maxTokens;
this.presencePenalty = presencePenalty;
this.frequencyPenalty = frequencyPenalty;
this.maxRetries = getOrDefault(maxRetries, 3);
this.tokenizer = tokenizer;
}

@Override
Expand All @@ -108,29 +122,27 @@ private Response<AiMessage> generate(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted
) {
ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder()
.messages(toOpenAiMessages(messages))
.temperature(temperature)
.topP(topP)
.maxTokens(maxTokens)
.presencePenalty(presencePenalty)
.frequencyPenalty(frequencyPenalty);
ChatCompletionsOptions options = new ChatCompletionsOptions(toOpenAiMessages(messages))
.setModel(deploymentName)
.setTemperature(temperature)
.setTopP(topP)
.setMaxTokens(maxTokens)
.setPresencePenalty(presencePenalty)
.setFrequencyPenalty(frequencyPenalty);

if (toolSpecifications != null && !toolSpecifications.isEmpty()) {
requestBuilder.functions(toFunctions(toolSpecifications));
options.setFunctions(toFunctions(toolSpecifications));
}
if (toolThatMustBeExecuted != null) {
requestBuilder.functionCall(toolThatMustBeExecuted.name());
options.setFunctionCall(new FunctionCallConfig(toolThatMustBeExecuted.name()));
}

ChatCompletionRequest request = requestBuilder.build();

ChatCompletionResponse response = withRetry(() -> client.chatCompletion(request).execute(), maxRetries);
ChatCompletions chatCompletions = client.getChatCompletions(deploymentName, options);

return Response.from(
aiMessageFrom(response),
tokenUsageFrom(response.usage()),
finishReasonFrom(response.choices().get(0).finishReason())
aiMessageFrom(chatCompletions.getChoices().get(0).getMessage()),
tokenUsageFrom(chatCompletions.getUsage()),
finishReasonFrom(chatCompletions.getChoices().get(0).getFinishReason())
);
}

Expand All @@ -145,9 +157,10 @@ public static Builder builder() {

public static class Builder {

private String baseUrl;
private String apiVersion;
private String endpoint;
private String serviceVersion;
private String apiKey;
private String deploymentName;
private Tokenizer tokenizer;
private Double temperature;
private Double topP;
Expand All @@ -156,29 +169,29 @@ public static class Builder {
private Double frequencyPenalty;
private Duration timeout;
private Integer maxRetries;
private Proxy proxy;
private Boolean logRequests;
private Boolean logResponses;
private ProxyOptions proxyOptions;
private boolean logRequestsAndResponses;
private OpenAIClient openAIClient;

/**
* Sets the Azure OpenAI base URL. This is a mandatory parameter.
* Sets the Azure OpenAI endpoint. This is a mandatory parameter.
*
* @param baseUrl The Azure OpenAI base URL in the format: https://{resource}.openai.azure.com/openai/deployments/{deployment}
* @param endpoint The Azure OpenAI endpoint in the format: https://{resource}.openai.azure.com/
* @return builder
*/
public Builder baseUrl(String baseUrl) {
this.baseUrl = baseUrl;
public Builder endpoint(String endpoint) {
this.endpoint = endpoint;
return this;
}

/**
* Sets the Azure OpenAI API version. This is a mandatory parameter.
* Sets the Azure OpenAI API service version. This is a mandatory parameter.
*
* @param apiVersion The Azure OpenAI api version in the format: 2023-05-15
* @param serviceVersion The Azure OpenAI API service version in the format: 2023-05-15
* @return builder
*/
public Builder apiVersion(String apiVersion) {
this.apiVersion = apiVersion;
public Builder serviceVersion(String serviceVersion) {
this.serviceVersion = serviceVersion;
return this;
}

Expand All @@ -193,6 +206,17 @@ public Builder apiKey(String apiKey) {
return this;
}

/**
* Sets the deployment name in Azure OpenAI. This is a mandatory parameter.
*
* @param deploymentName The Deployment name.
* @return builder
*/
public Builder deploymentName(String deploymentName) {
this.deploymentName = deploymentName;
return this;
}

public Builder tokenizer(Tokenizer tokenizer) {
this.tokenizer = tokenizer;
return this;
Expand Down Expand Up @@ -233,38 +257,57 @@ public Builder maxRetries(Integer maxRetries) {
return this;
}

public Builder proxy(Proxy proxy) {
this.proxy = proxy;
public Builder proxyOptions(ProxyOptions proxyOptions) {
this.proxyOptions = proxyOptions;
return this;
}

public Builder logRequests(Boolean logRequests) {
this.logRequests = logRequests;
public Builder logRequestsAndResponses(Boolean logRequestsAndResponses) {
this.logRequestsAndResponses = logRequestsAndResponses;
return this;
}

public Builder logResponses(Boolean logResponses) {
this.logResponses = logResponses;
/**
* Sets the Azure OpenAI client. This is an optional parameter, if you need more flexibility than using the endpoint, serviceVersion, apiKey, deploymentName parameters.
*
* @param openAIClient The Azure OpenAI client.
* @return builder
*/
public Builder openAIClient(OpenAIClient openAIClient) {
this.openAIClient = openAIClient;
return this;
}

public AzureOpenAiChatModel build() {
return new AzureOpenAiChatModel(
baseUrl,
apiVersion,
apiKey,
tokenizer,
temperature,
topP,
maxTokens,
presencePenalty,
frequencyPenalty,
timeout,
maxRetries,
proxy,
logRequests,
logResponses
);
if (openAIClient == null) {
return new AzureOpenAiChatModel(
endpoint,
serviceVersion,
apiKey,
deploymentName,
tokenizer,
temperature,
topP,
maxTokens,
presencePenalty,
frequencyPenalty,
timeout,
maxRetries,
proxyOptions,
logRequestsAndResponses
);
} else {
return new AzureOpenAiChatModel(
openAIClient,
deploymentName,
tokenizer,
temperature,
topP,
maxTokens,
presencePenalty,
frequencyPenalty
);
}
}
}
}
Loading

0 comments on commit 09ab6a1

Please sign in to comment.