Skip to content

Commit

Permalink
Change text embedding processor to async mode
Browse files Browse the repository at this point in the history
Signed-off-by: Zan Niu <zaniu@amazon.com>
  • Loading branch information
zane-neo committed Oct 24, 2022
1 parent d168e95 commit fbef895
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 117 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,12 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;

import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.opensearch.action.ActionFuture;
import org.opensearch.action.ActionListener;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
Expand Down Expand Up @@ -110,29 +108,6 @@ public void inferenceSentences(
}, listener::onFailure));
}

/**
* Abstraction to call predict function of api of MLClient with provided targetResponseFilters. It uses the
* custom model provided as modelId and run the {@link MLModelTaskType#TEXT_EMBEDDING}. The return will be sent
* using the actionListener which will have a {@link List} of {@link List} of {@link Float} in the order of
* inputText. We are not making this function generic enough to take any function or TaskType as currently we need
* to run only TextEmbedding tasks only. Please note this method is a blocking method, use this only when the processing
* needs block waiting for response, otherwise please use {@link #inferenceSentences(String, List, ActionListener)}
* instead.
* @param modelId {@link String}
* @param inputText {@link List} of {@link String} on which inference needs to happen.
* @return {@link List} of {@link List} of {@link String} represents the text embedding vector result.
* @throws ExecutionException If the underlying task failed, this exception will be thrown in the future.get().
* @throws InterruptedException If the thread is interrupted, this will be thrown.
*/
public List<List<Float>> inferenceSentences(@NonNull final String modelId, @NonNull final List<String> inputText)
throws ExecutionException, InterruptedException {
final MLInput mlInput = createMLInput(TARGET_RESPONSE_FILTERS, inputText);
final ActionFuture<MLOutput> outputActionFuture = mlClient.predict(modelId, mlInput);
final List<List<Float>> vector = buildVectorFromResponse(outputActionFuture.get());
log.debug("Inference Response for input sentence {} is : {} ", inputText, vector);
return vector;
}

private MLInput createMLInput(final List<String> targetResponseFilters, List<String> inputText) {
final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null);
final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ExecutionException;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import java.util.stream.IntStream;

import lombok.extern.log4j.Log4j2;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.action.ActionListener;
import org.opensearch.env.Environment;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.ingest.AbstractProcessor;
Expand Down Expand Up @@ -80,17 +81,21 @@ private void validateEmbeddingConfiguration(Map<String, Object> fieldMap) {

@Override
public IngestDocument execute(IngestDocument ingestDocument) {
validateEmbeddingFieldsValue(ingestDocument);
Map<String, Object> knnMap = buildMapWithKnnKeyAndOriginalValue(ingestDocument);
return ingestDocument;
}

public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) {
try {
List<List<Float>> vectors = mlCommonsClientAccessor.inferenceSentences(this.modelId, createInferenceList(knnMap));
appendVectorFieldsToDocument(ingestDocument, knnMap, vectors);
} catch (ExecutionException | InterruptedException e) {
log.error("Text embedding processor failed with exception: ", e);
throw new RuntimeException("Text embedding processor failed with exception", e);
validateEmbeddingFieldsValue(ingestDocument);
Map<String, Object> knnMap = buildMapWithKnnKeyAndOriginalValue(ingestDocument);
mlCommonsClientAccessor.inferenceSentences(this.modelId, createInferenceList(knnMap), ActionListener.wrap(x -> {
appendVectorFieldsToDocument(ingestDocument, knnMap, x);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); }));
} catch (Exception e) {
handler.accept(null, e);
}
log.debug("Text embedding completed, returning ingestDocument!");
return ingestDocument;

}

void appendVectorFieldsToDocument(IngestDocument ingestDocument, Map<String, Object> knnMap, List<List<Float>> vectors) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,17 @@

package org.opensearch.neuralsearch.ml;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import lombok.SneakyThrows;

import org.junit.Before;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.ActionFuture;
import org.opensearch.action.ActionListener;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.input.MLInput;
Expand All @@ -35,8 +27,6 @@
import org.opensearch.neuralsearch.constants.TestCommonConstants;
import org.opensearch.test.OpenSearchTestCase;

import com.google.common.collect.ImmutableList;

