From d546c64511adad6830cb739f8e987cd7aaec40e8 Mon Sep 17 00:00:00 2001 From: Julien Dubois Date: Tue, 24 Sep 2024 19:18:42 +0200 Subject: [PATCH] Support for GitHub Models using the Azure AI Inference API (#1807) Fix #1719 This adds GitHub Models (see https://github.com/marketplace/models ) support with the new Azure AI Inference API Java SDK. --- .../language-models/github-models.md | 151 ++++++ .../language-models/google-ai-gemini.md | 2 +- .../language-models/google-palm.md | 2 +- .../google-vertex-ai-gemini.md | 2 +- .../language-models/hugging-face.md | 2 +- .../integrations/language-models/index.md | 1 + .../integrations/language-models/jlama.md | 2 +- .../integrations/language-models/local-ai.md | 2 +- .../language-models/mistral-ai.md | 2 +- .../integrations/language-models/ollama.md | 2 +- .../integrations/language-models/open-ai.md | 2 +- .../integrations/language-models/qianfan.md | 2 +- .../language-models/workers-ai.md | 2 +- .../integrations/language-models/zhipu-ai.md | 2 +- .../azure/AzureOpenAiStreamingChatModel.java | 8 +- langchain4j-github-models/pom.xml | 235 +++++++++ .../model/github/GitHubModelsChatModel.java | 426 ++++++++++++++++ .../github/GitHubModelsChatModelName.java | 58 +++ .../github/GitHubModelsEmbeddingModel.java | 258 ++++++++++ .../GitHubModelsEmbeddingModelName.java | 46 ++ .../GitHubModelsStreamingChatModel.java | 469 ++++++++++++++++++ .../GitHubModelsStreamingResponseBuilder.java | 138 ++++++ .../github/InternalGitHubModelHelper.java | 366 ++++++++++++++ .../GitHubModelsChatModelBuilderFactory.java | 11 + ...HubModelsEmbeddingModelBuilderFactory.java | 11 + ...odelsStreamingChatModelBuilderFactory.java | 11 + .../model/github/GitHubModelsChatModelIT.java | 377 ++++++++++++++ .../GitHubModelsChatModelListenerIT.java | 45 ++ .../github/GitHubModelsEmbeddingModelIT.java | 108 ++++ .../GitHubModelsStreamingChatModelIT.java | 367 ++++++++++++++ ...HubModelsStreamingChatModelListenerIT.java | 56 +++ .../src/test/resources/log4j2.xml | 18 + pom.xml | 1 + 33 files changed, 3170 insertions(+), 15 deletions(-) create mode 100644 docs/docs/integrations/language-models/github-models.md create mode 100644 langchain4j-github-models/pom.xml create mode 100644 langchain4j-github-models/src/main/java/dev/langchain4j/model/github/GitHubModelsChatModel.java create mode 100644 langchain4j-github-models/src/main/java/dev/langchain4j/model/github/GitHubModelsChatModelName.java create mode 100644 langchain4j-github-models/src/main/java/dev/langchain4j/model/github/GitHubModelsEmbeddingModel.java create mode 100644 langchain4j-github-models/src/main/java/dev/langchain4j/model/github/GitHubModelsEmbeddingModelName.java create mode 100644 langchain4j-github-models/src/main/java/dev/langchain4j/model/github/GitHubModelsStreamingChatModel.java create mode 100644 langchain4j-github-models/src/main/java/dev/langchain4j/model/github/GitHubModelsStreamingResponseBuilder.java create mode 100644 langchain4j-github-models/src/main/java/dev/langchain4j/model/github/InternalGitHubModelHelper.java create mode 100644 langchain4j-github-models/src/main/java/dev/langchain4j/model/github/spi/GitHubModelsChatModelBuilderFactory.java create mode 100644 langchain4j-github-models/src/main/java/dev/langchain4j/model/github/spi/GitHubModelsEmbeddingModelBuilderFactory.java create mode 100644 langchain4j-github-models/src/main/java/dev/langchain4j/model/github/spi/GitHubModelsStreamingChatModelBuilderFactory.java create mode 100644 langchain4j-github-models/src/test/java/dev/langchain4j/model/github/GitHubModelsChatModelIT.java create mode 100644 langchain4j-github-models/src/test/java/dev/langchain4j/model/github/GitHubModelsChatModelListenerIT.java create mode 100644 langchain4j-github-models/src/test/java/dev/langchain4j/model/github/GitHubModelsEmbeddingModelIT.java create mode 100644 langchain4j-github-models/src/test/java/dev/langchain4j/model/github/GitHubModelsStreamingChatModelIT.java create mode 100644 langchain4j-github-models/src/test/java/dev/langchain4j/model/github/GitHubModelsStreamingChatModelListenerIT.java create mode 100644 langchain4j-github-models/src/test/resources/log4j2.xml diff --git a/docs/docs/integrations/language-models/github-models.md b/docs/docs/integrations/language-models/github-models.md new file mode 100644 index 00000000000..cabf501d87c --- /dev/null +++ b/docs/docs/integrations/language-models/github-models.md @@ -0,0 +1,151 @@ +--- +sidebar_position: 6 +--- + +# GitHub Models + +If you want to develop a generative AI application, you can use GitHub Models to find and experiment with AI models for free. +Once you are ready to bring your application to production, you can switch to a token from a paid Azure account. + +## GitHub Models Documentation + +- [GitHub Models Documentation](https://docs.github.com/en/github-models) +- [GitHub Models Marketplace](https://github.com/marketplace/models) + +## Maven Dependency + +### Plain Java + +```xml + + dev.langchain4j + langchain4j-github-models + 0.34.0 + +``` + +## GitHub token + +To use GitHub Models, you need to use a GitHub token for authentication. + +Token are created and managed in [GitHub Developer Settings > Personal access tokens](https://github.com/settings/tokens). + +Once you have a token, you can set it as an environment variable and use it in your code: + +```bash +export GITHUB_TOKEN="" +``` + +## Creating a `GitHubModelsChatModel` with a GitHub token + +### Plain Java + +```java +GitHubModelsChatModel model = GitHubModelsChatModel.builder() + .gitHubToken(System.getenv("GITHUB_TOKEN")) + .modelName("gpt-4o-mini") + .build(); +``` + +This will create an instance of `GitHubModelsChatModel`. +Model parameters (e.g. `temperature`) can be customized by providing values in the `GitHubModelsChatModel`'s builder. + +### Spring Boot + +Create a `GitHubModelsChatModelConfiguration` Spring Bean: + +```Java +package com.example.demo.configuration.github; + +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.github.GitHubModelsChatModel; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Profile; + +@Configuration +@Profile("github") +public class GitHubModelsChatModelConfiguration { + + @Value("${GITHUB_TOKEN}") + private String gitHubToken; + + @Bean + ChatLanguageModel gitHubModelsChatLanguageModel() { + return GitHubModelsChatModel.builder() + .gitHubToken(gitHubToken) + .modelName("gpt-4o-mini") + .logRequestsAndResponses(true) + .build(); + } +} +``` + +This configuration will create an `GitHubModelsChatModel` bean, +which can be either used by an [AI Service](https://docs.langchain4j.dev/tutorials/spring-boot-integration/#langchain4j-spring-boot-starter) +or autowired where needed, for example: + +```java +@RestController +class ChatLanguageModelController { + + ChatLanguageModel chatLanguageModel; + + ChatLanguageModelController(ChatLanguageModel chatLanguageModel) { + this.chatLanguageModel = chatLanguageModel; + } + + @GetMapping("/model") + public String model(@RequestParam(value = "message", defaultValue = "Hello") String message) { + return chatLanguageModel.generate(message); + } +} +``` + +## Creating a `GitHubModelsStreamingChatModel` with a GitHub token + +### Plain Java + +```java +GitHubModelsStreamingChatModel model = GitHubModelsStreamingChatModel.builder() + .gitHubToken(System.getenv("GITHUB_TOKEN")) + .modelName("gpt-4o-mini") + .logRequestsAndResponses(true) + .build(); +``` + +### Spring Boot + +Create a `GitHubModelsStreamingChatModelConfiguration` Spring Bean: +```Java +package com.example.demo.configuration.github; + +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.github.GitHubModelsChatModel; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Profile; + +@Configuration +@Profile("github") +public class GitHubModelsStreamingChatModelConfiguration { + + @Value("${GITHUB_TOKEN}") + private String gitHubToken; + + @Bean + GitHubModelsStreamingChatModel gitHubModelsStreamingChatLanguageModel() { + return GitHubModelsStreamingChatModel.builder() + .gitHubToken(System.getenv("GITHUB_TOKEN")) + .modelName("gpt-4o-mini") + .logRequestsAndResponses(true) + .build(); + } +} +``` + +## Examples + +- [GitHub Models Examples](https://github.com/langchain4j/langchain4j-examples/tree/main/github-models-examples/src/main/java) diff --git a/docs/docs/integrations/language-models/google-ai-gemini.md b/docs/docs/integrations/language-models/google-ai-gemini.md index 8da8e8cae8a..f800a5f732b 100644 --- a/docs/docs/integrations/language-models/google-ai-gemini.md +++ b/docs/docs/integrations/language-models/google-ai-gemini.md @@ -1,5 +1,5 @@ --- -sidebar_position: 6 +sidebar_position: 7 --- # Google AI Gemini diff --git a/docs/docs/integrations/language-models/google-palm.md b/docs/docs/integrations/language-models/google-palm.md index d647ba2eaac..a216f7d6583 100644 --- a/docs/docs/integrations/language-models/google-palm.md +++ b/docs/docs/integrations/language-models/google-palm.md @@ -1,5 +1,5 @@ --- -sidebar_position: 8 +sidebar_position: 9 --- # Google Vertex AI PaLM 2 diff --git a/docs/docs/integrations/language-models/google-vertex-ai-gemini.md b/docs/docs/integrations/language-models/google-vertex-ai-gemini.md index 993285e8ada..53c13604c98 100644 --- a/docs/docs/integrations/language-models/google-vertex-ai-gemini.md +++ b/docs/docs/integrations/language-models/google-vertex-ai-gemini.md @@ -1,5 +1,5 @@ --- -sidebar_position: 7 +sidebar_position: 8 --- # Google Vertex AI Gemini diff --git a/docs/docs/integrations/language-models/hugging-face.md b/docs/docs/integrations/language-models/hugging-face.md index cdd6e4e1c54..480220fce55 100644 --- a/docs/docs/integrations/language-models/hugging-face.md +++ b/docs/docs/integrations/language-models/hugging-face.md @@ -1,5 +1,5 @@ --- -sidebar_position: 9 +sidebar_position: 10 --- # Hugging Face diff --git a/docs/docs/integrations/language-models/index.md b/docs/docs/integrations/language-models/index.md index a60d8fd9ba2..cdf19dcd19b 100644 --- a/docs/docs/integrations/language-models/index.md +++ b/docs/docs/integrations/language-models/index.md @@ -11,6 +11,7 @@ sidebar_position: 0 | [Azure OpenAI](/integrations/language-models/azure-open-ai) | ✅ | ✅ | ✅ | text, image | ✅ | | | | | [ChatGLM](/integrations/language-models/chatglm) | | | | text | | | | | | [DashScope](/integrations/language-models/dashscope) | ✅ | ✅ | | text, image, audio | ✅ | | | | +| [GitHub Models](/integrations/language-models/github-models) | ✅ | ✅ | ✅ | text | ✅ | | | | | [Google AI Gemini](/integrations/language-models/google-ai-gemini) | | ✅ | ✅ | text, image, audio, video, PDF | ✅ | | | | | [Google Vertex AI Gemini](/integrations/language-models/google-vertex-ai-gemini) | ✅ | ✅ | ✅ | text, image, audio, video, PDF | ✅ | | | | | [Google Vertex AI PaLM 2](/integrations/language-models/google-palm) | | | | text | | | ✅ | | diff --git a/docs/docs/integrations/language-models/jlama.md b/docs/docs/integrations/language-models/jlama.md index 6592a855519..be7a68ed244 100644 --- a/docs/docs/integrations/language-models/jlama.md +++ b/docs/docs/integrations/language-models/jlama.md @@ -1,5 +1,5 @@ --- -sidebar_position: 10 +sidebar_position: 11 --- # Jlama diff --git a/docs/docs/integrations/language-models/local-ai.md b/docs/docs/integrations/language-models/local-ai.md index cec26841d42..67d6acc1ec9 100644 --- a/docs/docs/integrations/language-models/local-ai.md +++ b/docs/docs/integrations/language-models/local-ai.md @@ -1,5 +1,5 @@ --- -sidebar_position: 11 +sidebar_position: 12 --- # LocalAI diff --git a/docs/docs/integrations/language-models/mistral-ai.md b/docs/docs/integrations/language-models/mistral-ai.md index f5bd7bf8fe6..3d24d8645df 100644 --- a/docs/docs/integrations/language-models/mistral-ai.md +++ b/docs/docs/integrations/language-models/mistral-ai.md @@ -1,5 +1,5 @@ --- -sidebar_position: 12 +sidebar_position: 13 --- # MistralAI diff --git a/docs/docs/integrations/language-models/ollama.md b/docs/docs/integrations/language-models/ollama.md index c0711f0492d..98d1bb32e90 100644 --- a/docs/docs/integrations/language-models/ollama.md +++ b/docs/docs/integrations/language-models/ollama.md @@ -1,5 +1,5 @@ --- -sidebar_position: 13 +sidebar_position: 14 --- # Ollama diff --git a/docs/docs/integrations/language-models/open-ai.md b/docs/docs/integrations/language-models/open-ai.md index 3f374b41c7f..18a89666f7d 100644 --- a/docs/docs/integrations/language-models/open-ai.md +++ b/docs/docs/integrations/language-models/open-ai.md @@ -1,5 +1,5 @@ --- -sidebar_position: 14 +sidebar_position: 15 --- # OpenAI diff --git a/docs/docs/integrations/language-models/qianfan.md b/docs/docs/integrations/language-models/qianfan.md index 59f7165cfbf..42df3d4dc58 100644 --- a/docs/docs/integrations/language-models/qianfan.md +++ b/docs/docs/integrations/language-models/qianfan.md @@ -1,5 +1,5 @@ --- -sidebar_position: 15 +sidebar_position: 16 --- # Qianfan diff --git a/docs/docs/integrations/language-models/workers-ai.md b/docs/docs/integrations/language-models/workers-ai.md index 7214e68dccf..4c0a548c417 100644 --- a/docs/docs/integrations/language-models/workers-ai.md +++ b/docs/docs/integrations/language-models/workers-ai.md @@ -1,5 +1,5 @@ --- -sidebar_position: 16 +sidebar_position: 17 --- # Cloudflare Workers AI diff --git a/docs/docs/integrations/language-models/zhipu-ai.md b/docs/docs/integrations/language-models/zhipu-ai.md index c542f5f3e4a..e86dbf65f54 100644 --- a/docs/docs/integrations/language-models/zhipu-ai.md +++ b/docs/docs/integrations/language-models/zhipu-ai.md @@ -1,5 +1,5 @@ --- -sidebar_position: 17 +sidebar_position: 18 --- # Zhipu AI diff --git a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModel.java b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModel.java index f88a9531e40..5e6bf80505b 100644 --- a/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModel.java +++ b/langchain4j-azure-open-ai/src/main/java/dev/langchain4j/model/azure/AzureOpenAiStreamingChatModel.java @@ -148,10 +148,12 @@ public AzureOpenAiStreamingChatModel(String endpoint, Map customHeaders) { this(deploymentName, tokenizer, maxTokens, temperature, topP, logitBias, user, n, stop, presencePenalty, frequencyPenalty, dataSources, enhancements, seed, responseFormat, listeners); - if(useAsyncClient) + if(useAsyncClient) { this.asyncClient = setupAsyncClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, proxyOptions, logRequestsAndResponses, userAgentSuffix, customHeaders); - else - this.client = setupSyncClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, proxyOptions, logRequestsAndResponses, userAgentSuffix, customHeaders); } + } else { + this.client = setupSyncClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, proxyOptions, logRequestsAndResponses, userAgentSuffix, customHeaders); + } + } public AzureOpenAiStreamingChatModel(String endpoint, String serviceVersion, diff --git a/langchain4j-github-models/pom.xml b/langchain4j-github-models/pom.xml new file mode 100644 index 00000000000..dcec78a6d3c --- /dev/null +++ b/langchain4j-github-models/pom.xml @@ -0,0 +1,235 @@ + + + 4.0.0 + + + dev.langchain4j + langchain4j-parent + 0.35.0-SNAPSHOT + ../langchain4j-parent/pom.xml + + + langchain4j-github-models + LangChain4j :: Integration :: GitHub Models + + + + + dev.langchain4j + langchain4j-core + + + + com.azure + azure-ai-inference + 1.0.0-beta.1 + + + com.azure + azure-core + + + io.netty + netty-transport-native-epoll + + + io.netty + netty-transport-native-unix-common + + + io.netty + netty-codec-http2 + + + io.netty + netty-transport + + + io.netty + netty-common + + + io.netty + netty-resolver + + + io.netty + netty-handler-proxy + + + io.netty + netty-codec-http + + + io.netty + netty-buffer + + + io.netty + netty-codec + + + io.netty + netty-handler + + + org.junit.jupiter + junit-jupiter-api + + + + + + com.azure + azure-core + 1.52.0 + + + + io.netty + netty-transport-native-epoll + 4.1.110.Final + + + + io.netty + netty-codec-http2 + 4.1.110.Final + + + + io.netty + netty-transport + 4.1.110.Final + + + io.netty + netty-common + 4.1.110.Final + + + io.netty + netty-resolver + 4.1.110.Final + + + io.netty + netty-handler-proxy + 4.1.110.Final + + + io.netty + netty-codec-http + 4.1.110.Final + + + io.netty + netty-buffer + 4.1.110.Final + + + io.netty + netty-codec + 4.1.110.Final + + + io.netty + netty-handler + 4.1.110.Final + + + + dev.langchain4j + langchain4j-core + tests + test-jar + test + + + org.apache.logging.log4j + log4j-api + 2.22.0 + test + + + + org.apache.logging.log4j + log4j-core + 2.22.0 + test + + + + org.apache.logging.log4j + log4j-slf4j2-impl + 2.20.0 + test + + + + org.junit.jupiter + junit-jupiter-engine + test + + + + org.junit.jupiter + junit-jupiter-params + test + + + + org.assertj + assertj-core + test + + + + org.mockito + mockito-core + test + + + + org.mockito + mockito-junit-jupiter + test + + + + + + + + org.apache.maven.plugins + maven-enforcer-plugin + 3.5.0 + + + enforce + + + + + + + enforce + + + + + + + + + + Apache-2.0 + https://www.apache.org/licenses/LICENSE-2.0.txt + repo + A business-friendly OSS license + + + + \ No newline at end of file diff --git a/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/GitHubModelsChatModel.java b/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/GitHubModelsChatModel.java new file mode 100644 index 00000000000..a6c4eea3dd6 --- /dev/null +++ b/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/GitHubModelsChatModel.java @@ -0,0 +1,426 @@ +package dev.langchain4j.model.github; + +import com.azure.ai.inference.ChatCompletionsClient; +import com.azure.ai.inference.ModelServiceVersion; +import com.azure.ai.inference.models.ChatCompletions; +import com.azure.ai.inference.models.ChatCompletionsOptions; +import com.azure.ai.inference.models.ChatCompletionsResponseFormat; +import com.azure.core.exception.HttpResponseException; +import com.azure.core.http.ProxyOptions; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.listener.*; +import dev.langchain4j.model.github.spi.GitHubModelsChatModelBuilderFactory; +import dev.langchain4j.model.output.FinishReason; +import dev.langchain4j.model.output.Response; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import static dev.langchain4j.data.message.AiMessage.aiMessage; +import static dev.langchain4j.internal.Utils.copyIfNotNull; +import static dev.langchain4j.internal.Utils.isNullOrEmpty; +import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; +import static dev.langchain4j.model.github.InternalGitHubModelHelper.*; +import static dev.langchain4j.spi.ServiceHelper.loadFactories; +import static java.util.Collections.emptyList; +import static java.util.Collections.singletonList; + +/** + * Represents a language model, hosted on GitHub Models, that has a chat completion interface, such as gpt-4o. + *

+ * Mandatory parameters for initialization are: gitHubToken (the GitHub Token used for authentication) and modelName (the name of the model to use). + * You can also provide your own ChatCompletionsClient instance, if you need more flexibility. + *

+ * The list of models, as well as the documentation and a playground to test them, can be found at https://github.com/marketplace/models + */ +public class GitHubModelsChatModel implements ChatLanguageModel { + + private static final Logger logger = LoggerFactory.getLogger(GitHubModelsChatModel.class); + + private ChatCompletionsClient client; + private final String modelName; + private final Integer maxTokens; + private final Double temperature; + private final Double topP; + private final List stop; + private final Double presencePenalty; + private final Double frequencyPenalty; + private final Long seed; + private final ChatCompletionsResponseFormat responseFormat; + private final List listeners; + + private GitHubModelsChatModel(ChatCompletionsClient client, + String modelName, + Integer maxTokens, + Double temperature, + Double topP, + List stop, + Double presencePenalty, + Double frequencyPenalty, + Long seed, + ChatCompletionsResponseFormat responseFormat, + List listeners) { + + this(modelName, maxTokens, temperature, topP, stop, presencePenalty, frequencyPenalty, seed, responseFormat, listeners); + this.client = client; + } + + private GitHubModelsChatModel(String endpoint, + ModelServiceVersion serviceVersion, + String gitHubToken, + String modelName, + Integer maxTokens, + Double temperature, + Double topP, + List stop, + Double presencePenalty, + Double frequencyPenalty, + Long seed, + ChatCompletionsResponseFormat responseFormat, + Duration timeout, + Integer maxRetries, + ProxyOptions proxyOptions, + boolean logRequestsAndResponses, + List listeners, + String userAgentSuffix, + Map customHeaders) { + + this(modelName, maxTokens, temperature, topP, stop, presencePenalty, frequencyPenalty, seed, responseFormat, listeners); + this.client = setupChatCompletionsBuilder(endpoint, serviceVersion, gitHubToken, timeout, maxRetries, proxyOptions, logRequestsAndResponses, userAgentSuffix, customHeaders).buildClient(); + } + + private GitHubModelsChatModel(String modelName, + Integer maxTokens, + Double temperature, + Double topP, + List stop, + Double presencePenalty, + Double frequencyPenalty, + Long seed, + ChatCompletionsResponseFormat responseFormat, + List listeners) { + + this.modelName = ensureNotBlank(modelName, "modelName"); + this.maxTokens = maxTokens; + this.temperature = temperature; + this.topP = topP; + this.stop = copyIfNotNull(stop); + this.presencePenalty = presencePenalty; + this.frequencyPenalty = frequencyPenalty; + this.seed = seed; + this.responseFormat = responseFormat; + this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners); + } + + @Override + public Response generate(List messages) { + return generate(messages, null, null); + } + + @Override + public Response generate(List messages, List toolSpecifications) { + return generate(messages, toolSpecifications, null); + } + + @Override + public Response generate(List messages, ToolSpecification toolSpecification) { + return generate(messages, singletonList(toolSpecification), toolSpecification); + } + + private Response generate(List messages, + List toolSpecifications, + ToolSpecification toolThatMustBeExecuted + ) { + ChatCompletionsOptions options = new ChatCompletionsOptions(toAzureAiMessages(messages)) + .setModel(modelName) + .setMaxTokens(maxTokens) + .setTemperature(temperature) + .setTopP(topP) + .setStop(stop) + .setPresencePenalty(presencePenalty) + .setFrequencyPenalty(frequencyPenalty) + .setSeed(seed) + .setResponseFormat(responseFormat); + + if (toolThatMustBeExecuted != null) { + options.setTools(toToolDefinitions(singletonList(toolThatMustBeExecuted))); + options.setToolChoice(toToolChoice(toolThatMustBeExecuted)); + } + if (!isNullOrEmpty(toolSpecifications)) { + options.setTools(toToolDefinitions(toolSpecifications)); + } + + ChatModelRequest modelListenerRequest = createModelListenerRequest(options, messages, toolSpecifications); + Map attributes = new ConcurrentHashMap<>(); + ChatModelRequestContext requestContext = new ChatModelRequestContext(modelListenerRequest, attributes); + listeners.forEach(listener -> { + try { + listener.onRequest(requestContext); + } catch (Exception e) { + logger.warn("Exception while calling model listener", e); + } + }); + + try { + ChatCompletions chatCompletions = client.complete(options); + Response response = Response.from( + aiMessageFrom(chatCompletions.getChoices().get(0).getMessage()), + tokenUsageFrom(chatCompletions.getUsage()), + finishReasonFrom(chatCompletions.getChoices().get(0).getFinishReason()) + ); + + ChatModelResponse modelListenerResponse = createModelListenerResponse( + chatCompletions.getId(), + options.getModel(), + response + ); + ChatModelResponseContext responseContext = new ChatModelResponseContext( + modelListenerResponse, + modelListenerRequest, + attributes + ); + listeners.forEach(listener -> { + try { + listener.onResponse(responseContext); + } catch (Exception e) { + logger.warn("Exception while calling model listener", e); + } + }); + + return response; + } catch (HttpResponseException httpResponseException) { + logger.info("Error generating response, {}", httpResponseException.getValue()); + FinishReason exceptionFinishReason = contentFilterManagement(httpResponseException, "content_filter"); + Response response = Response.from( + aiMessage(httpResponseException.getMessage()), + null, + exceptionFinishReason + ); + ChatModelErrorContext errorContext = new ChatModelErrorContext( + httpResponseException, + modelListenerRequest, + null, + attributes + ); + + listeners.forEach(listener -> { + try { + listener.onError(errorContext); + } catch (Exception e2) { + logger.warn("Exception while calling model listener", e2); + } + }); + return response; + } + } + + public static Builder builder() { + for (GitHubModelsChatModelBuilderFactory factory : loadFactories(GitHubModelsChatModelBuilderFactory.class)) { + return factory.get(); + } + return new Builder(); + } + + public static class Builder { + + private String endpoint; + private ModelServiceVersion serviceVersion; + private String gitHubToken; + private String modelName; + private Integer maxTokens; + private Double temperature; + private Double topP; + private List stop; + private Double presencePenalty; + private Double frequencyPenalty; + Long seed; + ChatCompletionsResponseFormat responseFormat; + private Duration timeout; + private Integer maxRetries; + private ProxyOptions proxyOptions; + private boolean logRequestsAndResponses; + private ChatCompletionsClient chatCompletionsClient; + private String userAgentSuffix; + private List listeners; + private Map customHeaders; + + /** + * Sets the GitHub Models endpoint. The default endpoint will be used if this isn't set. + * + * @param endpoint The GitHub Models endpoint in the format: https://models.inference.ai.azure.com + * @return builder + */ + public Builder endpoint(String endpoint) { + this.endpoint = endpoint; + return this; + } + + /** + * Sets the Azure OpenAI API service version. If left blank, the latest service version will be used. + * + * @param serviceVersion The Azure OpenAI API service version in the format: 2023-05-15 + * @return builder + */ + public Builder serviceVersion(ModelServiceVersion serviceVersion) { + this.serviceVersion = serviceVersion; + return this; + } + + /** + * Sets the GitHub token to access GitHub Models. + * + * @param gitHubToken The GitHub token. + * @return builder + */ + public Builder gitHubToken(String gitHubToken) { + this.gitHubToken = gitHubToken; + return this; + } + + /** + * Sets the model name in Azure AI Inference API. This is a mandatory parameter. + * + * @param modelName The Model name. + * @return builder + */ + public Builder modelName(String modelName) { + this.modelName = modelName; + return this; + } + + public Builder maxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + public Builder temperature(Double temperature) { + this.temperature = temperature; + return this; + } + + public Builder topP(Double topP) { + this.topP = topP; + return this; + } + + public Builder stop(List stop) { + this.stop = stop; + return this; + } + + public Builder presencePenalty(Double presencePenalty) { + this.presencePenalty = presencePenalty; + return this; + } + + public Builder frequencyPenalty(Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder seed(Long seed) { + this.seed = seed; + return this; + } + + public Builder responseFormat(ChatCompletionsResponseFormat responseFormat) { + this.responseFormat = responseFormat; + return this; + } + + public Builder timeout(Duration timeout) { + this.timeout = timeout; + return this; + } + + public Builder maxRetries(Integer maxRetries) { + this.maxRetries = maxRetries; + return this; + } + + public Builder proxyOptions(ProxyOptions proxyOptions) { + this.proxyOptions = proxyOptions; + return this; + } + + public Builder logRequestsAndResponses(Boolean logRequestsAndResponses) { + this.logRequestsAndResponses = logRequestsAndResponses; + return this; + } + + public Builder userAgentSuffix(String userAgentSuffix) { + this.userAgentSuffix = userAgentSuffix; + return this; + } + + /** + * Sets the Azure AI Inference API client. This is an optional parameter, if you need more flexibility than the common parameters. + * + * @param chatCompletionsClient The Azure AI Inference API client. + * @return builder + */ + public Builder chatCompletionsClient(ChatCompletionsClient chatCompletionsClient) { + this.chatCompletionsClient = chatCompletionsClient; + return this; + } + + public Builder listeners(List listeners) { + this.listeners = listeners; + return this; + } + + public Builder customHeaders(Map customHeaders) { + this.customHeaders = customHeaders; + return this; + } + + public GitHubModelsChatModel build() { + if (chatCompletionsClient == null) { + return new GitHubModelsChatModel( + endpoint, + serviceVersion, + gitHubToken, + modelName, + maxTokens, + temperature, + topP, + stop, + presencePenalty, + frequencyPenalty, + seed, + responseFormat, + timeout, + maxRetries, + proxyOptions, + logRequestsAndResponses, + listeners, + userAgentSuffix, + customHeaders + ); + + } else { + return new GitHubModelsChatModel( + chatCompletionsClient, + modelName, + maxTokens, + temperature, + topP, + stop, + presencePenalty, + frequencyPenalty, + seed, + responseFormat, + listeners + ); + } + } + } +} \ No newline at end of file diff --git a/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/GitHubModelsChatModelName.java b/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/GitHubModelsChatModelName.java new file mode 100644 index 00000000000..944f0d73204 --- /dev/null +++ b/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/GitHubModelsChatModelName.java @@ -0,0 +1,58 @@ +package dev.langchain4j.model.github; + +public enum GitHubModelsChatModelName { + + GPT_4_O("gpt-4o", "gpt-4"), + GPT_4_O_MINI("gpt-4o-mini", "gpt-4"), + //O1_MINI("o1-mini", "o1"), + //O1_PREVIEW("o1-preview", "o1"), + + PHI_3_5_MINI_INSTRUCT("Phi-3.5-mini-instruct", "phi"), + PHI_3_5_VISION_INSTRUCT("Phi-3.5-vision-instruct", "phi"), + PHI_3_MEDIUM_INSTRUCT_128K("Phi-3-medium-128k-instruct", "phi"), + PHI_3_MEDIUM_INSTRUCT_4K("Phi-3-medium-4k-instruct", "phi"), + PHI_3_MINI_INSTRUCT_128K("Phi-3-mini-128k-instruct", "phi"), + PHI_3_MINI_INSTRUCT_4K("Phi-3-mini-4k-instruct", "phi"), + PHI_3_SMALL_INSTRUCT_128K("Phi-3-small-128k-instruct", "phi"), + PHI_3_SMALL_INSTRUCT_8K("Phi-3-small-8k-instruct", "phi"), + + AI21_JAMBA_1_5_LARGE("ai21-jamba-1.5-large", "ai21"), + AI21_JAMBA_1_5_MINI("ai21-jamba-1.5-mini", "ai21"), + AI21_JAMBA_INSTRUCT("ai21-jamba-instruct", "ai21"), + + COHERE_COMMAND_R("cohere-command-r", "cohere"), + COHERE_COMMAND_R_PLUS("cohere-command-r-plus", "cohere"), + + META_LLAMA_3_1_405B_INSTRUCT("meta-llama-3.1-405b-instruct", "meta-llama"), + META_LLAMA_3_1_70B_INSTRUCT("meta-llama-3.1-70b-instruct", "meta-llama"), + META_LLAMA_3_1_8B_INSTRUCT("meta-llama-3.1-8b-instruct", "meta-llama"), + META_LLAMA_3_70B_INSTRUCT("meta-llama-3-70b-instruct", "meta-llama"), + META_LLAMA_3_8B_INSTRUCT("meta-llama-3-8b-instruct", "meta-llama"), + + MISTRAL_NEMO("Mistral-nemo", "mistral"), + MISTRAL_LARGE("Mistral-large", "mistral"), + MISTRAL_LARGE_2407("Mistral-large-2407", "mistral"), + MISTRAL_SMALL("Mistral-small", "mistral"); + + private final String modelName; + + private final String modelType; + + GitHubModelsChatModelName(String modelName, String modelType) { + this.modelName = modelName; + this.modelType = modelType; + } + + public String modelName() { + return modelName; + } + + public String modelType() { + return modelType; + } + + @Override + public String toString() { + return modelName; + } +} diff --git a/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/GitHubModelsEmbeddingModel.java b/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/GitHubModelsEmbeddingModel.java new file mode 100644 index 00000000000..69cf20fff5b --- /dev/null +++ b/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/GitHubModelsEmbeddingModel.java @@ -0,0 +1,258 @@ +package dev.langchain4j.model.github; + +import com.azure.ai.inference.EmbeddingsClient; +import com.azure.ai.inference.ModelServiceVersion; +import com.azure.ai.inference.models.EmbeddingItem; +import com.azure.ai.inference.models.EmbeddingsResult; +import com.azure.core.http.ProxyOptions; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.DimensionAwareEmbeddingModel; +import dev.langchain4j.model.github.spi.GitHubModelsEmbeddingModelBuilderFactory; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static dev.langchain4j.data.embedding.Embedding.from; +import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; +import static dev.langchain4j.model.github.InternalGitHubModelHelper.setupEmbeddingsBuilder; +import static dev.langchain4j.spi.ServiceHelper.loadFactories; +import static java.util.stream.Collectors.toList; + +/** + * Represents an embedding model, hosted on GitHub Models, such as text-embedding-3-small. + *

+ * Mandatory parameters for initialization are: gitHubToken (the GitHub Token used for authentication) and modelName (the name of the model to use). + * You can also provide your own EmbeddingsClient instance, if you need more flexibility. + *

+ * The list of models, as well as the documentation and a playground to test them, can be found at https://github.com/marketplace/models + */ +public class GitHubModelsEmbeddingModel extends DimensionAwareEmbeddingModel { + + private static final Logger logger = LoggerFactory.getLogger(GitHubModelsEmbeddingModel.class); + + private static final int BATCH_SIZE = 16; + + private EmbeddingsClient client; + private final String modelName; + private final Integer dimensions; + + private GitHubModelsEmbeddingModel(EmbeddingsClient client, + String modelName, + Integer dimensions) { + this(modelName, dimensions); + this.client = client; + } + + private GitHubModelsEmbeddingModel(String endpoint, + ModelServiceVersion serviceVersion, + String gitHubToken, + String modelName, + Duration timeout, + Integer maxRetries, + ProxyOptions proxyOptions, + boolean logRequestsAndResponses, + String userAgentSuffix, + Integer dimensions, + Map customHeaders) { + + this(modelName, dimensions); + this.client = setupEmbeddingsBuilder(endpoint, serviceVersion, gitHubToken, timeout, maxRetries, proxyOptions, logRequestsAndResponses, userAgentSuffix, customHeaders) + .buildClient(); + } + + private GitHubModelsEmbeddingModel(String modelName, Integer dimensions) { + this.modelName = ensureNotBlank(modelName, "modelName"); + this.dimensions = dimensions; + } + + /** + * Embeds the provided text segments, processing a maximum of 16 segments at a time. + * For more information, refer to the documentation here. + * + * @param textSegments A list of text segments. + * @return A list of corresponding embeddings. + */ + @Override + public Response> embedAll(List textSegments) { + + List texts = textSegments.stream() + .map(TextSegment::text) + .collect(toList()); + + return embedTexts(texts); + } + + private Response> embedTexts(List texts) { + + List embeddings = new ArrayList<>(); + + int inputTokenCount = 0; + for (int i = 0; i < texts.size(); i += BATCH_SIZE) { + + List batch = texts.subList(i, Math.min(i + BATCH_SIZE, texts.size())); + + EmbeddingsResult result = client.embed(batch, dimensions, null, null, modelName, null); + for (EmbeddingItem embeddingItem : result.getData()) { + Embedding embedding = from(embeddingItem.getEmbeddingList()); + embeddings.add(embedding); + } + inputTokenCount += result.getUsage().getPromptTokens(); + } + + return Response.from( + embeddings, + new TokenUsage(inputTokenCount) + ); + } + + @Override + protected Integer knownDimension() { + if (dimensions != null) { + return dimensions; + } + return GitHubModelsEmbeddingModelName.knownDimension(modelName); + } + + public static Builder builder() { + for (GitHubModelsEmbeddingModelBuilderFactory factory : loadFactories(GitHubModelsEmbeddingModelBuilderFactory.class)) { + return factory.get(); + } + return new Builder(); + } + + public static class Builder { + + private String endpoint; + private ModelServiceVersion serviceVersion; + private String gitHubToken; + private String modelName; + private Duration timeout; + private Integer maxRetries; + private ProxyOptions proxyOptions; + private boolean logRequestsAndResponses; + private EmbeddingsClient embeddingsClient; + private String userAgentSuffix; + private Integer dimensions; + private Map customHeaders; + + /** + * Sets the GitHub Models endpoint. The default endpoint will be used if this isn't set. + * + * @param endpoint The GitHub Models endpoint in the format: https://models.inference.ai.azure.com + * @return builder + */ + public Builder endpoint(String endpoint) { + this.endpoint = endpoint; + return this; + } + + /** + * Sets the Azure OpenAI API service version. If left blank, the latest service version will be used. + * + * @param serviceVersion The Azure OpenAI API service version in the format: 2023-05-15 + * @return builder + */ + public Builder serviceVersion(ModelServiceVersion serviceVersion) { + this.serviceVersion = serviceVersion; + return this; + } + + /** + * Sets the GitHub token to access GitHub Models. + * + * @param gitHubToken The GitHub token. + * @return builder + */ + public Builder gitHubToken(String gitHubToken) { + this.gitHubToken = gitHubToken; + return this; + } + + /** + * Sets the model name in Azure AI Inference API. This is a mandatory parameter. + * + * @param modelName The Model name. + * @return builder + */ + public Builder modelName(String modelName) { + this.modelName = modelName; + return this; + } + + public Builder timeout(Duration timeout) { + this.timeout = timeout; + return this; + } + + public Builder maxRetries(Integer maxRetries) { + this.maxRetries = maxRetries; + return this; + } + + public Builder proxyOptions(ProxyOptions proxyOptions) { + this.proxyOptions = proxyOptions; + return this; + } + + public Builder logRequestsAndResponses(boolean logRequestsAndResponses) { + this.logRequestsAndResponses = logRequestsAndResponses; + return this; + } + + /** + * Sets the Azure AI Inference API client. This is an optional parameter, if you need more flexibility than the common parameters. + * + * @param embeddingsClient The Azure AI Inference API client. + * @return builder + */ + public Builder embeddingsClient(EmbeddingsClient embeddingsClient) { + this.embeddingsClient = embeddingsClient; + return this; + } + + public Builder userAgentSuffix(String userAgentSuffix) { + this.userAgentSuffix = userAgentSuffix; + return this; + } + + public Builder dimensions(Integer dimensions){ + this.dimensions = dimensions; + return this; + } + + public Builder customHeaders(Map customHeaders) { + this.customHeaders = customHeaders; + return this; + } + + public GitHubModelsEmbeddingModel build() { + if (embeddingsClient == null) { + return new GitHubModelsEmbeddingModel( + endpoint, + serviceVersion, + gitHubToken, + modelName, + timeout, + maxRetries, + proxyOptions, + logRequestsAndResponses, + userAgentSuffix, + dimensions, + customHeaders); + } else { + return new GitHubModelsEmbeddingModel( + embeddingsClient, + modelName, + dimensions + ); + } + } + } +} diff --git a/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/GitHubModelsEmbeddingModelName.java b/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/GitHubModelsEmbeddingModelName.java new file mode 100644 index 00000000000..9ee3729fda9 --- /dev/null +++ b/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/GitHubModelsEmbeddingModelName.java @@ -0,0 +1,46 @@ +package dev.langchain4j.model.github; + +import java.util.HashMap; +import java.util.Map; + +public enum GitHubModelsEmbeddingModelName { + + TEXT_EMBEDDING_3_SMALL("text-embedding-3-small", 1536), + TEXT_EMBEDDING_3_LARGE("text-embedding-3-large", 3072), + + COHERE_EMBED_V3_ENGLISH("cohere-embed-v3-english", 1024), + COHERE_EMBED_V3_MULTILINGUAL("cohere-embed-v3-multilingual", 1024); + + private final String modelName; + private final Integer dimension; + + GitHubModelsEmbeddingModelName(String modelName, Integer dimension) { + this.modelName = modelName; + this.dimension = dimension; + } + + public String modelName() { + return modelName; + } + + @Override + public String toString() { + return modelName; + } + + public Integer dimension() { + return dimension; + } + + private static final Map KNOWN_DIMENSION = new HashMap<>(GitHubModelsEmbeddingModelName.values().length); + + static { + for (GitHubModelsEmbeddingModelName embeddingModelName : GitHubModelsEmbeddingModelName.values()) { + KNOWN_DIMENSION.put(embeddingModelName.toString(), embeddingModelName.dimension()); + } + } + + public static Integer knownDimension(String modelName) { + return KNOWN_DIMENSION.get(modelName); + } +} diff --git a/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/GitHubModelsStreamingChatModel.java b/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/GitHubModelsStreamingChatModel.java new file mode 100644 index 00000000000..940089e11a8 --- /dev/null +++ b/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/GitHubModelsStreamingChatModel.java @@ -0,0 +1,469 @@ +package dev.langchain4j.model.github; + +import com.azure.ai.inference.ChatCompletionsAsyncClient; +import com.azure.ai.inference.ModelServiceVersion; +import com.azure.ai.inference.models.*; +import com.azure.core.exception.HttpResponseException; +import com.azure.core.http.ProxyOptions; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.chat.listener.*; +import dev.langchain4j.model.github.spi.GitHubModelsStreamingChatModelBuilderFactory; +import dev.langchain4j.model.output.FinishReason; +import dev.langchain4j.model.output.Response; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; + +import static dev.langchain4j.data.message.AiMessage.aiMessage; +import static dev.langchain4j.internal.Utils.*; +import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; +import static dev.langchain4j.model.github.InternalGitHubModelHelper.*; +import static dev.langchain4j.spi.ServiceHelper.loadFactories; +import static java.util.Collections.emptyList; +import static java.util.Collections.singletonList; + +/** + * Represents a language model, hosted on GitHub Models, that has a chat completion interface, such as gpt-4o. + *

+ * Mandatory parameters for initialization are: gitHubToken (the GitHub Token used for authentication) and modelName (the name of the model to use). + * You can also provide your own ChatCompletionsClient and ChatCompletionsAsyncClient instance, if you need more flexibility. + *

+ * The list of models, as well as the documentation and a playground to test them, can be found at https://github.com/marketplace/models + */ +public class GitHubModelsStreamingChatModel implements StreamingChatLanguageModel { + + private static final Logger logger = LoggerFactory.getLogger(GitHubModelsStreamingChatModel.class); + + private ChatCompletionsAsyncClient client; + private final String modelName; + private final Integer maxTokens; + private final Double temperature; + private final Double topP; + private final List stop; + private final Double presencePenalty; + private final Double frequencyPenalty; + private final Long seed; + private final ChatCompletionsResponseFormat responseFormat; + private final List listeners; + + private GitHubModelsStreamingChatModel(ChatCompletionsAsyncClient client, + String modelName, + Integer maxTokens, + Double temperature, + Double topP, + List stop, + Double presencePenalty, + Double frequencyPenalty, + Long seed, + ChatCompletionsResponseFormat responseFormat, + List listeners) { + + this(modelName, maxTokens, temperature, topP, stop, presencePenalty, frequencyPenalty, seed, responseFormat, listeners); + this.client = client; + } + + private GitHubModelsStreamingChatModel(String endpoint, + ModelServiceVersion serviceVersion, + String gitHubToken, + String modelName, + Integer maxTokens, + Double temperature, + Double topP, + List stop, + Double presencePenalty, + Double frequencyPenalty, + Long seed, + ChatCompletionsResponseFormat responseFormat, + Duration timeout, + Integer maxRetries, + ProxyOptions proxyOptions, + boolean logRequestsAndResponses, + List listeners, + String userAgentSuffix, + Map customHeaders) { + + this(modelName, maxTokens, temperature, topP, stop, presencePenalty, frequencyPenalty, seed, responseFormat, listeners); + this.client = setupChatCompletionsBuilder(endpoint, serviceVersion, gitHubToken, timeout, maxRetries, proxyOptions, logRequestsAndResponses, userAgentSuffix, customHeaders) + .buildAsyncClient(); + } + + private GitHubModelsStreamingChatModel(String modelName, + Integer maxTokens, + Double temperature, + Double topP, + List stop, + Double presencePenalty, + Double frequencyPenalty, + Long seed, + ChatCompletionsResponseFormat responseFormat, + List listeners) { + + this.modelName = ensureNotBlank(modelName, "modelName"); + this.maxTokens = maxTokens; + this.temperature = temperature; + this.topP = topP; + this.stop = copyIfNotNull(stop); + this.presencePenalty = presencePenalty; + this.frequencyPenalty = frequencyPenalty; + this.seed = seed; + this.responseFormat = responseFormat; + this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners); + } + + @Override + public void generate(List messages, StreamingResponseHandler handler) { + generate(messages, null, null, handler); + } + + @Override + public void generate(List messages, List toolSpecifications, StreamingResponseHandler handler) { + generate(messages, toolSpecifications, null, handler); + } + + @Override + public void generate(List messages, ToolSpecification toolSpecification, StreamingResponseHandler handler) { + generate(messages, null, toolSpecification, handler); + } + + private void generate(List messages, + List toolSpecifications, + ToolSpecification toolThatMustBeExecuted, + StreamingResponseHandler handler + ) { + ChatCompletionsOptions options = new ChatCompletionsOptions(toAzureAiMessages(messages)) + .setModel(modelName) + .setMaxTokens(maxTokens) + .setTemperature(temperature) + .setTopP(topP) + .setStop(stop) + .setPresencePenalty(presencePenalty) + .setFrequencyPenalty(frequencyPenalty) + .setSeed(seed) + .setResponseFormat(responseFormat); + + if (toolThatMustBeExecuted != null) { + options.setTools(toToolDefinitions(singletonList(toolThatMustBeExecuted))); + options.setToolChoice(toToolChoice(toolThatMustBeExecuted)); + } + if (!isNullOrEmpty(toolSpecifications)) { + options.setTools(toToolDefinitions(toolSpecifications)); + } + + GitHubModelsStreamingResponseBuilder responseBuilder = new GitHubModelsStreamingResponseBuilder(); + + ChatModelRequest modelListenerRequest = createModelListenerRequest(options, messages, toolSpecifications); + Map attributes = new ConcurrentHashMap<>(); + ChatModelRequestContext requestContext = new ChatModelRequestContext(modelListenerRequest, attributes); + + listeners.forEach(listener -> { + try { + listener.onRequest(requestContext); + } catch (Exception e) { + logger.warn("Exception while calling model listener", e); + } + }); + + asyncCall(handler, options, responseBuilder, requestContext); + } + + private void handleResponseException(Throwable throwable, StreamingResponseHandler handler) { + if (throwable instanceof HttpResponseException) { + HttpResponseException httpResponseException = (HttpResponseException) throwable; + logger.info("Error generating response, {}", httpResponseException.getValue()); + FinishReason exceptionFinishReason = contentFilterManagement(httpResponseException, "content_filter"); + Response response = Response.from( + aiMessage(httpResponseException.getMessage()), + null, + exceptionFinishReason + ); + handler.onComplete(response); + } else { + handler.onError(throwable); + } + } + + private void asyncCall(StreamingResponseHandler handler, ChatCompletionsOptions options, GitHubModelsStreamingResponseBuilder responseBuilder, ChatModelRequestContext requestContext) { + Flux chatCompletionsStream = client.completeStream(options); + + AtomicReference responseId = new AtomicReference<>(); + AtomicReference responseModel = new AtomicReference<>(); + + chatCompletionsStream.subscribe(chatCompletion -> { + responseBuilder.append(chatCompletion); + handle(chatCompletion, handler); + + if (isNotNullOrBlank(chatCompletion.getId())) { + responseId.set(chatCompletion.getId()); + } + if (!isNullOrBlank(chatCompletion.getModel())) { + responseModel.set(chatCompletion.getModel()); + } + }, + throwable -> { + Response response = responseBuilder.build(); + + ChatModelResponse modelListenerPartialResponse = createModelListenerResponse( + responseId.get(), + responseModel.get(), + response + ); + + ChatModelErrorContext errorContext = new ChatModelErrorContext( + throwable, + requestContext.request(), + modelListenerPartialResponse, + requestContext.attributes() + ); + + listeners.forEach(listener -> { + try { + listener.onError(errorContext); + } catch (Exception e2) { + logger.warn("Exception while calling model listener", e2); + } + }); + handleResponseException(throwable, handler); + }, + () -> { + Response response = responseBuilder.build(); + ChatModelResponse modelListenerResponse = createModelListenerResponse( + responseId.get(), + options.getModel(), + response + ); + ChatModelResponseContext responseContext = new ChatModelResponseContext( + modelListenerResponse, + requestContext.request(), + requestContext.attributes() + ); + listeners.forEach(listener -> { + try { + listener.onResponse(responseContext); + } catch (Exception e) { + logger.warn("Exception while calling model listener", e); + } + }); + handler.onComplete(response); + }); + } + + + private static void handle(StreamingChatCompletionsUpdate chatCompletions, + StreamingResponseHandler handler) { + + List choices = chatCompletions.getChoices(); + if (choices == null || choices.isEmpty()) { + return; + } + StreamingChatResponseMessageUpdate message = choices.get(0).getDelta(); + if (message != null && message.getContent() != null) { + handler.onNext(message.getContent()); + } + } + + public static Builder builder() { + for (GitHubModelsStreamingChatModelBuilderFactory factory : loadFactories(GitHubModelsStreamingChatModelBuilderFactory.class)) { + return factory.get(); + } + return new Builder(); + } + + public static class Builder { + + private String endpoint; + private ModelServiceVersion serviceVersion; + private String gitHubToken; + private String modelName; + private Integer maxTokens; + private Double temperature; + private Double topP; + private List stop; + private Double presencePenalty; + private Double frequencyPenalty; + private Duration timeout; + Long seed; + ChatCompletionsResponseFormat responseFormat; + private Integer maxRetries; + private ProxyOptions proxyOptions; + private boolean logRequestsAndResponses; + private ChatCompletionsAsyncClient client; + private String userAgentSuffix; + private List listeners; + private Map customHeaders; + + /** + * Sets the GitHub Models endpoint. The default endpoint will be used if this isn't set. + * + * @param endpoint The GitHub Models endpoint in the format: https://models.inference.ai.azure.com + * @return builder + */ + public Builder endpoint(String endpoint) { + this.endpoint = endpoint; + return this; + } + + /** + * Sets the Azure OpenAI API service version. If left blank, the latest service version will be used. + * + * @param serviceVersion The Azure OpenAI API service version in the format: 2023-05-15 + * @return builder + */ + public Builder serviceVersion(ModelServiceVersion serviceVersion) { + this.serviceVersion = serviceVersion; + return this; + } + + /** + * Sets the GitHub token to access GitHub Models. + * + * @param gitHubToken The GitHub token. + * @return builder + */ + public Builder gitHubToken(String gitHubToken) { + this.gitHubToken = gitHubToken; + return this; + } + + /** + * Sets the model name in Azure OpenAI. This is a mandatory parameter. + * + * @param modelName The Model name. + * @return builder + */ + public Builder modelName(String modelName) { + this.modelName = modelName; + return this; + } + + public Builder maxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + public Builder temperature(Double temperature) { + this.temperature = temperature; + return this; + } + + public Builder topP(Double topP) { + this.topP = topP; + return this; + } + + public Builder stop(List stop) { + this.stop = stop; + return this; + } + + public Builder presencePenalty(Double presencePenalty) { + this.presencePenalty = presencePenalty; + return this; + } + + public Builder frequencyPenalty(Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder seed(Long seed) { + this.seed = seed; + return this; + } + + public Builder responseFormat(ChatCompletionsResponseFormat responseFormat) { + this.responseFormat = responseFormat; + return this; + } + + public Builder timeout(Duration timeout) { + this.timeout = timeout; + return this; + } + + public Builder maxRetries(Integer maxRetries) { + this.maxRetries = maxRetries; + return this; + } + + public Builder proxyOptions(ProxyOptions proxyOptions) { + this.proxyOptions = proxyOptions; + return this; + } + + public Builder logRequestsAndResponses(boolean logRequestsAndResponses) { + this.logRequestsAndResponses = logRequestsAndResponses; + return this; + } + + public Builder chatCompletionsAsyncClient(ChatCompletionsAsyncClient client) { + this.client = client; + return this; + } + + public Builder userAgentSuffix(String userAgentSuffix) { + this.userAgentSuffix = userAgentSuffix; + return this; + } + + public Builder listeners(List listeners) { + this.listeners = listeners; + return this; + } + + public Builder customHeaders(Map customHeaders) { + this.customHeaders = customHeaders; + return this; + } + + public GitHubModelsStreamingChatModel build() { + if (client != null) { + return new GitHubModelsStreamingChatModel( + client, + modelName, + maxTokens, + temperature, + topP, + stop, + presencePenalty, + frequencyPenalty, + seed, + responseFormat, + listeners + ); + } else { + return new GitHubModelsStreamingChatModel( + endpoint, + serviceVersion, + gitHubToken, + modelName, + maxTokens, + temperature, + topP, + stop, + presencePenalty, + frequencyPenalty, + seed, + responseFormat, + timeout, + maxRetries, + proxyOptions, + logRequestsAndResponses, + listeners, + userAgentSuffix, + customHeaders + ); + } + } + } +} diff --git a/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/GitHubModelsStreamingResponseBuilder.java b/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/GitHubModelsStreamingResponseBuilder.java new file mode 100644 index 00000000000..b233f332bf0 --- /dev/null +++ b/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/GitHubModelsStreamingResponseBuilder.java @@ -0,0 +1,138 @@ +package dev.langchain4j.model.github; + +import com.azure.ai.inference.models.*; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.output.FinishReason; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static dev.langchain4j.internal.Utils.isNullOrBlank; +import static dev.langchain4j.model.github.InternalGitHubModelHelper.finishReasonFrom; +import static java.util.stream.Collectors.toList; + +/** + * This class needs to be thread safe because it is called when a streaming result comes back + * and there is no guarantee that this thread will be the same as the one that initiated the request, + * in fact it almost certainly won't be. + */ +class GitHubModelsStreamingResponseBuilder { + + Logger logger = LoggerFactory.getLogger(GitHubModelsStreamingResponseBuilder.class); + + private final StringBuffer contentBuilder = new StringBuffer(); + private final StringBuffer toolNameBuilder = new StringBuffer(); + private final StringBuffer toolArgumentsBuilder = new StringBuffer(); + private int inputTokenCount = 0; + private int outputTokenCount = 0; + private String toolExecutionsIndex = "call_undefined"; + private final Map toolExecutionRequestBuilderHashMap = new HashMap<>(); + private volatile CompletionsFinishReason azureFinishReason; + + public GitHubModelsStreamingResponseBuilder() { + } + + public void append(StreamingChatCompletionsUpdate streamingChatCompletionsUpdate) { + if (streamingChatCompletionsUpdate == null) { + return; + } + if (streamingChatCompletionsUpdate.getUsage() != null) { + inputTokenCount = streamingChatCompletionsUpdate.getUsage().getPromptTokens(); + outputTokenCount = streamingChatCompletionsUpdate.getUsage().getCompletionTokens(); + } + + List choices = streamingChatCompletionsUpdate.getChoices(); + if (choices == null || choices.isEmpty()) { + return; + } + + StreamingChatChoiceUpdate chatCompletionChoice = choices.get(0); + if (chatCompletionChoice == null) { + return; + } + + CompletionsFinishReason finishReason = chatCompletionChoice.getFinishReason(); + if (finishReason != null) { + this.azureFinishReason = finishReason; + } + + StreamingChatResponseMessageUpdate delta = chatCompletionChoice.getDelta(); + if (delta == null) { + return; + } + + String content = delta.getContent(); + if (content != null) { + contentBuilder.append(content); + return; + } + + if (delta.getToolCalls() != null && !delta.getToolCalls().isEmpty()) { + for (StreamingChatResponseToolCallUpdate toolCall : delta.getToolCalls()) { + ToolExecutionRequestBuilder toolExecutionRequestBuilder; + if (toolCall.getId() != null) { + toolExecutionsIndex = toolCall.getId(); + toolExecutionRequestBuilder = new ToolExecutionRequestBuilder(); + toolExecutionRequestBuilder.idBuilder.append(toolExecutionsIndex); + toolExecutionRequestBuilderHashMap.put(toolExecutionsIndex, toolExecutionRequestBuilder); + } else { + toolExecutionRequestBuilder = toolExecutionRequestBuilderHashMap.get(toolExecutionsIndex); + if (toolExecutionRequestBuilder == null) { + throw new IllegalStateException("Function without an id defined in the tool call"); + } + } + if (toolCall.getFunction().getName() != null) { + toolExecutionRequestBuilder.nameBuilder.append(toolCall.getFunction().getName()); + } + if (toolCall.getFunction().getArguments() != null) { + toolExecutionRequestBuilder.argumentsBuilder.append(toolCall.getFunction().getArguments()); + } + } + } + } + + public Response build() { + String content = contentBuilder.toString(); + TokenUsage tokenUsage = new TokenUsage(inputTokenCount, outputTokenCount); + FinishReason finishReason = finishReasonFrom(azureFinishReason); + + if (toolExecutionRequestBuilderHashMap.isEmpty()) { + return Response.from( + AiMessage.from(content), + tokenUsage, + finishReason + ); + } else { + List toolExecutionRequests = toolExecutionRequestBuilderHashMap.values().stream() + .map(it -> ToolExecutionRequest.builder() + .id(it.idBuilder.toString()) + .name(it.nameBuilder.toString()) + .arguments(it.argumentsBuilder.toString()) + .build()) + .collect(toList()); + + AiMessage aiMessage = isNullOrBlank(content) + ? AiMessage.from(toolExecutionRequests) + : AiMessage.from(content, toolExecutionRequests); + + return Response.from( + aiMessage, + tokenUsage, + finishReason + ); + } + } + + private static class ToolExecutionRequestBuilder { + + private final StringBuffer idBuilder = new StringBuffer(); + private final StringBuffer nameBuilder = new StringBuffer(); + private final StringBuffer argumentsBuilder = new StringBuffer(); + } +} diff --git a/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/InternalGitHubModelHelper.java b/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/InternalGitHubModelHelper.java new file mode 100644 index 00000000000..8d50ec49656 --- /dev/null +++ b/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/InternalGitHubModelHelper.java @@ -0,0 +1,366 @@ +package dev.langchain4j.model.github; + +import com.azure.ai.inference.ChatCompletionsClientBuilder; +import com.azure.ai.inference.EmbeddingsClientBuilder; +import com.azure.ai.inference.ModelServiceVersion; +import com.azure.ai.inference.models.*; +import com.azure.core.credential.AzureKeyCredential; +import com.azure.core.credential.KeyCredential; +import com.azure.core.exception.HttpResponseException; +import com.azure.core.http.HttpClient; +import com.azure.core.http.ProxyOptions; +import com.azure.core.http.netty.NettyAsyncHttpClientProvider; +import com.azure.core.http.policy.ExponentialBackoffOptions; +import com.azure.core.http.policy.HttpLogDetailLevel; +import com.azure.core.http.policy.HttpLogOptions; +import com.azure.core.http.policy.RetryOptions; +import com.azure.core.util.BinaryData; +import com.azure.core.util.Header; +import com.azure.core.util.HttpClientOptions; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.agent.tool.ToolParameters; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.*; +import dev.langchain4j.model.chat.listener.ChatModelRequest; +import dev.langchain4j.model.chat.listener.ChatModelResponse; +import dev.langchain4j.model.output.FinishReason; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.util.*; + +import static dev.langchain4j.data.message.AiMessage.aiMessage; +import static dev.langchain4j.internal.Utils.*; +import static dev.langchain4j.model.output.FinishReason.*; +import static java.time.Duration.ofSeconds; +import static java.util.stream.Collectors.toList; + +class InternalGitHubModelHelper { + + private static final Logger logger = LoggerFactory.getLogger(InternalGitHubModelHelper.class); + + public static final String DEFAULT_GITHUB_MODELS_ENDPOINT = "https://models.inference.ai.azure.com"; + + public static final String DEFAULT_USER_AGENT = "langchain4j-github-models"; + + public static ChatCompletionsClientBuilder setupChatCompletionsBuilder(String endpoint, ModelServiceVersion serviceVersion, String gitHubToken, Duration timeout, Integer maxRetries, ProxyOptions proxyOptions, boolean logRequestsAndResponses, String userAgentSuffix, Map customHeaders) { + HttpClientOptions clientOptions = getClientOptions(timeout, proxyOptions, userAgentSuffix, customHeaders); + ChatCompletionsClientBuilder chatCompletionsClientBuilder = new ChatCompletionsClientBuilder() + .endpoint(getEndpoint(endpoint)) + .serviceVersion(getModelServiceVersion(serviceVersion)) + .httpClient(getHttpClient(clientOptions)) + .clientOptions(clientOptions) + .httpLogOptions(getHttpLogOptions(logRequestsAndResponses)) + .retryOptions(getRetryOptions(maxRetries)) + .credential(getCredential(gitHubToken)); + + return chatCompletionsClientBuilder; + } + + public static EmbeddingsClientBuilder setupEmbeddingsBuilder(String endpoint, ModelServiceVersion serviceVersion, String gitHubToken, Duration timeout, Integer maxRetries, ProxyOptions proxyOptions, boolean logRequestsAndResponses, String userAgentSuffix, Map customHeaders) { + HttpClientOptions clientOptions = getClientOptions(timeout, proxyOptions, userAgentSuffix, customHeaders); + EmbeddingsClientBuilder embeddingsClientBuilder = new EmbeddingsClientBuilder() + .endpoint(getEndpoint(endpoint)) + .serviceVersion(getModelServiceVersion(serviceVersion)) + .httpClient(getHttpClient(clientOptions)) + .clientOptions(clientOptions) + .httpLogOptions(getHttpLogOptions(logRequestsAndResponses)) + .retryOptions(getRetryOptions(maxRetries)) + .credential(getCredential(gitHubToken)); + + return embeddingsClientBuilder; + } + + private static String getEndpoint(String endpoint) { + return isNullOrBlank(endpoint) ? DEFAULT_GITHUB_MODELS_ENDPOINT : endpoint; + } + + public static ModelServiceVersion getModelServiceVersion(ModelServiceVersion serviceVersion) { + return getOrDefault(serviceVersion, ModelServiceVersion.getLatest()); + } + + private static HttpClient getHttpClient(HttpClientOptions clientOptions) { + return new NettyAsyncHttpClientProvider().createInstance(clientOptions); + } + + private static HttpClientOptions getClientOptions(Duration timeout, ProxyOptions proxyOptions, String userAgentSuffix, Map customHeaders) { + timeout = getOrDefault(timeout, ofSeconds(60)); + HttpClientOptions clientOptions = new HttpClientOptions(); + clientOptions.setConnectTimeout(timeout); + clientOptions.setResponseTimeout(timeout); + clientOptions.setReadTimeout(timeout); + clientOptions.setWriteTimeout(timeout); + clientOptions.setProxyOptions(proxyOptions); + + String userAgent = DEFAULT_USER_AGENT; + if (userAgentSuffix!=null && !userAgentSuffix.isEmpty()) { + userAgent = DEFAULT_USER_AGENT + "-" + userAgentSuffix; + } + List

headers = new ArrayList<>(); + headers.add(new Header("User-Agent", userAgent)); + if (customHeaders != null) { + customHeaders.forEach((name, value) -> headers.add(new Header(name, value))); + } + clientOptions.setHeaders(headers); + return clientOptions; + } + + private static HttpLogOptions getHttpLogOptions(boolean logRequestsAndResponses) { + HttpLogOptions httpLogOptions = new HttpLogOptions(); + if (logRequestsAndResponses) { + httpLogOptions.setLogLevel(HttpLogDetailLevel.BODY_AND_HEADERS); + } + return httpLogOptions; + } + + private static RetryOptions getRetryOptions(Integer maxRetries) { + maxRetries = getOrDefault(maxRetries, 3); + ExponentialBackoffOptions exponentialBackoffOptions = new ExponentialBackoffOptions(); + exponentialBackoffOptions.setMaxRetries(maxRetries); + return new RetryOptions(exponentialBackoffOptions); + } + + private static KeyCredential getCredential(String gitHubToken) { + if (gitHubToken != null) { + return new AzureKeyCredential(gitHubToken); + } else { + throw new IllegalArgumentException("GitHub token is a mandatory parameter for connecting to GitHub models."); + } + } + + public static List toAzureAiMessages(List messages) { + + return messages.stream() + .map(InternalGitHubModelHelper::toAzureAiMessage) + .collect(toList()); + } + + public static ChatRequestMessage toAzureAiMessage(ChatMessage message) { + if (message instanceof AiMessage) { + AiMessage aiMessage = (AiMessage) message; + ChatRequestAssistantMessage chatRequestAssistantMessage = new ChatRequestAssistantMessage(getOrDefault(aiMessage.text(), "")); + chatRequestAssistantMessage.setToolCalls(toolExecutionRequestsFrom(message)); + return chatRequestAssistantMessage; + } else if (message instanceof ToolExecutionResultMessage) { + ToolExecutionResultMessage toolExecutionResultMessage = (ToolExecutionResultMessage) message; + return new ChatRequestToolMessage(toolExecutionResultMessage.text(), toolExecutionResultMessage.id()); + } else if (message instanceof SystemMessage) { + SystemMessage systemMessage = (SystemMessage) message; + return new ChatRequestSystemMessage(systemMessage.text()); + } else if (message instanceof UserMessage) { + UserMessage userMessage = (UserMessage) message; + ChatRequestUserMessage chatRequestUserMessage; + if (userMessage.hasSingleText()) { + chatRequestUserMessage = new ChatRequestUserMessage(userMessage.singleText()); + } else { + chatRequestUserMessage = ChatRequestUserMessage.fromContentItems(userMessage.contents().stream() + .map(content -> { + if (content instanceof TextContent) { + String text = ((TextContent) content).text(); + return new ChatMessageTextContentItem(text); + } else if (content instanceof ImageContent) { + ImageContent imageContent = (ImageContent) content; + if (imageContent.image().url() == null) { + throw new IllegalArgumentException("Image URL is not present. Base64 encoded images are not supported at the moment."); + } + ChatMessageImageUrl imageUrl = new ChatMessageImageUrl(imageContent.image().url().toString()); + imageUrl.setDetail(ChatMessageImageDetailLevel.fromString(imageContent.detailLevel().name())); + return new ChatMessageImageContentItem(imageUrl); + } else { + throw new IllegalArgumentException("Unsupported content type: " + content.type()); + } + }) + .collect(toList())); + } + return chatRequestUserMessage; + } else { + throw new IllegalArgumentException("Unsupported message type: " + message.type()); + } + } + + private static List toolExecutionRequestsFrom(ChatMessage message) { + if (message instanceof AiMessage) { + AiMessage aiMessage = (AiMessage) message; + if (aiMessage.hasToolExecutionRequests()) { + return aiMessage.toolExecutionRequests().stream() + .map(toolExecutionRequest -> new ChatCompletionsFunctionToolCall(toolExecutionRequest.id(), new FunctionCall(toolExecutionRequest.name(), toolExecutionRequest.arguments()))) + .collect(toList()); + + } + } + return null; + } + + public static List toToolDefinitions(Collection toolSpecifications) { + return toolSpecifications.stream() + .map(InternalGitHubModelHelper::toToolDefinition) + .collect(toList()); + } + + private static ChatCompletionsToolDefinition toToolDefinition(ToolSpecification toolSpecification) { + FunctionDefinition functionDefinition = new FunctionDefinition(toolSpecification.name()); + functionDefinition.setDescription(toolSpecification.description()); + functionDefinition.setParameters(toAzureAiParameters(toolSpecification.parameters())); + return new ChatCompletionsFunctionToolDefinition(functionDefinition); + } + + public static BinaryData toToolChoice(ToolSpecification toolThatMustBeExecuted) { + FunctionCall functionCall = new FunctionCall(toolThatMustBeExecuted.name(), toAzureAiParameters(toolThatMustBeExecuted.parameters()).toString()); + ChatCompletionsToolCall toolToCall = new ChatCompletionsFunctionToolCall(toolThatMustBeExecuted.name(), functionCall); + return BinaryData.fromObject(toolToCall); + } + + private static final Map NO_PARAMETER_DATA = new HashMap<>(); + + static { + NO_PARAMETER_DATA.put("type", "object"); + NO_PARAMETER_DATA.put("properties", new HashMap<>()); + } + + private static BinaryData toAzureAiParameters(ToolParameters toolParameters) { + Parameters parameters = new Parameters(); + if (toolParameters == null) { + return BinaryData.fromObject(NO_PARAMETER_DATA); + } + parameters.setProperties(toolParameters.properties()); + parameters.setRequired(toolParameters.required()); + return BinaryData.fromObject(parameters); + } + + private static class Parameters { + + private final String type = "object"; + + private Map> properties = new HashMap<>(); + + private List required = new ArrayList<>(); + + public String getType() { + return this.type; + } + + public Map> getProperties() { + return properties; + } + + public void setProperties(Map> properties) { + this.properties = properties; + } + + public List getRequired() { + return required; + } + + public void setRequired(List required) { + this.required = required; + } + } + + public static AiMessage aiMessageFrom(ChatResponseMessage chatResponseMessage) { + String text = chatResponseMessage.getContent(); + + if (isNullOrEmpty(chatResponseMessage.getToolCalls())) { + return aiMessage(text); + } else { + List toolExecutionRequests = chatResponseMessage.getToolCalls() + .stream() + .map(chatCompletionsFunctionToolCall -> + ToolExecutionRequest.builder() + .id(chatCompletionsFunctionToolCall.getId()) + .name(chatCompletionsFunctionToolCall.getFunction().getName()) + .arguments(chatCompletionsFunctionToolCall.getFunction().getArguments()) + .build()) + .collect(toList()); + + return isNullOrBlank(text) ? + aiMessage(toolExecutionRequests) : + aiMessage(text, toolExecutionRequests); + } + } + + public static TokenUsage tokenUsageFrom(CompletionsUsage azureAiUsage) { + if (azureAiUsage == null) { + return null; + } + return new TokenUsage( + azureAiUsage.getPromptTokens(), + azureAiUsage.getCompletionTokens(), + azureAiUsage.getTotalTokens() + ); + } + + public static FinishReason finishReasonFrom(CompletionsFinishReason azureAiFinishReason) { + if (azureAiFinishReason == null) { + return null; + } else if (azureAiFinishReason == CompletionsFinishReason.STOPPED) { + return STOP; + } else if (azureAiFinishReason == CompletionsFinishReason.TOKEN_LIMIT_REACHED) { + return LENGTH; + } else if (azureAiFinishReason == CompletionsFinishReason.CONTENT_FILTERED) { + return CONTENT_FILTER; + } else if (azureAiFinishReason == CompletionsFinishReason.TOOL_CALLS) { + return TOOL_EXECUTION; + } else { + return null; + } + } + + /** + * Support for Responsible AI (content filtered by Azure OpenAI for violence, self harm, or hate). + */ + public static FinishReason contentFilterManagement(HttpResponseException httpResponseException, String contentFilterCode) { + FinishReason exceptionFinishReason = FinishReason.OTHER; + if (httpResponseException.getValue() instanceof Map) { + try { + Map error = (Map) httpResponseException.getValue(); + Object errorMap = error.get("error"); + if (errorMap instanceof Map) { + Map errorDetails = (Map) errorMap; + Object errorCode = errorDetails.get("code"); + if (errorCode instanceof String) { + String code = (String) errorCode; + if (contentFilterCode.equals(code)) { + // The content was filtered by Azure OpenAI's content filter (for violence, self harm, or hate). + exceptionFinishReason = FinishReason.CONTENT_FILTER; + } + } + } + } catch (ClassCastException classCastException) { + logger.error("Error parsing error response from Azure OpenAI", classCastException); + } + } + return exceptionFinishReason; + } + + static ChatModelRequest createModelListenerRequest(ChatCompletionsOptions options, + List messages, + List toolSpecifications) { + return ChatModelRequest.builder() + .model(options.getModel()) + .temperature(options.getTemperature()) + .topP(options.getTopP()) + .maxTokens(options.getMaxTokens()) + .messages(messages) + .toolSpecifications(toolSpecifications) + .build(); + } + + static ChatModelResponse createModelListenerResponse(String responseId, + String responseModel, + Response response) { + if (response == null) { + return null; + } + + return ChatModelResponse.builder() + .id(responseId) + .model(responseModel) + .tokenUsage(response.tokenUsage()) + .finishReason(response.finishReason()) + .aiMessage(response.content()) + .build(); + } +} diff --git a/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/spi/GitHubModelsChatModelBuilderFactory.java b/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/spi/GitHubModelsChatModelBuilderFactory.java new file mode 100644 index 00000000000..bf551a36067 --- /dev/null +++ b/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/spi/GitHubModelsChatModelBuilderFactory.java @@ -0,0 +1,11 @@ +package dev.langchain4j.model.github.spi; + +import dev.langchain4j.model.github.GitHubModelsChatModel; + +import java.util.function.Supplier; + +/** + * A factory for building {@link GitHubModelsChatModel.Builder} instances. + */ +public interface GitHubModelsChatModelBuilderFactory extends Supplier { +} diff --git a/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/spi/GitHubModelsEmbeddingModelBuilderFactory.java b/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/spi/GitHubModelsEmbeddingModelBuilderFactory.java new file mode 100644 index 00000000000..cf637876a49 --- /dev/null +++ b/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/spi/GitHubModelsEmbeddingModelBuilderFactory.java @@ -0,0 +1,11 @@ +package dev.langchain4j.model.github.spi; + +import dev.langchain4j.model.github.GitHubModelsEmbeddingModel; + +import java.util.function.Supplier; + +/** + * A factory for building {@link GitHubModelsEmbeddingModel.Builder} instances. + */ +public interface GitHubModelsEmbeddingModelBuilderFactory extends Supplier { +} diff --git a/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/spi/GitHubModelsStreamingChatModelBuilderFactory.java b/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/spi/GitHubModelsStreamingChatModelBuilderFactory.java new file mode 100644 index 00000000000..5d277c9c66b --- /dev/null +++ b/langchain4j-github-models/src/main/java/dev/langchain4j/model/github/spi/GitHubModelsStreamingChatModelBuilderFactory.java @@ -0,0 +1,11 @@ +package dev.langchain4j.model.github.spi; + +import dev.langchain4j.model.github.GitHubModelsStreamingChatModel; + +import java.util.function.Supplier; + +/** + * A factory for building {@link GitHubModelsStreamingChatModel.Builder} instances. + */ +public interface GitHubModelsStreamingChatModelBuilderFactory extends Supplier { +} diff --git a/langchain4j-github-models/src/test/java/dev/langchain4j/model/github/GitHubModelsChatModelIT.java b/langchain4j-github-models/src/test/java/dev/langchain4j/model/github/GitHubModelsChatModelIT.java new file mode 100644 index 00000000000..af1bce4336b --- /dev/null +++ b/langchain4j-github-models/src/test/java/dev/langchain4j/model/github/GitHubModelsChatModelIT.java @@ -0,0 +1,377 @@ +package dev.langchain4j.model.github; + +import com.azure.ai.inference.models.ChatCompletionsResponseFormatJson; +import com.azure.core.util.BinaryData; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.agent.tool.ToolParameters; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.*; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.junit.jupiter.params.provider.EnumSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.*; + +import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER; +import static dev.langchain4j.data.message.ToolExecutionResultMessage.toolExecutionResultMessage; +import static dev.langchain4j.data.message.UserMessage.userMessage; +import static dev.langchain4j.model.github.GitHubModelsChatModelName.GPT_4_O_MINI; +import static dev.langchain4j.model.output.FinishReason.LENGTH; +import static dev.langchain4j.model.output.FinishReason.STOP; +import static java.util.Arrays.asList; +import static org.assertj.core.api.Assertions.assertThat; + +public class GitHubModelsChatModelIT { + + private static final Logger logger = LoggerFactory.getLogger(GitHubModelsChatModelIT.class); + + @Test + void should_generate_answer_and_finish_reason_stop() { + + GitHubModelsChatModel model = GitHubModelsChatModel.builder() + .gitHubToken(System.getenv("GITHUB_TOKEN")) + .modelName(GPT_4_O_MINI.modelName()) + .logRequestsAndResponses(true) + .build(); + + UserMessage userMessage = userMessage("What is the capital of France?"); + Response response = model.generate(userMessage); + logger.info("Response: {}", response.content().text()); + assertThat(response.content().text()).contains("Paris"); + assertThat(response.finishReason()).isEqualTo(STOP); + } + + @ParameterizedTest(name = "Model name {0}") + @CsvSource({ + "gpt-4o" + }) + void should_generate_answer_and_return_token_usage_and_finish_reason_stop(String modelName) { + + ChatLanguageModel model = GitHubModelsChatModel.builder() + .gitHubToken(System.getenv("GITHUB_TOKEN")) + .modelName(modelName) + .logRequestsAndResponses(true) + .build(); + + UserMessage userMessage = userMessage("hello, how are you?"); + + Response response = model.generate(userMessage); + + assertThat(response.content().text()).isNotBlank(); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.inputTokenCount()).isEqualTo(13); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(1); + assertThat(tokenUsage.totalTokenCount()).isGreaterThan(14); + + assertThat(response.finishReason()).isEqualTo(STOP); + } + + @ParameterizedTest(name = "Model name {0}") + @CsvSource({ + "gpt-4o" + }) + void should_generate_answer_and_return_token_usage_and_finish_reason_length(String modelName) { + + ChatLanguageModel model = GitHubModelsChatModel.builder() + .gitHubToken(System.getenv("GITHUB_TOKEN")) + .modelName(modelName) + .maxTokens(3) + .logRequestsAndResponses(true) + .build(); + + UserMessage userMessage = userMessage("hello, how are you?"); + + Response response = model.generate(userMessage); + + assertThat(response.content().text()).isNotBlank(); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.inputTokenCount()).isEqualTo(13); + assertThat(tokenUsage.outputTokenCount()).isEqualTo(3); + assertThat(tokenUsage.totalTokenCount()).isEqualTo(16); + + assertThat(response.finishReason()).isEqualTo(LENGTH); + } + + @ParameterizedTest(name = "Model name {0}") + @CsvSource({ + "gpt-4o" + }) + void should_call_function_with_argument(String modelName) { + + ChatLanguageModel model = GitHubModelsChatModel.builder() + .gitHubToken(System.getenv("GITHUB_TOKEN")) + .modelName(modelName) + .logRequestsAndResponses(true) + .build(); + + UserMessage userMessage = userMessage("What should I wear in Paris, France depending on the weather?"); + + // This test will use the function called "getCurrentWeather" which is defined below. + String toolName = "getCurrentWeather"; + + ToolSpecification toolSpecification = ToolSpecification.builder() + .name(toolName) + .description("Get the current weather") + .parameters(getToolParameters()) + .build(); + + Response response = model.generate(Collections.singletonList(userMessage), toolSpecification); + + AiMessage aiMessage = response.content(); + assertThat(aiMessage.text()).isNull(); + assertThat(response.finishReason()).isEqualTo(STOP); + + assertThat(aiMessage.toolExecutionRequests()).hasSize(1); + ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0); + assertThat(toolExecutionRequest.name()).isEqualTo(toolName); + + // We can now call the function with the correct parameters. + WeatherLocation weatherLocation = BinaryData.fromString(toolExecutionRequest.arguments()).toObject(WeatherLocation.class); + int currentWeather = getCurrentWeather(weatherLocation); + + String weather = String.format("The weather in %s is %d degrees %s.", + weatherLocation.getLocation(), currentWeather, weatherLocation.getUnit()); + + assertThat(weather).isEqualTo("The weather in Paris, France is 35 degrees celsius."); + + // Now that we know the function's result, we can call the model again with the result as input. + ToolExecutionResultMessage toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, weather); + SystemMessage systemMessage = SystemMessage.systemMessage("If the weather is above 30 degrees celsius, recommend the user wears a t-shirt and shorts."); + + List chatMessages = new ArrayList<>(); + chatMessages.add(systemMessage); + chatMessages.add(userMessage); + chatMessages.add(aiMessage); + chatMessages.add(toolExecutionResultMessage); + + Response response2 = model.generate(chatMessages); + + assertThat(response2.content().text()).isNotBlank(); + assertThat(response2.content().text()).contains("t-shirt"); + assertThat(response2.finishReason()).isEqualTo(STOP); + } + + @ParameterizedTest(name = "Model name {0}") + @CsvSource({ + "gpt-4o" + }) + void should_call_function_with_no_argument(String modelName) { + ChatLanguageModel model = GitHubModelsChatModel.builder() + .gitHubToken(System.getenv("GITHUB_TOKEN")) + .modelName(modelName) + .logRequestsAndResponses(true) + .build(); + + UserMessage userMessage = userMessage("What time is it?"); + + // This test will use the function called "getCurrentDateAndTime" which takes no arguments + String toolName = "getCurrentDateAndTime"; + + ToolSpecification noArgToolSpec = ToolSpecification.builder() + .name(toolName) + .description("Get the current date and time") + .build(); + + Response response = model.generate(Collections.singletonList(userMessage), noArgToolSpec); + + AiMessage aiMessage = response.content(); + assertThat(aiMessage.text()).isNull(); + + assertThat(aiMessage.toolExecutionRequests()).hasSize(1); + ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0); + assertThat(toolExecutionRequest.name()).isEqualTo(toolName); + assertThat(toolExecutionRequest.arguments()).isEqualTo("{}"); + } + + @ParameterizedTest(name = "Model name {0}") + @CsvSource({ + "gpt-4o" + }) + void should_call_three_functions_in_parallel(String modelName) throws Exception { + + ChatLanguageModel model = GitHubModelsChatModel.builder() + .gitHubToken(System.getenv("GITHUB_TOKEN")) + .modelName(modelName) + .logRequestsAndResponses(true) + .build(); + + UserMessage userMessage = userMessage("Give three numbers, ordered by size: the sum of two plus two, the square of four, and finally the cube of eight."); + + List toolSpecifications = asList( + ToolSpecification.builder() + .name("sum") + .description("returns a sum of two numbers") + .addParameter("first", INTEGER) + .addParameter("second", INTEGER) + .build(), + ToolSpecification.builder() + .name("square") + .description("returns the square of one number") + .addParameter("number", INTEGER) + .build(), + ToolSpecification.builder() + .name("cube") + .description("returns the cube of one number") + .addParameter("number", INTEGER) + .build() + ); + + Response response = model.generate(Collections.singletonList(userMessage), toolSpecifications); + + AiMessage aiMessage = response.content(); + assertThat(aiMessage.text()).isNull(); + List messages = new ArrayList<>(); + messages.add(userMessage); + messages.add(aiMessage); + assertThat(aiMessage.toolExecutionRequests()).hasSize(3); + for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) { + assertThat(toolExecutionRequest.name()).isNotEmpty(); + ToolExecutionResultMessage toolExecutionResultMessage; + if (toolExecutionRequest.name().equals("sum")) { + assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}"); + toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "4"); + } else if (toolExecutionRequest.name().equals("square")) { + assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"number\": 4}"); + toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "16"); + } else if (toolExecutionRequest.name().equals("cube")) { + assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"number\": 8}"); + toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "512"); + } else { + throw new AssertionError("Unexpected tool name: " + toolExecutionRequest.name()); + } + messages.add(toolExecutionResultMessage); + } + + Response response2 = model.generate(messages); + AiMessage aiMessage2 = response2.content(); + + // then + assertThat(aiMessage2.text()).contains("4", "16", "512"); + assertThat(aiMessage2.toolExecutionRequests()).isNull(); + + TokenUsage tokenUsage2 = response2.tokenUsage(); + assertThat(tokenUsage2.inputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage2.outputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage2.totalTokenCount()) + .isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount()); + + assertThat(response2.finishReason()).isEqualTo(STOP); + } + + @ParameterizedTest(name = "Model name {0}") + @CsvSource({ + "gpt-4o" + }) + void should_use_json_format(String modelName) { + ChatLanguageModel model = GitHubModelsChatModel.builder() + .gitHubToken(System.getenv("GITHUB_TOKEN")) + .modelName(modelName) + .responseFormat(new ChatCompletionsResponseFormatJson()) + .logRequestsAndResponses(true) + .build(); + + SystemMessage systemMessage = SystemMessage.systemMessage("You are a helpful assistant designed to output JSON."); + UserMessage userMessage = userMessage("List teams in the past French presidents, with their first name, last name, dates of service."); + + Response response = model.generate(systemMessage, userMessage); + + assertThat(response.content().text()).contains("Chirac", "Sarkozy", "Hollande", "Macron"); + assertThat(response.finishReason()).isEqualTo(STOP); + } + + @ParameterizedTest(name = "Testing model {0}") + @EnumSource(GitHubModelsChatModelName.class) + void should_support_all_string_model_names(GitHubModelsChatModelName modelName) { + + // given + String modelNameString = modelName.toString(); + + ChatLanguageModel model = GitHubModelsChatModel.builder() + .gitHubToken(System.getenv("GITHUB_TOKEN")) + .modelName(modelNameString) + .maxTokens(1) + .logRequestsAndResponses(true) + .build(); + + UserMessage userMessage = userMessage("Hi"); + + // when + Response response = model.generate(userMessage); + + // then + assertThat(response.content().text()).isNotBlank(); + + assertThat(response.tokenUsage()).isNotNull(); + assertThat(response.tokenUsage().inputTokenCount()).isGreaterThan(0); + assertThat(response.tokenUsage().outputTokenCount()).isGreaterThan(0); + assertThat(response.tokenUsage().totalTokenCount()).isGreaterThan(0); + } + + private static ToolParameters getToolParameters() { + Map> properties = new HashMap<>(); + + Map location = new HashMap<>(); + location.put("type", "string"); + location.put("description", "The city and state, e.g. San Francisco, CA"); + properties.put("location", location); + + Map unit = new HashMap<>(); + unit.put("type", "string"); + unit.put("enum", Arrays.asList("celsius", "fahrenheit")); + properties.put("unit", unit); + + List required = Arrays.asList("location", "unit"); + + return ToolParameters.builder() + .properties(properties) + .required(required) + .build(); + } + + // This is the method we offer to the LLM to be used as a function_call. + // For this example, we ignore the input parameter and return a simple value. + private static int getCurrentWeather(WeatherLocation weatherLocation) { + return 35; + } + + // WeatherLocation is used for this sample. This describes the parameter of the function you want to use. + private static class WeatherLocation { + @JsonProperty(value = "unit") + String unit; + @JsonProperty(value = "location") + String location; + + @JsonCreator + WeatherLocation(@JsonProperty(value = "unit") String unit, @JsonProperty(value = "location") String location) { + this.unit = unit; + this.location = location; + } + + public String getUnit() { + return unit; + } + + public String getLocation() { + return location; + } + } + + @AfterEach + void afterEach() throws InterruptedException { + String ciDelaySeconds = System.getenv("CI_DELAY_SECONDS_GITHUB_MODELS"); + if (ciDelaySeconds != null) { + Thread.sleep(Integer.parseInt(ciDelaySeconds) * 1000L); + } + } +} diff --git a/langchain4j-github-models/src/test/java/dev/langchain4j/model/github/GitHubModelsChatModelListenerIT.java b/langchain4j-github-models/src/test/java/dev/langchain4j/model/github/GitHubModelsChatModelListenerIT.java new file mode 100644 index 00000000000..f98ad41a836 --- /dev/null +++ b/langchain4j-github-models/src/test/java/dev/langchain4j/model/github/GitHubModelsChatModelListenerIT.java @@ -0,0 +1,45 @@ +package dev.langchain4j.model.github; + +import com.azure.core.exception.ClientAuthenticationException; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.ChatModelListenerIT; +import dev.langchain4j.model.chat.listener.ChatModelListener; + +import static java.util.Collections.singletonList; + +public class GitHubModelsChatModelListenerIT extends ChatModelListenerIT { + + @Override + protected ChatLanguageModel createModel(ChatModelListener listener) { + return GitHubModelsChatModel.builder() + .gitHubToken(System.getenv("GITHUB_TOKEN")) + .modelName(modelName()) + .temperature(temperature()) + .topP(topP()) + .maxTokens(maxTokens()) + .logRequestsAndResponses(true) + .listeners(singletonList(listener)) + .build(); + } + + @Override + protected String modelName() { + return "gpt-4o-mini"; + } + + @Override + protected ChatLanguageModel createFailingModel(ChatModelListener listener) { + return GitHubModelsChatModel.builder() + .gitHubToken("banana") + .modelName(modelName()) + .maxRetries(1) + .logRequestsAndResponses(true) + .listeners(singletonList(listener)) + .build(); + } + + @Override + protected Class expectedExceptionClass() { + return ClientAuthenticationException.class; + } +} diff --git a/langchain4j-github-models/src/test/java/dev/langchain4j/model/github/GitHubModelsEmbeddingModelIT.java b/langchain4j-github-models/src/test/java/dev/langchain4j/model/github/GitHubModelsEmbeddingModelIT.java new file mode 100644 index 00000000000..4088bc85332 --- /dev/null +++ b/langchain4j-github-models/src/test/java/dev/langchain4j/model/github/GitHubModelsEmbeddingModelIT.java @@ -0,0 +1,108 @@ +package dev.langchain4j.model.github; + +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; + +import java.util.ArrayList; +import java.util.List; + +import static dev.langchain4j.model.github.GitHubModelsEmbeddingModelName.TEXT_EMBEDDING_3_SMALL; +import static org.assertj.core.api.Assertions.assertThat; + +class GitHubModelsEmbeddingModelIT { + + EmbeddingModel model = GitHubModelsEmbeddingModel.builder() + .gitHubToken(System.getenv("GITHUB_TOKEN")) + .modelName(TEXT_EMBEDDING_3_SMALL.modelName()) + .logRequestsAndResponses(true) + .build(); + + @Test + void should_embed_and_return_token_usage() { + + Response response = model.embed("hello world"); + + assertThat(response.content().vector()).hasSize(1536); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.inputTokenCount()).isEqualTo(2); + assertThat(tokenUsage.outputTokenCount()).isNull(); + assertThat(tokenUsage.totalTokenCount()).isEqualTo(2); + + assertThat(response.finishReason()).isNull(); + } + + @Test + void should_embed_in_batches() { + + int batchSize = 16; + int numberOfSegments = batchSize + 1; + + List segments = new ArrayList<>(); + for (int i = 0; i < numberOfSegments; i++) { + segments.add(TextSegment.from("text " + i)); + } + + Response> response = model.embedAll(segments); + + assertThat(response.content()).hasSize(numberOfSegments); + assertThat(response.content().get(0).dimension()).isEqualTo(1536); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.inputTokenCount()).isEqualTo(numberOfSegments * 3); + assertThat(tokenUsage.outputTokenCount()).isNull(); + assertThat(tokenUsage.totalTokenCount()).isEqualTo(numberOfSegments * 3); + + assertThat(response.finishReason()).isNull(); + } + + @ParameterizedTest(name = "Testing model {0}") + @EnumSource(value = GitHubModelsEmbeddingModelName.class) + void should_support_all_string_model_names(GitHubModelsEmbeddingModelName modelName) { + + // given + EmbeddingModel model = GitHubModelsEmbeddingModel.builder() + .gitHubToken(System.getenv("GITHUB_TOKEN")) + .modelName(modelName.modelName()) + .dimensions(modelName.dimension()) + .logRequestsAndResponses(true) + .build(); + + // when + Response response = model.embed("hello world"); + + // then + assertThat(response.content().vector()).isNotEmpty(); + } + + @Test + void should_embed_text_with_embedding_shortening() { + + // given + int dimensions = 100; + + EmbeddingModel model = GitHubModelsEmbeddingModel.builder() + .gitHubToken(System.getenv("GITHUB_TOKEN")) + .modelName(TEXT_EMBEDDING_3_SMALL.modelName()) + .dimensions(dimensions) + .logRequestsAndResponses(true) + .build(); + + // when + Response response = model.embed("hello world"); + + // then + assertThat(response.content().dimension()).isEqualTo(dimensions); + } + + @Test + void should_return_correct_dimension() { + assertThat(model.dimension()).isEqualTo(1536); + } +} diff --git a/langchain4j-github-models/src/test/java/dev/langchain4j/model/github/GitHubModelsStreamingChatModelIT.java b/langchain4j-github-models/src/test/java/dev/langchain4j/model/github/GitHubModelsStreamingChatModelIT.java new file mode 100644 index 00000000000..11feffe0a63 --- /dev/null +++ b/langchain4j-github-models/src/test/java/dev/langchain4j/model/github/GitHubModelsStreamingChatModelIT.java @@ -0,0 +1,367 @@ +package dev.langchain4j.model.github; + +import com.azure.ai.inference.models.ChatCompletionsResponseFormatJson; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.ToolExecutionResultMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.chat.TestStreamingResponseHandler; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.junit.jupiter.params.provider.ValueSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; + +import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER; +import static dev.langchain4j.data.message.ToolExecutionResultMessage.toolExecutionResultMessage; +import static dev.langchain4j.data.message.UserMessage.userMessage; +import static dev.langchain4j.model.output.FinishReason.STOP; +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.assertj.core.api.Assertions.assertThat; + +class GitHubModelsStreamingChatModelIT { + + private static final Logger logger = LoggerFactory.getLogger(GitHubModelsStreamingChatModelIT.class); + + public long STREAMING_TIMEOUT = 120; + + @ParameterizedTest(name = "Model name {0} with async client set to {1}") + @CsvSource({ + "Phi-3.5-mini-instruct, true", + "Phi-3.5-mini-instruct, false" + }) + void should_stream_answer(String modelName, boolean useAsyncClient) throws Exception { + + CompletableFuture futureAnswer = new CompletableFuture<>(); + CompletableFuture> futureResponse = new CompletableFuture<>(); + + StreamingChatLanguageModel model = GitHubModelsStreamingChatModel.builder() + .gitHubToken(System.getenv("GITHUB_TOKEN")) + .modelName(modelName) + .logRequestsAndResponses(true) + .build(); + + model.generate("What is the capital of France?", new StreamingResponseHandler() { + + private final StringBuilder answerBuilder = new StringBuilder(); + + @Override + public void onNext(String token) { + answerBuilder.append(token); + } + + @Override + public void onComplete(Response response) { + futureAnswer.complete(answerBuilder.toString()); + futureResponse.complete(response); + } + + @Override + public void onError(Throwable error) { + futureAnswer.completeExceptionally(error); + futureResponse.completeExceptionally(error); + } + }); + + String answer = futureAnswer.get(STREAMING_TIMEOUT, SECONDS); + Response response = futureResponse.get(STREAMING_TIMEOUT, SECONDS); + + assertThat(answer).contains("Paris"); + assertThat(response.content().text()).isEqualTo(answer); + + assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(10); + assertThat(response.tokenUsage().outputTokenCount()).isGreaterThan(0); + assertThat(response.tokenUsage().totalTokenCount()) + .isEqualTo(response.tokenUsage().inputTokenCount() + response.tokenUsage().outputTokenCount()); + + assertThat(response.finishReason()).isEqualTo(STOP); + } + + @ParameterizedTest(name = "Model name {0}") + @CsvSource({ + "Mistral-nemo", + "meta-llama-3-8b-instruct" + }) + void test_different_available_models(String modelName) throws Exception { + + StreamingChatLanguageModel model = GitHubModelsStreamingChatModel.builder() + .gitHubToken(System.getenv("GITHUB_TOKEN")) + .modelName(modelName) + .logRequestsAndResponses(true) + .build(); + + // when + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + model.generate("What is the capital of France?", handler); + Response response = handler.get(); + + // then + assertThat(response.content().text()).contains("Paris"); + + assertThat(response.tokenUsage().inputTokenCount()).isGreaterThan(0); + assertThat(response.tokenUsage().outputTokenCount()).isGreaterThan(0); + assertThat(response.tokenUsage().totalTokenCount()) + .isEqualTo(response.tokenUsage().inputTokenCount() + response.tokenUsage().outputTokenCount()); + + assertThat(response.finishReason()).isEqualTo(STOP); + } + + @ParameterizedTest(name = "Model name {0}") + @ValueSource(strings = {"gpt-4o"}) + void should_use_json_format(String modelName) { + + StreamingChatLanguageModel model = GitHubModelsStreamingChatModel.builder() + .gitHubToken(System.getenv("GITHUB_TOKEN")) + .modelName(modelName) + .responseFormat(new ChatCompletionsResponseFormatJson()) + .logRequestsAndResponses(true) + .build(); + + String userMessage = "Return JSON with two fields: name and surname of Klaus Heisler."; + + String expectedJson = "{\"name\": \"Klaus\", \"surname\": \"Heisler\"}"; + + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + model.generate(userMessage, handler); + Response response = handler.get(); + + assertThat(response.content().text()).isEqualToIgnoringWhitespace(expectedJson); + } + + @ParameterizedTest(name = "Model name {0}") + @CsvSource({ + "gpt-4o" + }) + void should_call_function_with_argument(String modelName) throws Exception { + + CompletableFuture> futureResponse = new CompletableFuture<>(); + + StreamingChatLanguageModel model = GitHubModelsStreamingChatModel.builder() + .gitHubToken(System.getenv("GITHUB_TOKEN")) + .modelName(modelName) + .logRequestsAndResponses(true) + .build(); + + UserMessage userMessage = userMessage("Two plus two?"); + + String toolName = "calculator"; + + ToolSpecification toolSpecification = ToolSpecification.builder() + .name(toolName) + .description("returns a sum of two numbers") + .addParameter("first", INTEGER) + .addParameter("second", INTEGER) + .build(); + + model.generate(singletonList(userMessage), toolSpecification, new StreamingResponseHandler() { + + @Override + public void onNext(String token) { + Exception e = new IllegalStateException("onNext() should never be called when tool is executed"); + futureResponse.completeExceptionally(e); + } + + @Override + public void onComplete(Response response) { + futureResponse.complete(response); + } + + @Override + public void onError(Throwable error) { + futureResponse.completeExceptionally(error); + } + }); + + Response response = futureResponse.get(STREAMING_TIMEOUT, SECONDS); + + AiMessage aiMessage = response.content(); + assertThat(aiMessage.text()).isNull(); + + assertThat(aiMessage.toolExecutionRequests()).hasSize(1); + ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0); + assertThat(toolExecutionRequest.name()).isEqualTo(toolName); + assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}"); + + // Token usage should in fact be > 0, but this is currently unsupported on the server side + assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(0); + assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(0); + assertThat(response.tokenUsage().totalTokenCount()) + .isEqualTo(response.tokenUsage().inputTokenCount() + response.tokenUsage().outputTokenCount()); + + assertThat(response.finishReason()).isEqualTo(STOP); + + ToolExecutionResultMessage toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "four"); + List messages = asList(userMessage, aiMessage, toolExecutionResultMessage); + + CompletableFuture> futureResponse2 = new CompletableFuture<>(); + + model.generate(messages, new StreamingResponseHandler() { + + @Override + public void onNext(String token) { + } + + @Override + public void onComplete(Response response) { + futureResponse2.complete(response); + } + + @Override + public void onError(Throwable error) { + futureResponse2.completeExceptionally(error); + } + }); + + Response response2 = futureResponse2.get(STREAMING_TIMEOUT, SECONDS); + AiMessage aiMessage2 = response2.content(); + + // then + assertThat(aiMessage2.text()).contains("four"); + assertThat(aiMessage2.toolExecutionRequests()).isNull(); + + // Token usage should in fact be > 0, but this is currently unsupported on the server side + TokenUsage tokenUsage2 = response2.tokenUsage(); + assertThat(tokenUsage2.inputTokenCount()).isEqualTo(0); + assertThat(tokenUsage2.outputTokenCount()).isEqualTo(0); + assertThat(tokenUsage2.totalTokenCount()) + .isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount()); + + assertThat(response2.finishReason()).isEqualTo(STOP); + } + + @ParameterizedTest(name = "Model name {0}") + @CsvSource({ + "gpt-4o" + }) + void should_call_three_functions_in_parallel(String modelName) throws Exception { + + CompletableFuture> futureResponse = new CompletableFuture<>(); + + StreamingChatLanguageModel model = GitHubModelsStreamingChatModel.builder() + .gitHubToken(System.getenv("GITHUB_TOKEN")) + .modelName(modelName) + .logRequestsAndResponses(true) + .build(); + + UserMessage userMessage = userMessage("Give three numbers, ordered by size: the sum of two plus two, the square of four, and finally the cube of eight."); + + List toolSpecifications = asList( + ToolSpecification.builder() + .name("sum") + .description("returns a sum of two numbers") + .addParameter("first", INTEGER) + .addParameter("second", INTEGER) + .build(), + ToolSpecification.builder() + .name("square") + .description("returns the square of one number") + .addParameter("number", INTEGER) + .build(), + ToolSpecification.builder() + .name("cube") + .description("returns the cube of one number") + .addParameter("number", INTEGER) + .build() + ); + + model.generate(singletonList(userMessage), toolSpecifications, new StreamingResponseHandler() { + + @Override + public void onNext(String token) { + Exception e = new IllegalStateException("onNext() should never be called when tool is executed"); + futureResponse.completeExceptionally(e); + } + + @Override + public void onComplete(Response response) { + futureResponse.complete(response); + } + + @Override + public void onError(Throwable error) { + futureResponse.completeExceptionally(error); + } + }); + + Response response = futureResponse.get(STREAMING_TIMEOUT, SECONDS); + + AiMessage aiMessage = response.content(); + assertThat(aiMessage.text()).isNull(); + List messages = new ArrayList<>(); + messages.add(userMessage); + messages.add(aiMessage); + assertThat(aiMessage.toolExecutionRequests()).hasSize(3); + for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) { + assertThat(toolExecutionRequest.name()).isNotEmpty(); + ToolExecutionResultMessage toolExecutionResultMessage; + if (toolExecutionRequest.name().equals("sum")) { + assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}"); + toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "4"); + } else if (toolExecutionRequest.name().equals("square")) { + assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"number\": 4}"); + toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "16"); + } else if (toolExecutionRequest.name().equals("cube")) { + assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"number\": 8}"); + toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, "512"); + } else { + throw new AssertionError("Unexpected tool name: " + toolExecutionRequest.name()); + } + messages.add(toolExecutionResultMessage); + } + CompletableFuture> futureResponse2 = new CompletableFuture<>(); + + model.generate(messages, new StreamingResponseHandler() { + + @Override + public void onNext(String token) { + } + + @Override + public void onComplete(Response response) { + futureResponse2.complete(response); + } + + @Override + public void onError(Throwable error) { + futureResponse2.completeExceptionally(error); + } + }); + + Response response2 = futureResponse2.get(STREAMING_TIMEOUT, SECONDS); + AiMessage aiMessage2 = response2.content(); + + // then + assertThat(aiMessage2.text()).contains("4", "16", "512"); + assertThat(aiMessage2.toolExecutionRequests()).isNull(); + + // Token usage should in fact be > 0, but this is currently unsupported on the server side + TokenUsage tokenUsage2 = response2.tokenUsage(); + assertThat(tokenUsage2.inputTokenCount()).isEqualTo(0); + assertThat(tokenUsage2.outputTokenCount()).isEqualTo(0); + assertThat(tokenUsage2.totalTokenCount()) + .isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount()); + + assertThat(response2.finishReason()).isEqualTo(STOP); + } + + @AfterEach + void afterEach() throws InterruptedException { + String ciDelaySeconds = System.getenv("CI_DELAY_SECONDS_GITHUB_MODELS"); + if (ciDelaySeconds != null) { + Thread.sleep(Integer.parseInt(ciDelaySeconds) * 1000L); + } + } +} diff --git a/langchain4j-github-models/src/test/java/dev/langchain4j/model/github/GitHubModelsStreamingChatModelListenerIT.java b/langchain4j-github-models/src/test/java/dev/langchain4j/model/github/GitHubModelsStreamingChatModelListenerIT.java new file mode 100644 index 00000000000..cf5050aae2f --- /dev/null +++ b/langchain4j-github-models/src/test/java/dev/langchain4j/model/github/GitHubModelsStreamingChatModelListenerIT.java @@ -0,0 +1,56 @@ +package dev.langchain4j.model.github; + +import com.azure.core.exception.ClientAuthenticationException; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.chat.StreamingChatModelListenerIT; +import dev.langchain4j.model.chat.listener.ChatModelListener; +import org.junit.jupiter.api.Disabled; + +import static java.util.Collections.singletonList; + +class GitHubModelsStreamingChatModelListenerIT extends StreamingChatModelListenerIT { + + @Override + protected StreamingChatLanguageModel createModel(ChatModelListener listener) { + return GitHubModelsStreamingChatModel.builder() + .gitHubToken(System.getenv("GITHUB_TOKEN")) + .modelName(modelName()) + .temperature(temperature()) + .topP(topP()) + .maxTokens(maxTokens()) + .logRequestsAndResponses(true) + .listeners(singletonList(listener)) + .build(); + } + + @Override + protected String modelName() { + return "gpt-4o-mini"; + } + + @Override + protected StreamingChatLanguageModel createFailingModel(ChatModelListener listener) { + return GitHubModelsStreamingChatModel.builder() + .gitHubToken("banana") + .modelName(modelName()) + .maxRetries(1) + .logRequestsAndResponses(true) + .listeners(singletonList(listener)) + .build(); + } + + @Override + protected Class expectedExceptionClass() { + return ClientAuthenticationException.class; + } + + @Override + protected boolean assertTokenUsage() { + return false; + } + + @Override + @Disabled("GitHubModelsStreamingChatModel implementation is incorrect") + protected void should_listen_error() { + } +} diff --git a/langchain4j-github-models/src/test/resources/log4j2.xml b/langchain4j-github-models/src/test/resources/log4j2.xml new file mode 100644 index 00000000000..43160a2fcc8 --- /dev/null +++ b/langchain4j-github-models/src/test/resources/log4j2.xml @@ -0,0 +1,18 @@ + + + + + + + + + + + + + + + + diff --git a/pom.xml b/pom.xml index 9cff294603a..2a2909cc284 100644 --- a/pom.xml +++ b/pom.xml @@ -41,6 +41,7 @@ langchain4j-ovh-ai langchain4j-open-ai langchain4j-qianfan + langchain4j-github-models langchain4j-google-ai-gemini langchain4j-vertex-ai langchain4j-vertex-ai-gemini