Skip to content

Commit

Permalink
Add map result support in neural search for non text embedding models
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <zaniu@amazon.com>
  • Loading branch information
zane-neo committed Aug 22, 2023
1 parent 70675c8 commit 5892f92
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
Expand Down Expand Up @@ -100,10 +102,38 @@ public void inferenceSentences(
@NonNull final List<String> inputText,
@NonNull final ActionListener<List<List<Float>>> listener
) {
inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, 0, listener);
retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, 0, listener);
}

private void inferenceSentencesWithRetry(
public void inferenceSentencesWithMapResult(
@NonNull final String modelId,
@NonNull final List<String> inputText,
@NonNull final ActionListener<Map<String, ?>> listener) {
retryableInferenceSentencesWithMapResult(modelId, inputText, 0, listener);
}

private void retryableInferenceSentencesWithMapResult(
final String modelId,
final List<String> inputText,
final int retryTime,
final ActionListener<Map<String, ?>> listener
) {
MLInput mlInput = createMLInput(null, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final Map<String, ?> result = buildMapResultFromResponse(mlOutput);
log.debug("Inference Response for input sentence {} is : {} ", inputText, result);
listener.onResponse(result);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
final int retryTimeAdd = retryTime + 1;
retryableInferenceSentencesWithMapResult(modelId, inputText, retryTimeAdd, listener);
} else {
listener.onFailure(e);
}
}));
}

private void retryableInferenceSentencesWithVectorResult(
final List<String> targetResponseFilters,
final String modelId,
final List<String> inputText,
Expand All @@ -118,7 +148,7 @@ private void inferenceSentencesWithRetry(
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
final int retryTimeAdd = retryTime + 1;
inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, retryTimeAdd, listener);
retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, retryTimeAdd, listener);
} else {
listener.onFailure(e);
}
Expand All @@ -144,4 +174,19 @@ private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
return vector;
}

private Map<String, ?> buildMapResultFromResponse(MLOutput mlOutput) {
final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput;
final List<ModelTensors> tensorOutputList = modelTensorOutput.getMlModelOutputs();
if (CollectionUtils.isEmpty(tensorOutputList)) {
log.error("No tensor output found!");
return null;
}
List<ModelTensor> tensorList = tensorOutputList.get(0).getMlModelTensors();
if (CollectionUtils.isEmpty(tensorList)) {
log.error("No tensor found!");
return null;
}
return tensorList.get(0).getDataAsMap();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.neuralsearch.ml;

import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;

Expand All @@ -13,6 +14,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import org.junit.Before;
import org.mockito.InjectMocks;
Expand Down Expand Up @@ -160,6 +162,98 @@ public void testInferenceSentences_whenNotConnectionException_thenNoRetry() {
Mockito.verify(resultListener).onFailure(illegalStateException);
}

public void test_inferenceSentencesWithMapResult_whenValidInput_thenSuccess() {
final Map<String, String> map = ImmutableMap.of("key", "value");
final ActionListener<Map<String, ?>> resultListener = mock(ActionListener.class);
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onResponse(createModelTensorOutput(map));
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener);

Mockito.verify(client)
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(resultListener).onResponse(map);
Mockito.verifyNoMoreInteractions(resultListener);
}

public void test_inferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenReturnNull() {
final ActionListener<Map<String, ?>> resultListener = mock(ActionListener.class);
final ModelTensorOutput modelTensorOutput = new ModelTensorOutput(Collections.emptyList());
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onResponse(modelTensorOutput);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener);

Mockito.verify(client)
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(resultListener).onResponse(null);
Mockito.verifyNoMoreInteractions(resultListener);
}

public void test_inferenceSentencesWithMapResult_whenModelTensorListEmpty_thenReturnNull() {
final ActionListener<Map<String, ?>> resultListener = mock(ActionListener.class);
final List<ModelTensors> tensorsList = new ArrayList<>();
final List<ModelTensor> mlModelTensorList = new ArrayList<>();
tensorsList.add(new ModelTensors(mlModelTensorList));
final ModelTensorOutput modelTensorOutput = new ModelTensorOutput(tensorsList);
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onResponse(modelTensorOutput);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener);

Mockito.verify(client)
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(resultListener).onResponse(null);
Mockito.verifyNoMoreInteractions(resultListener);
}

public void test_inferenceSentencesWithMapResult_whenRetryableException_retry3Times() {
final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException(
mock(DiscoveryNode.class),
"Node not connected"
);
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onFailure(nodeNodeConnectedException);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
final ActionListener<Map<String, ?>> resultListener = mock(ActionListener.class);
accessor.inferenceSentencesWithMapResult(
TestCommonConstants.MODEL_ID,
TestCommonConstants.SENTENCES_LIST,
resultListener
);

Mockito.verify(client, times(4))
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(resultListener).onFailure(nodeNodeConnectedException);
}

public void test_inferenceSentencesWithMapResult_whenNotRetryableException_thenFail() {
final IllegalStateException illegalStateException = new IllegalStateException("Illegal state");
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onFailure(illegalStateException);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
final ActionListener<Map<String, ?>> resultListener = mock(ActionListener.class);
accessor.inferenceSentencesWithMapResult(
TestCommonConstants.MODEL_ID,
TestCommonConstants.SENTENCES_LIST,
resultListener
);

Mockito.verify(client, times(1))
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(resultListener).onFailure(illegalStateException);
}

private ModelTensorOutput createModelTensorOutput(final Float[] output) {
final List<ModelTensors> tensorsList = new ArrayList<>();
final List<ModelTensor> mlModelTensorList = new ArrayList<>();
Expand All @@ -175,4 +269,22 @@ private ModelTensorOutput createModelTensorOutput(final Float[] output) {
tensorsList.add(modelTensors);
return new ModelTensorOutput(tensorsList);
}

private ModelTensorOutput createModelTensorOutput(final Map<String, String> map) {
final List<ModelTensors> tensorsList = new ArrayList<>();
final List<ModelTensor> mlModelTensorList = new ArrayList<>();
final ModelTensor tensor = new ModelTensor(
"response",
null,
null,
null,
null,
null,
map
);
mlModelTensorList.add(tensor);
final ModelTensors modelTensors = new ModelTensors(mlModelTensorList);
tensorsList.add(modelTensors);
return new ModelTensorOutput(tensorsList);
}
}

0 comments on commit 5892f92

Please sign in to comment.