diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index fb718e9be..46acf7ac2 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -8,14 +8,12 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.concurrent.ExecutionException; import java.util.stream.Collectors; import lombok.NonNull; import lombok.RequiredArgsConstructor; import lombok.extern.log4j.Log4j2; -import org.opensearch.action.ActionFuture; import org.opensearch.action.ActionListener; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; @@ -110,29 +108,6 @@ public void inferenceSentences( }, listener::onFailure)); } - /** - * Abstraction to call predict function of api of MLClient with provided targetResponseFilters. It uses the - * custom model provided as modelId and run the {@link MLModelTaskType#TEXT_EMBEDDING}. The return will be sent - * using the actionListener which will have a {@link List} of {@link List} of {@link Float} in the order of - * inputText. We are not making this function generic enough to take any function or TaskType as currently we need - * to run only TextEmbedding tasks only. Please note this method is a blocking method, use this only when the processing - * needs block waiting for response, otherwise please use {@link #inferenceSentences(String, List, ActionListener)} - * instead. - * @param modelId {@link String} - * @param inputText {@link List} of {@link String} on which inference needs to happen. - * @return {@link List} of {@link List} of {@link String} represents the text embedding vector result. - * @throws ExecutionException If the underlying task failed, this exception will be thrown in the future.get(). - * @throws InterruptedException If the thread is interrupted, this will be thrown. - */ - public List> inferenceSentences(@NonNull final String modelId, @NonNull final List inputText) - throws ExecutionException, InterruptedException { - final MLInput mlInput = createMLInput(TARGET_RESPONSE_FILTERS, inputText); - final ActionFuture outputActionFuture = mlClient.predict(modelId, mlInput); - final List> vector = buildVectorFromResponse(outputActionFuture.get()); - log.debug("Inference Response for input sentence {} is : {} ", inputText, vector); - return vector; - } - private MLInput createMLInput(final List targetResponseFilters, List inputText) { final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null); final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index 36271d83c..9f711866f 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -10,13 +10,14 @@ import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.concurrent.ExecutionException; +import java.util.function.BiConsumer; import java.util.function.Supplier; import java.util.stream.IntStream; import lombok.extern.log4j.Log4j2; import org.apache.commons.lang3.StringUtils; +import org.opensearch.action.ActionListener; import org.opensearch.env.Environment; import org.opensearch.index.mapper.MapperService; import org.opensearch.ingest.AbstractProcessor; @@ -80,17 +81,21 @@ private void validateEmbeddingConfiguration(Map fieldMap) { @Override public IngestDocument execute(IngestDocument ingestDocument) { - validateEmbeddingFieldsValue(ingestDocument); - Map knnMap = buildMapWithKnnKeyAndOriginalValue(ingestDocument); + return ingestDocument; + } + + public void execute(IngestDocument ingestDocument, BiConsumer handler) { try { - List> vectors = mlCommonsClientAccessor.inferenceSentences(this.modelId, createInferenceList(knnMap)); - appendVectorFieldsToDocument(ingestDocument, knnMap, vectors); - } catch (ExecutionException | InterruptedException e) { - log.error("Text embedding processor failed with exception: ", e); - throw new RuntimeException("Text embedding processor failed with exception", e); + validateEmbeddingFieldsValue(ingestDocument); + Map knnMap = buildMapWithKnnKeyAndOriginalValue(ingestDocument); + mlCommonsClientAccessor.inferenceSentences(this.modelId, createInferenceList(knnMap), ActionListener.wrap(x -> { + appendVectorFieldsToDocument(ingestDocument, knnMap, x); + handler.accept(ingestDocument, null); + }, e -> { handler.accept(null, e); })); + } catch (Exception e) { + handler.accept(null, e); } - log.debug("Text embedding completed, returning ingestDocument!"); - return ingestDocument; + } void appendVectorFieldsToDocument(IngestDocument ingestDocument, Map knnMap, List> vectors) { diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index b4a4f21a0..bfcc4eb86 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -5,25 +5,17 @@ package org.opensearch.neuralsearch.ml; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; -import lombok.SneakyThrows; - import org.junit.Before; import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.MockitoAnnotations; -import org.opensearch.action.ActionFuture; import org.opensearch.action.ActionListener; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.input.MLInput; @@ -35,8 +27,6 @@ import org.opensearch.neuralsearch.constants.TestCommonConstants; import org.opensearch.test.OpenSearchTestCase; -import com.google.common.collect.ImmutableList; - public class MLCommonsClientAccessorTests extends OpenSearchTestCase { @Mock @@ -124,26 +114,6 @@ public void testInferenceSentences_whenExceptionFromMLClient_thenFailure() { Mockito.verifyNoMoreInteractions(resultListener); } - @SneakyThrows - public void test_blockingInferenceSentences() { - ActionFuture actionFuture = mock(ActionFuture.class); - when(client.predict(anyString(), any(MLInput.class))).thenReturn(actionFuture); - List tensorsList = new ArrayList<>(); - - List tensors = new ArrayList<>(); - ModelTensor tensor = mock(ModelTensor.class); - when(tensor.getData()).thenReturn(TestCommonConstants.PREDICT_VECTOR_ARRAY); - tensors.add(tensor); - - ModelTensors modelTensors = new ModelTensors(tensors); - tensorsList.add(modelTensors); - - ModelTensorOutput mlOutput = new ModelTensorOutput(tensorsList); - when(actionFuture.get()).thenReturn(mlOutput); - List> result = accessor.inferenceSentences("modelId", ImmutableList.of("mock")); - assertEquals(TestCommonConstants.PREDICT_VECTOR_ARRAY[0], result.get(0).get(0)); - } - private ModelTensorOutput createModelTensorOutput(final Float[] output) { final List tensorsList = new ArrayList<>(); final List mlModelTensorList = new ArrayList<>(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index 15af12157..243da88e4 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -7,14 +7,14 @@ import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; import java.util.ArrayList; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.function.BiConsumer; import java.util.function.Supplier; import org.junit.Before; @@ -22,6 +22,7 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchParseException; +import org.opensearch.action.ActionListener; import org.opensearch.common.settings.Settings; import org.opensearch.env.Environment; import org.opensearch.ingest.IngestDocument; @@ -59,7 +60,7 @@ private TextEmbeddingProcessor createInstance(List> vector) throws E config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); config.put(TextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1Mapped", "key2", "key2Mapped")); TextEmbeddingProcessor processor = textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); - when(mlCommonsClientAccessor.inferenceSentences(anyString(), anyList())).thenReturn(vector); + doReturn(vector).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); return processor; } @@ -95,8 +96,17 @@ public void testExecute_successful() throws Exception { sourceAndMetadata.put("key2", "value2"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); - IngestDocument document = processor.execute(ingestDocument); - assert document.getSourceAndMetadata().containsKey("key1"); + + List> modelTensorList = createMockVectorResult(); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(modelTensorList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(any(IngestDocument.class), isNull()); } public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeException() throws Exception { @@ -112,12 +122,10 @@ public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeExcep config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); config.put(TextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1Mapped", "key2", "key2Mapped")); TextEmbeddingProcessor processor = textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); - when(accessor.inferenceSentences(anyString(), anyList())).thenThrow(new InterruptedException()); - try { - processor.execute(ingestDocument); - } catch (RuntimeException e) { - assertEquals("Text embedding processor failed with exception", e.getMessage()); - } + doThrow(new RuntimeException()).when(accessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(RuntimeException.class)); } public void testExecute_withListTypeInput_successful() throws Exception { @@ -128,8 +136,17 @@ public void testExecute_withListTypeInput_successful() throws Exception { sourceAndMetadata.put("key2", list2); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(6)); - IngestDocument document = processor.execute(ingestDocument); - assert document.getSourceAndMetadata().containsKey("key1"); + + List> modelTensorList = createMockVectorResult(); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(modelTensorList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(any(IngestDocument.class), isNull()); } public void testExecute_SimpleTypeWithEmptyStringValue_throwIllegalArgumentException() throws Exception { @@ -137,11 +154,10 @@ public void testExecute_SimpleTypeWithEmptyStringValue_throwIllegalArgumentExcep sourceAndMetadata.put("key1", " "); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); - try { - processor.execute(ingestDocument); - } catch (IllegalArgumentException e) { - assertEquals("field [key1] has empty string value, can not process it", e.getMessage()); - } + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } public void testExecute_listHasEmptyStringValue_throwIllegalArgumentException() throws Exception { @@ -150,11 +166,10 @@ public void testExecute_listHasEmptyStringValue_throwIllegalArgumentException() sourceAndMetadata.put("key1", list1); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); - try { - processor.execute(ingestDocument); - } catch (IllegalArgumentException e) { - assertEquals("list type field [key1] has empty string, can not process it", e.getMessage()); - } + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } public void testExecute_listHasNonStringValue_throwIllegalArgumentException() throws Exception { @@ -163,11 +178,9 @@ public void testExecute_listHasNonStringValue_throwIllegalArgumentException() th sourceAndMetadata.put("key2", list2); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); - try { - processor.execute(ingestDocument); - } catch (IllegalArgumentException e) { - assertEquals("list type field [key2] has non string value, can not process it", e.getMessage()); - } + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } public void testExecute_listHasNull_throwIllegalArgumentException() throws Exception { @@ -179,11 +192,9 @@ public void testExecute_listHasNull_throwIllegalArgumentException() throws Excep sourceAndMetadata.put("key2", list); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); - try { - processor.execute(ingestDocument); - } catch (IllegalArgumentException e) { - assertEquals("list type field [key2] has null, can not process it", e.getMessage()); - } + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } public void testExecute_withMapTypeInput_successful() throws Exception { @@ -194,8 +205,18 @@ public void testExecute_withMapTypeInput_successful() throws Exception { sourceAndMetadata.put("key2", map2); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); - IngestDocument document = processor.execute(ingestDocument); - assert document.getSourceAndMetadata().containsKey("key1"); + + List> modelTensorList = createMockVectorResult(); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(modelTensorList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(any(IngestDocument.class), isNull()); + } public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() throws Exception { @@ -206,11 +227,9 @@ public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() thr sourceAndMetadata.put("key2", map2); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); - try { - processor.execute(ingestDocument); - } catch (IllegalArgumentException e) { - assertEquals("map type field [key2] has non-string type, can not process it", e.getMessage()); - } + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() throws Exception { @@ -221,11 +240,9 @@ public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() t sourceAndMetadata.put("key2", map2); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); - try { - processor.execute(ingestDocument); - } catch (IllegalArgumentException e) { - assertEquals("map type field [key2] has empty string, can not process it", e.getMessage()); - } + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() throws Exception { @@ -235,13 +252,27 @@ public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() throw sourceAndMetadata.put("key2", ret); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); - try { - processor.execute(ingestDocument); - } catch (IllegalArgumentException e) { - assertEquals("map type field [key2] reached max depth limit, can not process it", e.getMessage()); - return; - } - fail("Shouldn't be here, expected exception!"); + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); + } + + public void testExecute_MLClientAccessorThrowFail_handlerFailure() throws Exception { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", "value1"); + sourceAndMetadata.put("key2", "value2"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); + + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onFailure(new IllegalArgumentException("illegal argument")); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } private Map createMaxDepthLimitExceedMap(Supplier maxDepthSupplier) {