Skip to content

Commit

Permalink
check state before deleting model or task (opensearch-project#725)
Browse files Browse the repository at this point in the history
Signed-off-by: Bhavana Goud Ramaram <rbhavna@amazon.com>
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
rbhavna authored and ylwu-amzn committed Mar 4, 2023
1 parent c738cc4 commit 3697d03
Show file tree
Hide file tree
Showing 4 changed files with 282 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

package org.opensearch.ml.action.models;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.common.MLModel.MODEL_ID_FIELD;
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;
import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext;

import lombok.AccessLevel;
import lombok.experimental.FieldDefaults;
Expand All @@ -18,18 +21,26 @@
import org.opensearch.action.ActionRequest;
import org.opensearch.action.delete.DeleteRequest;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.index.query.TermsQueryBuilder;
import org.opensearch.index.reindex.BulkByScrollResponse;
import org.opensearch.index.reindex.DeleteByQueryAction;
import org.opensearch.index.reindex.DeleteByQueryRequest;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
import org.opensearch.ml.common.transport.model.MLModelGetRequest;
import org.opensearch.rest.RestStatus;
import org.opensearch.search.fetch.subphase.FetchSourceContext;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

Expand All @@ -44,36 +55,66 @@ public class DeleteModelTransportAction extends HandledTransportAction<ActionReq
static final String SEARCH_FAILURE_MSG = "Search failure while deleting model of ";
static final String OS_STATUS_EXCEPTION_MESSAGE = "Failed to delete all model chunks";
Client client;
NamedXContentRegistry xContentRegistry;

@Inject
public DeleteModelTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
public DeleteModelTransportAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
NamedXContentRegistry xContentRegistry
) {
super(MLModelDeleteAction.NAME, transportService, actionFilters, MLModelDeleteRequest::new);
this.client = client;
this.xContentRegistry = xContentRegistry;
}

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<DeleteResponse> actionListener) {
MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.fromActionRequest(request);
String modelId = mlModelDeleteRequest.getModelId();

DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_INDEX, modelId);
MLModelGetRequest mlModelGetRequest = new MLModelGetRequest(modelId, false);
FetchSourceContext fetchSourceContext = getFetchSourceContext(mlModelGetRequest.isReturnContent());
GetRequest getRequest = new GetRequest(ML_MODEL_INDEX).id(modelId).fetchSourceContext(fetchSourceContext);

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.delete(deleteRequest, new ActionListener<DeleteResponse>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
deleteModelChunks(modelId, deleteResponse, actionListener);
}
client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> {
if (r != null && r.isExists()) {
try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLModel mlModel = MLModel.parse(parser);
MLModelState mlModelState = mlModel.getModelState();
if (mlModelState.equals(MLModelState.LOADED)
|| mlModelState.equals(MLModelState.LOADING)
|| mlModelState.equals(MLModelState.PARTIALLY_LOADED)) {
actionListener
.onFailure(
new Exception("Model cannot be deleted in loading or loaded state. Try unloading first and then delete")
);
} else {
DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_INDEX, modelId);
client.delete(deleteRequest, new ActionListener<DeleteResponse>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
deleteModelChunks(modelId, deleteResponse, actionListener);
}

@Override
public void onFailure(Exception e) {
log.error("Failed to delete model meta data for model: " + modelId, e);
if (e instanceof ResourceNotFoundException) {
deleteModelChunks(modelId, null, actionListener);
@Override
public void onFailure(Exception e) {
log.error("Failed to delete model meta data for model: " + modelId, e);
if (e instanceof ResourceNotFoundException) {
deleteModelChunks(modelId, null, actionListener);
}
actionListener.onFailure(e);
}
});
}
} catch (Exception e) {
log.error("Failed to parse ml model" + r.getId(), e);
actionListener.onFailure(e);
}
actionListener.onFailure(e);
}
});
}, e -> { actionListener.onFailure(new MLResourceNotFoundException("Fail to find model")); }), () -> context.restore()));
} catch (Exception e) {
log.error("Failed to delete ML model " + modelId, e);
actionListener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,27 @@

package org.opensearch.ml.action.tasks;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX;
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;

import lombok.extern.log4j.Log4j2;

import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.delete.DeleteRequest;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.common.transport.task.MLTaskDeleteAction;
import org.opensearch.ml.common.transport.task.MLTaskDeleteRequest;
import org.opensearch.tasks.Task;
Expand All @@ -28,35 +36,62 @@ public class DeleteTaskTransportAction extends HandledTransportAction<ActionRequ

Client client;

NamedXContentRegistry xContentRegistry;

@Inject
public DeleteTaskTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
public DeleteTaskTransportAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
NamedXContentRegistry xContentRegistry
) {
super(MLTaskDeleteAction.NAME, transportService, actionFilters, MLTaskDeleteRequest::new);
this.client = client;
this.xContentRegistry = xContentRegistry;
}

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<DeleteResponse> actionListener) {
MLTaskDeleteRequest mlTaskDeleteRequest = MLTaskDeleteRequest.fromActionRequest(request);
String taskId = mlTaskDeleteRequest.getTaskId();

DeleteRequest deleteRequest = new DeleteRequest(ML_TASK_INDEX, taskId);
GetRequest getRequest = new GetRequest(ML_TASK_INDEX).id(taskId);

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.delete(deleteRequest, new ActionListener<DeleteResponse>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
log.debug("Completed Delete Task Request, task id:{} deleted", taskId);
actionListener.onResponse(deleteResponse);
}
client.get(getRequest, ActionListener.wrap(r -> {

if (r != null && r.isExists()) {
try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLTask mlTask = MLTask.parse(parser);
MLTaskState mlTaskState = mlTask.getState();
if (mlTaskState.equals(MLTaskState.RUNNING)) {
actionListener.onFailure(new Exception("Task cannot be deleted in running state. Try after sometime"));
} else {
DeleteRequest deleteRequest = new DeleteRequest(ML_TASK_INDEX, taskId);
client.delete(deleteRequest, new ActionListener<DeleteResponse>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
log.debug("Completed Delete Task Request, task id:{} deleted", taskId);
actionListener.onResponse(deleteResponse);
}

@Override
public void onFailure(Exception e) {
log.error("Failed to delete ML Task " + taskId, e);
actionListener.onFailure(e);
@Override
public void onFailure(Exception e) {
log.error("Failed to delete ML Task " + taskId, e);
actionListener.onFailure(e);
}
});
}
} catch (Exception e) {
log.error("Failed to parse ML task " + taskId, e);
actionListener.onFailure(e);
}
} else {
actionListener.onFailure(new MLResourceNotFoundException("Fail to find task"));
}
});
}, e -> { actionListener.onFailure(new MLResourceNotFoundException("Fail to find task")); }));
} catch (Exception e) {
log.error("Failed to delete ML task " + taskId, e);
log.error("Failed to delete ml task " + taskId, e);
actionListener.onFailure(e);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,25 @@
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.ActionListener;
import org.opensearch.action.bulk.BulkItemResponse;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.client.Client;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.index.get.GetResult;
import org.opensearch.index.reindex.BulkByScrollResponse;
import org.opensearch.index.reindex.ScrollableHitSource;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -62,27 +72,31 @@ public class DeleteModelTransportActionTests extends OpenSearchTestCase {
@Mock
BulkByScrollResponse bulkByScrollResponse;

@Mock
NamedXContentRegistry xContentRegistry;

@Rule
public ExpectedException exceptionRule = ExpectedException.none();

DeleteModelTransportAction deleteModelTransportAction;
MLModelDeleteRequest mlModelDeleteRequest;
ThreadContext threadContext;
MLModel model;

@Before
public void setup() throws IOException {
MockitoAnnotations.openMocks(this);

mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId("test_id").build();
deleteModelTransportAction = spy(new DeleteModelTransportAction(transportService, actionFilters, client));
deleteModelTransportAction = spy(new DeleteModelTransportAction(transportService, actionFilters, client, xContentRegistry));

Settings settings = Settings.builder().build();
threadContext = new ThreadContext(settings);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
}

public void testDeleteModel_Success() {
public void testDeleteModel_Success() throws IOException {
doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onResponse(deleteResponse);
Expand All @@ -96,10 +110,74 @@ public void testDeleteModel_Success() {
return null;
}).when(client).execute(any(), any(), any());

GetResponse getResponse = prepareMLModel(MLModelState.UPLOADED);
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
return null;
}).when(client).get(any(), any());

deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
verify(actionListener).onResponse(deleteResponse);
}

