From 21eb8b962fc226e7f75232d34741715a26d3b7a8 Mon Sep 17 00:00:00 2001 From: Jake Luciani Date: Wed, 25 Sep 2024 02:55:25 -0400 Subject: [PATCH] Jlama revision bump, add working Q type to builder (#1825) ## Change Bump jlama rev to 0.5.0 This rev is currently being promoted in maven so should be there in next 6 hours. ## General checklist - [X] There are no breaking changes - [X] I have added unit and integration tests for my change - [X] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [X] I have manually run all the unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules, and they are all green --- langchain4j-jlama/README.md | 2 +- langchain4j-jlama/pom.xml | 4 ++-- .../langchain4j/model/jlama/JlamaChatModel.java | 5 +++++ .../model/jlama/JlamaEmbeddingModel.java | 9 +++++++-- .../model/jlama/JlamaLanguageModel.java | 5 +++++ .../dev/langchain4j/model/jlama/JlamaModel.java | 17 ++++++++++++++--- .../model/jlama/JlamaModelRegistry.java | 2 +- .../model/jlama/JlamaStreamingChatModel.java | 5 +++++ .../jlama/JlamaStreamingLanguageModel.java | 5 +++++ .../model/jlama/JlamaChatModelIT.java | 2 +- .../model/jlama/JlamaLanguageModelIT.java | 2 +- .../model/jlama/JlamaStreamingChatModelIT.java | 2 +- .../jlama/JlamaStreamingLanguageModelIT.java | 2 +- 13 files changed, 49 insertions(+), 13 deletions(-) diff --git a/langchain4j-jlama/README.md b/langchain4j-jlama/README.md index 8909b32f919..9c4eebd59a9 100644 --- a/langchain4j-jlama/README.md +++ b/langchain4j-jlama/README.md @@ -3,7 +3,7 @@ [Jlama](https://github.com/tjake/Jlama) is a Java library that provides a simple way to integrate LLM models into Java applications. -Jlama is built with Java 21 and utilizes the new [Vector API](https://openjdk.org/jeps/448) for faster inference. +Jlama is built with Java 20+ and utilizes the new [Vector API](https://openjdk.org/jeps/448) for faster inference. Jlama uses huggingface models in safetensor format. Models must be specified using the `owner/model-name` format. For example, `meta-llama/Llama-2-7b-chat-hf`. diff --git a/langchain4j-jlama/pom.xml b/langchain4j-jlama/pom.xml index c55a8aee2d9..9ea8e7ac79a 100644 --- a/langchain4j-jlama/pom.xml +++ b/langchain4j-jlama/pom.xml @@ -5,7 +5,7 @@ 4.0.0 langchain4j-jlama LangChain4j :: Integration :: Jlama - Jlama: Pure Java LLM Inference Engine - Requires Java 21 + Jlama: LLM Inference Engine for Java - Requires Java 20+ dev.langchain4j @@ -15,7 +15,7 @@ - 0.3.1 + 0.5.0 2.16.1 2.40.0 21 diff --git a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaChatModel.java b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaChatModel.java index 54276dc851b..666a4cb9068 100644 --- a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaChatModel.java +++ b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaChatModel.java @@ -2,6 +2,7 @@ import com.github.tjake.jlama.model.AbstractModel; import com.github.tjake.jlama.model.functions.Generator; +import com.github.tjake.jlama.safetensors.DType; import com.github.tjake.jlama.safetensors.prompt.*; import com.github.tjake.jlama.util.JsonSupport; import dev.langchain4j.agent.tool.ToolExecutionRequest; @@ -33,6 +34,7 @@ public JlamaChatModel(Path modelCachePath, Integer threadCount, Boolean quantizeModelAtRuntime, Path workingDirectory, + DType workingQuantizedType, Float temperature, Integer maxTokens) { JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath); @@ -42,6 +44,9 @@ public JlamaChatModel(Path modelCachePath, if (quantizeModelAtRuntime != null && quantizeModelAtRuntime) loader = loader.quantized(); + if (workingQuantizedType != null) + loader = loader.workingQuantizationType(workingQuantizedType); + if (threadCount != null) loader = loader.threadCount(threadCount); diff --git a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaEmbeddingModel.java b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaEmbeddingModel.java index cf71198edb1..94fea97aa0e 100644 --- a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaEmbeddingModel.java +++ b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaEmbeddingModel.java @@ -3,6 +3,7 @@ import com.github.tjake.jlama.model.AbstractModel; import com.github.tjake.jlama.model.ModelSupport; import com.github.tjake.jlama.model.bert.BertModel; +import com.github.tjake.jlama.model.functions.Generator; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.internal.RetryUtils; @@ -20,6 +21,7 @@ public class JlamaEmbeddingModel extends DimensionAwareEmbeddingModel { private final BertModel model; + private final Generator.PoolingType poolingType; @Builder public JlamaEmbeddingModel(Path modelCachePath, @@ -27,6 +29,7 @@ public JlamaEmbeddingModel(Path modelCachePath, String authToken, Integer threadCount, Boolean quantizeModelAtRuntime, + Generator.PoolingType poolingType, Path workingDirectory) { JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath); @@ -46,10 +49,12 @@ public JlamaEmbeddingModel(Path modelCachePath, if (workingDirectory != null) loader = loader.workingDirectory(workingDirectory); - loader = loader.inferenceType(AbstractModel.InferenceType.FORWARD_PASS); + loader = loader.inferenceType(AbstractModel.InferenceType.FULL_EMBEDDING); this.model = (BertModel) loader.load(); this.dimension = model.getConfig().embeddingLength; + + this.poolingType = poolingType == null ? Generator.PoolingType.MODEL : poolingType; } public static JlamaEmbeddingModelBuilder builder() { @@ -64,7 +69,7 @@ public Response> embedAll(List textSegments) { List embeddings = new ArrayList<>(); textSegments.forEach(textSegment -> { - embeddings.add(Embedding.from(model.embed(textSegment.text()))); + embeddings.add(Embedding.from(model.embed(textSegment.text(), poolingType))); }); return Response.from(embeddings); diff --git a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaLanguageModel.java b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaLanguageModel.java index 1bf0d998e6e..8d81b3f8334 100644 --- a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaLanguageModel.java +++ b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaLanguageModel.java @@ -2,6 +2,7 @@ import com.github.tjake.jlama.model.AbstractModel; import com.github.tjake.jlama.model.functions.Generator; +import com.github.tjake.jlama.safetensors.DType; import com.github.tjake.jlama.safetensors.prompt.PromptContext; import dev.langchain4j.internal.RetryUtils; import dev.langchain4j.model.jlama.spi.JlamaLanguageModelBuilderFactory; @@ -30,6 +31,7 @@ public JlamaLanguageModel(Path modelCachePath, Integer threadCount, Boolean quantizeModelAtRuntime, Path workingDirectory, + DType workingQuantizedType, Float temperature, Integer maxTokens) { JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath); @@ -39,6 +41,9 @@ public JlamaLanguageModel(Path modelCachePath, if (quantizeModelAtRuntime != null && quantizeModelAtRuntime) loader = loader.quantized(); + if (workingQuantizedType != null) + loader = loader.workingQuantizationType(workingQuantizedType); + if (threadCount != null) loader = loader.threadCount(threadCount); diff --git a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaModel.java b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaModel.java index fc298900df6..1b50b1fe692 100644 --- a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaModel.java +++ b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaModel.java @@ -60,6 +60,7 @@ public void download(Optional authToken) throws IOException { registry.getModelCachePath().toString(), owner, modelName, + true, Optional.empty(), authToken, Optional.empty()); @@ -67,6 +68,7 @@ public void download(Optional authToken) throws IOException { public class Loader { private Path workingDirectory; + private DType workingQuantizationType = DType.I8; private DType quantizationType; private Integer threadCount; private AbstractModel.InferenceType inferenceType = AbstractModel.InferenceType.FULL_GENERATION; @@ -75,11 +77,19 @@ private Loader() { } public Loader quantized() { - //For now only allow Q4 quantization at load time + //For now only allow Q4 quantization at runtime this.quantizationType = DType.Q4; return this; } + /** + * Set the working quantization type. This is the type that the model will use for working inference memory. + */ + public Loader workingQuantizationType(DType workingQuantizationType) { + this.workingQuantizationType = workingQuantizationType; + return this; + } + public Loader workingDirectory(Path workingDirectory) { this.workingDirectory = workingDirectory; return this; @@ -101,10 +111,11 @@ public AbstractModel load() { new File(registry.getModelCachePath().toFile(), modelName), workingDirectory == null ? null : workingDirectory.toFile(), DType.F32, - DType.I8, + workingQuantizationType, Optional.ofNullable(quantizationType), Optional.ofNullable(threadCount), - Optional.empty()); + Optional.empty(), + SafeTensorSupport::loadWeights); } } diff --git a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaModelRegistry.java b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaModelRegistry.java index 36251dd5697..446b02dbb50 100644 --- a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaModelRegistry.java +++ b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaModelRegistry.java @@ -90,7 +90,7 @@ public JlamaModel downloadModel(String modelName, Optional authToken) th name = parts[1]; } - File modelDir = SafeTensorSupport.maybeDownloadModel(modelCachePath.toString(), Optional.ofNullable(owner), name, Optional.empty(), authToken, Optional.empty()); + File modelDir = SafeTensorSupport.maybeDownloadModel(modelCachePath.toString(), Optional.ofNullable(owner), name, true, Optional.empty(), authToken, Optional.empty()); File config = new File(modelDir, "config.json"); ModelSupport.ModelType type = SafeTensorSupport.detectModel(config); diff --git a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaStreamingChatModel.java b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaStreamingChatModel.java index 90a74369010..5fe99a6c1a8 100644 --- a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaStreamingChatModel.java +++ b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaStreamingChatModel.java @@ -2,6 +2,7 @@ import com.github.tjake.jlama.model.AbstractModel; import com.github.tjake.jlama.model.functions.Generator; +import com.github.tjake.jlama.safetensors.DType; import com.github.tjake.jlama.safetensors.prompt.PromptSupport; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; @@ -34,6 +35,7 @@ public JlamaStreamingChatModel(Path modelCachePath, Integer threadCount, Boolean quantizeModelAtRuntime, Path workingDirectory, + DType workingQuantizedType, Float temperature, Integer maxTokens) { JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath); @@ -43,6 +45,9 @@ public JlamaStreamingChatModel(Path modelCachePath, if (quantizeModelAtRuntime != null && quantizeModelAtRuntime) loader = loader.quantized(); + if (workingQuantizedType != null) + loader = loader.workingQuantizationType(workingQuantizedType); + if (threadCount != null) loader = loader.threadCount(threadCount); diff --git a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaStreamingLanguageModel.java b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaStreamingLanguageModel.java index 8a3e65280e7..0405f31c6f5 100644 --- a/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaStreamingLanguageModel.java +++ b/langchain4j-jlama/src/main/java/dev/langchain4j/model/jlama/JlamaStreamingLanguageModel.java @@ -2,6 +2,7 @@ import com.github.tjake.jlama.model.AbstractModel; import com.github.tjake.jlama.model.functions.Generator; +import com.github.tjake.jlama.safetensors.DType; import com.github.tjake.jlama.safetensors.prompt.PromptContext; import dev.langchain4j.internal.RetryUtils; import dev.langchain4j.model.StreamingResponseHandler; @@ -31,6 +32,7 @@ public JlamaStreamingLanguageModel(Path modelCachePath, Integer threadCount, Boolean quantizeModelAtRuntime, Path workingDirectory, + DType workingQuantizedType, Float temperature, Integer maxTokens) { JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath); @@ -40,6 +42,9 @@ public JlamaStreamingLanguageModel(Path modelCachePath, if (quantizeModelAtRuntime != null && quantizeModelAtRuntime) loader = loader.quantized(); + if (workingQuantizedType != null) + loader = loader.workingQuantizationType(workingQuantizedType); + if (threadCount != null) loader = loader.threadCount(threadCount); diff --git a/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaChatModelIT.java b/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaChatModelIT.java index 8e715c1c9cb..31085aa422d 100644 --- a/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaChatModelIT.java +++ b/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaChatModelIT.java @@ -29,7 +29,7 @@ static void setup() { .modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4") .modelCachePath(tmpDir.toPath()) .temperature(0.0f) - .maxTokens(30) + .maxTokens(64) .build(); } diff --git a/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaLanguageModelIT.java b/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaLanguageModelIT.java index d4d614a066a..c304e23e9fa 100644 --- a/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaLanguageModelIT.java +++ b/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaLanguageModelIT.java @@ -26,7 +26,7 @@ static void setup() { .modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4") .modelCachePath(tmpDir.toPath()) .temperature(0.0f) - .maxTokens(30) + .maxTokens(64) .build(); } diff --git a/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaStreamingChatModelIT.java b/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaStreamingChatModelIT.java index e4bdc4baf5a..97e5fdd58e8 100644 --- a/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaStreamingChatModelIT.java +++ b/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaStreamingChatModelIT.java @@ -27,7 +27,7 @@ static void setup() { model = JlamaStreamingChatModel.builder() .modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4") .modelCachePath(tmpDir.toPath()) - .maxTokens(30) + .maxTokens(64) .temperature(0.0f) .build(); } diff --git a/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaStreamingLanguageModelIT.java b/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaStreamingLanguageModelIT.java index 1d9c30217c3..818a6a2d532 100644 --- a/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaStreamingLanguageModelIT.java +++ b/langchain4j-jlama/src/test/java/dev/langchain4j/model/jlama/JlamaStreamingLanguageModelIT.java @@ -26,7 +26,7 @@ static void setup() { model = JlamaStreamingLanguageModel.builder() .modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4") .modelCachePath(tmpDir.toPath()) - .maxTokens(30) + .maxTokens(64) .temperature(0.0f) .build(); }