Skip to content

Commit

Permalink
Milvus: configurable field names (langchain4j#1852)
Browse files Browse the repository at this point in the history
## Issue
Closes langchain4j#1842 

## Change
Add `FieldDefinition` to hold customized filed name.
Replace default field name with customized filed name.

## General checklist
- [X] There are no breaking changes
- [X] I have added unit and integration tests for my change
- [X] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [X] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
- [ ] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added/updated [Spring Boot
starter(s)](https://github.com/langchain4j/langchain4j-spring) (if
applicable)

## Checklist for changing existing embedding store integration
- [X] I have manually verified that the
`{NameOfIntegration}EmbeddingStore` works correctly with the data
persisted using the latest released version of LangChain4j
  • Loading branch information
hrhrng authored Oct 10, 2024
1 parent 250c59f commit 2bd251a
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,29 +40,29 @@ static boolean hasCollection(MilvusServiceClient milvusClient, String collection
return response.getData();
}

static void createCollection(MilvusServiceClient milvusClient, String collectionName, int dimension) {
static void createCollection(MilvusServiceClient milvusClient, String collectionName, FieldDefinition fieldDefinition, int dimension) {

CreateCollectionParam request = CreateCollectionParam.newBuilder()
.withCollectionName(collectionName)
.withSchema(CollectionSchemaParam.newBuilder()
.addFieldType(FieldType.newBuilder()
.withName(ID_FIELD_NAME)
.withName(fieldDefinition.getIdFieldName())
.withDataType(VarChar)
.withMaxLength(36)
.withPrimaryKey(true)
.withAutoID(false)
.build())
.addFieldType(FieldType.newBuilder()
.withName(TEXT_FIELD_NAME)
.withName(fieldDefinition.getTextFieldName())
.withDataType(VarChar)
.withMaxLength(65535)
.build())
.addFieldType(FieldType.newBuilder()
.withName(METADATA_FIELD_NAME)
.withName(fieldDefinition.getMetadataFieldName())
.withDataType(JSON)
.build())
.addFieldType(FieldType.newBuilder()
.withName(VECTOR_FIELD_NAME)
.withName(fieldDefinition.getVectorFieldName())
.withDataType(FloatVector)
.withDimension(dimension)
.build())
Expand All @@ -82,12 +82,13 @@ static void dropCollection(MilvusServiceClient milvusClient, String collectionNa

static void createIndex(MilvusServiceClient milvusClient,
String collectionName,
String vectorFieldName,
IndexType indexType,
MetricType metricType) {

CreateIndexParam request = CreateIndexParam.newBuilder()
.withCollectionName(collectionName)
.withFieldName(VECTOR_FIELD_NAME)
.withFieldName(vectorFieldName)
.withIndexType(indexType)
.withMetricType(metricType)
.build();
Expand Down Expand Up @@ -117,9 +118,10 @@ static SearchResultsWrapper search(MilvusServiceClient milvusClient, SearchParam

static QueryResultsWrapper queryForVectors(MilvusServiceClient milvusClient,
String collectionName,
FieldDefinition fieldDefinition,
List<String> rowIds,
ConsistencyLevelEnum consistencyLevel) {
QueryParam request = buildQueryRequest(collectionName, rowIds, consistencyLevel);
QueryParam request = buildQueryRequest(collectionName, fieldDefinition, rowIds, consistencyLevel);
R<QueryResults> response = milvusClient.query(request);
checkResponseNotFailed(response);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ static LoadCollectionParam buildLoadCollectionInMemoryRequest(String collectionN
}

static SearchParam buildSearchRequest(String collectionName,
FieldDefinition fieldDefinition,
List<Float> vector,
Filter filter,
int maxResults,
Expand All @@ -62,27 +63,28 @@ static SearchParam buildSearchRequest(String collectionName,
SearchParam.Builder builder = SearchParam.newBuilder()
.withCollectionName(collectionName)
.withVectors(singletonList(vector))
.withVectorFieldName(VECTOR_FIELD_NAME)
.withVectorFieldName(fieldDefinition.getVectorFieldName())
.withTopK(maxResults)
.withMetricType(metricType)
.withConsistencyLevel(consistencyLevel)
.withOutFields(asList(ID_FIELD_NAME, TEXT_FIELD_NAME, METADATA_FIELD_NAME));
.withOutFields(asList(fieldDefinition.getIdFieldName(), fieldDefinition.getTextFieldName(), fieldDefinition.getMetadataFieldName()));

if (filter != null) {
builder.withExpr(MilvusMetadataFilterMapper.map(filter));
builder.withExpr(MilvusMetadataFilterMapper.map(filter, fieldDefinition.getMetadataFieldName()));
}

return builder.build();
}

static QueryParam buildQueryRequest(String collectionName,
FieldDefinition fieldDefinition,
List<String> rowIds,
ConsistencyLevelEnum consistencyLevel) {
return QueryParam.newBuilder()
.withCollectionName(collectionName)
.withExpr(buildQueryExpression(rowIds))
.withExpr(buildQueryExpression(rowIds, fieldDefinition.getIdFieldName()))
.withConsistencyLevel(consistencyLevel)
.withOutFields(singletonList(VECTOR_FIELD_NAME))
.withOutFields(singletonList(fieldDefinition.getVectorFieldName()))
.build();
}

Expand All @@ -94,9 +96,9 @@ static DeleteParam buildDeleteRequest(String collectionName,
.build();
}

private static String buildQueryExpression(List<String> rowIds) {
private static String buildQueryExpression(List<String> rowIds, String idFieldName) {
return rowIds.stream()
.map(id -> format("%s == '%s'", ID_FIELD_NAME, id))
.map(id -> format("%s == '%s'", idFieldName, id))
.collect(joining(" || "));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package dev.langchain4j.store.embedding.milvus;


class FieldDefinition {

String idFieldName;

String textFieldName;

String metadataFieldName;

String vectorFieldName;

public FieldDefinition(String idFieldName, String textFieldName, String metadataFieldName, String vectorFieldName) {
this.idFieldName = idFieldName;
this.textFieldName = textFieldName;
this.metadataFieldName = metadataFieldName;
this.vectorFieldName = vectorFieldName;
}

public String getIdFieldName() {
return idFieldName;
}

public String getTextFieldName() {
return textFieldName;
}

public String getMetadataFieldName() {
return metadataFieldName;
}

public String getVectorFieldName() {
return vectorFieldName;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import static dev.langchain4j.store.embedding.milvus.CollectionOperationsExecutor.queryForVectors;
import static dev.langchain4j.store.embedding.milvus.Generator.generateEmptyJsons;
import static dev.langchain4j.store.embedding.milvus.Generator.generateEmptyScalars;
import static dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore.*;
import static java.util.stream.Collectors.toList;

class Mapper {
Expand Down Expand Up @@ -61,15 +60,16 @@ static List<String> textSegmentsToScalars(List<TextSegment> textSegments) {
static List<EmbeddingMatch<TextSegment>> toEmbeddingMatches(MilvusServiceClient milvusClient,
SearchResultsWrapper resultsWrapper,
String collectionName,
FieldDefinition fieldDefinition,
ConsistencyLevelEnum consistencyLevel,
boolean queryForVectorOnSearch) {
List<EmbeddingMatch<TextSegment>> matches = new ArrayList<>();

Map<String, Embedding> idToEmbedding = new HashMap<>();
if (queryForVectorOnSearch) {
try {
List<String> rowIds = (List<String>) resultsWrapper.getFieldWrapper(ID_FIELD_NAME).getFieldData();
idToEmbedding.putAll(queryEmbeddings(milvusClient, collectionName, rowIds, consistencyLevel));
List<String> rowIds = (List<String>) resultsWrapper.getFieldWrapper(fieldDefinition.getIdFieldName()).getFieldData();
idToEmbedding.putAll(queryEmbeddings(milvusClient, collectionName, fieldDefinition, rowIds, consistencyLevel));
} catch (ParamException e) {
// There is no way to check if the result is empty or not.
// If the result is empty, the exception will be thrown.
Expand All @@ -80,7 +80,7 @@ static List<EmbeddingMatch<TextSegment>> toEmbeddingMatches(MilvusServiceClient
double score = resultsWrapper.getIDScore(0).get(i).getScore();
String rowId = resultsWrapper.getIDScore(0).get(i).getStrID();
Embedding embedding = idToEmbedding.get(rowId);
TextSegment textSegment = toTextSegment(resultsWrapper.getRowRecords().get(i));
TextSegment textSegment = toTextSegment(resultsWrapper.getRowRecords().get(i), fieldDefinition);
EmbeddingMatch<TextSegment> embeddingMatch = new EmbeddingMatch<>(
RelevanceScore.fromCosineSimilarity(score),
rowId,
Expand All @@ -93,18 +93,18 @@ static List<EmbeddingMatch<TextSegment>> toEmbeddingMatches(MilvusServiceClient
return matches;
}

private static TextSegment toTextSegment(RowRecord rowRecord) {
private static TextSegment toTextSegment(RowRecord rowRecord, FieldDefinition fieldDefinition) {

String text = (String) rowRecord.get(TEXT_FIELD_NAME);
String text = (String) rowRecord.get(fieldDefinition.getTextFieldName());
if (isNullOrBlank(text)) {
return null;
}

if (!rowRecord.getFieldValues().containsKey(METADATA_FIELD_NAME)) {
if (!rowRecord.getFieldValues().containsKey(fieldDefinition.getMetadataFieldName())) {
return TextSegment.from(text);
}

JsonObject metadata = (JsonObject) rowRecord.get(METADATA_FIELD_NAME);
JsonObject metadata = (JsonObject) rowRecord.get(fieldDefinition.getMetadataFieldName());
return TextSegment.from(text, toMetadata(metadata));
}

Expand All @@ -121,19 +121,21 @@ private static Metadata toMetadata(JsonObject metadata) {

private static Map<String, Embedding> queryEmbeddings(MilvusServiceClient milvusClient,
String collectionName,
FieldDefinition fieldDefinition,
List<String> rowIds,
ConsistencyLevelEnum consistencyLevel) {
QueryResultsWrapper queryResultsWrapper = queryForVectors(
milvusClient,
collectionName,
fieldDefinition,
rowIds,
consistencyLevel
);

Map<String, Embedding> idToEmbedding = new HashMap<>();
for (RowRecord row : queryResultsWrapper.getRowRecords()) {
String id = row.get(ID_FIELD_NAME).toString();
List<Float> vector = (List<Float>) row.get(VECTOR_FIELD_NAME);
String id = row.get(fieldDefinition.getIdFieldName()).toString();
List<Float> vector = (List<Float>) row.get(fieldDefinition.getVectorFieldName());
idToEmbedding.put(id, Embedding.from(vector));
}

Expand Down
Loading

0 comments on commit 2bd251a

Please sign in to comment.