Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance syntax for nested mapping in destination fields #841

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- InferenceProcessor inherits from AbstractBatchingProcessor to support sub batching in processor [#820](https://github.com/opensearch-project/neural-search/pull/820)
- Adds dynamic knn query parameters efsearch and nprobes [#814](https://github.com/opensearch-project/neural-search/pull/814/)
- Enable '.' for nested field in text embedding processor ([#811](https://github.com/opensearch-project/neural-search/pull/811))
- Enhance syntax for nested mapping in destination fields([#841](https://github.com/opensearch-project/neural-search/pull/841))
### Bug Fixes
- Fix for missing HybridQuery results when concurrent segment search is enabled ([#800](https://github.com/opensearch-project/neural-search/pull/800))
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
Expand All @@ -29,6 +30,7 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.env.Environment;
import org.opensearch.index.mapper.IndexFieldMapper;

import org.opensearch.ingest.AbstractBatchingProcessor;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ingest.IngestDocumentWrapper;
Expand All @@ -50,6 +52,17 @@ public abstract class InferenceProcessor extends AbstractBatchingProcessor {

public static final String MODEL_ID_FIELD = "model_id";
public static final String FIELD_MAP_FIELD = "field_map";
private static final BiFunction<Object, Object, Object> REMAPPING_FUNCTION = (v1, v2) -> {
if (v1 instanceof Collection && v2 instanceof Collection) {
((Collection) v1).addAll((Collection) v2);
return v1;
} else if (v1 instanceof Map && v2 instanceof Map) {
((Map) v1).putAll((Map) v2);
return v1;
} else {
return v2;
}
};

private final String type;

Expand Down Expand Up @@ -325,17 +338,7 @@ void buildNestedMap(String parentKey, Object processorKey, Map<String, Object> s
buildNestedMap(nestedFieldMapEntry.getKey(), nestedFieldMapEntry.getValue(), map, next);
}
}
treeRes.merge(parentKey, next, (v1, v2) -> {
if (v1 instanceof Collection && v2 instanceof Collection) {
((Collection) v1).addAll((Collection) v2);
return v1;
} else if (v1 instanceof Map && v2 instanceof Map) {
((Map) v1).putAll((Map) v2);
return v1;
} else {
return v2;
}
});
treeRes.merge(parentKey, next, REMAPPING_FUNCTION);
} else {
String key = String.valueOf(processorKey);
treeRes.put(key, sourceAndMetadataMap.get(parentKey));
Expand Down Expand Up @@ -389,8 +392,9 @@ Map<String, Object> buildNLPResult(Map<String, Object> processorMap, List<?> res
IndexWrapper indexWrapper = new IndexWrapper(0);
Map<String, Object> result = new LinkedHashMap<>();
for (Map.Entry<String, Object> knnMapEntry : processorMap.entrySet()) {
String knnKey = knnMapEntry.getKey();
Object sourceValue = knnMapEntry.getValue();
Pair<String, Object> processedNestedKey = processNestedKey(knnMapEntry);
String knnKey = processedNestedKey.getKey();
Object sourceValue = processedNestedKey.getValue();
if (sourceValue instanceof String) {
result.put(knnKey, results.get(indexWrapper.index++));
} else if (sourceValue instanceof List) {
Expand Down Expand Up @@ -419,19 +423,31 @@ private void putNLPResultToSourceMapForMapType(
nestedElement.put(inputNestedMapEntry.getKey(), results.get(indexWrapper.index++));
}
} else {
Pair<String, Object> processedNestedKey = processNestedKey(inputNestedMapEntry);
Map<String, Object> sourceMap;
if (sourceAndMetadataMap.get(processorKey) == null) {
sourceMap = new HashMap<>();
sourceAndMetadataMap.put(processorKey, sourceMap);
} else {
sourceMap = (Map<String, Object>) sourceAndMetadataMap.get(processorKey);
}
putNLPResultToSourceMapForMapType(
inputNestedMapEntry.getKey(),
inputNestedMapEntry.getValue(),
processedNestedKey.getKey(),
processedNestedKey.getValue(),
results,
indexWrapper,
(Map<String, Object>) sourceAndMetadataMap.get(processorKey)
sourceMap
);
}
}
} else if (sourceValue instanceof String) {
sourceAndMetadataMap.put(processorKey, results.get(indexWrapper.index++));
sourceAndMetadataMap.merge(processorKey, results.get(indexWrapper.index++), REMAPPING_FUNCTION);
} else if (sourceValue instanceof List) {
sourceAndMetadataMap.put(processorKey, buildNLPResultForListType((List<String>) sourceValue, results, indexWrapper));
sourceAndMetadataMap.merge(
processorKey,
buildNLPResultForListType((List<String>) sourceValue, results, indexWrapper),
REMAPPING_FUNCTION
);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

import org.apache.commons.lang3.StringUtils;
Expand Down Expand Up @@ -41,10 +42,15 @@ public class TextEmbeddingProcessorIT extends BaseNeuralSearchIT {
protected static final String LEVEL_1_FIELD = "nested_passages";
protected static final String LEVEL_2_FIELD = "level_2";
protected static final String LEVEL_3_FIELD_TEXT = "level_3_text";
protected static final String LEVEL_3_FIELD_CONTAINER = "level_3_container";
protected static final String LEVEL_3_FIELD_EMBEDDING = "level_3_embedding";
protected static final String TEXT_FIELD_VALUE_1 = "hello";
protected static final String TEXT_FIELD_VALUE_2 = "clown";
protected static final String TEXT_FIELD_VALUE_3 = "abc";
private final String INGEST_DOC1 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc1.json").toURI()));
private final String INGEST_DOC2 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc2.json").toURI()));
private final String INGEST_DOC3 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc3.json").toURI()));
private final String INGEST_DOC4 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc4.json").toURI()));
private final String BULK_ITEM_TEMPLATE = Files.readString(
Path.of(classLoader.getResource("processor/bulk_item_template.json").toURI())
);
Expand Down Expand Up @@ -99,23 +105,17 @@ public void testNestedFieldMapping_whenDocumentsIngested_thenSuccessful() throws
createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING);
createTextEmbeddingIndex();
ingestDocument(INGEST_DOC3, "3");
ingestDocument(INGEST_DOC4, "4");

Map<String, Object> sourceMap = (Map<String, Object>) getDocById(INDEX_NAME, "3").get("_source");
assertNotNull(sourceMap);
assertTrue(sourceMap.containsKey(LEVEL_1_FIELD));
Map<String, Object> nestedPassages = (Map<String, Object>) sourceMap.get(LEVEL_1_FIELD);
assertTrue(nestedPassages.containsKey(LEVEL_2_FIELD));
Map<String, Object> level2 = (Map<String, Object>) nestedPassages.get(LEVEL_2_FIELD);
assertEquals(QUERY_TEXT, level2.get(LEVEL_3_FIELD_TEXT));
assertTrue(level2.containsKey(LEVEL_3_FIELD_EMBEDDING));
List<Double> embeddings = (List<Double>) level2.get(LEVEL_3_FIELD_EMBEDDING);
assertEquals(768, embeddings.size());
for (Double embedding : embeddings) {
assertTrue(embedding >= 0.0 && embedding <= 1.0);
}
assertDoc(
(Map<String, Object>) getDocById(INDEX_NAME, "3").get("_source"),
TEXT_FIELD_VALUE_1,
Optional.of(TEXT_FIELD_VALUE_3)
);
assertDoc((Map<String, Object>) getDocById(INDEX_NAME, "4").get("_source"), TEXT_FIELD_VALUE_2, Optional.empty());

NeuralQueryBuilder neuralQueryBuilderQuery = new NeuralQueryBuilder(
LEVEL_1_FIELD + "." + LEVEL_2_FIELD + "." + LEVEL_3_FIELD_EMBEDDING,
LEVEL_1_FIELD + "." + LEVEL_2_FIELD + "." + LEVEL_3_FIELD_CONTAINER + "." + LEVEL_3_FIELD_EMBEDDING,
QUERY_TEXT,
"",
modelId,
Expand All @@ -133,7 +133,7 @@ public void testNestedFieldMapping_whenDocumentsIngested_thenSuccessful() throws
);
QueryBuilder queryNestedHighLevel = QueryBuilders.nestedQuery(LEVEL_1_FIELD, queryNestedLowerLevel, ScoreMode.Total);

Map<String, Object> searchResponseAsMap = search(INDEX_NAME, queryNestedHighLevel, 1);
Map<String, Object> searchResponseAsMap = search(INDEX_NAME, queryNestedHighLevel, 2);
assertNotNull(searchResponseAsMap);

Map<String, Object> hits = (Map<String, Object>) searchResponseAsMap.get("hits");
Expand All @@ -142,15 +142,38 @@ public void testNestedFieldMapping_whenDocumentsIngested_thenSuccessful() throws
assertEquals(1.0, hits.get("max_score"));
List<Map<String, Object>> listOfHits = (List<Map<String, Object>>) hits.get("hits");
assertNotNull(listOfHits);
assertEquals(1, listOfHits.size());
Map<String, Object> hitsInner = listOfHits.get(0);
assertEquals("3", hitsInner.get("_id"));
assertEquals(1.0, hitsInner.get("_score"));
assertEquals(2, listOfHits.size());

Map<String, Object> innerHitDetails = listOfHits.get(0);
assertEquals("3", innerHitDetails.get("_id"));
assertEquals(1.0, innerHitDetails.get("_score"));

innerHitDetails = listOfHits.get(1);
assertEquals("4", innerHitDetails.get("_id"));
assertTrue((double) innerHitDetails.get("_score") <= 1.0);
} finally {
wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null);
}
}

private void assertDoc(Map<String, Object> sourceMap, String textFieldValue, Optional<String> level3ExpectedValue) {
assertNotNull(sourceMap);
assertTrue(sourceMap.containsKey(LEVEL_1_FIELD));
Map<String, Object> nestedPassages = (Map<String, Object>) sourceMap.get(LEVEL_1_FIELD);
assertTrue(nestedPassages.containsKey(LEVEL_2_FIELD));
Map<String, Object> level2 = (Map<String, Object>) nestedPassages.get(LEVEL_2_FIELD);
assertEquals(textFieldValue, level2.get(LEVEL_3_FIELD_TEXT));
Map<String, Object> level3 = (Map<String, Object>) level2.get(LEVEL_3_FIELD_CONTAINER);
List<Double> embeddings = (List<Double>) level3.get(LEVEL_3_FIELD_EMBEDDING);
assertEquals(768, embeddings.size());
for (Double embedding : embeddings) {
assertTrue(embedding >= 0.0 && embedding <= 1.0);
}
if (level3ExpectedValue.isPresent()) {
assertEquals(level3ExpectedValue.get(), level3.get("level_4_text_field"));
}
}

public void testTextEmbeddingProcessor_withBatchSizeInProcessor() throws Exception {
String modelId = null;
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ public class TextEmbeddingProcessorTests extends InferenceProcessorTestCase {
protected static final String CHILD_FIELD_LEVEL_2 = "child_level2";
protected static final String CHILD_LEVEL_2_TEXT_FIELD_VALUE = "text_field_value";
protected static final String CHILD_LEVEL_2_KNN_FIELD = "test3_knn";
protected static final String CHILD_1_TEXT_FIELD = "child_1_text_field";
protected static final String TEXT_VALUE_1 = "text_value";
protected static final String TEXT_FIELD_2 = "abc";
@Mock
private MLCommonsClientAccessor mlCommonsClientAccessor;

Expand Down Expand Up @@ -363,6 +366,126 @@ public void testNestedFieldInMapping_withMapTypeInput_successful() {
}
}

@SneakyThrows
public void testNestedFieldInMappingForSourceAndDestination_withIngestDocumentHasTheDestinationStructure_theSuccessful() {
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index");
/*
modeling following document:
parent:
child_level_1:
child_level_1_text_field: "text"
child_level_2:
child_level_2_text_field: "abc"
*/
Map<String, String> childLevel2NestedField = new HashMap<>();
childLevel2NestedField.put(CHILD_LEVEL_2_TEXT_FIELD_VALUE, TEXT_FIELD_2);
Map<String, Object> childLevel2 = new HashMap<>();
childLevel2.put(CHILD_FIELD_LEVEL_2, childLevel2NestedField);
childLevel2.put(CHILD_1_TEXT_FIELD, TEXT_VALUE_1);
Map<String, Object> childLevel1 = new HashMap<>();
childLevel1.put(CHILD_FIELD_LEVEL_1, childLevel2);
sourceAndMetadata.put(PARENT_FIELD, childLevel1);
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());

Map<String, Processor.Factory> registry = new HashMap<>();
Map<String, Object> config = new HashMap<>();
config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId");
config.put(
TextEmbeddingProcessor.FIELD_MAP_FIELD,
ImmutableMap.of(
String.join(".", Arrays.asList(PARENT_FIELD, CHILD_FIELD_LEVEL_1, CHILD_1_TEXT_FIELD)),
CHILD_FIELD_LEVEL_2 + "." + CHILD_LEVEL_2_KNN_FIELD
)
);
TextEmbeddingProcessor processor = (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create(
registry,
PROCESSOR_TAG,
DESCRIPTION,
config
);

List<List<Float>> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f);
doAnswer(invocation -> {
ActionListener<List<List<Float>>> listener = invocation.getArgument(2);
listener.onResponse(modelTensorList);
return null;
}).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class));

processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {});
assertNotNull(ingestDocument);
assertNotNull(ingestDocument.getSourceAndMetadata().get(PARENT_FIELD));
Map<String, Object> parent1AfterProcessor = (Map<String, Object>) ingestDocument.getSourceAndMetadata().get(PARENT_FIELD);
Map<String, Object> childLevel1Actual = (Map<String, Object>) parent1AfterProcessor.get(CHILD_FIELD_LEVEL_1);
assertEquals(2, childLevel1Actual.size());
assertEquals(TEXT_VALUE_1, childLevel1Actual.get(CHILD_1_TEXT_FIELD));
Map<String, Object> child2Actual = (Map<String, Object>) childLevel1Actual.get(CHILD_FIELD_LEVEL_2);
assertEquals(2, child2Actual.size());
assertEquals(TEXT_FIELD_2, child2Actual.get(CHILD_LEVEL_2_TEXT_FIELD_VALUE));
List<Float> vectors = (List<Float>) child2Actual.get(CHILD_LEVEL_2_KNN_FIELD);
assertEquals(100, vectors.size());
for (Float vector : vectors) {
assertTrue(vector >= 0.0f && vector <= 1.0f);
}
}

@SneakyThrows
public void testNestedFieldInMappingForSourceAndDestination_withIngestDocumentWithoutDestinationStructure_theSuccessful() {
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index");
/*
modeling following document:
parent:
child_level_1:
child_level_1_text_field: "text"
*/
Map<String, Object> childLevel2 = new HashMap<>();
childLevel2.put(CHILD_1_TEXT_FIELD, TEXT_VALUE_1);
Map<String, Object> childLevel1 = new HashMap<>();
childLevel1.put(CHILD_FIELD_LEVEL_1, childLevel2);
sourceAndMetadata.put(PARENT_FIELD, childLevel1);
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());

Map<String, Processor.Factory> registry = new HashMap<>();
Map<String, Object> config = new HashMap<>();
config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId");
config.put(
TextEmbeddingProcessor.FIELD_MAP_FIELD,
ImmutableMap.of(
String.join(".", Arrays.asList(PARENT_FIELD, CHILD_FIELD_LEVEL_1, CHILD_1_TEXT_FIELD)),
CHILD_FIELD_LEVEL_2 + "." + CHILD_LEVEL_2_KNN_FIELD
)
);
TextEmbeddingProcessor processor = (TextEmbeddingProcessor) textEmbeddingProcessorFactory.create(
registry,
PROCESSOR_TAG,
DESCRIPTION,
config
);

List<List<Float>> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f);
doAnswer(invocation -> {
ActionListener<List<List<Float>>> listener = invocation.getArgument(2);
listener.onResponse(modelTensorList);
return null;
}).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class));

