Skip to content

Commit

Permalink
Refactored in-process embeddings, added more docs (langchain4j#55)
Browse files Browse the repository at this point in the history
Co-authored-by: deep-learning-dynamo <deep-learning-dynamo@gmail.com>
  • Loading branch information
langchain4j and deep-learning-dynamo authored Aug 1, 2023
1 parent 200bfe1 commit d1f5257
Show file tree
Hide file tree
Showing 21 changed files with 305 additions and 91,625 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@ public static boolean isNullOrBlank(String string) {
return string == null || string.trim().isEmpty();
}

public static String repeat(String string, int times) {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < times; i++) {
sb.append(string);
}
return sb.toString();
}

public static <T> List<T> list(T... elements) {
return Arrays.asList(elements);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,11 @@ default int estimateTokenCount(TextSegment textSegment) {
return estimateTokenCount(textSegment.text());
}

int estimateTokenCount(List<TextSegment> textSegments);
default int estimateTokenCount(List<TextSegment> textSegments) {
int tokenCount = 0;
for (TextSegment textSegment : textSegments) {
tokenCount += estimateTokenCount(textSegment);
}
return tokenCount;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@

public class ALL_MINILM_L6_V2_Q_EmbeddingModel extends AbstractInProcessEmbeddingModel {

private static final OnnxEmbeddingModel model = new OnnxEmbeddingModel(
"/all-minilm-l6-v2-q.onnx",
"/vocab.txt"
);
private static final OnnxBertEmbeddingModel model = new OnnxBertEmbeddingModel("/all-minilm-l6-v2-q.onnx");

@Override
protected OnnxEmbeddingModel model() {
protected OnnxBertEmbeddingModel model() {
return model;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import static dev.langchain4j.internal.Utils.repeat;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

class ALL_MINILM_L6_V2_Q_EmbeddingModelTest {

Expand All @@ -22,4 +24,30 @@ void should_embed() {

assertThat(Similarity.cosine(first.vector(), second.vector())).isGreaterThan(0.8);
}

@Test
@Disabled("Temporary disabling. This test should run only when this or used (e.g. langchain4j-embeddings) module(s) change")
void should_embed_510_token_long_text() {

EmbeddingModel model = new ALL_MINILM_L6_V2_Q_EmbeddingModel();

String oneToken = "hello ";

Embedding embedding = model.embed(repeat(oneToken, 510));

assertThat(embedding.vector()).hasSize(384);
}

@Test
@Disabled("Temporary disabling. This test should run only when this or used (e.g. langchain4j-embeddings) module(s) change")
void should_fail_to_embed_511_token_long_text() {

EmbeddingModel model = new ALL_MINILM_L6_V2_Q_EmbeddingModel();

String oneToken = "hello ";

assertThatThrownBy(() -> model.embed(repeat(oneToken, 511)))
.isExactlyInstanceOf(IllegalArgumentException.class)
.hasMessageStartingWith("Cannot embed text longer than 510 tokens. The following text is 511 tokens long: hello hello");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@

public class ALL_MINILM_L6_V2_EmbeddingModel extends AbstractInProcessEmbeddingModel {

private static final OnnxEmbeddingModel model = new OnnxEmbeddingModel(
"/all-minilm-l6-v2.onnx",
"/vocab.txt"
);
private static final OnnxBertEmbeddingModel model = new OnnxBertEmbeddingModel("/all-minilm-l6-v2.onnx");

@Override
protected OnnxEmbeddingModel model() {
protected OnnxBertEmbeddingModel model() {
return model;
}
}
Loading

0 comments on commit d1f5257

Please sign in to comment.