diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 3f434d3ce..55ee89bbd 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -177,14 +177,16 @@ private List> buildVectorFromResponse(MLOutput mlOutput) { private Map buildMapResultFromResponse(MLOutput mlOutput) { final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput; final List tensorOutputList = modelTensorOutput.getMlModelOutputs(); - if (CollectionUtils.isEmpty(tensorOutputList)) { - log.error("No tensor output found!"); - return null; + if (CollectionUtils.isEmpty(tensorOutputList) || CollectionUtils.isEmpty(tensorOutputList.get(0).getMlModelTensors())) { + throw new IllegalStateException( + "Empty model result produced. Expected 1 tensor output and 1 model tensor, but got [0]" + ); } List tensorList = tensorOutputList.get(0).getMlModelTensors(); - if (CollectionUtils.isEmpty(tensorList)) { - log.error("No tensor found!"); - return null; + if (tensorList.size() != 1) { + throw new IllegalStateException( + "Unexpected number of map result produced. Expected 1 map result to be returned, but got [" + tensorList.size() + "]" + ); } return tensorList.get(0).getDataAsMap(); } diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index 813742878..d7c2cddcb 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -16,12 +16,14 @@ import java.util.Map; import org.junit.Before; +import org.mockito.ArgumentCaptor; import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.action.ActionListener; +import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.MLOutput; @@ -177,7 +179,7 @@ public void test_inferenceSentencesWithMapResult_whenValidInput_thenSuccess() { Mockito.verifyNoMoreInteractions(resultListener); } - public void test_inferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenReturnNull() { + public void test_inferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenException() { final ActionListener> resultListener = mock(ActionListener.class); final ModelTensorOutput modelTensorOutput = new ModelTensorOutput(Collections.emptyList()); Mockito.doAnswer(invocation -> { @@ -189,11 +191,13 @@ public void test_inferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenR Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - Mockito.verify(resultListener).onResponse(null); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class); + Mockito.verify(resultListener).onFailure(argumentCaptor.capture()); + assertEquals("Empty model result produced. Expected 1 tensor output and 1 model tensor, but got [0]", argumentCaptor.getValue().getMessage()); Mockito.verifyNoMoreInteractions(resultListener); } - public void test_inferenceSentencesWithMapResult_whenModelTensorListEmpty_thenReturnNull() { + public void test_inferenceSentencesWithMapResult_whenModelTensorListEmpty_thenException() { final ActionListener> resultListener = mock(ActionListener.class); final List tensorsList = new ArrayList<>(); final List mlModelTensorList = new ArrayList<>(); @@ -208,7 +212,41 @@ public void test_inferenceSentencesWithMapResult_whenModelTensorListEmpty_thenRe Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - Mockito.verify(resultListener).onResponse(null); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class); + Mockito.verify(resultListener).onFailure(argumentCaptor.capture()); + assertEquals("Empty model result produced. Expected 1 tensor output and 1 model tensor, but got [0]", argumentCaptor.getValue().getMessage()); + Mockito.verifyNoMoreInteractions(resultListener); + } + + public void test_inferenceSentencesWithMapResult_whenModelTensorListSizeBiggerThan1_thenException() { + final ActionListener> resultListener = mock(ActionListener.class); + final List tensorsList = new ArrayList<>(); + final List mlModelTensorList = new ArrayList<>(); + final ModelTensor tensor = new ModelTensor( + "response", + null, + null, + null, + null, + null, + Map.of("key", "value") + ); + mlModelTensorList.add(tensor); + mlModelTensorList.add(tensor); + tensorsList.add(new ModelTensors(mlModelTensorList)); + final ModelTensorOutput modelTensorOutput = new ModelTensorOutput(tensorsList); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(modelTensorOutput); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + + Mockito.verify(client) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class); + Mockito.verify(resultListener).onFailure(argumentCaptor.capture()); + assertEquals("Unexpected number of map result produced. Expected 1 map result to be returned, but got [2]", argumentCaptor.getValue().getMessage()); Mockito.verifyNoMoreInteractions(resultListener); }