processor.execute(ingestDocument, (BiConsumer) (doc, ex) -> {});
assertNotNull(ingestDocument);
assertNotNull(ingestDocument.getSourceAndMetadata().get(PARENT_FIELD));
Map<String, Object> parent1AfterProcessor = (Map<String, Object>) ingestDocument.getSourceAndMetadata().get(PARENT_FIELD);
Map<String, Object> childLevel1Actual = (Map<String, Object>) parent1AfterProcessor.get(CHILD_FIELD_LEVEL_1);
assertEquals(2, childLevel1Actual.size());
assertEquals(TEXT_VALUE_1, childLevel1Actual.get(CHILD_1_TEXT_FIELD));
Map<String, Object> child2Actual = (Map<String, Object>) childLevel1Actual.get(CHILD_FIELD_LEVEL_2);
assertEquals(1, child2Actual.size());
List<Float> vectors = (List<Float>) child2Actual.get(CHILD_LEVEL_2_KNN_FIELD);
assertEquals(100, vectors.size());
for (Float vector : vectors) {
assertTrue(vector >= 0.0f && vector <= 1.0f);
}
}

@SneakyThrows
public void testNestedFieldInMappingMixedSyntax_withMapTypeInput_successful() {
Map<String, Object> sourceAndMetadata = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.query.KNNQueryBuilder;

Expand Down Expand Up @@ -119,6 +120,7 @@ public void testRewrite_whenRewriteQuery_thenSuccessful() {
when(mockKNNVectorField.getDimension()).thenReturn(4);
when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField);
when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.L2);
when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT);
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(VECTOR_FIELD_NAME, VECTOR_QUERY, K);
Query knnQuery = knnQueryBuilder.toQuery(mockQueryShardContext);

Expand Down
Loading
Loading