Skip to content

Commit

Permalink
more UT for rest and trasport actions
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <xunzh@amazon.com>
  • Loading branch information
Zhangxunmt committed Jul 11, 2023
1 parent 50ca665 commit 3e08582
Show file tree
Hide file tree
Showing 3 changed files with 307 additions and 7 deletions.
8 changes: 1 addition & 7 deletions plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,7 @@ List<String> jacocoExclusions = [
'org.opensearch.ml.action.training.TrainingITTests',
'org.opensearch.ml.action.prediction.PredictionITTests',
'org.opensearch.ml.cluster.MLSyncUpCron',
'org.opensearch.ml.action.connector.GetConnectorTransportAction',
'org.opensearch.ml.breaker.MemoryCircuitBreaker',
'org.opensearch.ml.action.connector.DeleteConnectorTransportAction',
'org.opensearch.ml.action.connector.DeleteConnectorTransportAction.1',
'org.opensearch.ml.action.connector.TransportCreateConnectorAction',
'org.opensearch.ml.action.connector.SearchConnectorTransportAction',
'org.opensearch.ml.rest.RestMLCreateConnectorAction'
'org.opensearch.ml.breaker.MemoryCircuitBreaker'
]

jacocoTestCoverageVerification {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.action.connector;

import org.apache.lucene.search.TotalHits;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Rule;
import org.junit.rules.ExpectedException;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.ActionListener;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchResponseSections;
import org.opensearch.action.search.ShardSearchFailure;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.client.Client;
import org.opensearch.client.Response;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.get.GetResult;
import org.opensearch.index.reindex.BulkByScrollResponse;
import org.opensearch.index.reindex.ScrollableHitSource;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.connector.HttpConnector;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

import java.io.IOException;
import java.util.ArrayList;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class DeleteConnectorTransportActionTests extends OpenSearchTestCase {
@Mock
ThreadPool threadPool;

@Mock
Client client;

@Mock
TransportService transportService;

@Mock
ActionFilters actionFilters;

@Mock
ActionListener<DeleteResponse> actionListener;

@Mock
DeleteResponse deleteResponse;

@Mock
NamedXContentRegistry xContentRegistry;

@Mock
private MLModelManager mlModelManager;

@Rule
public ExpectedException exceptionRule = ExpectedException.none();

@Mock
ClusterService clusterService;

DeleteConnectorTransportAction deleteConnectorTransportAction;
MLConnectorDeleteRequest mlConnectorDeleteRequest;
ThreadContext threadContext;
MLModel model;

@Mock
private ConnectorAccessControlHelper connectorAccessControlHelper;

@Before
public void setup() throws IOException {
MockitoAnnotations.openMocks(this);

mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder().connectorId("connector_id").build();

Settings settings = Settings.builder().build();
deleteConnectorTransportAction = spy(
new DeleteConnectorTransportAction(
transportService,
actionFilters,
client,
xContentRegistry,
connectorAccessControlHelper
)
);

doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(2);
listener.onResponse(true);
return null;
}).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any());

threadContext = new ThreadContext(settings);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
}

public void testDeleteConnector_Success() throws IOException {
doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onResponse(deleteResponse);
return null;
}).when(client).delete(any(), any());

SearchResponse searchResponse = getEmptySearchResponse();
doAnswer(invocation -> {
ActionListener<SearchResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(searchResponse);
return null;
}).when(client).search(any(), any());

deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, actionListener);
verify(actionListener).onResponse(deleteResponse);
}

public void testDeleteConnector_BlockedByModel() throws IOException {
doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onResponse(deleteResponse);
return null;
}).when(client).delete(any(), any());

SearchResponse searchResponse = getNonEmptySearchResponse();
doAnswer(invocation -> {
ActionListener<SearchResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(searchResponse);
return null;
}).when(client).search(any(), any());

deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(MLValidationException.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("1 models are still using this connector, please delete or update the models first!", argumentCaptor.getValue().getMessage());
}

public void test_UserHasNoAccessException() throws IOException {
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(2);
listener.onResponse(false);
return null;
}).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any());

deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("You are not allowed to delete this connector", argumentCaptor.getValue().getMessage());
}

public void testDeleteConnector_SearchFailure() throws IOException {
doAnswer(invocation -> {
ActionListener<SearchResponse> actionListener = invocation.getArgument(1);
actionListener.onFailure(new RuntimeException("Search Failed!"));
return null;
}).when(client).search(any(), any());

doAnswer(invocation -> {
ActionListener<DeleteResponse> actionListener = invocation.getArgument(1);
actionListener.onFailure(new ResourceNotFoundException("errorMessage"));
return null;
}).when(client).delete(any(), any());

deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("Search Failed!", argumentCaptor.getValue().getMessage());
}

public void testDeleteConnector_ResourceNotFoundException() throws IOException {
SearchResponse searchResponse = getEmptySearchResponse();
doAnswer(invocation -> {
ActionListener<SearchResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(searchResponse);
return null;
}).when(client).search(any(), any());

doAnswer(invocation -> {
ActionListener<DeleteResponse> actionListener = invocation.getArgument(1);
actionListener.onFailure(new ResourceNotFoundException("errorMessage"));
return null;
}).when(client).delete(any(), any());

deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("errorMessage", argumentCaptor.getValue().getMessage());
}

public void test_ValidationFailedException() throws IOException {
GetResponse getResponse = prepareMLConnector();
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
return null;
}).when(client).search(any(), any());

doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(2);
listener.onFailure(new Exception("Failed to validate access"));
return null;
}).when(connectorAccessControlHelper).validateConnectorAccess(any(), any(), any());

deleteConnectorTransportAction.doExecute(null, mlConnectorDeleteRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("Failed to validate access", argumentCaptor.getValue().getMessage());
}

public GetResponse prepareMLConnector() throws IOException {
HttpConnector connector = HttpConnector.builder().name("test_connector").protocol("http").build();
XContentBuilder content = connector.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
BytesReference bytesReference = BytesReference.bytes(content);
GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null);
GetResponse getResponse = new GetResponse(getResult);
return getResponse;
}

private SearchResponse getEmptySearchResponse() {
SearchHits hits = new SearchHits(new SearchHit[0], null, Float.NaN);
SearchResponseSections searchSections = new SearchResponseSections(
hits,
InternalAggregations.EMPTY,
null,
true,
false,
null,
1
);
SearchResponse searchResponse = new SearchResponse(
searchSections,
null,
1,
1,
0,
11,
ShardSearchFailure.EMPTY_ARRAY,
SearchResponse.Clusters.EMPTY
);
return searchResponse;
}

private SearchResponse getNonEmptySearchResponse() {
SearchHit[] hits = new SearchHit[1];
SearchHits searchHits = new SearchHits(hits, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0f);
SearchResponseSections searchSections = new SearchResponseSections(
searchHits,
InternalAggregations.EMPTY,
null,
true,
false,
null,
1
);
SearchResponse searchResponse = new SearchResponse(
searchSections,
null,
1,
1,
0,
11,
ShardSearchFailure.EMPTY_ARRAY,
SearchResponse.Clusters.EMPTY
);
return searchResponse;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
import static org.opensearch.ml.utils.TestHelper.verifyParsedCreateConnectorInput;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.junit.Before;
import org.junit.Rule;
Expand All @@ -26,6 +28,7 @@
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.common.Strings;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest;
Expand All @@ -34,6 +37,7 @@
import org.opensearch.rest.RestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.test.rest.FakeRestRequest;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;

Expand Down Expand Up @@ -107,4 +111,13 @@ public void testPrepareRequest() throws Exception {
MLCreateConnectorInput mlCreateConnectorInput = argumentCaptor.getValue().getMlCreateConnectorInput();
verifyParsedCreateConnectorInput(mlCreateConnectorInput);
}

public void testPrepareRequest_EmptyContent() throws Exception {
thrown.expect(IOException.class);
thrown.expectMessage("Create Connector request has empty body");
Map<String, String> params = new HashMap<>();
RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build();

restMLCreateConnectorAction.handleRequest(request, channel, client);
}
}

0 comments on commit 3e08582

Please sign in to comment.