public class MLCommonsClientAccessorTests extends OpenSearchTestCase {

@Mock
Expand Down Expand Up @@ -124,26 +114,6 @@ public void testInferenceSentences_whenExceptionFromMLClient_thenFailure() {
Mockito.verifyNoMoreInteractions(resultListener);
}

@SneakyThrows
public void test_blockingInferenceSentences() {
ActionFuture actionFuture = mock(ActionFuture.class);
when(client.predict(anyString(), any(MLInput.class))).thenReturn(actionFuture);
List<ModelTensors> tensorsList = new ArrayList<>();

List<ModelTensor> tensors = new ArrayList<>();
ModelTensor tensor = mock(ModelTensor.class);
when(tensor.getData()).thenReturn(TestCommonConstants.PREDICT_VECTOR_ARRAY);
tensors.add(tensor);

ModelTensors modelTensors = new ModelTensors(tensors);
tensorsList.add(modelTensors);

ModelTensorOutput mlOutput = new ModelTensorOutput(tensorsList);
when(actionFuture.get()).thenReturn(mlOutput);
List<List<Float>> result = accessor.inferenceSentences("modelId", ImmutableList.of("mock"));
assertEquals(TestCommonConstants.PREDICT_VECTOR_ARRAY[0], result.get(0).get(0));
}

private ModelTensorOutput createModelTensorOutput(final Float[] output) {
final List<ModelTensors> tensorsList = new ArrayList<>();
final List<ModelTensor> mlModelTensorList = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,22 @@

import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.*;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.Supplier;

import org.junit.Before;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.OpenSearchParseException;
import org.opensearch.action.ActionListener;
import org.opensearch.common.settings.Settings;
import org.opensearch.env.Environment;
import org.opensearch.ingest.IngestDocument;
Expand Down Expand Up @@ -59,7 +60,7 @@ private TextEmbeddingProcessor createInstance(List<List<Float>> vector) throws E
config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId");
config.put(TextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1Mapped", "key2", "key2Mapped"));
TextEmbeddingProcessor processor = textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config);
when(mlCommonsClientAccessor.inferenceSentences(anyString(), anyList())).thenReturn(vector);
doReturn(vector).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class));
return processor;
}

Expand Down Expand Up @@ -95,8 +96,17 @@ public void testExecute_successful() throws Exception {
sourceAndMetadata.put("key2", "value2");
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2));
IngestDocument document = processor.execute(ingestDocument);
assert document.getSourceAndMetadata().containsKey("key1");

List<List<Float>> modelTensorList = createMockVectorResult();
doAnswer(invocation -> {
ActionListener<List<List<Float>>> listener = invocation.getArgument(2);
listener.onResponse(modelTensorList);
return null;
}).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class));

BiConsumer handler = mock(BiConsumer.class);
processor.execute(ingestDocument, handler);
verify(handler).accept(any(IngestDocument.class), isNull());
}

public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeException() throws Exception {
Expand All @@ -112,12 +122,10 @@ public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeExcep
config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId");
config.put(TextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1Mapped", "key2", "key2Mapped"));
TextEmbeddingProcessor processor = textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config);
when(accessor.inferenceSentences(anyString(), anyList())).thenThrow(new InterruptedException());
try {
processor.execute(ingestDocument);
} catch (RuntimeException e) {
assertEquals("Text embedding processor failed with exception", e.getMessage());
}
doThrow(new RuntimeException()).when(accessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class));
BiConsumer handler = mock(BiConsumer.class);
processor.execute(ingestDocument, handler);
verify(handler).accept(isNull(), any(RuntimeException.class));
}

public void testExecute_withListTypeInput_successful() throws Exception {
Expand All @@ -128,20 +136,28 @@ public void testExecute_withListTypeInput_successful() throws Exception {
sourceAndMetadata.put("key2", list2);
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(6));
IngestDocument document = processor.execute(ingestDocument);
assert document.getSourceAndMetadata().containsKey("key1");

List<List<Float>> modelTensorList = createMockVectorResult();
doAnswer(invocation -> {
ActionListener<List<List<Float>>> listener = invocation.getArgument(2);
listener.onResponse(modelTensorList);
return null;
}).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class));

BiConsumer handler = mock(BiConsumer.class);
processor.execute(ingestDocument, handler);
verify(handler).accept(any(IngestDocument.class), isNull());
}

public void testExecute_SimpleTypeWithEmptyStringValue_throwIllegalArgumentException() throws Exception {
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put("key1", " ");
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2));
try {
processor.execute(ingestDocument);
} catch (IllegalArgumentException e) {
assertEquals("field [key1] has empty string value, can not process it", e.getMessage());
}

