Skip to content

Commit

Permalink
fix model group auto-deletion when last version is deleted (#1444)
Browse files Browse the repository at this point in the history
* fix model group auto-deletion when last version is deleted

Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com>
(cherry picked from commit 1f43b28)
  • Loading branch information
rbhavna authored and github-actions[bot] committed Oct 16, 2023
1 parent 701af31 commit 9885bf5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 206 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,19 @@
package org.opensearch.ml.action.models;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.common.MLModel.ALGORITHM_FIELD;
import static org.opensearch.ml.common.MLModel.MODEL_ID_FIELD;
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;
import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.delete.DeleteRequest;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
Expand All @@ -34,8 +30,6 @@
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.index.query.TermsQueryBuilder;
import org.opensearch.index.reindex.BulkByScrollResponse;
import org.opensearch.index.reindex.DeleteByQueryAction;
Expand All @@ -48,7 +42,6 @@
import org.opensearch.ml.common.transport.model.MLModelGetRequest;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.fetch.subphase.FetchSourceContext;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
Expand Down Expand Up @@ -110,6 +103,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
algorithmName = getResponse.getSource().get(ALGORITHM_FIELD).toString();
}
MLModel mlModel = MLModel.parse(parser, algorithmName);
MLModelState mlModelState = mlModel.getModelState();

modelAccessControlHelper
.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> {
Expand All @@ -118,37 +112,20 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
.onFailure(
new MLValidationException("User doesn't have privilege to perform this operation on this model")
);
} else if (mlModelState.equals(MLModelState.LOADED)
|| mlModelState.equals(MLModelState.LOADING)
|| mlModelState.equals(MLModelState.PARTIALLY_LOADED)
|| mlModelState.equals(MLModelState.DEPLOYED)
|| mlModelState.equals(MLModelState.DEPLOYING)
|| mlModelState.equals(MLModelState.PARTIALLY_DEPLOYED)) {
wrappedListener
.onFailure(
new Exception(
"Model cannot be deleted in deploying or deployed state. Try undeploy model first then delete"
)
);
} else {
MLModelState mlModelState = mlModel.getModelState();
if (mlModelState.equals(MLModelState.LOADED)
|| mlModelState.equals(MLModelState.LOADING)
|| mlModelState.equals(MLModelState.PARTIALLY_LOADED)
|| mlModelState.equals(MLModelState.DEPLOYED)
|| mlModelState.equals(MLModelState.DEPLOYING)
|| mlModelState.equals(MLModelState.PARTIALLY_DEPLOYED)) {
wrappedListener
.onFailure(
new Exception(
"Model cannot be deleted in deploying or deployed state. Try undeploy model first then delete"
)
);
} else if (StringUtils.isNotEmpty(mlModel.getModelGroupId())) {
searchModel(mlModel.getModelGroupId(), ActionListener.wrap(response -> {
boolean isLastModelOfGroup = false;
if (response != null
&& response.getHits() != null
&& response.getHits().getTotalHits() != null
&& response.getHits().getTotalHits().value == 1) {
isLastModelOfGroup = true;
}
deleteModel(modelId, mlModel.getModelGroupId(), isLastModelOfGroup, wrappedListener);
}, e -> {
log.error("Failed to Search Model index " + modelId, e);
wrappedListener.onFailure(e);
}));
} else {
deleteModel(modelId, mlModel.getModelGroupId(), false, wrappedListener);
}
deleteModel(modelId, actionListener);
}
}, e -> {
log.error("Failed to validate Access for Model Id " + modelId, e);
Expand All @@ -168,18 +145,6 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
}
}

private void searchModel(String modelGroupId, ActionListener<SearchResponse> listener) {
BoolQueryBuilder query = new BoolQueryBuilder();
query.filter(new TermQueryBuilder(MLModel.MODEL_GROUP_ID_FIELD, modelGroupId));

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query);
SearchRequest searchRequest = new SearchRequest(ML_MODEL_INDEX).source(searchSourceBuilder);
client.search(searchRequest, ActionListener.wrap(response -> { listener.onResponse(response); }, e -> {
log.error("Failed to search Model index", e);
listener.onFailure(e);
}));
}

@VisibleForTesting
void deleteModelChunks(String modelId, DeleteResponse deleteResponse, ActionListener<DeleteResponse> actionListener) {
DeleteByQueryRequest deleteModelsRequest = new DeleteByQueryRequest(ML_MODEL_INDEX);
Expand Down Expand Up @@ -218,19 +183,11 @@ private void returnFailure(BulkByScrollResponse response, String modelId, Action
actionListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.INTERNAL_SERVER_ERROR));
}

private void deleteModel(
String modelId,
String modelGroupId,
boolean isLastModelOfGroup,
ActionListener<DeleteResponse> actionListener
) {
private void deleteModel(String modelId, ActionListener<DeleteResponse> actionListener) {
DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_INDEX, modelId);
client.delete(deleteRequest, new ActionListener<DeleteResponse>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
if (isLastModelOfGroup) {
deleteModelGroup(modelGroupId);
}
deleteModelChunks(modelId, deleteResponse, actionListener);
}

Expand All @@ -244,19 +201,4 @@ public void onFailure(Exception e) {
}
});
}

