Skip to content

Commit

Permalink
Support for GitHub Models using the Azure AI Inference API (langchain…
Browse files Browse the repository at this point in the history
  • Loading branch information
langchain4j committed Sep 24, 2024
1 parent d546c64 commit 3579664
Show file tree
Hide file tree
Showing 15 changed files with 52 additions and 46 deletions.
1 change: 1 addition & 0 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ jobs:
GCP_LOCATION: ${{ secrets.GCP_LOCATION }}
GCP_PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }}
GCP_VERTEXAI_ENDPOINT: ${{ secrets.GCP_VERTEXAI_ENDPOINT }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GOOGLE_AI_GEMINI_API_KEY: ${{ secrets.GOOGLE_AI_GEMINI_API_KEY }}
HF_API_KEY: ${{ secrets.HF_API_KEY }}
JINA_API_KEY: ${{ secrets.JINA_API_KEY }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/nightly.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ jobs:
GCP_LOCATION: ${{ secrets.GCP_LOCATION }}
GCP_PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }}
GCP_VERTEXAI_ENDPOINT: ${{ secrets.GCP_VERTEXAI_ENDPOINT }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GOOGLE_AI_GEMINI_API_KEY: ${{ secrets.GOOGLE_AI_GEMINI_API_KEY }}
HF_API_KEY: ${{ secrets.HF_API_KEY }}
JINA_API_KEY: ${{ secrets.JINA_API_KEY }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ jobs:
GCP_LOCATION: ${{ secrets.GCP_LOCATION }}
GCP_PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }}
GCP_VERTEXAI_ENDPOINT: ${{ secrets.GCP_VERTEXAI_ENDPOINT }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GOOGLE_AI_GEMINI_API_KEY: ${{ secrets.GOOGLE_AI_GEMINI_API_KEY }}
HF_API_KEY: ${{ secrets.HF_API_KEY }}
JINA_API_KEY: ${{ secrets.JINA_API_KEY }}
Expand Down
6 changes: 6 additions & 0 deletions langchain4j-bom/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-github-models</artifactId>
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-hugging-face</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,8 @@ public static class Builder {
private List<String> stop;
private Double presencePenalty;
private Double frequencyPenalty;
Long seed;
ChatCompletionsResponseFormat responseFormat;
private Long seed;
private ChatCompletionsResponseFormat responseFormat;
private Duration timeout;
private Integer maxRetries;
private ProxyOptions proxyOptions;
Expand Down Expand Up @@ -296,6 +296,11 @@ public Builder modelName(String modelName) {
return this;
}

public Builder modelName(GitHubModelsChatModelName modelName) {
this.modelName = modelName.toString();
return this;
}

public Builder maxTokens(Integer maxTokens) {
this.maxTokens = maxTokens;
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,13 @@ public enum GitHubModelsChatModelName {
MISTRAL_SMALL("Mistral-small", "mistral");

private final String modelName;

private final String modelType;

GitHubModelsChatModelName(String modelName, String modelType) {
this.modelName = modelName;
this.modelType = modelType;
}

public String modelName() {
return modelName;
}

public String modelType() {
return modelType;
}

@Override
public String toString() {
return modelName;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,11 @@ public Builder modelName(String modelName) {
return this;
}

public Builder modelName(GitHubModelsEmbeddingModelName modelName) {
this.modelName = modelName.toString();
return this;
}

public Builder timeout(Duration timeout) {
this.timeout = timeout;
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,20 @@ public enum GitHubModelsEmbeddingModelName {
this.dimension = dimension;
}

public String modelName() {
return modelName;
}

@Override
public String toString() {
return modelName;
}

public Integer dimension() {
return dimension;
}

private static final Map<String, Integer> KNOWN_DIMENSION = new HashMap<>(GitHubModelsEmbeddingModelName.values().length);

static {
for (GitHubModelsEmbeddingModelName embeddingModelName : GitHubModelsEmbeddingModelName.values()) {
KNOWN_DIMENSION.put(embeddingModelName.toString(), embeddingModelName.dimension());
KNOWN_DIMENSION.put(embeddingModelName.toString(), embeddingModelName.dimension);
}
}

public static Integer knownDimension(String modelName) {
static Integer knownDimension(String modelName) {
return KNOWN_DIMENSION.get(modelName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,8 @@ public static class Builder {
private Double presencePenalty;
private Double frequencyPenalty;
private Duration timeout;
Long seed;
ChatCompletionsResponseFormat responseFormat;
private Long seed;
private ChatCompletionsResponseFormat responseFormat;
private Integer maxRetries;
private ProxyOptions proxyOptions;
private boolean logRequestsAndResponses;
Expand Down Expand Up @@ -346,6 +346,11 @@ public Builder modelName(String modelName) {
return this;
}

public Builder modelName(GitHubModelsChatModelName modelName) {
this.modelName = modelName.toString();
return this;
}

public Builder maxTokens(Integer maxTokens) {
this.maxTokens = maxTokens;
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,7 @@
*/
class GitHubModelsStreamingResponseBuilder {

Logger logger = LoggerFactory.getLogger(GitHubModelsStreamingResponseBuilder.class);

private final StringBuffer contentBuilder = new StringBuffer();
private final StringBuffer toolNameBuilder = new StringBuffer();
private final StringBuffer toolArgumentsBuilder = new StringBuffer();
private int inputTokenCount = 0;
private int outputTokenCount = 0;
private String toolExecutionsIndex = "call_undefined";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import dev.langchain4j.model.output.TokenUsage;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.EnumSource;
Expand All @@ -30,7 +31,8 @@
import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;

public class GitHubModelsChatModelIT {
@EnabledIfEnvironmentVariable(named = "GITHUB_TOKEN", matches = ".+")
class GitHubModelsChatModelIT {

private static final Logger logger = LoggerFactory.getLogger(GitHubModelsChatModelIT.class);

Expand All @@ -39,7 +41,7 @@ void should_generate_answer_and_finish_reason_stop() {

GitHubModelsChatModel model = GitHubModelsChatModel.builder()
.gitHubToken(System.getenv("GITHUB_TOKEN"))
.modelName(GPT_4_O_MINI.modelName())
.modelName(GPT_4_O_MINI)
.logRequestsAndResponses(true)
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.ChatModelListenerIT;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;

import static java.util.Collections.singletonList;

public class GitHubModelsChatModelListenerIT extends ChatModelListenerIT {
@EnabledIfEnvironmentVariable(named = "GITHUB_TOKEN", matches = ".+")
class GitHubModelsChatModelListenerIT extends ChatModelListenerIT {

@Override
protected ChatLanguageModel createModel(ChatModelListener listener) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;

Expand All @@ -15,11 +16,12 @@
import static dev.langchain4j.model.github.GitHubModelsEmbeddingModelName.TEXT_EMBEDDING_3_SMALL;
import static org.assertj.core.api.Assertions.assertThat;

@EnabledIfEnvironmentVariable(named = "GITHUB_TOKEN", matches = ".+")
class GitHubModelsEmbeddingModelIT {

EmbeddingModel model = GitHubModelsEmbeddingModel.builder()
.gitHubToken(System.getenv("GITHUB_TOKEN"))
.modelName(TEXT_EMBEDDING_3_SMALL.modelName())
.modelName(TEXT_EMBEDDING_3_SMALL)
.logRequestsAndResponses(true)
.build();

Expand Down Expand Up @@ -69,8 +71,7 @@ void should_support_all_string_model_names(GitHubModelsEmbeddingModelName modelN
// given
EmbeddingModel model = GitHubModelsEmbeddingModel.builder()
.gitHubToken(System.getenv("GITHUB_TOKEN"))
.modelName(modelName.modelName())
.dimensions(modelName.dimension())
.modelName(modelName.toString())
.logRequestsAndResponses(true)
.build();

Expand All @@ -89,7 +90,7 @@ void should_embed_text_with_embedding_shortening() {

EmbeddingModel model = GitHubModelsEmbeddingModel.builder()
.gitHubToken(System.getenv("GITHUB_TOKEN"))
.modelName(TEXT_EMBEDDING_3_SMALL.modelName())
.modelName(TEXT_EMBEDDING_3_SMALL)
.dimensions(dimensions)
.logRequestsAndResponses(true)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -26,31 +26,27 @@
import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER;
import static dev.langchain4j.data.message.ToolExecutionResultMessage.toolExecutionResultMessage;
import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.model.github.GitHubModelsChatModelName.PHI_3_5_MINI_INSTRUCT;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.assertj.core.api.Assertions.assertThat;

@EnabledIfEnvironmentVariable(named = "GITHUB_TOKEN", matches = ".+")
class GitHubModelsStreamingChatModelIT {

private static final Logger logger = LoggerFactory.getLogger(GitHubModelsStreamingChatModelIT.class);

public long STREAMING_TIMEOUT = 120;

@ParameterizedTest(name = "Model name {0} with async client set to {1}")
@CsvSource({
"Phi-3.5-mini-instruct, true",
"Phi-3.5-mini-instruct, false"
})
void should_stream_answer(String modelName, boolean useAsyncClient) throws Exception {
@Test
void should_stream_answer() throws Exception {

CompletableFuture<String> futureAnswer = new CompletableFuture<>();
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();

StreamingChatLanguageModel model = GitHubModelsStreamingChatModel.builder()
.gitHubToken(System.getenv("GITHUB_TOKEN"))
.modelName(modelName)
.modelName(PHI_3_5_MINI_INSTRUCT)
.logRequestsAndResponses(true)
.build();

Expand Down Expand Up @@ -95,7 +91,7 @@ public void onError(Throwable error) {
"Mistral-nemo",
"meta-llama-3-8b-instruct"
})
void test_different_available_models(String modelName) throws Exception {
void test_different_available_models(String modelName) {

StreamingChatLanguageModel model = GitHubModelsStreamingChatModel.builder()
.gitHubToken(System.getenv("GITHUB_TOKEN"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import dev.langchain4j.model.chat.StreamingChatModelListenerIT;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;

import static java.util.Collections.singletonList;

@EnabledIfEnvironmentVariable(named = "GITHUB_TOKEN", matches = ".+")
class GitHubModelsStreamingChatModelListenerIT extends StreamingChatModelListenerIT {

@Override
Expand Down

0 comments on commit 3579664

Please sign in to comment.