BiConsumer handler = mock(BiConsumer.class);
processor.execute(ingestDocument, handler);
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_listHasEmptyStringValue_throwIllegalArgumentException() throws Exception {
Expand All @@ -150,11 +166,10 @@ public void testExecute_listHasEmptyStringValue_throwIllegalArgumentException()
sourceAndMetadata.put("key1", list1);
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2));
try {
processor.execute(ingestDocument);
} catch (IllegalArgumentException e) {
assertEquals("list type field [key1] has empty string, can not process it", e.getMessage());
}

BiConsumer handler = mock(BiConsumer.class);
processor.execute(ingestDocument, handler);
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_listHasNonStringValue_throwIllegalArgumentException() throws Exception {
Expand All @@ -163,11 +178,9 @@ public void testExecute_listHasNonStringValue_throwIllegalArgumentException() th
sourceAndMetadata.put("key2", list2);
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2));
try {
processor.execute(ingestDocument);
} catch (IllegalArgumentException e) {
assertEquals("list type field [key2] has non string value, can not process it", e.getMessage());
}
BiConsumer handler = mock(BiConsumer.class);
processor.execute(ingestDocument, handler);
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_listHasNull_throwIllegalArgumentException() throws Exception {
Expand All @@ -179,11 +192,9 @@ public void testExecute_listHasNull_throwIllegalArgumentException() throws Excep
sourceAndMetadata.put("key2", list);
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2));
try {
processor.execute(ingestDocument);
} catch (IllegalArgumentException e) {
assertEquals("list type field [key2] has null, can not process it", e.getMessage());
}
BiConsumer handler = mock(BiConsumer.class);
processor.execute(ingestDocument, handler);
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_withMapTypeInput_successful() throws Exception {
Expand All @@ -194,8 +205,18 @@ public void testExecute_withMapTypeInput_successful() throws Exception {
sourceAndMetadata.put("key2", map2);
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2));
IngestDocument document = processor.execute(ingestDocument);
assert document.getSourceAndMetadata().containsKey("key1");

List<List<Float>> modelTensorList = createMockVectorResult();
doAnswer(invocation -> {
ActionListener<List<List<Float>>> listener = invocation.getArgument(2);
listener.onResponse(modelTensorList);
return null;
}).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class));

BiConsumer handler = mock(BiConsumer.class);
processor.execute(ingestDocument, handler);
verify(handler).accept(any(IngestDocument.class), isNull());

}

public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() throws Exception {
Expand All @@ -206,11 +227,9 @@ public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() thr
sourceAndMetadata.put("key2", map2);
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2));
try {
processor.execute(ingestDocument);
} catch (IllegalArgumentException e) {
assertEquals("map type field [key2] has non-string type, can not process it", e.getMessage());
}
BiConsumer handler = mock(BiConsumer.class);
processor.execute(ingestDocument, handler);
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() throws Exception {
Expand All @@ -221,11 +240,9 @@ public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() t
sourceAndMetadata.put("key2", map2);
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2));
try {
processor.execute(ingestDocument);
} catch (IllegalArgumentException e) {
assertEquals("map type field [key2] has empty string, can not process it", e.getMessage());
}
BiConsumer handler = mock(BiConsumer.class);
processor.execute(ingestDocument, handler);
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() throws Exception {
Expand All @@ -235,13 +252,27 @@ public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() throw
sourceAndMetadata.put("key2", ret);
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2));
try {
processor.execute(ingestDocument);
} catch (IllegalArgumentException e) {
assertEquals("map type field [key2] reached max depth limit, can not process it", e.getMessage());
return;
}
fail("Shouldn't be here, expected exception!");
BiConsumer handler = mock(BiConsumer.class);
processor.execute(ingestDocument, handler);
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

public void testExecute_MLClientAccessorThrowFail_handlerFailure() throws Exception {
Map<String, Object> sourceAndMetadata = new HashMap<>();
sourceAndMetadata.put("key1", "value1");
sourceAndMetadata.put("key2", "value2");
IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>());
TextEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2));

doAnswer(invocation -> {
ActionListener<List<List<Float>>> listener = invocation.getArgument(2);
listener.onFailure(new IllegalArgumentException("illegal argument"));
return null;
}).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class));

BiConsumer handler = mock(BiConsumer.class);
processor.execute(ingestDocument, handler);
verify(handler).accept(isNull(), any(IllegalArgumentException.class));
}

private Map<String, Object> createMaxDepthLimitExceedMap(Supplier<Integer> maxDepthSupplier) {
Expand Down

0 comments on commit fbef895

Please sign in to comment.