Skip to content

Commit

Permalink
Add more UTs
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <zaniu@amazon.com>
  • Loading branch information
zane-neo committed Sep 1, 2023
1 parent 8dddbeb commit e4f8780
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,16 @@ private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
private Map<String, ?> buildMapResultFromResponse(MLOutput mlOutput) {
final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput;
final List<ModelTensors> tensorOutputList = modelTensorOutput.getMlModelOutputs();
if (CollectionUtils.isEmpty(tensorOutputList)) {
log.error("No tensor output found!");
return null;
if (CollectionUtils.isEmpty(tensorOutputList) || CollectionUtils.isEmpty(tensorOutputList.get(0).getMlModelTensors())) {
throw new IllegalStateException(
"Empty model result produced. Expected 1 tensor output and 1 model tensor, but got [0]"
);
}
List<ModelTensor> tensorList = tensorOutputList.get(0).getMlModelTensors();
if (CollectionUtils.isEmpty(tensorList)) {
log.error("No tensor found!");
return null;
if (tensorList.size() != 1) {
throw new IllegalStateException(
"Unexpected number of map result produced. Expected 1 map result to be returned, but got [" + tensorList.size() + "]"
);
}
return tensorList.get(0).getDataAsMap();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
import java.util.Map;

import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
Expand Down Expand Up @@ -177,7 +179,7 @@ public void test_inferenceSentencesWithMapResult_whenValidInput_thenSuccess() {
Mockito.verifyNoMoreInteractions(resultListener);
}

public void test_inferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenReturnNull() {
public void test_inferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenException() {
final ActionListener<Map<String, ?>> resultListener = mock(ActionListener.class);
final ModelTensorOutput modelTensorOutput = new ModelTensorOutput(Collections.emptyList());
Mockito.doAnswer(invocation -> {
Expand All @@ -189,11 +191,13 @@ public void test_inferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenR

Mockito.verify(client)
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(resultListener).onResponse(null);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class);
Mockito.verify(resultListener).onFailure(argumentCaptor.capture());
assertEquals("Empty model result produced. Expected 1 tensor output and 1 model tensor, but got [0]", argumentCaptor.getValue().getMessage());
Mockito.verifyNoMoreInteractions(resultListener);
}

public void test_inferenceSentencesWithMapResult_whenModelTensorListEmpty_thenReturnNull() {
public void test_inferenceSentencesWithMapResult_whenModelTensorListEmpty_thenException() {
final ActionListener<Map<String, ?>> resultListener = mock(ActionListener.class);
final List<ModelTensors> tensorsList = new ArrayList<>();
final List<ModelTensor> mlModelTensorList = new ArrayList<>();
Expand All @@ -208,7 +212,41 @@ public void test_inferenceSentencesWithMapResult_whenModelTensorListEmpty_thenRe

Mockito.verify(client)
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
Mockito.verify(resultListener).onResponse(null);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class);
Mockito.verify(resultListener).onFailure(argumentCaptor.capture());
assertEquals("Empty model result produced. Expected 1 tensor output and 1 model tensor, but got [0]", argumentCaptor.getValue().getMessage());
Mockito.verifyNoMoreInteractions(resultListener);
}

public void test_inferenceSentencesWithMapResult_whenModelTensorListSizeBiggerThan1_thenException() {
final ActionListener<Map<String, ?>> resultListener = mock(ActionListener.class);
final List<ModelTensors> tensorsList = new ArrayList<>();
final List<ModelTensor> mlModelTensorList = new ArrayList<>();
final ModelTensor tensor = new ModelTensor(
"response",
null,
null,
null,
null,
null,
Map.of("key", "value")
);
mlModelTensorList.add(tensor);
mlModelTensorList.add(tensor);
tensorsList.add(new ModelTensors(mlModelTensorList));
final ModelTensorOutput modelTensorOutput = new ModelTensorOutput(tensorsList);
Mockito.doAnswer(invocation -> {
final ActionListener<MLOutput> actionListener = invocation.getArgument(2);
actionListener.onResponse(modelTensorOutput);
return null;
}).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener);

Mockito.verify(client)
.predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class));
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class);
Mockito.verify(resultListener).onFailure(argumentCaptor.capture());
assertEquals("Unexpected number of map result produced. Expected 1 map result to be returned, but got [2]", argumentCaptor.getValue().getMessage());
Mockito.verifyNoMoreInteractions(resultListener);
}

Expand Down

0 comments on commit e4f8780

Please sign in to comment.