Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

auto deployment for remote models #2206

Merged
merged 4 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,12 @@ public static FunctionName from(String value) {
public static boolean isDLModel(FunctionName functionName) {
return DL_MODELS.contains(functionName);
}

public static boolean needDeployFirst(FunctionName functionName) {
return DL_MODELS.contains(functionName) || functionName == REMOTE;
}

public static boolean isAutoDeployEnabled(boolean autoDeploymentEnabled, FunctionName functionName) {
return autoDeploymentEnabled && functionName == FunctionName.REMOTE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ public class MLModel implements ToXContentObject {
private Integer totalChunks; // model chunk doc only
private Integer planningWorkerNodeCount; // plan to deploy model to how many nodes
private Integer currentWorkerNodeCount; // model is deployed to how many nodes

private String[] planningWorkerNodes; // plan to deploy model to these nodes
private boolean deployToAllNodes;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,4 @@ public static MLDeployModelRequest fromActionRequest(ActionRequest actionRequest

}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -210,12 +210,12 @@ private void deployModel(
Set<String> allEligibleNodeIds = Arrays.stream(allEligibleNodes).map(DiscoveryNode::getId).collect(Collectors.toSet());

List<DiscoveryNode> eligibleNodes = new ArrayList<>();
List<String> nodeIds = new ArrayList<>();
List<String> eligibleNodeIds = new ArrayList<>();
if (!deployToAllNodes) {
for (String nodeId : targetNodeIds) {
if (allEligibleNodeIds.contains(nodeId)) {
eligibleNodes.add(nodeMapping.get(nodeId));
nodeIds.add(nodeId);
eligibleNodeIds.add(nodeId);
}
}
String[] workerNodes = mlModelManager.getWorkerNodes(modelId, mlModel.getAlgorithm());
Expand All @@ -237,15 +237,15 @@ private void deployModel(
}
}
} else {
nodeIds.addAll(allEligibleNodeIds);
eligibleNodeIds.addAll(allEligibleNodeIds);
eligibleNodes.addAll(Arrays.asList(allEligibleNodes));
}
if (nodeIds.size() == 0) {
if (eligibleNodeIds.size() == 0) {
wrappedListener.onFailure(new IllegalArgumentException("no eligible node found"));
return;
}

log.info("Will deploy model on these nodes: {}", String.join(",", nodeIds));
log.info("Will deploy model on these nodes: {}", String.join(",", eligibleNodeIds));
String localNodeId = clusterService.localNode().getId();

FunctionName algorithm = mlModel.getAlgorithm();
Expand All @@ -261,18 +261,18 @@ private void deployModel(
.createTime(Instant.now())
.lastUpdateTime(Instant.now())
.state(MLTaskState.CREATED)
.workerNodes(nodeIds)
.workerNodes(eligibleNodeIds)
.build();
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
String taskId = response.getId();
mlTask.setTaskId(taskId);
if (algorithm == FunctionName.REMOTE) {
mlTaskManager.add(mlTask, nodeIds);
mlTaskManager.add(mlTask, eligibleNodeIds);
deployRemoteModel(mlModel, mlTask, localNodeId, eligibleNodes, deployToAllNodes, listener);
return;
}
try {
mlTaskManager.add(mlTask, nodeIds);
mlTaskManager.add(mlTask, eligibleNodeIds);
wrappedListener.onResponse(new MLDeployModelResponse(taskId, MLTaskType.DEPLOY_MODEL, MLTaskState.CREATED.name()));
threadPool
.executor(DEPLOY_THREAD_POOL)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,19 @@ private void deployModel(
try {
log.debug("start deploying model {}", modelId);
mlModelManager
.deployModel(modelId, modelContentHash, functionName, deployToAllNodes, mlTask, ActionListener.runBefore(listener, () -> {
if (!coordinatingNodeId.equals(localNodeId)) {
mlTaskManager.remove(mlTask.getTaskId());
}
}));
.deployModel(
modelId,
modelContentHash,
functionName,
deployToAllNodes,
false,
mlTask,
ActionListener.runBefore(listener, () -> {
if (!coordinatingNodeId.equals(localNodeId)) {
mlTaskManager.remove(mlTask.getTaskId());
}
})
);
} catch (Exception e) {
logException("Failed to deploy model " + modelId, e, log);
listener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
import org.opensearch.action.ActionRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.autoredeploy.MLModelAutoReDeployer;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
Expand Down Expand Up @@ -163,7 +165,16 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw
updateFields.put(MLModel.AUTO_REDEPLOY_RETRY_TIMES_FIELD, 0);
}
log.info("deploy model done with state: {}, model id: {}", modelState, modelId);
mlModelManager.updateModel(modelId, updateFields);
ActionListener updateModelListener = ActionListener.<UpdateResponse>wrap(response -> {
if (response.status() == RestStatus.OK) {
log.debug("Updated ML model successfully: {}, model id: {}", response.status(), modelId);
} else {
log.error("Failed to update ML model {}, status: {}", modelId, response.status());
}
}, e -> { log.error("Failed to update ML model: " + modelId, e); });
mlModelManager.updateModel(modelId, updateFields, ActionListener.runBefore(updateModelListener, () -> {
mlModelManager.removeAutoDeployModel(modelId);
}));
}
listener.onResponse(new MLForwardResponse("ok", null));
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@

package org.opensearch.ml.action.prediction;

import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
Expand All @@ -37,7 +40,7 @@
import lombok.extern.log4j.Log4j2;

@Log4j2
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@FieldDefaults(level = AccessLevel.PRIVATE)
public class TransportPredictionTaskAction extends HandledTransportAction<ActionRequest, MLTaskResponse> {
MLTaskRunner<MLPredictionTaskRequest, MLTaskResponse> mlPredictTaskRunner;
TransportService transportService;
Expand All @@ -53,6 +56,8 @@ public class TransportPredictionTaskAction extends HandledTransportAction<Action

ModelAccessControlHelper modelAccessControlHelper;

private volatile boolean enableAutomaticDeployment;

@Inject
public TransportPredictionTaskAction(
TransportService transportService,
Expand All @@ -63,7 +68,8 @@ public TransportPredictionTaskAction(
Client client,
NamedXContentRegistry xContentRegistry,
MLModelManager mlModelManager,
ModelAccessControlHelper modelAccessControlHelper
ModelAccessControlHelper modelAccessControlHelper,
Settings settings
) {
super(MLPredictionTaskAction.NAME, transportService, actionFilters, MLPredictionTaskRequest::new);
this.mlPredictTaskRunner = mlPredictTaskRunner;
Expand All @@ -74,6 +80,10 @@ public TransportPredictionTaskAction(
this.xContentRegistry = xContentRegistry;
this.mlModelManager = mlModelManager;
this.modelAccessControlHelper = modelAccessControlHelper;
enableAutomaticDeployment = ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE.get(settings);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE, it -> enableAutomaticDeployment = it);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,21 +74,21 @@ public DiscoveryNode[] getEligibleNodes(FunctionName functionName) {
continue;
}
if (functionName == FunctionName.REMOTE) {// remote model
getEligibleNodes(remoteModelEligibleNodeRoles, eligibleNodes, node);
getEligibleNode(remoteModelEligibleNodeRoles, eligibleNodes, node);
} else { // local model
if (onlyRunOnMLNode) {
if (MLNodeUtils.isMLNode(node)) {
eligibleNodes.add(node);
}
} else {
getEligibleNodes(localModelEligibleNodeRoles, eligibleNodes, node);
getEligibleNode(localModelEligibleNodeRoles, eligibleNodes, node);
}
}
}
return eligibleNodes.toArray(new DiscoveryNode[0]);
}

private void getEligibleNodes(Set<String> allowedNodeRoles, Set<DiscoveryNode> eligibleNodes, DiscoveryNode node) {
private void getEligibleNode(Set<String> allowedNodeRoles, Set<DiscoveryNode> eligibleNodes, DiscoveryNode node) {
if (allowedNodeRoles.contains("data") && isEligibleDataNode(node)) {
eligibleNodes.add(node);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,13 @@
@Log4j2
public class MLModelCacheHelper {
private final Map<String, MLModelCache> modelCaches;

private final Map<String, MLModel> autoDeployModels;
private volatile Long maxRequestCount;

public MLModelCacheHelper(ClusterService clusterService, Settings settings) {
this.modelCaches = new ConcurrentHashMap<>();
this.autoDeployModels = new ConcurrentHashMap<>();

maxRequestCount = ML_COMMONS_MONITORING_REQUEST_COUNT.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MONITORING_REQUEST_COUNT, it -> maxRequestCount = it);
Expand Down Expand Up @@ -68,6 +71,25 @@ public synchronized void initModelState(
modelCaches.put(modelId, modelCache);
}

public synchronized void initModelStateLocal(
String modelId,
MLModelState state,
FunctionName functionName,
List<String> targetWorkerNodes
) {
log.debug("init local model deployment state for model {}, state: {}", modelId, state);
if (isModelRunningOnNode(modelId)) {
// model state initialized
return;
}
MLModelCache modelCache = new MLModelCache();
modelCache.setModelState(state);
modelCache.setFunctionName(functionName);
modelCache.setTargetWorkerNodes(targetWorkerNodes);
modelCache.setDeployToAllNodes(false);
modelCaches.put(modelId, modelCache);
}

/**
* Set model state
*
Expand Down Expand Up @@ -393,6 +415,7 @@ public void removeModel(String modelId) {
modelCache.clear();
modelCaches.remove(modelId);
}
autoDeployModels.remove(modelId);
}

/**
Expand Down Expand Up @@ -625,4 +648,18 @@ private MLModelCache getOrCreateModelCache(String modelId) {
return modelCaches.computeIfAbsent(modelId, it -> new MLModelCache());
}

public MLModel addModelToAutoDeployCache(String modelId, MLModel model) {
MLModel addedModel = autoDeployModels.computeIfAbsent(modelId, key -> model);
if (addedModel == model) {
log.info("Add model {} to auto deploy cache", modelId);
}
return addedModel;
}

public void removeAutoDeployModel(String modelId) {
MLModel removedModel = autoDeployModels.remove(modelId);
if (removedModel != null) {
log.info("Remove model {} from auto deploy cache", modelId);
}
}
}
Loading
Loading