Skip to content

Commit

Permalink
Support Embedding for Large Amounts of Texts (langchain4j#1142)
Browse files Browse the repository at this point in the history
fix: langchain4j#1140

<!-- Thank you so much for your contribution! -->
<!-- Please fill in all the sections below. -->

<!-- Please open the PR as a draft initially. Once it is reviewed and
approved, we will ask you to add documentation and examples. -->
<!-- Please note that PRs with breaking changes will be rejected. -->
<!-- Please note that PRs without tests will be rejected. -->

<!-- Please note that PRs will be reviewed based on the priority of the
issues they address. -->
<!-- We ask for your patience. We are doing our best to review your PR
as quickly as possible. -->
<!-- Please refrain from pinging and asking when it will be reviewed.
Thank you for understanding! -->


## Issue
<!-- Please paste the link to the issue this PR is addressing. For
example: langchain4j#1012 -->


## Change
<!-- Please describe the changes you made. -->


## General checklist
<!-- Please double-check the following points and mark them like this:
[X] -->
- [x] There are no breaking changes
- [x] I have added unit and integration tests for my change
- [x] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [ ] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
<!-- Before adding documentation and example(s) (below), please wait
until the PR is reviewed and approved. -->
- [ ] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)


## Checklist for adding new model integration
<!-- Please double-check the following points and mark them like this:
[X] -->
- [ ] I have added my new module in the
[BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml)


## Checklist for adding new embedding store integration
<!-- Please double-check the following points and mark them like this:
[X] -->
- [ ] I have added a `{NameOfIntegration}EmbeddingStoreIT` that extends
from either `EmbeddingStoreIT` or `EmbeddingStoreWithFilteringIT`
- [ ] I have added my new module in the
[BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml)


## Checklist for changing existing embedding store integration
<!-- Please double-check the following points and mark them like this:
[X] -->
- [ ] I have manually verified that the
`{NameOfIntegration}EmbeddingStore` works correctly with the data
persisted using the latest released version of LangChain4j
  • Loading branch information
jiangsier-xyz authored May 24, 2024
1 parent 463a3a3 commit 4398395
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public class QwenEmbeddingModel implements EmbeddingModel {
public static final String TYPE_KEY = "type";
public static final String TYPE_QUERY = "query";
public static final String TYPE_DOCUMENT = "document";
private static final int MAX_BATCH_SIZE = 25;

private final String apiKey;
private final String modelName;
Expand Down Expand Up @@ -53,7 +54,30 @@ private boolean containsQueries(List<TextSegment> textSegments) {
.anyMatch(TYPE_QUERY::equalsIgnoreCase);
}

private Response<List<Embedding>> embedTexts(List<TextSegment> textSegments, TextEmbeddingParam.TextType textType) {
private Response<List<Embedding>> embedTexts(List<TextSegment> textSegments,
TextEmbeddingParam.TextType textType) {
int size = textSegments.size();
if (size < MAX_BATCH_SIZE) {
return batchEmbedTexts(textSegments, textType);
}

List<Embedding> allEmbeddings = new ArrayList<>(size);
TokenUsage allUsage = null;
int fromIndex = 0;
int toIndex = MAX_BATCH_SIZE;
while (fromIndex < size) {
List<TextSegment> batchTextSegments = textSegments.subList(fromIndex, toIndex);
Response<List<Embedding>> batchResponse = batchEmbedTexts(batchTextSegments, textType);
allEmbeddings.addAll(batchResponse.content());
allUsage = TokenUsage.sum(allUsage, batchResponse.tokenUsage());
fromIndex = toIndex;
toIndex = Math.min(size, fromIndex + MAX_BATCH_SIZE);
}

return Response.from(allEmbeddings, allUsage);
}

private Response<List<Embedding>> batchEmbedTexts(List<TextSegment> textSegments, TextEmbeddingParam.TextType textType) {
TextEmbeddingParam param = TextEmbeddingParam.builder()
.apiKey(apiKey)
.model(modelName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;

import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static dev.langchain4j.data.segment.TextSegment.textSegment;
import static dev.langchain4j.model.dashscope.QwenEmbeddingModel.TYPE_KEY;
import static dev.langchain4j.model.dashscope.QwenEmbeddingModel.TYPE_QUERY;
import static dev.langchain4j.model.dashscope.QwenTestHelper.apiKey;
import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -52,8 +57,8 @@ void should_embed_documents(String modelName) {
void should_embed_queries(String modelName) {
EmbeddingModel model = getModel(modelName);
List<Embedding> embeddings = model.embedAll(asList(
textSegment("hello", Metadata.from(QwenEmbeddingModel.TYPE_KEY, QwenEmbeddingModel.TYPE_QUERY)),
textSegment("how are you?", Metadata.from(QwenEmbeddingModel.TYPE_KEY, QwenEmbeddingModel.TYPE_QUERY))
textSegment("hello", Metadata.from(TYPE_KEY, TYPE_QUERY)),
textSegment("how are you?", Metadata.from(TYPE_KEY, TYPE_QUERY))
)).content();

assertThat(embeddings).hasSize(2);
Expand All @@ -66,12 +71,47 @@ void should_embed_queries(String modelName) {
void should_embed_mix_segments(String modelName) {
EmbeddingModel model = getModel(modelName);
List<Embedding> embeddings = model.embedAll(asList(
textSegment("hello", Metadata.from(QwenEmbeddingModel.TYPE_KEY, QwenEmbeddingModel.TYPE_QUERY)),
textSegment("hello", Metadata.from(TYPE_KEY, TYPE_QUERY)),
textSegment("how are you?")
)).content();

assertThat(embeddings).hasSize(2);
assertThat(embeddings.get(0).vector()).isNotEmpty();
assertThat(embeddings.get(1).vector()).isNotEmpty();
}

@ParameterizedTest
@MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider")
void should_embed_large_amounts_of_documents(String modelName) {
EmbeddingModel model = getModel(modelName);
List<Embedding> embeddings = model.embedAll(
Collections.nCopies(50, textSegment("hello"))).content();

assertThat(embeddings).hasSize(50);
}

@ParameterizedTest
@MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider")
void should_embed_large_amounts_of_queries(String modelName) {
EmbeddingModel model = getModel(modelName);
List<Embedding> embeddings = model.embedAll(
Collections.nCopies(50, textSegment("hello", Metadata.from(TYPE_KEY, TYPE_QUERY)))
).content();

assertThat(embeddings).hasSize(50);
}

@ParameterizedTest
@MethodSource("dev.langchain4j.model.dashscope.QwenTestHelper#embeddingModelNameProvider")
void should_embed_large_amounts_of_mix_segments(String modelName) {
EmbeddingModel model = getModel(modelName);
List<Embedding> embeddings = model.embedAll(
Stream.concat(
Collections.nCopies(50, textSegment("hello", Metadata.from(TYPE_KEY, TYPE_QUERY))).stream(),
Collections.nCopies(50, textSegment("how are you?")).stream()
).collect(Collectors.toList())
).content();

assertThat(embeddings).hasSize(100);
}
}

0 comments on commit 4398395

Please sign in to comment.