Skip to content

Commit

Permalink
Vertex AI: embed in batches of 5
Browse files Browse the repository at this point in the history
  • Loading branch information
deep-learning-dynamo committed Oct 9, 2023
1 parent eef1796 commit 43917ee
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static com.google.cloud.aiplatform.util.ValueConverter.EMPTY_VALUE;
import static dev.langchain4j.internal.Json.toJson;
import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static java.util.stream.Collectors.toList;

Expand All @@ -28,6 +29,8 @@
*/
public class VertexAiEmbeddingModel implements EmbeddingModel {

private static final int BATCH_SIZE = 5; // Vertex AI has a limit of up to 5 input texts per request

private final PredictionServiceSettings settings;
private final EndpointName endpointName;
private final Integer maxRetries;
Expand All @@ -51,38 +54,37 @@ public VertexAiEmbeddingModel(String endpoint,
ensureNotBlank(publisher, "publisher"),
ensureNotBlank(modelName, "modelName")
);
this.maxRetries = maxRetries == null ? 3 : maxRetries;
this.maxRetries = getOrDefault(maxRetries, 3);
}

@Override
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
List<String> texts = textSegments.stream()
.map(TextSegment::text)
.collect(toList());

return embedTexts(texts);
}
public Response<List<Embedding>> embedAll(List<TextSegment> segments) {

private Response<List<Embedding>> embedTexts(List<String> texts) {
try (PredictionServiceClient client = PredictionServiceClient.create(settings)) {

List<Value> instances = new ArrayList<>();
for (String text : texts) {
Value.Builder instanceBuilder = Value.newBuilder();
JsonFormat.parser().merge(toJson(new VertexAiEmbeddingInstance(text)), instanceBuilder);
instances.add(instanceBuilder.build());
}
List<Embedding> embeddings = new ArrayList<>();
int inputTokenCount = 0;

PredictResponse response = withRetry(() -> client.predict(endpointName, instances, EMPTY_VALUE), maxRetries);
for (int i = 0; i < segments.size(); i += BATCH_SIZE) {

List<Embedding> embeddings = response.getPredictionsList().stream()
.map(VertexAiEmbeddingModel::toVector)
.map(Embedding::from)
.collect(toList());
List<TextSegment> batch = segments.subList(i, Math.min(i + BATCH_SIZE, segments.size()));

int inputTokenCount = 0;
for (Value value : response.getPredictionsList()) {
inputTokenCount += extractTokenCount(value);
List<Value> instances = new ArrayList<>();
for (TextSegment segment : batch) {
Value.Builder instanceBuilder = Value.newBuilder();
JsonFormat.parser().merge(toJson(new VertexAiEmbeddingInstance(segment.text())), instanceBuilder);
instances.add(instanceBuilder.build());
}

PredictResponse response = withRetry(() -> client.predict(endpointName, instances, EMPTY_VALUE), maxRetries);

embeddings.addAll(response.getPredictionsList().stream()
.map(VertexAiEmbeddingModel::toEmbedding)
.collect(toList()));

for (Value prediction : response.getPredictionsList()) {
inputTokenCount += extractTokenCount(prediction);
}
}

return Response.from(
Expand All @@ -94,8 +96,9 @@ private Response<List<Embedding>> embedTexts(List<String> texts) {
}
}

private static List<Float> toVector(Value prediction) {
return prediction.getStructValue()
private static Embedding toEmbedding(Value prediction) {

List<Float> vector = prediction.getStructValue()
.getFieldsMap()
.get("embeddings")
.getStructValue()
Expand All @@ -105,10 +108,12 @@ private static List<Float> toVector(Value prediction) {
.stream()
.map(v -> (float) v.getNumberValue())
.collect(toList());

return Embedding.from(vector);
}

private static int extractTokenCount(Value value) {
return (int) value.getStructValue()
private static int extractTokenCount(Value prediction) {
return (int) prediction.getStructValue()
.getFieldsMap()
.get("embeddings")
.getStructValue()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.junit.jupiter.api.Disabled;
Expand All @@ -18,7 +19,8 @@ class VertexAiEmbeddingModelIT {
@Test
@Disabled("To run this test, you must have provide your own endpoint, project and location")
void testEmbeddingModel() {
VertexAiEmbeddingModel vertexAiEmbeddingModel = VertexAiEmbeddingModel.builder()

EmbeddingModel embeddingModel = VertexAiEmbeddingModel.builder()
.endpoint("us-central1-aiplatform.googleapis.com:443")
.project("langchain4j")
.location("us-central1")
Expand All @@ -28,22 +30,22 @@ void testEmbeddingModel() {
.build();

List<TextSegment> segments = asList(
TextSegment.from("hello world"),
TextSegment.from("how are you?")
TextSegment.from("one"),
TextSegment.from("two"),
TextSegment.from("three"),
TextSegment.from("four"),
TextSegment.from("five"),
TextSegment.from("six")
);

Response<List<Embedding>> response = vertexAiEmbeddingModel.embedAll(segments);
Response<List<Embedding>> response = embeddingModel.embedAll(segments);

List<Embedding> embeddings = response.content();
assertThat(embeddings).hasSize(2);

Embedding embedding1 = embeddings.get(0);
assertThat(embedding1.vector()).hasSize(768);
System.out.println(Arrays.toString(embedding1.vector()));
assertThat(embeddings).hasSize(6);

Embedding embedding2 = embeddings.get(1);
assertThat(embedding2.vector()).hasSize(768);
System.out.println(Arrays.toString(embedding2.vector()));
Embedding embedding = embeddings.get(0);
assertThat(embedding.vector()).hasSize(768);
System.out.println(Arrays.toString(embedding.vector()));

TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isEqualTo(6);
Expand Down

0 comments on commit 43917ee

Please sign in to comment.