Skip to content

Commit

Permalink
Address comments and add one more UT to cover uncovered line
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <zaniu@amazon.com>
  • Loading branch information
zane-neo committed Apr 12, 2024
1 parent 0fcad86 commit f18868f
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 24 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Allowing execution of hybrid query on index alias with filters ([#670](https://github.com/opensearch-project/neural-search/pull/670))
### Bug Fixes
- Add support for request_cache flag in hybrid query ([#663](https://github.com/opensearch-project/neural-search/pull/663))
- Fix may type validation issue in multiple pipeline processors ([#661](https://github.com/opensearch-project/neural-search/pull/661))
- Fix map type validation issue in multiple pipeline processors ([#661](https://github.com/opensearch-project/neural-search/pull/661))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ private void validateEmbeddingFieldsValue(final IngestDocument ingestDocument) {
fieldMap,
1,
ProcessorDocumentUtils.getMaxDepth(sourceAndMetadataMap, clusterService, environment),
true
false
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,29 @@
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.env.Environment;
import org.opensearch.index.mapper.IndexFieldMapper;
import org.opensearch.index.mapper.MapperService;

import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;

/**
* This class is used to accommodate the common code pieces of parsing, validating and processing the document for multiple
* pipeline processors.
*/
public class ProcessorDocumentUtils {

/**
* This method is used to get the max depth of the index or from system settings.
*
* @param sourceAndMetadataMap _source and metadata info in document.
* @param clusterService cluster service passed from OpenSearch core.
* @param environment environment passed from OpenSearch core.
* @return max depth of the index or from system settings.
*/
public static long getMaxDepth(Map<String, Object> sourceAndMetadataMap, ClusterService clusterService, Environment environment) {
String indexName = sourceAndMetadataMap.get(IndexFieldMapper.NAME).toString();
IndexMetadata indexMetadata = clusterService.state().metadata().index(indexName);
Expand All @@ -29,12 +41,23 @@ public static long getMaxDepth(Map<String, Object> sourceAndMetadataMap, Cluster
return MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings());
}

/**
* Validates a map type value recursively up to a specified depth. Supports Map type, List type and String type.
* If current sourceValue is Map or List type, recursively validates its values, otherwise validates its value.
*
* @param sourceKey the key of the source map being validated, the first level is always the "field_map" key.
* @param sourceValue the source map being validated, the first level is always the sourceAndMetadataMap.
* @param fieldMap the configuration map for validation, the first level is always the value of "field_map" in the processor configuration.
* @param depth the current depth of recursion
* @param maxDepth the maximum allowed depth for recursion
* @param allowEmpty flag to allow empty values in map type validation.
*/
@SuppressWarnings({ "rawtypes", "unchecked" })
public static void validateMapTypeValue(
final String sourceKey,
final Map<String, Object> sourceValue,
final Object fieldMap,
final int depth,
final long depth,
final long maxDepth,
final boolean allowEmpty
) {
Expand Down Expand Up @@ -73,18 +96,19 @@ private static void validateListTypeValue(
String sourceKey,
List sourceValue,
Object fieldMap,
int depth,
long depth,
long maxDepth,
boolean allowEmpty
) {
validateDepth(sourceKey, depth, maxDepth);
if (sourceValue == null || sourceValue.isEmpty()) return;
Object firstNonNullElement = sourceValue.stream().filter(Objects::nonNull).findFirst().orElse(null);
if (firstNonNullElement == null) return;
if (CollectionUtils.isEmpty(sourceValue)) return;
for (Object element : sourceValue) {
if (firstNonNullElement instanceof List) { // nested list case.
validateListTypeValue(sourceKey, (List) element, fieldMap, depth + 1, maxDepth, allowEmpty);
} else if (firstNonNullElement instanceof Map) {
if (element == null) {
throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, cannot process it");
}
if (element instanceof List) { // nested list case.
throw new IllegalArgumentException("list type field [" + sourceKey + "] is nested list type, cannot process it");
} else if (element instanceof Map) {
validateMapTypeValue(
sourceKey,
(Map<String, Object>) element,
Expand All @@ -93,23 +117,17 @@ private static void validateListTypeValue(
maxDepth,
allowEmpty
);
} else if (!(firstNonNullElement instanceof String)) {
} else if (!(element instanceof String)) {
throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, cannot process it");
} else {
if (element == null) {
throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, cannot process it");
} else if (!(element instanceof String)) {
throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, cannot process it");
} else if (!allowEmpty && StringUtils.isBlank(element.toString())) {
throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, cannot process it");
}
} else if (!allowEmpty && StringUtils.isBlank(element.toString())) {
throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, cannot process it");
}
}
}

private static void validateDepth(String sourceKey, int depth, long maxDepth) {
private static void validateDepth(String sourceKey, long depth, long maxDepth) {
if (depth > maxDepth) {
throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, cannot process it");
throw new IllegalArgumentException("map type field [" + sourceKey + "] reaches max depth limit, cannot process it");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ public void testExecute_withFixedTokenLength_andMaxDepthLimitExceedFieldMap_then
IllegalArgumentException.class,
() -> processor.execute(ingestDocument)
);
assertEquals("map type field [body] reached max depth limit, cannot process it", illegalArgumentException.getMessage());
assertEquals("map type field [body] reaches max depth limit, cannot process it", illegalArgumentException.getMessage());
}

@SneakyThrows
Expand Down Expand Up @@ -657,7 +657,10 @@ public void testExecute_withFixedTokenLength_andSourceDataListWithHybridType_the
IllegalArgumentException.class,
() -> processor.execute(ingestDocument)
);
assertEquals("list type field [body] has non string value, cannot process it", illegalArgumentException.getMessage());
assertEquals(
"[body] configuration doesn't match actual value type, configuration type is: java.lang.String, actual value type is: com.google.common.collect.RegularImmutableMap",
illegalArgumentException.getMessage()
);
}

@SneakyThrows
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.function.Supplier;

import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
Expand Down Expand Up @@ -488,6 +489,28 @@ public void test_updateDocument_appendVectorFieldsToDocument_successful() {
assertEquals(2, ((List<?>) ingestDocument.getSourceAndMetadata().get("oriKey6_knn")).size());
}

public void test_doublyNestedList_withMapType_successful() {
Map<String, Object> config = createNestedListConfiguration();

Map<String, Object> toEmbeddings = new HashMap<>();
toEmbeddings.put("textField", "text to embedding");
List<Map<String, Object>> l1List = new ArrayList<>();
l1List.add(toEmbeddings);
List<List<Map<String, Object>>> l2List = new ArrayList<>();
l2List.add(l1List);
Map<String, Object> document = new HashMap<>();
document.put("nestedField", l2List);
document.put(IndexFieldMapper.NAME, "my_index");

IngestDocument ingestDocument = new IngestDocument(document, new HashMap<>());
TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config);
BiConsumer handler = mock(BiConsumer.class);
processor.execute(ingestDocument, handler);
ArgumentCaptor<IllegalArgumentException> argumentCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class);
verify(handler).accept(isNull(), argumentCaptor.capture());
assertEquals("list type field [nestedField] is nested list type, cannot process it", argumentCaptor.getValue().getMessage());
}

private List<List<Float>> createMockVectorResult() {
List<List<Float>> modelTensorList = new ArrayList<>();
List<Float> number1 = ImmutableList.of(1.234f, 2.354f);
Expand Down

0 comments on commit f18868f

Please sign in to comment.