From 4f87254382b41e93aa4da59a176e619176263a02 Mon Sep 17 00:00:00 2001 From: Xun Zhang Date: Mon, 18 Mar 2024 22:28:24 -0700 Subject: [PATCH] auto deployment for remote models (#2206) * auto deployment for remote models Signed-off-by: Xun Zhang * add auto deploy feature flag Signed-off-by: Xun Zhang * add eligible node check and avoid over-deployment Signed-off-by: Xun Zhang * dispatch local deploy Signed-off-by: Xun Zhang --------- Signed-off-by: Xun Zhang --- .../opensearch/ml/common/FunctionName.java | 8 ++ .../org/opensearch/ml/common/MLModel.java | 1 - .../deploy/MLDeployModelRequest.java | 2 +- .../deploy/TransportDeployModelAction.java | 16 ++-- .../TransportDeployModelOnNodeAction.java | 18 ++-- .../forward/TransportForwardAction.java | 13 ++- .../TransportPredictionTaskAction.java | 14 +++- .../ml/cluster/DiscoveryNodeHelper.java | 6 +- .../ml/model/MLModelCacheHelper.java | 37 ++++++++ .../opensearch/ml/model/MLModelManager.java | 58 ++++++++++++- .../ml/plugin/MachineLearningPlugin.java | 6 +- .../ml/settings/MLCommonsSettings.java | 3 + .../ml/task/MLPredictTaskRunner.java | 84 ++++++++++++++++++- ...TransportDeployModelOnNodeActionTests.java | 12 +-- .../TransportPredictionTaskActionTests.java | 22 +++-- .../ml/model/MLModelManagerTests.java | 18 ++-- .../ml/task/MLPredictTaskRunnerTests.java | 12 ++- 17 files changed, 276 insertions(+), 54 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/FunctionName.java b/common/src/main/java/org/opensearch/ml/common/FunctionName.java index 2c02b3e13d..a2c900f6cc 100644 --- a/common/src/main/java/org/opensearch/ml/common/FunctionName.java +++ b/common/src/main/java/org/opensearch/ml/common/FunctionName.java @@ -52,4 +52,12 @@ public static FunctionName from(String value) { public static boolean isDLModel(FunctionName functionName) { return DL_MODELS.contains(functionName); } + + public static boolean needDeployFirst(FunctionName functionName) { + return DL_MODELS.contains(functionName) || functionName == REMOTE; + } + + public static boolean isAutoDeployEnabled(boolean autoDeploymentEnabled, FunctionName functionName) { + return autoDeploymentEnabled && functionName == FunctionName.REMOTE; + } } diff --git a/common/src/main/java/org/opensearch/ml/common/MLModel.java b/common/src/main/java/org/opensearch/ml/common/MLModel.java index cec2805891..bd341d036c 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModel.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModel.java @@ -116,7 +116,6 @@ public class MLModel implements ToXContentObject { private Integer totalChunks; // model chunk doc only private Integer planningWorkerNodeCount; // plan to deploy model to how many nodes private Integer currentWorkerNodeCount; // model is deployed to how many nodes - private String[] planningWorkerNodes; // plan to deploy model to these nodes private boolean deployToAllNodes; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelRequest.java index b0ad113d95..7279dbed57 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelRequest.java @@ -116,4 +116,4 @@ public static MLDeployModelRequest fromActionRequest(ActionRequest actionRequest } -} +} \ No newline at end of file diff --git a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java index 8d1c4f706e..bd54c0ef60 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java @@ -207,12 +207,12 @@ private void deployModel( Set allEligibleNodeIds = Arrays.stream(allEligibleNodes).map(DiscoveryNode::getId).collect(Collectors.toSet()); List eligibleNodes = new ArrayList<>(); - List nodeIds = new ArrayList<>(); + List eligibleNodeIds = new ArrayList<>(); if (!deployToAllNodes) { for (String nodeId : targetNodeIds) { if (allEligibleNodeIds.contains(nodeId)) { eligibleNodes.add(nodeMapping.get(nodeId)); - nodeIds.add(nodeId); + eligibleNodeIds.add(nodeId); } } String[] workerNodes = mlModelManager.getWorkerNodes(modelId, mlModel.getAlgorithm()); @@ -234,15 +234,15 @@ private void deployModel( } } } else { - nodeIds.addAll(allEligibleNodeIds); + eligibleNodeIds.addAll(allEligibleNodeIds); eligibleNodes.addAll(Arrays.asList(allEligibleNodes)); } - if (nodeIds.size() == 0) { + if (eligibleNodeIds.size() == 0) { wrappedListener.onFailure(new IllegalArgumentException("no eligible node found")); return; } - log.info("Will deploy model on these nodes: {}", String.join(",", nodeIds)); + log.info("Will deploy model on these nodes: {}", String.join(",", eligibleNodeIds)); String localNodeId = clusterService.localNode().getId(); FunctionName algorithm = mlModel.getAlgorithm(); @@ -258,18 +258,18 @@ private void deployModel( .createTime(Instant.now()) .lastUpdateTime(Instant.now()) .state(MLTaskState.CREATED) - .workerNodes(nodeIds) + .workerNodes(eligibleNodeIds) .build(); mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> { String taskId = response.getId(); mlTask.setTaskId(taskId); if (algorithm == FunctionName.REMOTE) { - mlTaskManager.add(mlTask, nodeIds); + mlTaskManager.add(mlTask, eligibleNodeIds); deployRemoteModel(mlModel, mlTask, localNodeId, eligibleNodes, deployToAllNodes, listener); return; } try { - mlTaskManager.add(mlTask, nodeIds); + mlTaskManager.add(mlTask, eligibleNodeIds); wrappedListener.onResponse(new MLDeployModelResponse(taskId, MLTaskType.DEPLOY_MODEL, MLTaskState.CREATED.name())); threadPool .executor(DEPLOY_THREAD_POOL) diff --git a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java index bf8c81756b..495ea771f2 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java @@ -222,11 +222,19 @@ private void deployModel( try { log.debug("start deploying model {}", modelId); mlModelManager - .deployModel(modelId, modelContentHash, functionName, deployToAllNodes, mlTask, ActionListener.runBefore(listener, () -> { - if (!coordinatingNodeId.equals(localNodeId)) { - mlTaskManager.remove(mlTask.getTaskId()); - } - })); + .deployModel( + modelId, + modelContentHash, + functionName, + deployToAllNodes, + false, + mlTask, + ActionListener.runBefore(listener, () -> { + if (!coordinatingNodeId.equals(localNodeId)) { + mlTaskManager.remove(mlTask.getTaskId()); + } + }) + ); } catch (Exception e) { logException("Failed to deploy model " + modelId, e, log); listener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java b/plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java index 01dfa690bb..276ce1774e 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java @@ -20,12 +20,14 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.autoredeploy.MLModelAutoReDeployer; import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.FunctionName; @@ -163,7 +165,16 @@ protected void doExecute(Task task, ActionRequest request, ActionListenerwrap(response -> { + if (response.status() == RestStatus.OK) { + log.debug("Updated ML model successfully: {}, model id: {}", response.status(), modelId); + } else { + log.error("Failed to update ML model {}, status: {}", modelId, response.status()); + } + }, e -> { log.error("Failed to update ML model: " + modelId, e); }); + mlModelManager.updateModel(modelId, updateFields, ActionListener.runBefore(updateModelListener, () -> { + mlModelManager.removeAutoDeployModel(modelId); + })); } listener.onResponse(new MLForwardResponse("ok", null)); break; diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java index 4bf66564d9..4a3f77a838 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java @@ -5,6 +5,8 @@ package org.opensearch.ml.action.prediction; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE; + import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.support.ActionFilters; @@ -12,6 +14,7 @@ import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; @@ -37,7 +40,7 @@ import lombok.extern.log4j.Log4j2; @Log4j2 -@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@FieldDefaults(level = AccessLevel.PRIVATE) public class TransportPredictionTaskAction extends HandledTransportAction { MLTaskRunner mlPredictTaskRunner; TransportService transportService; @@ -53,6 +56,8 @@ public class TransportPredictionTaskAction extends HandledTransportAction enableAutomaticDeployment = it); } @Override diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/DiscoveryNodeHelper.java b/plugin/src/main/java/org/opensearch/ml/cluster/DiscoveryNodeHelper.java index 5476e3f520..5b06236d54 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/DiscoveryNodeHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/DiscoveryNodeHelper.java @@ -74,21 +74,21 @@ public DiscoveryNode[] getEligibleNodes(FunctionName functionName) { continue; } if (functionName == FunctionName.REMOTE) {// remote model - getEligibleNodes(remoteModelEligibleNodeRoles, eligibleNodes, node); + getEligibleNode(remoteModelEligibleNodeRoles, eligibleNodes, node); } else { // local model if (onlyRunOnMLNode) { if (MLNodeUtils.isMLNode(node)) { eligibleNodes.add(node); } } else { - getEligibleNodes(localModelEligibleNodeRoles, eligibleNodes, node); + getEligibleNode(localModelEligibleNodeRoles, eligibleNodes, node); } } } return eligibleNodes.toArray(new DiscoveryNode[0]); } - private void getEligibleNodes(Set allowedNodeRoles, Set eligibleNodes, DiscoveryNode node) { + private void getEligibleNode(Set allowedNodeRoles, Set eligibleNodes, DiscoveryNode node) { if (allowedNodeRoles.contains("data") && isEligibleDataNode(node)) { eligibleNodes.add(node); } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java index 99ccc9cce1..5cf7c6e09f 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java @@ -32,10 +32,13 @@ @Log4j2 public class MLModelCacheHelper { private final Map modelCaches; + + private final Map autoDeployModels; private volatile Long maxRequestCount; public MLModelCacheHelper(ClusterService clusterService, Settings settings) { this.modelCaches = new ConcurrentHashMap<>(); + this.autoDeployModels = new ConcurrentHashMap<>(); maxRequestCount = ML_COMMONS_MONITORING_REQUEST_COUNT.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MONITORING_REQUEST_COUNT, it -> maxRequestCount = it); @@ -67,6 +70,25 @@ public synchronized void initModelState( modelCaches.put(modelId, modelCache); } + public synchronized void initModelStateLocal( + String modelId, + MLModelState state, + FunctionName functionName, + List targetWorkerNodes + ) { + log.debug("init local model deployment state for model {}, state: {}", modelId, state); + if (isModelRunningOnNode(modelId)) { + // model state initialized + return; + } + MLModelCache modelCache = new MLModelCache(); + modelCache.setModelState(state); + modelCache.setFunctionName(functionName); + modelCache.setTargetWorkerNodes(targetWorkerNodes); + modelCache.setDeployToAllNodes(false); + modelCaches.put(modelId, modelCache); + } + /** * Set model state * @@ -358,6 +380,7 @@ public void removeModel(String modelId) { modelCache.clear(); modelCaches.remove(modelId); } + autoDeployModels.remove(modelId); } /** @@ -590,4 +613,18 @@ private MLModelCache getOrCreateModelCache(String modelId) { return modelCaches.computeIfAbsent(modelId, it -> new MLModelCache()); } + public MLModel addModelToAutoDeployCache(String modelId, MLModel model) { + MLModel addedModel = autoDeployModels.computeIfAbsent(modelId, key -> model); + if (addedModel == model) { + log.info("Add model {} to auto deploy cache", modelId); + } + return addedModel; + } + + public void removeAutoDeployModel(String modelId) { + MLModel removedModel = autoDeployModels.remove(modelId); + if (removedModel != null) { + log.info("Remove model {} from auto deploy cache", modelId); + } + } } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 2d20167f54..a2c3ce2645 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -50,6 +50,7 @@ import java.nio.file.Path; import java.security.PrivilegedActionException; import java.time.Instant; +import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; import java.util.Collections; @@ -64,6 +65,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Supplier; +import java.util.stream.Collectors; import org.apache.commons.lang3.BooleanUtils; import org.apache.logging.log4j.util.Strings; @@ -78,6 +80,7 @@ import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.TokenBucket; @@ -934,6 +937,7 @@ public void deployModel( String modelContentHash, FunctionName functionName, boolean deployToAllNodes, + boolean autoDeployModel, MLTask mlTask, ActionListener listener ) { @@ -943,7 +947,7 @@ public void deployModel( mlStats.createModelCounterStatIfAbsent(modelId, ActionName.DEPLOY, ML_ACTION_REQUEST_COUNT).increment(); List workerNodes = mlTask.getWorkerNodes(); if (modelCacheHelper.isModelDeployed(modelId)) { - if (workerNodes != null && workerNodes.size() > 0) { + if (!autoDeployModel && workerNodes != null && workerNodes.size() > 0) { log.info("Set new target node ids {} for model {}", Arrays.toString(workerNodes.toArray(new String[0])), modelId); modelCacheHelper.setDeployToAllNodes(modelId, deployToAllNodes); modelCacheHelper.setTargetWorkerNodes(modelId, workerNodes); @@ -956,10 +960,20 @@ public void deployModel( return; } int eligibleNodeCount = workerNodes.size(); - modelCacheHelper.initModelState(modelId, MLModelState.DEPLOYING, functionName, workerNodes, deployToAllNodes); + if (!autoDeployModel) { + modelCacheHelper.initModelState(modelId, MLModelState.DEPLOYING, functionName, workerNodes, deployToAllNodes); + } else { + modelCacheHelper.initModelStateLocal(modelId, MLModelState.DEPLOYING, functionName, workerNodes); + } + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); - checkAndAddRunningTask(mlTask, maxDeployTasksPerNode); + ActionListener wrappedListener = ActionListener.runBefore(listener, () -> { + context.restore(); + modelCacheHelper.removeAutoDeployModel(modelId); + }); + if (!autoDeployModel) { + checkAndAddRunningTask(mlTask, maxDeployTasksPerNode); + } this.getModel(modelId, threadedActionListener(DEPLOY_THREAD_POOL, ActionListener.wrap(mlModel -> { modelCacheHelper.setIsModelEnabled(modelId, mlModel.getIsEnabled()); if (FunctionName.REMOTE == mlModel.getAlgorithm() @@ -1049,6 +1063,23 @@ public void deployModel( } } + public void deployRemoteModelToLocal(String modelId, MLModel mlModel, ActionListener listener) { + if (modelCacheHelper.isModelDeployed(modelId)) { + listener.onResponse("Success"); + return; + } + modelCacheHelper + .initModelState(modelId, MLModelState.DEPLOYING, FunctionName.REMOTE, new ArrayList<>(), mlModel.isDeployToAllNodes()); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); + modelCacheHelper.setIsModelEnabled(modelId, mlModel.getIsEnabled()); + deployRemoteOrBuiltInModel(mlModel, 1, wrappedListener); + } catch (Exception e) { + log.error("Failed to deploy model to local node" + modelId, e); + listener.onFailure(e); + } + } + private void deployRemoteOrBuiltInModel(MLModel mlModel, Integer eligibleNodeCount, ActionListener wrappedListener) { String modelId = mlModel.getModelId(); setupRateLimiter(modelId, eligibleNodeCount, mlModel.getRateLimiter()); @@ -1821,4 +1852,23 @@ public boolean isModelRunningOnNode(String modelId) { return modelCacheHelper.isModelRunningOnNode(modelId); } + public boolean isModelDeployed(String modelId) { + return modelCacheHelper.isModelDeployed(modelId); + } + + public boolean isNodeEligible(String nodeId, FunctionName functionName) { + Set allEligibleNodeIds = Arrays + .stream(nodeHelper.getEligibleNodes(functionName)) + .map(DiscoveryNode::getId) + .collect(Collectors.toSet()); + return allEligibleNodeIds.contains(nodeId); + } + + public MLModel addModelToAutoDeployCache(String modelId, MLModel model) { + return modelCacheHelper.addModelToAutoDeployCache(modelId, model); + } + + public void removeAutoDeployModel(String modelId) { + modelCacheHelper.removeAutoDeployModel(modelId); + } } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 8eb48340f6..e3fc93bfb4 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -523,7 +523,8 @@ public Collection createComponents( xContentRegistry, mlModelManager, nodeHelper, - mlEngine + mlEngine, + settings ); mlTrainAndPredictTaskRunner = new MLTrainAndPredictTaskRunner( threadPool, @@ -897,7 +898,8 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED, MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED, MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, - MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED + MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED, + MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE ); return settings; } diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 33c5b4b554..f96eb37e02 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -82,6 +82,9 @@ private MLCommonsSettings() {} public static final Setting ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN = Setting .boolSetting("plugins.ml_commons.allow_custom_deployment_plan", false, Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE = Setting + .boolSetting("plugins.ml_commons.model_auto_deploy.enable", true, Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting ML_COMMONS_MODEL_AUTO_REDEPLOY_ENABLE = Setting .boolSetting("plugins.ml_commons.model_auto_redeploy.enable", true, Setting.Property.NodeScope, Setting.Property.Dynamic); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 10904740ec..a6d82e979f 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -12,8 +12,10 @@ import static org.opensearch.ml.permission.AccessController.getUserContext; import static org.opensearch.ml.plugin.MachineLearningPlugin.PREDICT_THREAD_POOL; import static org.opensearch.ml.plugin.MachineLearningPlugin.REMOTE_PREDICT_THREAD_POOL; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE; import java.time.Instant; +import java.util.Arrays; import java.util.UUID; import org.opensearch.OpenSearchException; @@ -25,6 +27,7 @@ import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentType; @@ -45,6 +48,8 @@ import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLPredictionOutput; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; +import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.engine.MLEngine; @@ -76,6 +81,7 @@ public class MLPredictTaskRunner extends MLTaskRunner autoDeploymentEnabled = it); } @Override @@ -134,7 +145,35 @@ public void dispatchTask( }, e -> { listener.onFailure(e); }); String[] workerNodes = mlModelManager.getWorkerNodes(modelId, functionName, true); if (workerNodes == null || workerNodes.length == 0) { - if (functionName == FunctionName.TEXT_EMBEDDING || functionName == FunctionName.REMOTE) { + if (FunctionName.isAutoDeployEnabled(autoDeploymentEnabled, functionName)) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + mlModelManager.getModel(modelId, ActionListener.runBefore(ActionListener.wrap(model -> { + String[] planningWorkerNodes = model.getPlanningWorkerNodes(); + MLModel modelBeingAutoDeployed = mlModelManager.addModelToAutoDeployCache(modelId, model); + if (modelBeingAutoDeployed == model) { + log.info("Automatically deploy model {}", modelId); + MLDeployModelRequest deployModelRequest = new MLDeployModelRequest( + modelId, + planningWorkerNodes, + false, + true, + false + ); + client.execute(MLDeployModelAction.INSTANCE, deployModelRequest, ActionListener.wrap(r -> { + log.info("Auto deployment action triggered for model {}", modelId); + }, e -> { log.error("Auto deployment action failed for model " + modelId, e); })); + } + if (planningWorkerNodes == null || planningWorkerNodes.length == 0) { + planningWorkerNodes = nodeHelper.getEligibleNodeIds(functionName); + } + mlTaskDispatcher.dispatchPredictTask(planningWorkerNodes, actionListener); + }, e -> { + log.error("Failed to get model " + modelId, e); + listener.onFailure(e); + }), context::restore)); + } + return; + } else if (FunctionName.needDeployFirst(functionName)) { listener .onFailure( new IllegalArgumentException( @@ -145,6 +184,8 @@ public void dispatchTask( } else { workerNodes = nodeHelper.getEligibleNodeIds(functionName); } + } else { + mlModelManager.removeAutoDeployModel(modelId); } mlTaskDispatcher.dispatchPredictTask(workerNodes, actionListener); } catch (Exception e) { @@ -218,7 +259,42 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe mlTask.setState(MLTaskState.RUNNING); mlTaskManager.add(mlTask); - FunctionName algorithm = mlInput.getAlgorithm(); + FunctionName functionName = mlInput.getFunctionName(); + Predictable predictor = mlModelManager.getPredictor(modelId); + boolean modelReady = predictor != null && predictor.isModelReady(); + if (!modelReady && FunctionName.isAutoDeployEnabled(autoDeploymentEnabled, functionName)) { + log.info("Auto deploy model {} to local node", modelId); + Instant now = Instant.now(); + MLTask mlDeployTask = MLTask + .builder() + .taskId(UUID.randomUUID().toString()) + .functionName(functionName) + .async(false) + .taskType(MLTaskType.DEPLOY_MODEL) + .createTime(now) + .lastUpdateTime(now) + .state(MLTaskState.RUNNING) + .workerNodes(Arrays.asList(clusterService.localNode().getId())) + .build(); + mlModelManager.deployModel(modelId, null, functionName, false, true, mlDeployTask, ActionListener.wrap(s -> { + runPredict(modelId, mlTask, mlInput, functionName, internalListener); + }, e -> { + log.error("Failed to auto deploy model " + modelId, e); + internalListener.onFailure(e); + })); + return; + } + + runPredict(modelId, mlTask, mlInput, functionName, internalListener); + } + + private void runPredict( + String modelId, + MLTask mlTask, + MLInput mlInput, + FunctionName algorithm, + ActionListener internalListener + ) { // run predict if (modelId != null) { Predictable predictor = mlModelManager.getPredictor(modelId); @@ -241,7 +317,7 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe handlePredictFailure(mlTask, internalListener, e, false, modelId); return; } - } else if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.REMOTE) { + } else if (FunctionName.needDeployFirst(algorithm)) { throw new IllegalArgumentException("Model not ready to be used: " + modelId); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeActionTests.java index 7a10c10e27..83852cc68f 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeActionTests.java @@ -207,7 +207,7 @@ public void setup() throws IOException { ActionListener listener = invocation.getArgument(5); listener.onResponse("successful"); return null; - }).when(mlModelManager).deployModel(any(), any(), any(), any(Boolean.class), any(), any()); + }).when(mlModelManager).deployModel(any(), any(), any(), any(Boolean.class), any(), any(), any()); MLForwardResponse forwardResponse = Mockito.mock(MLForwardResponse.class); doAnswer(invocation -> { ActionListenerResponseHandler handler = invocation.getArgument(3); @@ -313,7 +313,7 @@ public void testNodeOperation_FailToSendForwardRequest() { ActionListener listener = invocation.getArgument(4); listener.onResponse("ok"); return null; - }).when(mlModelManager).deployModel(any(), any(), any(), any(Boolean.class), any(), any()); + }).when(mlModelManager).deployModel(any(), any(), any(), any(Boolean.class), any(), any(), any()); doAnswer(invocation -> { TransportResponseHandler handler = invocation.getArgument(3); handler.handleException(new TransportException("error")); @@ -331,7 +331,7 @@ public void testNodeOperation_Exception() { ActionListener listener = invocation.getArgument(4); listener.onFailure(new RuntimeException("Something went wrong")); return null; - }).when(mlModelManager).deployModel(any(), any(), any(), any(Boolean.class), any(), any()); + }).when(mlModelManager).deployModel(any(), any(), any(), any(Boolean.class), any(), any(), any()); final MLDeployModelNodesRequest nodesRequest = prepareRequest(localNode.getId()); final MLDeployModelNodeRequest request = action.newNodeRequest(nodesRequest); final MLDeployModelNodeResponse response = action.nodeOperation(request); @@ -340,7 +340,9 @@ public void testNodeOperation_Exception() { @Ignore public void testNodeOperation_DeployModelRuntimeException() { - doThrow(new RuntimeException("error")).when(mlModelManager).deployModel(any(), any(), any(), any(Boolean.class), any(), any()); + doThrow(new RuntimeException("error")) + .when(mlModelManager) + .deployModel(any(), any(), any(), any(Boolean.class), any(), any(), any()); final MLDeployModelNodesRequest nodesRequest = prepareRequest(localNode.getId()); final MLDeployModelNodeRequest request = action.newNodeRequest(nodesRequest); final MLDeployModelNodeResponse response = action.nodeOperation(request); @@ -353,7 +355,7 @@ public void testNodeOperation_MLLimitExceededException() { ActionListener listener = invocation.getArgument(4); listener.onFailure(new MLLimitExceededException("Limit exceeded exception")); return null; - }).when(mlModelManager).deployModel(any(), any(), any(), any(Boolean.class), any(), any()); + }).when(mlModelManager).deployModel(any(), any(), any(), any(Boolean.class), any(), any(), any()); final MLDeployModelNodesRequest nodesRequest = prepareRequest(localNode.getId()); final MLDeployModelNodeRequest request = action.newNodeRequest(nodesRequest); final MLDeployModelNodeResponse response = action.nodeOperation(request); diff --git a/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java index c714237695..461f2b834b 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java @@ -11,9 +11,12 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import org.junit.Before; import org.junit.Rule; @@ -25,6 +28,7 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; @@ -117,7 +121,15 @@ public void setup() { mlPredictionTaskRequest = MLPredictionTaskRequest.builder().modelId("test_id").mlInput(mlInput).user(user).build(); - Settings settings = Settings.builder().build(); + Settings settings = Settings.builder().put(ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE.getKey(), true).build(); + ClusterSettings clusterSettings = new ClusterSettings(settings, new HashSet<>(Arrays.asList(ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE))); + + threadContext = new ThreadContext(settings); + when(clusterService.getSettings()).thenReturn(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + transportPredictionTaskAction = spy( new TransportPredictionTaskAction( transportService, @@ -128,14 +140,10 @@ public void setup() { client, xContentRegistry, mlModelManager, - modelAccessControlHelper + modelAccessControlHelper, + settings ) ); - - threadContext = new ThreadContext(settings); - when(clusterService.getSettings()).thenReturn(settings); - when(client.threadPool()).thenReturn(threadPool); - when(threadPool.getThreadContext()).thenReturn(threadContext); } public void testPrediction_default_exception() { diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 60338e4ccd..760dc44417 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -576,7 +576,7 @@ public void testDeployModel_FailedToGetModel() { mock_threadpool(threadPool, taskExecutorService); mock_client_get_failure(client); mock_client_ThreadContext(client, threadPool, threadContext); - modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, mlTask, listener); + modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); assertFalse(modelManager.isModelRunningOnNode(modelId)); ArgumentCaptor exception = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(exception.capture()); @@ -617,7 +617,7 @@ public void testDeployModel_NullGetModelResponse() { when(modelCacheHelper.getLocalDeployedModels()).thenReturn(new String[] {}); mock_threadpool(threadPool, taskExecutorService); mock_client_get_NullResponse(client); - modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, mlTask, listener); + modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); assertFalse(modelManager.isModelRunningOnNode(modelId)); ArgumentCaptor exception = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(exception.capture()); @@ -658,7 +658,7 @@ public void testDeployModel_GetModelResponse_NotExist() { when(modelCacheHelper.getLocalDeployedModels()).thenReturn(new String[] {}); mock_threadpool(threadPool, taskExecutorService); mock_client_get_NotExist(client); - modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, mlTask, listener); + modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); assertFalse(modelManager.isModelRunningOnNode(modelId)); ArgumentCaptor exception = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(exception.capture()); @@ -703,7 +703,7 @@ public void testDeployModel_GetModelResponse_wrong_hash_value() { setUpMock_GetModel(model); setUpMock_GetModel(modelChunk0); setUpMock_GetModel(modelChunk0); - modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, mlTask, listener); + modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); assertFalse(modelManager.isModelRunningOnNode(modelId)); ArgumentCaptor exception = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(exception.capture()); @@ -753,7 +753,7 @@ public void testDeployModel_GetModelResponse_FailedToDeploy() { setUpMock_GetModelChunks(model); // setUpMock_GetModel(modelChunk0); // setUpMock_GetModel(modelChunk1); - modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, mlTask, listener); + modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); assertFalse(modelManager.isModelRunningOnNode(modelId)); ArgumentCaptor exception = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(exception.capture()); @@ -769,7 +769,7 @@ public void testDeployModel_GetModelResponse_FailedToDeploy() { public void testDeployModel_ModelAlreadyDeployed() { when(modelCacheHelper.isModelDeployed(modelId)).thenReturn(true); ActionListener listener = mock(ActionListener.class); - modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, mlTask, listener); + modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); ArgumentCaptor response = ArgumentCaptor.forClass(String.class); verify(listener).onResponse(response.capture()); assertEquals("successful", response.getValue()); @@ -784,7 +784,7 @@ public void testDeployModel_ExceedMaxDeployedModel() { when(modelCacheHelper.getDeployedModels()).thenReturn(models); when(modelCacheHelper.getLocalDeployedModels()).thenReturn(models); ActionListener listener = mock(ActionListener.class); - modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, mlTask, listener); + modelManager.deployModel(modelId, modelContentHashValue, FunctionName.TEXT_EMBEDDING, true, false, mlTask, listener); ArgumentCaptor failure = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(failure.capture()); assertEquals("Exceed max local model per node limit", failure.getValue().getMessage()); @@ -819,7 +819,7 @@ public void testDeployModel_ThreadPoolException() { ActionListener listener = mock(ActionListener.class); FunctionName functionName = FunctionName.TEXT_EMBEDDING; - modelManager.deployModel(modelId, modelContentHashValue, functionName, true, mlTask, listener); + modelManager.deployModel(modelId, modelContentHashValue, functionName, true, false, mlTask, listener); verify(modelCacheHelper).removeModel(eq(modelId)); verify(mlStats).createCounterStatIfAbsent(eq(functionName), eq(ActionName.DEPLOY), eq(MLActionLevelStat.ML_ACTION_FAILURE_COUNT)); } @@ -978,7 +978,7 @@ private void testDeployModel_FailedToRetrieveModelChunks(boolean lastChunk) { ActionListener listener = mock(ActionListener.class); FunctionName functionName = FunctionName.TEXT_EMBEDDING; - modelManager.deployModel(modelId, modelContentHashValue, functionName, true, mlTask, listener); + modelManager.deployModel(modelId, modelContentHashValue, functionName, true, false, mlTask, listener); verify(modelCacheHelper).removeModel(eq(modelId)); verify(mlStats).createCounterStatIfAbsent(eq(functionName), eq(ActionName.DEPLOY), eq(MLActionLevelStat.ML_ACTION_REQUEST_COUNT)); verify(mlStats).getStat(eq(MLNodeLevelStat.ML_REQUEST_COUNT)); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index 13de526978..ac33c6c76d 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -9,9 +9,12 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.*; import static org.mockito.Mockito.spy; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE; import java.io.IOException; import java.nio.file.Path; +import java.util.Arrays; +import java.util.HashSet; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; @@ -27,6 +30,7 @@ import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; @@ -145,6 +149,10 @@ public void setup() throws IOException { stats.put(MLNodeLevelStat.ML_REQUEST_COUNT, new MLStat<>(false, new CounterSupplier())); stats.put(MLNodeLevelStat.ML_FAILURE_COUNT, new MLStat<>(false, new CounterSupplier())); stats.put(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT, new MLStat<>(false, new CounterSupplier())); + + Settings settings = Settings.builder().put(ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE.getKey(), true).build(); + ClusterSettings clusterSettings = new ClusterSettings(settings, new HashSet<>(Arrays.asList(ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE))); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); this.mlStats = new MLStats(stats); mlInputDatasetHandler = spy(new MLInputDatasetHandler(client)); taskRunner = spy( @@ -160,7 +168,8 @@ public void setup() throws IOException { xContentRegistry(), mlModelManager, nodeHelper, - mlEngine + mlEngine, + settings ) ); @@ -188,7 +197,6 @@ public void setup() throws IOException { requestWithQuery = MLPredictionTaskRequest.builder().modelId("111").mlInput(mlInputWithQuery).build(); when(client.threadPool()).thenReturn(threadPool); - Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING); when(client.threadPool()).thenReturn(threadPool);