public void testDeleteModel_CheckModelState() throws IOException {
GetResponse getResponse = prepareMLModel(MLModelState.LOADING);
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
return null;
}).when(client).get(any(), any());

deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(
"Model cannot be deleted in loading or loaded state. Try unloading first and then delete",
argumentCaptor.getValue().getMessage()
);
}

public void testDeleteModel_ModelNotFoundException() throws IOException {
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onFailure(new Exception());
return null;
}).when(client).get(any(), any());

deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("Fail to find model", argumentCaptor.getValue().getMessage());
}

public void testDeleteModel_ResourceNotFoundException() throws IOException {
doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onFailure(new ResourceNotFoundException("errorMessage"));
return null;
}).when(client).delete(any(), any());

doAnswer(invocation -> {
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null);
listener.onResponse(response);
return null;
}).when(client).execute(any(), any(), any());

GetResponse getResponse = prepareMLModel(MLModelState.UPLOADED);
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
return null;
}).when(client).get(any(), any());

deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
ArgumentCaptor<ResourceNotFoundException> argumentCaptor = ArgumentCaptor.forClass(ResourceNotFoundException.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("errorMessage", argumentCaptor.getValue().getMessage());
}

public void testDeleteModelChunks_Success() {
when(bulkByScrollResponse.getBulkFailures()).thenReturn(null);
doAnswer(invocation -> {
Expand All @@ -112,7 +190,14 @@ public void testDeleteModelChunks_Success() {
verify(actionListener).onResponse(deleteResponse);
}

public void testDeleteModel_RuntimeException() {
public void testDeleteModel_RuntimeException() throws IOException {
GetResponse getResponse = prepareMLModel(MLModelState.UPLOADED);
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
return null;
}).when(client).get(any(), any());

doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onFailure(new RuntimeException("errorMessage"));
Expand Down Expand Up @@ -198,4 +283,13 @@ public void test_FailToDeleteAllModelChunks_SearchFailure() {
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(OS_STATUS_EXCEPTION_MESSAGE + ", " + SEARCH_FAILURE_MSG + "test_id", argumentCaptor.getValue().getMessage());
}

public GetResponse prepareMLModel(MLModelState mlModelState) throws IOException {
MLModel mlModel = MLModel.builder().modelId("test_id").modelState(mlModelState).build();
XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
BytesReference bytesReference = BytesReference.bytes(content);
GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null);
GetResponse getResponse = new GetResponse(getResult);
return getResponse;
}
}
Loading

0 comments on commit 3697d03

Please sign in to comment.