diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index 736fecafc..482a8730c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -105,7 +105,7 @@ public void execute(IngestDocument ingestDocument, BiConsumer { - appendVectorFieldsToDocument(ingestDocument, knnMap, vectors); + setVectorFieldsToDocument(ingestDocument, knnMap, vectors); handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); })); } @@ -115,11 +115,11 @@ public void execute(IngestDocument ingestDocument, BiConsumer knnMap, List> vectors) { + void setVectorFieldsToDocument(IngestDocument ingestDocument, Map knnMap, List> vectors) { Objects.requireNonNull(vectors, "embedding failed, inference returns null result!"); log.debug("Text embedding result fetched, starting build vector output!"); Map textEmbeddingResult = buildTextEmbeddingResult(knnMap, vectors, ingestDocument.getSourceAndMetadata()); - textEmbeddingResult.forEach(ingestDocument::appendFieldValue); + textEmbeddingResult.forEach(ingestDocument::setFieldValue); } @SuppressWarnings({ "unchecked" }) diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index 0b5384b77..2d376ad7c 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -17,6 +17,7 @@ import java.util.Locale; import java.util.Map; import java.util.Optional; +import java.util.UUID; import java.util.function.Predicate; import java.util.stream.Collectors; @@ -501,6 +502,7 @@ private String registerModelGroup() throws IOException, URISyntaxException { String modelGroupRegisterRequestBody = Files.readString( Path.of(classLoader.getResource("processor/CreateModelGroupRequestBody.json").toURI()) ); + modelGroupRegisterRequestBody = modelGroupRegisterRequestBody.replace("", UUID.randomUUID().toString()); Response modelGroupResponse = makeRequest( client(), "POST", diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index b7dfec083..2eb552a3b 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -32,6 +32,8 @@ import org.opensearch.test.OpenSearchTestCase; import org.opensearch.transport.NodeNotConnectedException; +import com.google.common.collect.ImmutableMap; + public class MLCommonsClientAccessorTests extends OpenSearchTestCase { @Mock @@ -168,7 +170,9 @@ private ModelTensorOutput createModelTensorOutput(final Float[] output) { output, new long[] { 1, 2 }, MLResultDataType.FLOAT64, - ByteBuffer.wrap(new byte[12]) + ByteBuffer.wrap(new byte[12]), + "mockResult", + ImmutableMap.of("mockKey", "mockValue") ); mlModelTensorList.add(tensor); final ModelTensors modelTensors = new ModelTensors(mlModelTensorList); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index d4a92f103..e9349aefe 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -7,7 +7,14 @@ import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.isA; +import static org.mockito.Mockito.isNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import java.util.ArrayList; import java.util.HashMap; @@ -350,7 +357,7 @@ public void testProcessResponse_successful() throws Exception { Map knnMap = processor.buildMapWithKnnKeyAndOriginalValue(ingestDocument); List> modelTensorList = createMockVectorResult(); - processor.appendVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList); + processor.setVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList); assertEquals(12, ingestDocument.getSourceAndMetadata().size()); } @@ -398,6 +405,20 @@ public void testBuildVectorOutput_withNestedMap_successful() { assertNotNull(actionGamesKnn); } + public void test_updateDocument_appendVectorFieldsToDocument_successful() { + Map config = createPlainStringConfiguration(); + IngestDocument ingestDocument = createPlainIngestDocument(); + TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); + Map knnMap = processor.buildMapWithKnnKeyAndOriginalValue(ingestDocument); + List> modelTensorList = createMockVectorResult(); + processor.setVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList); + + List> modelTensorList1 = createMockVectorResult(); + processor.setVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList1); + assertEquals(12, ingestDocument.getSourceAndMetadata().size()); + assertEquals(2, ((List) ingestDocument.getSourceAndMetadata().get("oriKey6_knn")).size()); + } + private List> createMockVectorResult() { List> modelTensorList = new ArrayList<>(); List number1 = ImmutableList.of(1.234f, 2.354f); diff --git a/src/test/resources/processor/CreateModelGroupRequestBody.json b/src/test/resources/processor/CreateModelGroupRequestBody.json index 2fddae02e..51511222c 100644 --- a/src/test/resources/processor/CreateModelGroupRequestBody.json +++ b/src/test/resources/processor/CreateModelGroupRequestBody.json @@ -1,4 +1,4 @@ { - "name": "test_model_group_public", + "name": "", "description": "This is a public model group" }