Skip to content

GH-1199: Prevent timeouts with configurable batching for PgVectorStor… #1400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ public PgVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embed
.withObservationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
.withSearchObservationConvention(customObservationConvention.getIfAvailable(() -> null))
.withBatchingStrategy(batchingStrategy)
.withMaxDocumentBatchSize(properties.getMaxDocumentBatchSize())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
/**
* @author Christian Tzolov
* @author Muthukumaran Navaneethakrishnan
* @author Soby Chacko
*/
@ConfigurationProperties(PgVectorStoreProperties.CONFIG_PREFIX)
public class PgVectorStoreProperties extends CommonVectorStoreProperties {
Expand All @@ -45,6 +46,8 @@ public class PgVectorStoreProperties extends CommonVectorStoreProperties {

private boolean schemaValidation = PgVectorStore.DEFAULT_SCHEMA_VALIDATION;

private int maxDocumentBatchSize = PgVectorStore.MAX_DOCUMENT_BATCH_SIZE;

public int getDimensions() {
return dimensions;
}
Expand Down Expand Up @@ -101,4 +104,12 @@ public void setSchemaValidation(boolean schemaValidation) {
this.schemaValidation = schemaValidation;
}

public int getMaxDocumentBatchSize() {
return this.maxDocumentBatchSize;
}

public void setMaxDocumentBatchSize(int maxDocumentBatchSize) {
this.maxDocumentBatchSize = maxDocumentBatchSize;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;

import org.postgresql.util.PGobject;
import org.slf4j.Logger;
Expand Down Expand Up @@ -81,6 +83,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini

public final FilterExpressionConverter filterExpressionConverter = new PgVectorFilterExpressionConverter();

public static final int MAX_DOCUMENT_BATCH_SIZE = 10_000;

private final String vectorTableName;

private final String vectorIndexName;
Expand Down Expand Up @@ -109,6 +113,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini

private final BatchingStrategy batchingStrategy;

private final int maxDocumentBatchSize;

public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) {
this(jdbcTemplate, embeddingModel, INVALID_EMBEDDING_DIMENSION, PgDistanceType.COSINE_DISTANCE, false,
PgIndexType.NONE, false);
Expand All @@ -132,7 +138,6 @@ public PgVectorStore(String vectorTableName, JdbcTemplate jdbcTemplate, Embeddin

this(DEFAULT_SCHEMA_NAME, vectorTableName, DEFAULT_SCHEMA_VALIDATION, jdbcTemplate, embeddingModel, dimensions,
distanceType, removeExistingVectorStoreTable, createIndexMethod, initializeSchema);

}

private PgVectorStore(String schemaName, String vectorTableName, boolean vectorTableValidationsEnabled,
Expand All @@ -141,14 +146,14 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT

this(schemaName, vectorTableName, vectorTableValidationsEnabled, jdbcTemplate, embeddingModel, dimensions,
distanceType, removeExistingVectorStoreTable, createIndexMethod, initializeSchema,
ObservationRegistry.NOOP, null, new TokenCountBatchingStrategy());
ObservationRegistry.NOOP, null, new TokenCountBatchingStrategy(), MAX_DOCUMENT_BATCH_SIZE);
}

private PgVectorStore(String schemaName, String vectorTableName, boolean vectorTableValidationsEnabled,
JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions, PgDistanceType distanceType,
boolean removeExistingVectorStoreTable, PgIndexType createIndexMethod, boolean initializeSchema,
ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention,
BatchingStrategy batchingStrategy) {
BatchingStrategy batchingStrategy, int maxDocumentBatchSize) {

super(observationRegistry, customObservationConvention);

Expand All @@ -172,6 +177,7 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT
this.initializeSchema = initializeSchema;
this.schemaValidator = new PgVectorSchemaValidator(jdbcTemplate);
this.batchingStrategy = batchingStrategy;
this.maxDocumentBatchSize = maxDocumentBatchSize;
}

public PgDistanceType getDistanceType() {
Expand All @@ -180,40 +186,50 @@ public PgDistanceType getDistanceType() {

@Override
public void doAdd(List<Document> documents) {
this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);

int size = documents.size();
List<List<Document>> batchedDocuments = batchDocuments(documents);
batchedDocuments.forEach(this::insertOrUpdateBatch);
}

this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
private List<List<Document>> batchDocuments(List<Document> documents) {
List<List<Document>> batches = new ArrayList<>();
for (int i = 0; i < documents.size(); i += this.maxDocumentBatchSize) {
batches.add(documents.subList(i, Math.min(i + this.maxDocumentBatchSize, documents.size())));
}
return batches;
}

this.jdbcTemplate.batchUpdate(
"INSERT INTO " + getFullyQualifiedTableName()
+ " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO "
+ "UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ",
new BatchPreparedStatementSetter() {
@Override
public void setValues(PreparedStatement ps, int i) throws SQLException {

var document = documents.get(i);
var content = document.getContent();
var json = toJson(document.getMetadata());
var embedding = document.getEmbedding();
var pGvector = new PGvector(embedding);

StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN,
UUID.fromString(document.getId()));
StatementCreatorUtils.setParameterValue(ps, 2, SqlTypeValue.TYPE_UNKNOWN, content);
StatementCreatorUtils.setParameterValue(ps, 3, SqlTypeValue.TYPE_UNKNOWN, json);
StatementCreatorUtils.setParameterValue(ps, 4, SqlTypeValue.TYPE_UNKNOWN, pGvector);
StatementCreatorUtils.setParameterValue(ps, 5, SqlTypeValue.TYPE_UNKNOWN, content);
StatementCreatorUtils.setParameterValue(ps, 6, SqlTypeValue.TYPE_UNKNOWN, json);
StatementCreatorUtils.setParameterValue(ps, 7, SqlTypeValue.TYPE_UNKNOWN, pGvector);
}
private void insertOrUpdateBatch(List<Document> batch) {
String sql = "INSERT INTO " + getFullyQualifiedTableName()
+ " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO "
+ "UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ";

this.jdbcTemplate.batchUpdate(sql, new BatchPreparedStatementSetter() {
@Override
public void setValues(PreparedStatement ps, int i) throws SQLException {

var document = batch.get(i);
var content = document.getContent();
var json = toJson(document.getMetadata());
var embedding = document.getEmbedding();
var pGvector = new PGvector(embedding);

StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN,
UUID.fromString(document.getId()));
StatementCreatorUtils.setParameterValue(ps, 2, SqlTypeValue.TYPE_UNKNOWN, content);
StatementCreatorUtils.setParameterValue(ps, 3, SqlTypeValue.TYPE_UNKNOWN, json);
StatementCreatorUtils.setParameterValue(ps, 4, SqlTypeValue.TYPE_UNKNOWN, pGvector);
StatementCreatorUtils.setParameterValue(ps, 5, SqlTypeValue.TYPE_UNKNOWN, content);
StatementCreatorUtils.setParameterValue(ps, 6, SqlTypeValue.TYPE_UNKNOWN, json);
StatementCreatorUtils.setParameterValue(ps, 7, SqlTypeValue.TYPE_UNKNOWN, pGvector);
}

@Override
public int getBatchSize() {
return size;
}
});
@Override
public int getBatchSize() {
return batch.size();
}
});
}

private String toJson(Map<String, Object> map) {
Expand Down Expand Up @@ -509,6 +525,8 @@ public static class Builder {

private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy();

private int maxDocumentBatchSize = MAX_DOCUMENT_BATCH_SIZE;

@Nullable
private VectorStoreObservationConvention searchObservationConvention;

Expand Down Expand Up @@ -576,11 +594,17 @@ public Builder withBatchingStrategy(BatchingStrategy batchingStrategy) {
return this;
}

public Builder withMaxDocumentBatchSize(int maxDocumentBatchSize) {
this.maxDocumentBatchSize = maxDocumentBatchSize;
return this;
}

public PgVectorStore build() {
return new PgVectorStore(this.schemaName, this.vectorTableName, this.vectorTableValidationsEnabled,
this.jdbcTemplate, this.embeddingModel, this.dimensions, this.distanceType,
this.removeExistingVectorStoreTable, this.indexType, this.initializeSchema,
this.observationRegistry, this.searchObservationConvention, this.batchingStrategy);
this.observationRegistry, this.searchObservationConvention, this.batchingStrategy,
this.maxDocumentBatchSize);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,31 @@
*/
package org.springframework.ai.vectorstore;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.mockito.ArgumentCaptor;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.only;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import java.util.Collections;

import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
import org.springframework.jdbc.core.JdbcTemplate;

/**
* @author Muthukumaran Navaneethakrishnan
* @author Soby Chacko
*/

public class PgVectorStoreTests {

@ParameterizedTest(name = "{0} - Verifies valid Table name")
Expand Down Expand Up @@ -53,8 +69,39 @@ public class PgVectorStoreTests {
// 64
// characters
})
public void isValidTable(String tableName, Boolean expected) {
void isValidTable(String tableName, Boolean expected) {
assertThat(PgVectorSchemaValidator.isValidNameForDatabaseObject(tableName)).isEqualTo(expected);
}

@Test
void shouldAddDocumentsInBatchesAndEmbedOnce() {
// Given
var jdbcTemplate = mock(JdbcTemplate.class);
var embeddingModel = mock(EmbeddingModel.class);
var pgVectorStore = new PgVectorStore.Builder(jdbcTemplate, embeddingModel).withMaxDocumentBatchSize(1000)
.build();

// Testing with 9989 documents
var documents = Collections.nCopies(9989, new Document("foo"));

// When
pgVectorStore.doAdd(documents);

// Then
verify(embeddingModel, only()).embed(eq(documents), any(), any());

var batchUpdateCaptor = ArgumentCaptor.forClass(BatchPreparedStatementSetter.class);
verify(jdbcTemplate, times(10)).batchUpdate(anyString(), batchUpdateCaptor.capture());

assertThat(batchUpdateCaptor.getAllValues()).hasSize(10)
.allSatisfy(BatchPreparedStatementSetter::getBatchSize)
.satisfies(batches -> {
for (int i = 0; i < 9; i++) {
assertThat(batches.get(i).getBatchSize()).as("Batch at index %d should have size 10", i)
.isEqualTo(1000);
}
assertThat(batches.get(9).getBatchSize()).as("Last batch should have size 989").isEqualTo(989);
});
}

}