From 5892f9240f9c29fbfb227bc2eaf97635483e08e0 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Tue, 22 Aug 2023 12:05:37 +0800 Subject: [PATCH] Add map result support in neural search for non text embedding models Signed-off-by: zane-neo --- .../ml/MLCommonsClientAccessor.java | 51 +++++++- .../ml/MLCommonsClientAccessorTests.java | 112 ++++++++++++++++++ 2 files changed, 160 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 768584ec9..3f434d3ce 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -8,6 +8,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import lombok.NonNull; @@ -15,6 +16,7 @@ import lombok.extern.log4j.Log4j2; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.MLInputDataset; @@ -100,10 +102,38 @@ public void inferenceSentences( @NonNull final List inputText, @NonNull final ActionListener>> listener ) { - inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, 0, listener); + retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, 0, listener); } - private void inferenceSentencesWithRetry( + public void inferenceSentencesWithMapResult( + @NonNull final String modelId, + @NonNull final List inputText, + @NonNull final ActionListener> listener) { + retryableInferenceSentencesWithMapResult(modelId, inputText, 0, listener); + } + + private void retryableInferenceSentencesWithMapResult( + final String modelId, + final List inputText, + final int retryTime, + final ActionListener> listener + ) { + MLInput mlInput = createMLInput(null, inputText); + mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { + final Map result = buildMapResultFromResponse(mlOutput); + log.debug("Inference Response for input sentence {} is : {} ", inputText, result); + listener.onResponse(result); + }, e -> { + if (RetryUtil.shouldRetry(e, retryTime)) { + final int retryTimeAdd = retryTime + 1; + retryableInferenceSentencesWithMapResult(modelId, inputText, retryTimeAdd, listener); + } else { + listener.onFailure(e); + } + })); + } + + private void retryableInferenceSentencesWithVectorResult( final List targetResponseFilters, final String modelId, final List inputText, @@ -118,7 +148,7 @@ private void inferenceSentencesWithRetry( }, e -> { if (RetryUtil.shouldRetry(e, retryTime)) { final int retryTimeAdd = retryTime + 1; - inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, retryTimeAdd, listener); + retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, retryTimeAdd, listener); } else { listener.onFailure(e); } @@ -144,4 +174,19 @@ private List> buildVectorFromResponse(MLOutput mlOutput) { return vector; } + private Map buildMapResultFromResponse(MLOutput mlOutput) { + final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput; + final List tensorOutputList = modelTensorOutput.getMlModelOutputs(); + if (CollectionUtils.isEmpty(tensorOutputList)) { + log.error("No tensor output found!"); + return null; + } + List tensorList = tensorOutputList.get(0).getMlModelTensors(); + if (CollectionUtils.isEmpty(tensorList)) { + log.error("No tensor found!"); + return null; + } + return tensorList.get(0).getDataAsMap(); + } + } diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index 350394250..41cb4d12c 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -5,6 +5,7 @@ package org.opensearch.neuralsearch.ml; +import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -13,6 +14,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Map; import org.junit.Before; import org.mockito.InjectMocks; @@ -160,6 +162,98 @@ public void testInferenceSentences_whenNotConnectionException_thenNoRetry() { Mockito.verify(resultListener).onFailure(illegalStateException); } + public void test_inferenceSentencesWithMapResult_whenValidInput_thenSuccess() { + final Map map = ImmutableMap.of("key", "value"); + final ActionListener> resultListener = mock(ActionListener.class); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(createModelTensorOutput(map)); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + + Mockito.verify(client) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(resultListener).onResponse(map); + Mockito.verifyNoMoreInteractions(resultListener); + } + + public void test_inferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenReturnNull() { + final ActionListener> resultListener = mock(ActionListener.class); + final ModelTensorOutput modelTensorOutput = new ModelTensorOutput(Collections.emptyList()); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(modelTensorOutput); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + + Mockito.verify(client) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(resultListener).onResponse(null); + Mockito.verifyNoMoreInteractions(resultListener); + } + + public void test_inferenceSentencesWithMapResult_whenModelTensorListEmpty_thenReturnNull() { + final ActionListener> resultListener = mock(ActionListener.class); + final List tensorsList = new ArrayList<>(); + final List mlModelTensorList = new ArrayList<>(); + tensorsList.add(new ModelTensors(mlModelTensorList)); + final ModelTensorOutput modelTensorOutput = new ModelTensorOutput(tensorsList); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(modelTensorOutput); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + + Mockito.verify(client) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(resultListener).onResponse(null); + Mockito.verifyNoMoreInteractions(resultListener); + } + + public void test_inferenceSentencesWithMapResult_whenRetryableException_retry3Times() { + final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException( + mock(DiscoveryNode.class), + "Node not connected" + ); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(nodeNodeConnectedException); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + final ActionListener> resultListener = mock(ActionListener.class); + accessor.inferenceSentencesWithMapResult( + TestCommonConstants.MODEL_ID, + TestCommonConstants.SENTENCES_LIST, + resultListener + ); + + Mockito.verify(client, times(4)) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(resultListener).onFailure(nodeNodeConnectedException); + } + + public void test_inferenceSentencesWithMapResult_whenNotRetryableException_thenFail() { + final IllegalStateException illegalStateException = new IllegalStateException("Illegal state"); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(illegalStateException); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + final ActionListener> resultListener = mock(ActionListener.class); + accessor.inferenceSentencesWithMapResult( + TestCommonConstants.MODEL_ID, + TestCommonConstants.SENTENCES_LIST, + resultListener + ); + + Mockito.verify(client, times(1)) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(resultListener).onFailure(illegalStateException); + } + private ModelTensorOutput createModelTensorOutput(final Float[] output) { final List tensorsList = new ArrayList<>(); final List mlModelTensorList = new ArrayList<>(); @@ -175,4 +269,22 @@ private ModelTensorOutput createModelTensorOutput(final Float[] output) { tensorsList.add(modelTensors); return new ModelTensorOutput(tensorsList); } + + private ModelTensorOutput createModelTensorOutput(final Map map) { + final List tensorsList = new ArrayList<>(); + final List mlModelTensorList = new ArrayList<>(); + final ModelTensor tensor = new ModelTensor( + "response", + null, + null, + null, + null, + null, + map + ); + mlModelTensorList.add(tensor); + final ModelTensors modelTensors = new ModelTensors(mlModelTensorList); + tensorsList.add(modelTensors); + return new ModelTensorOutput(tensorsList); + } }