Skip to content

Commit

Permalink
[FEATURE]Improve test coverage for RemoteModel.java (opensearch-proje…
Browse files Browse the repository at this point in the history
…ct#3205)

* [FEATURE]Improve test coverage for RemoteModel.java

Added new tests for missing coverage. Mainly coverage was missing for catching exceptions in the methods initModel() and asyncPredict().
Also renamed some tests to match with testing methods.

Resolves opensearch-project#1382

Signed-off-by: Abdul Muneer Kolarkunnu <muneer.kolarkunnu@netapp.com>

* [FEATURE]Improve test coverage for RemoteModel.java

Added new tests for missing coverage. Mainly coverage was missing for catching exceptions in the methods initModel() and asyncPredict().
Also renamed some tests to match with testing methods.

Resolves opensearch-project#1382

Signed-off-by: Abdul Muneer Kolarkunnu <muneer.kolarkunnu@netapp.com>

* [FEATURE]Improve test coverage for RemoteModel.java

Added new tests for missing coverage. Mainly coverage was missing for catching exceptions in the methods initModel() and asyncPredict().
Also renamed some tests to match with testing methods.

Resolves opensearch-project#1382

Signed-off-by: Abdul Muneer Kolarkunnu <muneer.kolarkunnu@netapp.com>

* [FEATURE]Improve test coverage for RemoteModel.java

Added new tests for missing coverage. Mainly coverage was missing for catching exceptions in the methods initModel() and asyncPredict().
Also renamed some tests to match with testing methods.

Resolves opensearch-project#1382

Signed-off-by: Abdul Muneer Kolarkunnu <muneer.kolarkunnu@netapp.com>

---------

Signed-off-by: Abdul Muneer Kolarkunnu <muneer.kolarkunnu@netapp.com>
  • Loading branch information
akolarkunnu authored Nov 14, 2024
1 parent 5cfdc3c commit 45ff4f5
Showing 1 changed file with 83 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import static org.mockito.Mockito.when;

import java.util.Arrays;
import java.util.Collections;
import java.util.Map;

import org.junit.Assert;
Expand All @@ -23,28 +24,36 @@
import org.junit.rules.ExpectedException;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.MockitoAnnotations;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.connector.ConnectorProtocols;
import org.opensearch.ml.common.connector.HttpConnector;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.engine.MLEngineClassLoader;
import org.opensearch.ml.engine.MLStaticMockBase;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;

import com.google.common.collect.ImmutableMap;

public class RemoteModelTest {
public class RemoteModelTest extends MLStaticMockBase {

@Mock
MLInput mlInput;

@Mock
MLModel mlModel;

@Mock
RemoteConnectorExecutor remoteConnectorExecutor;

@Rule
public ExpectedException exceptionRule = ExpectedException.none();

Expand Down Expand Up @@ -73,7 +82,7 @@ public void test_predict_throw_IllegalStateException() {
}

@Test
public void predict_NullConnectorExecutor() {
public void asyncPredict_NullConnectorExecutor() {
ActionListener<MLTaskResponse> actionListener = mock(ActionListener.class);
remoteModel.asyncPredict(mlInput, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
Expand All @@ -86,7 +95,18 @@ public void predict_NullConnectorExecutor() {
}

@Test
public void predict_ModelDeployed_WrongInput() {
public void asyncPredict_ModelDeployed_WrongInput() {
asyncPredict_ModelDeployed_WrongInput("pre_process_function not defined in connector");
}

@Test
public void asyncPredict_With_RemoteInferenceInputDataSet() {
when(mlInput.getInputDataset()).thenReturn(
new RemoteInferenceInputDataSet(Collections.emptyMap(), ConnectorAction.ActionType.BATCH_PREDICT));
asyncPredict_ModelDeployed_WrongInput("no BATCH_PREDICT action found");
}

private void asyncPredict_ModelDeployed_WrongInput(String expExceptionMessage) {
Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}"));
when(mlModel.getConnector()).thenReturn(connector);
remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor);
Expand All @@ -95,16 +115,71 @@ public void predict_ModelDeployed_WrongInput() {
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assert argumentCaptor.getValue() instanceof RuntimeException;
assertEquals("pre_process_function not defined in connector", argumentCaptor.getValue().getMessage());
assertEquals(expExceptionMessage, argumentCaptor.getValue().getMessage());
}

@Test
public void initModel_RuntimeException() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Tag mismatch!");
public void asyncPredict_Failure_With_RuntimeException() {
asyncPredict_Failure_With_Throwable(
new RuntimeException("Remote Connection Exception!"),
RuntimeException.class,
"Remote Connection Exception!"
);
}

@Test
public void asyncPredict_Failure_With_Throwable() {
asyncPredict_Failure_With_Throwable(
new Error("Remote Connection Error!"),
MLException.class,
"java.lang.Error: Remote Connection Error!"
);
}

private void asyncPredict_Failure_With_Throwable(
Throwable actualException,
Class<? extends Throwable> expExceptionClass,
String expExceptionMessage
) {
ActionListener<MLTaskResponse> actionListener = mock(ActionListener.class);
doThrow(actualException)
.when(remoteConnectorExecutor)
.executeAction(ConnectorAction.ActionType.PREDICT.toString(), mlInput, actionListener);
try (MockedStatic<MLEngineClassLoader> loader = mockStatic(MLEngineClassLoader.class)) {
Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}"));
when(mlModel.getConnector()).thenReturn(connector);
loader
.when(() -> MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class))
.thenReturn(remoteConnectorExecutor);
remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor);
remoteModel.asyncPredict(mlInput, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assert expExceptionClass.isInstance(argumentCaptor.getValue());
assertEquals(expExceptionMessage, argumentCaptor.getValue().getMessage());
}
}

@Test
public void initModel_Failure_With_RuntimeException() {
initModel_Failure_With_Throwable(new IllegalArgumentException("Tag mismatch!"), IllegalArgumentException.class, "Tag mismatch!");
}

@Test
public void initModel_Failure_With_Throwable() {
initModel_Failure_With_Throwable(new Error("Decryption Error!"), MLException.class, "Decryption Error!");
}

private void initModel_Failure_With_Throwable(
Throwable actualException,
Class<? extends Throwable> expExcepClass,
String expExceptionMessage
) {
exceptionRule.expect(expExcepClass);
exceptionRule.expectMessage(expExceptionMessage);
Connector connector = createConnector(null);
when(mlModel.getConnector()).thenReturn(connector);
doThrow(new IllegalArgumentException("Tag mismatch!")).when(encryptor).decrypt(any());
doThrow(actualException).when(encryptor).decrypt(any());
remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor);
}

Expand All @@ -129,7 +204,6 @@ public void initModel_WithHeader() {
Assert.assertNotNull(executor.getConnector().getDecryptedHeaders());
assertEquals(1, executor.getConnector().getDecryptedHeaders().size());
assertEquals("Bearer test_api_key", executor.getConnector().getDecryptedHeaders().get("Authorization"));

remoteModel.close();
Assert.assertNull(remoteModel.getConnectorExecutor());
}
Expand Down

0 comments on commit 45ff4f5

Please sign in to comment.