Skip to content

Commit

Permalink
auto deployment for remote models (opensearch-project#2206)
Browse files Browse the repository at this point in the history
* auto deployment for remote models

Signed-off-by: Xun Zhang <xunzh@amazon.com>

* add auto deploy feature flag

Signed-off-by: Xun Zhang <xunzh@amazon.com>

* add eligible node check and avoid over-deployment

Signed-off-by: Xun Zhang <xunzh@amazon.com>

* dispatch local deploy

Signed-off-by: Xun Zhang <xunzh@amazon.com>

---------

Signed-off-by: Xun Zhang <xunzh@amazon.com>
  • Loading branch information
Zhangxunmt committed Mar 21, 2024
1 parent 045915c commit 4f87254
Show file tree
Hide file tree
Showing 17 changed files with 276 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,12 @@ public static FunctionName from(String value) {
public static boolean isDLModel(FunctionName functionName) {
return DL_MODELS.contains(functionName);
}

public static boolean needDeployFirst(FunctionName functionName) {
return DL_MODELS.contains(functionName) || functionName == REMOTE;
}

public static boolean isAutoDeployEnabled(boolean autoDeploymentEnabled, FunctionName functionName) {
return autoDeploymentEnabled && functionName == FunctionName.REMOTE;
}
}
1 change: 0 additions & 1 deletion common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ public class MLModel implements ToXContentObject {
private Integer totalChunks; // model chunk doc only
private Integer planningWorkerNodeCount; // plan to deploy model to how many nodes
private Integer currentWorkerNodeCount; // model is deployed to how many nodes

private String[] planningWorkerNodes; // plan to deploy model to these nodes
private boolean deployToAllNodes;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,4 @@ public static MLDeployModelRequest fromActionRequest(ActionRequest actionRequest

}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -207,12 +207,12 @@ private void deployModel(
Set<String> allEligibleNodeIds = Arrays.stream(allEligibleNodes).map(DiscoveryNode::getId).collect(Collectors.toSet());

List<DiscoveryNode> eligibleNodes = new ArrayList<>();
List<String> nodeIds = new ArrayList<>();
List<String> eligibleNodeIds = new ArrayList<>();
if (!deployToAllNodes) {
for (String nodeId : targetNodeIds) {
if (allEligibleNodeIds.contains(nodeId)) {
eligibleNodes.add(nodeMapping.get(nodeId));
nodeIds.add(nodeId);
eligibleNodeIds.add(nodeId);
}
}
String[] workerNodes = mlModelManager.getWorkerNodes(modelId, mlModel.getAlgorithm());
Expand All @@ -234,15 +234,15 @@ private void deployModel(
}
}
} else {
nodeIds.addAll(allEligibleNodeIds);
eligibleNodeIds.addAll(allEligibleNodeIds);
eligibleNodes.addAll(Arrays.asList(allEligibleNodes));
}
if (nodeIds.size() == 0) {
if (eligibleNodeIds.size() == 0) {
wrappedListener.onFailure(new IllegalArgumentException("no eligible node found"));
return;
}

log.info("Will deploy model on these nodes: {}", String.join(",", nodeIds));
log.info("Will deploy model on these nodes: {}", String.join(",", eligibleNodeIds));
String localNodeId = clusterService.localNode().getId();

FunctionName algorithm = mlModel.getAlgorithm();
Expand All @@ -258,18 +258,18 @@ private void deployModel(
.createTime(Instant.now())
.lastUpdateTime(Instant.now())
.state(MLTaskState.CREATED)
.workerNodes(nodeIds)
.workerNodes(eligibleNodeIds)
.build();
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
String taskId = response.getId();
mlTask.setTaskId(taskId);
if (algorithm == FunctionName.REMOTE) {
mlTaskManager.add(mlTask, nodeIds);
mlTaskManager.add(mlTask, eligibleNodeIds);
deployRemoteModel(mlModel, mlTask, localNodeId, eligibleNodes, deployToAllNodes, listener);
return;
}
try {
mlTaskManager.add(mlTask, nodeIds);
mlTaskManager.add(mlTask, eligibleNodeIds);
wrappedListener.onResponse(new MLDeployModelResponse(taskId, MLTaskType.DEPLOY_MODEL, MLTaskState.CREATED.name()));
threadPool
.executor(DEPLOY_THREAD_POOL)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,19 @@ private void deployModel(
try {
log.debug("start deploying model {}", modelId);
mlModelManager
.deployModel(modelId, modelContentHash, functionName, deployToAllNodes, mlTask, ActionListener.runBefore(listener, () -> {
if (!coordinatingNodeId.equals(localNodeId)) {
mlTaskManager.remove(mlTask.getTaskId());
}
}));
.deployModel(
modelId,
modelContentHash,
functionName,
deployToAllNodes,
false,
mlTask,
ActionListener.runBefore(listener, () -> {
if (!coordinatingNodeId.equals(localNodeId)) {
mlTaskManager.remove(mlTask.getTaskId());
}
})
);
} catch (Exception e) {
logException("Failed to deploy model " + modelId, e, log);
listener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
import org.opensearch.action.ActionRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.autoredeploy.MLModelAutoReDeployer;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
Expand Down Expand Up @@ -163,7 +165,16 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw
updateFields.put(MLModel.AUTO_REDEPLOY_RETRY_TIMES_FIELD, 0);
}
log.info("deploy model done with state: {}, model id: {}", modelState, modelId);
mlModelManager.updateModel(modelId, updateFields);
ActionListener updateModelListener = ActionListener.<UpdateResponse>wrap(response -> {
if (response.status() == RestStatus.OK) {
log.debug("Updated ML model successfully: {}, model id: {}", response.status(), modelId);
} else {
log.error("Failed to update ML model {}, status: {}", modelId, response.status());
}
}, e -> { log.error("Failed to update ML model: " + modelId, e); });
mlModelManager.updateModel(modelId, updateFields, ActionListener.runBefore(updateModelListener, () -> {
mlModelManager.removeAutoDeployModel(modelId);
}));
}
listener.onResponse(new MLForwardResponse("ok", null));
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@

package org.opensearch.ml.action.prediction;

import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
Expand All @@ -37,7 +40,7 @@
import lombok.extern.log4j.Log4j2;

@Log4j2
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@FieldDefaults(level = AccessLevel.PRIVATE)
public class TransportPredictionTaskAction extends HandledTransportAction<ActionRequest, MLTaskResponse> {
MLTaskRunner<MLPredictionTaskRequest, MLTaskResponse> mlPredictTaskRunner;
TransportService transportService;
Expand All @@ -53,6 +56,8 @@ public class TransportPredictionTaskAction extends HandledTransportAction<Action

ModelAccessControlHelper modelAccessControlHelper;

private volatile boolean enableAutomaticDeployment;

@Inject
public TransportPredictionTaskAction(
TransportService transportService,
Expand All @@ -63,7 +68,8 @@ public TransportPredictionTaskAction(
Client client,
NamedXContentRegistry xContentRegistry,
MLModelManager mlModelManager,
ModelAccessControlHelper modelAccessControlHelper
ModelAccessControlHelper modelAccessControlHelper,
Settings settings
) {
super(MLPredictionTaskAction.NAME, transportService, actionFilters, MLPredictionTaskRequest::new);
this.mlPredictTaskRunner = mlPredictTaskRunner;
Expand All @@ -74,6 +80,10 @@ public TransportPredictionTaskAction(
this.xContentRegistry = xContentRegistry;
this.mlModelManager = mlModelManager;
this.modelAccessControlHelper = modelAccessControlHelper;
enableAutomaticDeployment = ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE.get(settings);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE, it -> enableAutomaticDeployment = it);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,21 +74,21 @@ public DiscoveryNode[] getEligibleNodes(FunctionName functionName) {
continue;
}
if (functionName == FunctionName.REMOTE) {// remote model
getEligibleNodes(remoteModelEligibleNodeRoles, eligibleNodes, node);
getEligibleNode(remoteModelEligibleNodeRoles, eligibleNodes, node);
} else { // local model
if (onlyRunOnMLNode) {
if (MLNodeUtils.isMLNode(node)) {
eligibleNodes.add(node);
}
} else {
getEligibleNodes(localModelEligibleNodeRoles, eligibleNodes, node);
getEligibleNode(localModelEligibleNodeRoles, eligibleNodes, node);
}
}
}
return eligibleNodes.toArray(new DiscoveryNode[0]);
}

private void getEligibleNodes(Set<String> allowedNodeRoles, Set<DiscoveryNode> eligibleNodes, DiscoveryNode node) {
private void getEligibleNode(Set<String> allowedNodeRoles, Set<DiscoveryNode> eligibleNodes, DiscoveryNode node) {
if (allowedNodeRoles.contains("data") && isEligibleDataNode(node)) {
eligibleNodes.add(node);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@
@Log4j2
public class MLModelCacheHelper {
private final Map<String, MLModelCache> modelCaches;

private final Map<String, MLModel> autoDeployModels;
private volatile Long maxRequestCount;

public MLModelCacheHelper(ClusterService clusterService, Settings settings) {
this.modelCaches = new ConcurrentHashMap<>();
this.autoDeployModels = new ConcurrentHashMap<>();

maxRequestCount = ML_COMMONS_MONITORING_REQUEST_COUNT.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MONITORING_REQUEST_COUNT, it -> maxRequestCount = it);
Expand Down Expand Up @@ -67,6 +70,25 @@ public synchronized void initModelState(
modelCaches.put(modelId, modelCache);
}

public synchronized void initModelStateLocal(
String modelId,
MLModelState state,
FunctionName functionName,
List<String> targetWorkerNodes
) {
log.debug("init local model deployment state for model {}, state: {}", modelId, state);
if (isModelRunningOnNode(modelId)) {
// model state initialized
return;
}
MLModelCache modelCache = new MLModelCache();
modelCache.setModelState(state);
modelCache.setFunctionName(functionName);
modelCache.setTargetWorkerNodes(targetWorkerNodes);
modelCache.setDeployToAllNodes(false);
modelCaches.put(modelId, modelCache);
}

/**
* Set model state
*
Expand Down Expand Up @@ -358,6 +380,7 @@ public void removeModel(String modelId) {
modelCache.clear();
modelCaches.remove(modelId);
}
autoDeployModels.remove(modelId);
}

/**
Expand Down Expand Up @@ -590,4 +613,18 @@ private MLModelCache getOrCreateModelCache(String modelId) {
return modelCaches.computeIfAbsent(modelId, it -> new MLModelCache());
}

public MLModel addModelToAutoDeployCache(String modelId, MLModel model) {
MLModel addedModel = autoDeployModels.computeIfAbsent(modelId, key -> model);
if (addedModel == model) {
log.info("Add model {} to auto deploy cache", modelId);
}
return addedModel;
}

public void removeAutoDeployModel(String modelId) {
MLModel removedModel = autoDeployModels.remove(modelId);
if (removedModel != null) {
log.info("Remove model {} from auto deploy cache", modelId);
}
}
}
Loading

0 comments on commit 4f87254

Please sign in to comment.