Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

check state before deleting model or task #725

Merged
merged 3 commits into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

package org.opensearch.ml.action.models;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.common.MLModel.MODEL_ID_FIELD;
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;
import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext;

import lombok.AccessLevel;
import lombok.experimental.FieldDefaults;
Expand All @@ -18,18 +21,26 @@
import org.opensearch.action.ActionRequest;
import org.opensearch.action.delete.DeleteRequest;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.index.query.TermsQueryBuilder;
import org.opensearch.index.reindex.BulkByScrollResponse;
import org.opensearch.index.reindex.DeleteByQueryAction;
import org.opensearch.index.reindex.DeleteByQueryRequest;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
import org.opensearch.ml.common.transport.model.MLModelGetRequest;
import org.opensearch.rest.RestStatus;
import org.opensearch.search.fetch.subphase.FetchSourceContext;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

Expand All @@ -44,36 +55,66 @@ public class DeleteModelTransportAction extends HandledTransportAction<ActionReq
static final String SEARCH_FAILURE_MSG = "Search failure while deleting model of ";
static final String OS_STATUS_EXCEPTION_MESSAGE = "Failed to delete all model chunks";
Client client;
NamedXContentRegistry xContentRegistry;

@Inject
public DeleteModelTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
public DeleteModelTransportAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
NamedXContentRegistry xContentRegistry
) {
super(MLModelDeleteAction.NAME, transportService, actionFilters, MLModelDeleteRequest::new);
this.client = client;
this.xContentRegistry = xContentRegistry;
}

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<DeleteResponse> actionListener) {
MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.fromActionRequest(request);
String modelId = mlModelDeleteRequest.getModelId();

DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_INDEX, modelId);
MLModelGetRequest mlModelGetRequest = new MLModelGetRequest(modelId, false);
FetchSourceContext fetchSourceContext = getFetchSourceContext(mlModelGetRequest.isReturnContent());
GetRequest getRequest = new GetRequest(ML_MODEL_INDEX).id(modelId).fetchSourceContext(fetchSourceContext);

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.delete(deleteRequest, new ActionListener<DeleteResponse>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
deleteModelChunks(modelId, deleteResponse, actionListener);
}
client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> {
if (r != null && r.isExists()) {
try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLModel mlModel = MLModel.parse(parser);
MLModelState mlModelState = mlModel.getModelState();
if (mlModelState.equals(MLModelState.LOADED)
|| mlModelState.equals(MLModelState.LOADING)
|| mlModelState.equals(MLModelState.PARTIALLY_LOADED)) {
actionListener
.onFailure(
new Exception("Model cannot be deleted in loading or loaded state. Try unloading first and then delete")
);
} else {
DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_INDEX, modelId);
client.delete(deleteRequest, new ActionListener<DeleteResponse>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
deleteModelChunks(modelId, deleteResponse, actionListener);
}

@Override
public void onFailure(Exception e) {
log.error("Failed to delete model meta data for model: " + modelId, e);
if (e instanceof ResourceNotFoundException) {
deleteModelChunks(modelId, null, actionListener);
@Override
public void onFailure(Exception e) {
log.error("Failed to delete model meta data for model: " + modelId, e);
if (e instanceof ResourceNotFoundException) {
deleteModelChunks(modelId, null, actionListener);
}
actionListener.onFailure(e);
}
});
}
} catch (Exception e) {
log.error("Failed to parse ml model" + r.getId(), e);
actionListener.onFailure(e);
}
actionListener.onFailure(e);
}
});
}, e -> { actionListener.onFailure(new MLResourceNotFoundException("Fail to find model")); }), () -> context.restore()));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From line 81, we are going to restore context before running listener to delete model. Will it fail on security enabled cluster?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested it on security enabled cluster and it works. Adding test results here.
Screenshot 2023-02-20 at 1 13 16 PM

} catch (Exception e) {
log.error("Failed to delete ML model " + modelId, e);
actionListener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,27 @@

package org.opensearch.ml.action.tasks;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX;
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;

import lombok.extern.log4j.Log4j2;

import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.delete.DeleteRequest;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.common.transport.task.MLTaskDeleteAction;
import org.opensearch.ml.common.transport.task.MLTaskDeleteRequest;
import org.opensearch.tasks.Task;
Expand All @@ -28,35 +36,62 @@ public class DeleteTaskTransportAction extends HandledTransportAction<ActionRequ

Client client;

NamedXContentRegistry xContentRegistry;

@Inject
public DeleteTaskTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
public DeleteTaskTransportAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
NamedXContentRegistry xContentRegistry
) {
super(MLTaskDeleteAction.NAME, transportService, actionFilters, MLTaskDeleteRequest::new);
this.client = client;
this.xContentRegistry = xContentRegistry;
}

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<DeleteResponse> actionListener) {
MLTaskDeleteRequest mlTaskDeleteRequest = MLTaskDeleteRequest.fromActionRequest(request);
String taskId = mlTaskDeleteRequest.getTaskId();

DeleteRequest deleteRequest = new DeleteRequest(ML_TASK_INDEX, taskId);
GetRequest getRequest = new GetRequest(ML_TASK_INDEX).id(taskId);

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.delete(deleteRequest, new ActionListener<DeleteResponse>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
log.debug("Completed Delete Task Request, task id:{} deleted", taskId);
actionListener.onResponse(deleteResponse);
}
client.get(getRequest, ActionListener.wrap(r -> {

if (r != null && r.isExists()) {
try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLTask mlTask = MLTask.parse(parser);
MLTaskState mlTaskState = mlTask.getState();
if (mlTaskState.equals(MLTaskState.RUNNING)) {
actionListener.onFailure(new Exception("Task cannot be deleted in running state. Try after sometime"));
} else {
DeleteRequest deleteRequest = new DeleteRequest(ML_TASK_INDEX, taskId);
client.delete(deleteRequest, new ActionListener<DeleteResponse>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
log.debug("Completed Delete Task Request, task id:{} deleted", taskId);
actionListener.onResponse(deleteResponse);
}

@Override
public void onFailure(Exception e) {
log.error("Failed to delete ML Task " + taskId, e);
actionListener.onFailure(e);
@Override
public void onFailure(Exception e) {
log.error("Failed to delete ML Task " + taskId, e);
actionListener.onFailure(e);
}
});
}
} catch (Exception e) {
log.error("Failed to parse ML task " + taskId, e);
actionListener.onFailure(e);
}
} else {
actionListener.onFailure(new MLResourceNotFoundException("Fail to find task"));
}
});
}, e -> { actionListener.onFailure(new MLResourceNotFoundException("Fail to find task")); }));
} catch (Exception e) {
log.error("Failed to delete ML task " + taskId, e);
log.error("Failed to delete ml task " + taskId, e);
actionListener.onFailure(e);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,25 @@
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.ActionListener;
import org.opensearch.action.bulk.BulkItemResponse;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.client.Client;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.index.get.GetResult;
import org.opensearch.index.reindex.BulkByScrollResponse;
import org.opensearch.index.reindex.ScrollableHitSource;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -62,27 +72,31 @@ public class DeleteModelTransportActionTests extends OpenSearchTestCase {
@Mock
BulkByScrollResponse bulkByScrollResponse;

@Mock
NamedXContentRegistry xContentRegistry;

@Rule
public ExpectedException exceptionRule = ExpectedException.none();

DeleteModelTransportAction deleteModelTransportAction;
MLModelDeleteRequest mlModelDeleteRequest;
ThreadContext threadContext;
MLModel model;

@Before
public void setup() throws IOException {
MockitoAnnotations.openMocks(this);

mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId("test_id").build();
deleteModelTransportAction = spy(new DeleteModelTransportAction(transportService, actionFilters, client));
deleteModelTransportAction = spy(new DeleteModelTransportAction(transportService, actionFilters, client, xContentRegistry));

Settings settings = Settings.builder().build();
threadContext = new ThreadContext(settings);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
}

public void testDeleteModel_Success() {
public void testDeleteModel_Success() throws IOException {
doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onResponse(deleteResponse);
Expand All @@ -96,10 +110,74 @@ public void testDeleteModel_Success() {
return null;
}).when(client).execute(any(), any(), any());

GetResponse getResponse = prepareMLModel(MLModelState.UPLOADED);
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
return null;
}).when(client).get(any(), any());

deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
verify(actionListener).onResponse(deleteResponse);
}