private void deleteModelGroup(String modelGroupId) {
DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_GROUP_INDEX, modelGroupId);
client.delete(deleteRequest, new ActionListener<DeleteResponse>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
log.debug("Completed Delete Model Group for modelGroupId:{}", modelGroupId);
}

@Override
public void onFailure(Exception e) {
log.error("Failed to delete ML Model Group with Id:{} " + modelGroupId, e);
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
package org.opensearch.ml.action.models;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
Expand All @@ -22,7 +20,6 @@
import java.util.ArrayList;
import java.util.Arrays;

import org.apache.lucene.search.TotalHits;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Rule;
Expand All @@ -35,7 +32,6 @@
import org.opensearch.action.bulk.BulkItemResponse;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
Expand All @@ -50,15 +46,11 @@
import org.opensearch.index.get.GetResult;
import org.opensearch.index.reindex.BulkByScrollResponse;
import org.opensearch.index.reindex.ScrollableHitSource;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.utils.TestHelper;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
Expand Down Expand Up @@ -184,115 +176,6 @@ public void testDeleteModel_Success_AlgorithmNotNull() throws IOException {
verify(actionListener).onResponse(deleteResponse);
}

public void test_Success_ModelGroupIDNotNull_LastModelOfGroup() throws IOException {
doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onResponse(deleteResponse);
return null;
}).when(client).delete(any(), any());

doAnswer(invocation -> {
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null);
listener.onResponse(response);
return null;
}).when(client).execute(any(), any(), any());

SearchResponse searchResponse = createModelGroupSearchResponse(1);
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
listener.onResponse(searchResponse);
return null;
}).when(client).search(any(), isA(ActionListener.class));

GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, "modelGroupID");

doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
return null;
}).when(client).get(any(), any());

deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
verify(actionListener).onResponse(deleteResponse);
}

public void test_Success_ModelGroupIDNotNull_NotLastModelOfGroup() throws IOException {
doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onResponse(deleteResponse);
return null;
}).when(client).delete(any(), any());

doAnswer(invocation -> {
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null);
listener.onResponse(response);
return null;
}).when(client).execute(any(), any(), any());

SearchResponse searchResponse = createModelGroupSearchResponse(2);
doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
listener.onResponse(searchResponse);
return null;
}).when(client).search(any(), isA(ActionListener.class));

MLModel mlModel = MLModel
.builder()
.modelId("test_id")
.modelGroupId("modelGroupID")
.modelState(MLModelState.REGISTERED)
.algorithm(FunctionName.TEXT_EMBEDDING)
.build();
XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
BytesReference bytesReference = BytesReference.bytes(content);
GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null);
GetResponse getResponse = new GetResponse(getResult);
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
return null;
}).when(client).get(any(), any());

deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
verify(actionListener).onResponse(deleteResponse);
}

public void test_Failure_FailedToSearchLastModel() throws IOException {
doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onResponse(deleteResponse);
return null;
}).when(client).delete(any(), any());

doAnswer(invocation -> {
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null);
listener.onResponse(response);
return null;
}).when(client).execute(any(), any(), any());

doAnswer(invocation -> {
ActionListener<SearchResponse> listener = invocation.getArgument(1);
listener.onFailure(new Exception("Failed to search Model index"));
return null;
}).when(client).search(any(), isA(ActionListener.class));

GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, "modelGroupID");

doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
return null;
}).when(client).get(any(), any());

deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("Failed to search Model index", argumentCaptor.getValue().getMessage());
}

public void test_UserHasNoAccessException() throws IOException {
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, "modelGroupID");
doAnswer(invocation -> {
Expand Down Expand Up @@ -517,20 +400,4 @@ public GetResponse prepareMLModel(MLModelState mlModelState, String modelGroupID
GetResponse getResponse = new GetResponse(getResult);
return getResponse;
}

private SearchResponse createModelGroupSearchResponse(long totalHits) throws IOException {
SearchResponse searchResponse = mock(SearchResponse.class);
String modelContent = "{\n"
+ " \"created_time\": 1684981986069,\n"
+ " \"access\": \"public\",\n"
+ " \"latest_version\": 0,\n"
+ " \"last_updated_time\": 1684981986069,\n"
+ " \"name\": \"model_group_IT\",\n"
+ " \"description\": \"This is an example description\"\n"
+ " }";
SearchHit modelGroup = SearchHit.fromXContent(TestHelper.parser(modelContent));
SearchHits hits = new SearchHits(new SearchHit[] { modelGroup }, new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), Float.NaN);
when(searchResponse.getHits()).thenReturn(hits);
return searchResponse;
}
}

0 comments on commit 9885bf5

Please sign in to comment.