Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] [main] Undeploy code change: add edge case for models that are marked as not found in cache #3520 #3551

Open
wants to merge 1 commit into
base: 2.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
add edge case for models that are marked as not found in cache (#3523)
There is a code change that requires to check the response of the model undeploy response object to check that the model has been marked as not found on all nodes.

Signed-off-by: Brian Flores <iflorbri@amazon.com>
(cherry picked from commit c5ceb48)
  • Loading branch information
brianf-aws authored and github-actions[bot] committed Feb 14, 2025
commit ac2714751cf9bfdb41bca8edee52210aca773b9a
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
package org.opensearch.ml.action.undeploy;

import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.common.CommonValue.NOT_FOUND;

import java.time.Instant;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import org.opensearch.ExceptionsHelper;
Expand Down Expand Up @@ -198,7 +200,19 @@ private void undeployModels(
* Having this change enables a check that this edge case occurs along with having access to the model id
* allowing us to update the stale model index correctly to `UNDEPLOYED` since no nodes service the model.
*/
if (response.getNodes().isEmpty()) {
boolean modelNotFoundInNodesCache = response.getNodes().stream().allMatch(nodeResponse -> {
Map<String, String> status = nodeResponse.getModelUndeployStatus();
if (status == null)
return false;
// Stream is used to catch all models edge case but only one is ever undeployed
boolean modelCacheMissForModelIds = Arrays.stream(modelIds).allMatch(modelId -> {
String modelStatus = status.get(modelId);
return modelStatus != null && modelStatus.equalsIgnoreCase(NOT_FOUND);
});

return modelCacheMissForModelIds;
});
if (response.getNodes().isEmpty() || modelNotFoundInNodesCache) {
bulkSetModelIndexToUndeploy(modelIds, listener, response);
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.common.CommonValue.NOT_FOUND;
import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -348,6 +350,63 @@ public void testHiddenModelSuccess() {
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
}

public void testDoExecute_bulkRequestFired_WhenModelNotFoundInAllNodes() {
MLModel mlModel = MLModel
.builder()
.user(User.parse(USER_STRING))
.modelGroupId("111")
.version("111")
.name(this.modelIds[0])
.modelId(this.modelIds[0])
.algorithm(FunctionName.BATCH_RCF)
.content("content")
.totalChunks(2)
.isHidden(true)
.build();

// Mock MLModel manager response
doAnswer(invocation -> {
ActionListener<MLModel> listener = invocation.getArgument(4);
listener.onResponse(mlModel);
return null;
}).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class));

doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client);

List<MLUndeployModelNodeResponse> responseList = new ArrayList<>();

for (String nodeId : this.nodeIds) {
Map<String, String> stats = new HashMap<>();
stats.put(this.modelIds[0], NOT_FOUND);
MLUndeployModelNodeResponse nodeResponse = mock(MLUndeployModelNodeResponse.class);
when(nodeResponse.getModelUndeployStatus()).thenReturn(stats);
responseList.add(nodeResponse);
}

List<FailedNodeException> failuresList = new ArrayList<>();
MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList);

doAnswer(invocation -> {
ActionListener<MLUndeployModelNodesResponse> listener = invocation.getArgument(2);
listener.onResponse(nodesResponse);
return null;
}).when(client).execute(any(), any(), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<BulkResponse> listener = invocation.getArgument(1);
listener.onResponse(mock(BulkResponse.class));
return null;
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));

MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null);

transportUndeployModelsAction.doExecute(task, request, actionListener);

// Verify that bulk request was fired because all nodes reported "not_found"
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
}

public void testHiddenModelPermissionError() {
MLModel mlModel = MLModel
.builder()
Expand Down
Loading