Skip to content

Commit

Permalink
Jlama revision bump, add working Q type to builder (langchain4j#1825)
Browse files Browse the repository at this point in the history
## Change
Bump jlama rev to 0.5.0

This rev is currently being promoted in maven so should be there in next
6 hours.

## General checklist
- [X] There are no breaking changes
- [X] I have added unit and integration tests for my change
- [X] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [X] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
  • Loading branch information
tjake authored Sep 25, 2024
1 parent 3579664 commit 21eb8b9
Show file tree
Hide file tree
Showing 13 changed files with 49 additions and 13 deletions.
2 changes: 1 addition & 1 deletion langchain4j-jlama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
[Jlama](https://github.com/tjake/Jlama) is a Java library that provides a simple way to integrate LLM models into Java
applications.

Jlama is built with Java 21 and utilizes the new [Vector API](https://openjdk.org/jeps/448) for faster inference.
Jlama is built with Java 20+ and utilizes the new [Vector API](https://openjdk.org/jeps/448) for faster inference.

Jlama uses huggingface models in safetensor format.
Models must be specified using the `owner/model-name` format. For example, `meta-llama/Llama-2-7b-chat-hf`.
Expand Down
4 changes: 2 additions & 2 deletions langchain4j-jlama/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<modelVersion>4.0.0</modelVersion>
<artifactId>langchain4j-jlama</artifactId>
<name>LangChain4j :: Integration :: Jlama</name>
<description>Jlama: Pure Java LLM Inference Engine - Requires Java 21</description>
<description>Jlama: LLM Inference Engine for Java - Requires Java 20+</description>

<parent>
<groupId>dev.langchain4j</groupId>
Expand All @@ -15,7 +15,7 @@
</parent>

<properties>
<jlama.version>0.3.1</jlama.version>
<jlama.version>0.5.0</jlama.version>
<jackson.version>2.16.1</jackson.version>
<spotless.version>2.40.0</spotless.version>
<maven.compiler.release>21</maven.compiler.release>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.functions.Generator;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.prompt.*;
import com.github.tjake.jlama.util.JsonSupport;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
Expand Down Expand Up @@ -33,6 +34,7 @@ public JlamaChatModel(Path modelCachePath,
Integer threadCount,
Boolean quantizeModelAtRuntime,
Path workingDirectory,
DType workingQuantizedType,
Float temperature,
Integer maxTokens) {
JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath);
Expand All @@ -42,6 +44,9 @@ public JlamaChatModel(Path modelCachePath,
if (quantizeModelAtRuntime != null && quantizeModelAtRuntime)
loader = loader.quantized();

if (workingQuantizedType != null)
loader = loader.workingQuantizationType(workingQuantizedType);

if (threadCount != null)
loader = loader.threadCount(threadCount);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.ModelSupport;
import com.github.tjake.jlama.model.bert.BertModel;
import com.github.tjake.jlama.model.functions.Generator;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.RetryUtils;
Expand All @@ -20,13 +21,15 @@

public class JlamaEmbeddingModel extends DimensionAwareEmbeddingModel {
private final BertModel model;
private final Generator.PoolingType poolingType;

@Builder
public JlamaEmbeddingModel(Path modelCachePath,
String modelName,
String authToken,
Integer threadCount,
Boolean quantizeModelAtRuntime,
Generator.PoolingType poolingType,
Path workingDirectory) {

JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath);
Expand All @@ -46,10 +49,12 @@ public JlamaEmbeddingModel(Path modelCachePath,
if (workingDirectory != null)
loader = loader.workingDirectory(workingDirectory);

loader = loader.inferenceType(AbstractModel.InferenceType.FORWARD_PASS);
loader = loader.inferenceType(AbstractModel.InferenceType.FULL_EMBEDDING);

this.model = (BertModel) loader.load();
this.dimension = model.getConfig().embeddingLength;

this.poolingType = poolingType == null ? Generator.PoolingType.MODEL : poolingType;
}

public static JlamaEmbeddingModelBuilder builder() {
Expand All @@ -64,7 +69,7 @@ public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
List<Embedding> embeddings = new ArrayList<>();

textSegments.forEach(textSegment -> {
embeddings.add(Embedding.from(model.embed(textSegment.text())));
embeddings.add(Embedding.from(model.embed(textSegment.text(), poolingType)));
});

return Response.from(embeddings);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.functions.Generator;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.prompt.PromptContext;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.model.jlama.spi.JlamaLanguageModelBuilderFactory;
Expand Down Expand Up @@ -30,6 +31,7 @@ public JlamaLanguageModel(Path modelCachePath,
Integer threadCount,
Boolean quantizeModelAtRuntime,
Path workingDirectory,
DType workingQuantizedType,
Float temperature,
Integer maxTokens) {
JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath);
Expand All @@ -39,6 +41,9 @@ public JlamaLanguageModel(Path modelCachePath,
if (quantizeModelAtRuntime != null && quantizeModelAtRuntime)
loader = loader.quantized();

if (workingQuantizedType != null)
loader = loader.workingQuantizationType(workingQuantizedType);

if (threadCount != null)
loader = loader.threadCount(threadCount);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,15 @@ public void download(Optional<String> authToken) throws IOException {
registry.getModelCachePath().toString(),
owner,
modelName,
true,
Optional.empty(),
authToken,
Optional.empty());
}

public class Loader {
private Path workingDirectory;
private DType workingQuantizationType = DType.I8;
private DType quantizationType;
private Integer threadCount;
private AbstractModel.InferenceType inferenceType = AbstractModel.InferenceType.FULL_GENERATION;
Expand All @@ -75,11 +77,19 @@ private Loader() {
}

public Loader quantized() {
//For now only allow Q4 quantization at load time
//For now only allow Q4 quantization at runtime
this.quantizationType = DType.Q4;
return this;
}

/**
* Set the working quantization type. This is the type that the model will use for working inference memory.
*/
public Loader workingQuantizationType(DType workingQuantizationType) {
this.workingQuantizationType = workingQuantizationType;
return this;
}

public Loader workingDirectory(Path workingDirectory) {
this.workingDirectory = workingDirectory;
return this;
Expand All @@ -101,10 +111,11 @@ public AbstractModel load() {
new File(registry.getModelCachePath().toFile(), modelName),
workingDirectory == null ? null : workingDirectory.toFile(),
DType.F32,
DType.I8,
workingQuantizationType,
Optional.ofNullable(quantizationType),
Optional.ofNullable(threadCount),
Optional.empty());
Optional.empty(),
SafeTensorSupport::loadWeights);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ public JlamaModel downloadModel(String modelName, Optional<String> authToken) th
name = parts[1];
}

File modelDir = SafeTensorSupport.maybeDownloadModel(modelCachePath.toString(), Optional.ofNullable(owner), name, Optional.empty(), authToken, Optional.empty());
File modelDir = SafeTensorSupport.maybeDownloadModel(modelCachePath.toString(), Optional.ofNullable(owner), name, true, Optional.empty(), authToken, Optional.empty());

File config = new File(modelDir, "config.json");
ModelSupport.ModelType type = SafeTensorSupport.detectModel(config);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.functions.Generator;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.prompt.PromptSupport;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
Expand Down Expand Up @@ -34,6 +35,7 @@ public JlamaStreamingChatModel(Path modelCachePath,
Integer threadCount,
Boolean quantizeModelAtRuntime,
Path workingDirectory,
DType workingQuantizedType,
Float temperature,
Integer maxTokens) {
JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath);
Expand All @@ -43,6 +45,9 @@ public JlamaStreamingChatModel(Path modelCachePath,
if (quantizeModelAtRuntime != null && quantizeModelAtRuntime)
loader = loader.quantized();

if (workingQuantizedType != null)
loader = loader.workingQuantizationType(workingQuantizedType);

if (threadCount != null)
loader = loader.threadCount(threadCount);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.functions.Generator;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.prompt.PromptContext;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.model.StreamingResponseHandler;
Expand Down Expand Up @@ -31,6 +32,7 @@ public JlamaStreamingLanguageModel(Path modelCachePath,
Integer threadCount,
Boolean quantizeModelAtRuntime,
Path workingDirectory,
DType workingQuantizedType,
Float temperature,
Integer maxTokens) {
JlamaModelRegistry registry = JlamaModelRegistry.getOrCreate(modelCachePath);
Expand All @@ -40,6 +42,9 @@ public JlamaStreamingLanguageModel(Path modelCachePath,
if (quantizeModelAtRuntime != null && quantizeModelAtRuntime)
loader = loader.quantized();

if (workingQuantizedType != null)
loader = loader.workingQuantizationType(workingQuantizedType);

if (threadCount != null)
loader = loader.threadCount(threadCount);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ static void setup() {
.modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4")
.modelCachePath(tmpDir.toPath())
.temperature(0.0f)
.maxTokens(30)
.maxTokens(64)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ static void setup() {
.modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4")
.modelCachePath(tmpDir.toPath())
.temperature(0.0f)
.maxTokens(30)
.maxTokens(64)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ static void setup() {
model = JlamaStreamingChatModel.builder()
.modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4")
.modelCachePath(tmpDir.toPath())
.maxTokens(30)
.maxTokens(64)
.temperature(0.0f)
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ static void setup() {
model = JlamaStreamingLanguageModel.builder()
.modelName("tjake/Meta-Llama-3.1-8B-Instruct-Jlama-Q4")
.modelCachePath(tmpDir.toPath())
.maxTokens(30)
.maxTokens(64)
.temperature(0.0f)
.build();
}
Expand Down

0 comments on commit 21eb8b9

Please sign in to comment.