public void testDeleteModel_CheckModelState() throws IOException {
GetResponse getResponse = prepareMLModel(MLModelState.LOADING);
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
return null;
}).when(client).get(any(), any());

deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(
"Model cannot be deleted in loading or loaded state. Try unloading first and then delete",
argumentCaptor.getValue().getMessage()
);
}

public void testDeleteModel_ModelNotFoundException() throws IOException {
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onFailure(new Exception());
return null;
}).when(client).get(any(), any());

deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("Fail to find model", argumentCaptor.getValue().getMessage());
}

public void testDeleteModel_ResourceNotFoundException() throws IOException {
doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onFailure(new ResourceNotFoundException("errorMessage"));
return null;
}).when(client).delete(any(), any());

doAnswer(invocation -> {
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null);
listener.onResponse(response);
return null;
}).when(client).execute(any(), any(), any());

GetResponse getResponse = prepareMLModel(MLModelState.UPLOADED);
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
return null;
}).when(client).get(any(), any());

deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
ArgumentCaptor<ResourceNotFoundException> argumentCaptor = ArgumentCaptor.forClass(ResourceNotFoundException.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("errorMessage", argumentCaptor.getValue().getMessage());
}

public void testDeleteModelChunks_Success() {
when(bulkByScrollResponse.getBulkFailures()).thenReturn(null);
doAnswer(invocation -> {
Expand All @@ -112,7 +190,14 @@ public void testDeleteModelChunks_Success() {
verify(actionListener).onResponse(deleteResponse);
}

public void testDeleteModel_RuntimeException() {
public void testDeleteModel_RuntimeException() throws IOException {
GetResponse getResponse = prepareMLModel(MLModelState.UPLOADED);
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
return null;
}).when(client).get(any(), any());

doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onFailure(new RuntimeException("errorMessage"));
Expand Down Expand Up @@ -198,4 +283,13 @@ public void test_FailToDeleteAllModelChunks_SearchFailure() {
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(OS_STATUS_EXCEPTION_MESSAGE + ", " + SEARCH_FAILURE_MSG + "test_id", argumentCaptor.getValue().getMessage());
}

public GetResponse prepareMLModel(MLModelState mlModelState) throws IOException {
MLModel mlModel = MLModel.builder().modelId("test_id").modelState(mlModelState).build();
XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
BytesReference bytesReference = BytesReference.bytes(content);
GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null);
GetResponse getResponse = new GetResponse(getResult);
return getResponse;
}
}
Loading