Skip to content

Commit

Permalink
fix model stuck in deploying state during node crash/cluster restart (#…
Browse files Browse the repository at this point in the history
…3137)

Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com>
(cherry picked from commit bb6339f)
  • Loading branch information
rbhavna committed Jan 10, 2025
1 parent 589e24b commit 9048e7f
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
import static org.opensearch.ml.utils.MLExceptionUtils.logException;
import static org.opensearch.ml.utils.MLExceptionUtils.toJsonString;
import static org.opensearch.ml.utils.RestActionUtils.getAllNodes;

import java.time.Instant;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.support.ActionFilters;
Expand Down Expand Up @@ -131,10 +135,29 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw
syncModelWorkerNodes(modelId, functionName);
}

if (workNodes == null || workNodes.size() == 0) {
Set<String> workNodesRemovedFromCluster = new HashSet<>();

if (workNodes != null && !workNodes.isEmpty()) {
Set<String> allNodesInCluster = new HashSet<>(List.of(getAllNodes(clusterService)));

workNodesRemovedFromCluster = workNodes
.stream()
.filter(node -> !allNodesInCluster.contains(node))
.collect(Collectors.toSet());

if (!workNodesRemovedFromCluster.isEmpty()) {
workNodes.removeAll(workNodesRemovedFromCluster);
}
}

if (workNodes == null || workNodes.isEmpty()) {
if (!workNodesRemovedFromCluster.isEmpty()) {
mlTaskCache.updateWorkerNode(workNodesRemovedFromCluster);
mlModelManager.removeModelWorkerNode(modelId, false, workNodesRemovedFromCluster.toArray(new String[0]));
}
int currentWorkerNodeCount = mlTaskCache.getWorkerNodeSize();
MLTaskState taskState = mlTaskCache.hasError() ? MLTaskState.COMPLETED_WITH_ERROR : MLTaskState.COMPLETED;
if (mlTaskCache.allNodeFailed()) {
if (mlTaskCache.allNodeFailed() || mlTaskCache.getWorkerNodeSize() == 0) {
taskState = MLTaskState.FAILED;
currentWorkerNodeCount = 0;
} else {
Expand All @@ -150,11 +173,11 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw
mlTaskManager.updateMLTask(taskId, builder.build(), TASK_SEMAPHORE_TIMEOUT, true);

MLModelState modelState;
if (!mlTaskCache.allNodeFailed()) {
modelState = mlTaskCache.hasError() ? MLModelState.PARTIALLY_DEPLOYED : MLModelState.DEPLOYED;
} else {
if (mlTaskCache.allNodeFailed() || mlTaskCache.getWorkerNodeSize() == 0) {
modelState = MLModelState.DEPLOY_FAILED;
log.error("deploy model failed on all nodes, model id: {}", modelId);
} else {
modelState = mlTaskCache.hasError() ? MLModelState.PARTIALLY_DEPLOYED : MLModelState.DEPLOYED;
}
Map<String, Object> updateFields = new HashMap<>();
updateFields.put(MLModel.MODEL_STATE_FIELD, modelState);
Expand Down
5 changes: 5 additions & 0 deletions plugin/src/main/java/org/opensearch/ml/task/MLTaskCache.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,9 @@ public int errorNodesCount() {
public boolean allNodeFailed() {
return workerNodeSize != null && errors.size() == workerNodeSize;
}

public void updateWorkerNode(Set<String> nodesRemovedFromCluster) {
this.workerNodes.removeAll(nodesRemovedFromCluster);
this.workerNodeSize = this.workerNodeSize - nodesRemovedFromCluster.size();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ONLY_RUN_ON_ML_NODE;
import static org.opensearch.ml.utils.TestHelper.ML_ROLE;
import static org.opensearch.ml.utils.TestHelper.clusterSetting;
import static org.opensearch.ml.utils.TestHelper.setupTestClusterState;

import java.util.Arrays;
import java.util.HashSet;
Expand All @@ -43,6 +44,7 @@
import org.opensearch.Version;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.client.Client;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
Expand Down Expand Up @@ -94,6 +96,8 @@ public class TransportForwardActionTests extends OpenSearchTestCase {

private TransportForwardAction forwardAction;

private ClusterState testState;

Settings settings = Settings
.builder()
.put(ML_COMMONS_MODEL_AUTO_REDEPLOY_ENABLE.getKey(), true)
Expand Down Expand Up @@ -137,6 +141,9 @@ public void setup() {
)
);

testState = setupTestClusterState("test_node_id2");
when(clusterService.state()).thenReturn(testState);

node1 = new DiscoveryNode(nodeId1, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT);
node2 = new DiscoveryNode(nodeId2, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ public void setup() throws IOException {
encryptor = spy(new EncryptorImpl(null));
syncUpCron = new MLSyncUpCron(client, clusterService, nodeHelper, mlIndicesHandler, encryptor);

testState = setupTestClusterState();
testState = setupTestClusterState("node");
when(clusterService.state()).thenReturn(testState);

doAnswer(invocation -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ public void setup() throws IOException {
.build();

clusterName = new ClusterName("test cluster");
testState = setupTestClusterState();
testState = setupTestClusterState("node");
when(clusterService.state()).thenReturn(testState);

doAnswer(invocation -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ public void setup() throws IOException {
roleSet,
Version.CURRENT
);
testState = setupTestClusterState();
testState = setupTestClusterState("node");
when(clusterService.state()).thenReturn(testState);

clusterName = new ClusterName(clusterNameStr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public class RestMLUndeployModelActionTests extends OpenSearchTestCase {
@Before
public void setup() {
MockitoAnnotations.openMocks(this);
testState = setupTestClusterState();
testState = setupTestClusterState("node");
when(clusterService.state()).thenReturn(testState);
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
restMLUndeployModelAction = new RestMLUndeployModelAction(clusterService, settings);
Expand Down
4 changes: 2 additions & 2 deletions plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -424,11 +424,11 @@ public static ClusterState state(int numDataNodes, String indexName, String mapp
return state(new ClusterName("test"), indexName, mapping, clusterManagerNode, clusterManagerNode, allNodes);
}

public static ClusterState setupTestClusterState() {
public static ClusterState setupTestClusterState(String nodeId) {
Set<DiscoveryNodeRole> roleSet = new HashSet<>();
roleSet.add(DiscoveryNodeRole.DATA_ROLE);
DiscoveryNode node = new DiscoveryNode(
"node",
nodeId,
new TransportAddress(TransportAddress.META_ADDRESS, new AtomicInteger().incrementAndGet()),
new HashMap<>(),
roleSet,
Expand Down

0 comments on commit 9048e7f

Please sign in to comment.