Skip to content

Commit

Permalink
add allow custom deployment plan setting; add deploy to all nodes fie…
Browse files Browse the repository at this point in the history
…ld in model index (#818)

Signed-off-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
ylwu-amzn authored Mar 22, 2023
1 parent aac0926 commit 4df13e3
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public class CommonValue {

public static final String ML_MODEL_INDEX = ".plugins-ml-model";
public static final String ML_TASK_INDEX = ".plugins-ml-task";
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 3;
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 4;
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 1;
public static final String USER_FIELD_MAPPING = " \""
+ CommonValue.USER
Expand Down Expand Up @@ -94,6 +94,9 @@ public class CommonValue {
+ MLModel.PLANNING_WORKER_NODES_FIELD
+ "\": {\"type\": \"keyword\"},\n"
+ " \""
+ MLModel.DEPLOY_TO_ALL_NODES_FIELD
+ "\": {\"type\": \"boolean\"},\n"
+ " \""
+ MLModel.MODEL_CONFIG_FIELD
+ "\" : {\"properties\":{\""
+ MODEL_TYPE_FIELD + "\":{\"type\":\"keyword\"},\""
Expand Down
16 changes: 15 additions & 1 deletion common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ public class MLModel implements ToXContentObject {
public static final String PLANNING_WORKER_NODE_COUNT_FIELD = "planning_worker_node_count";
public static final String CURRENT_WORKER_NODE_COUNT_FIELD = "current_worker_node_count";
public static final String PLANNING_WORKER_NODES_FIELD = "planning_worker_nodes";
public static final String DEPLOY_TO_ALL_NODES_FIELD = "deploy_to_all_nodes";

private String name;
private FunctionName algorithm;
Expand Down Expand Up @@ -85,6 +86,7 @@ public class MLModel implements ToXContentObject {
private Integer currentWorkerNodeCount; // model is deployed to how many nodes

private String[] planningWorkerNodes; // plan to deploy model to these nodes
private boolean deployToAllNodes;
@Builder(toBuilder = true)
public MLModel(String name,
FunctionName algorithm,
Expand All @@ -106,7 +108,8 @@ public MLModel(String name,
Integer totalChunks,
Integer planningWorkerNodeCount,
Integer currentWorkerNodeCount,
String[] planningWorkerNodes) {
String[] planningWorkerNodes,
boolean deployToAllNodes) {
this.name = name;
this.algorithm = algorithm;
this.version = version;
Expand All @@ -129,6 +132,7 @@ public MLModel(String name,
this.planningWorkerNodeCount = planningWorkerNodeCount;
this.currentWorkerNodeCount = currentWorkerNodeCount;
this.planningWorkerNodes = planningWorkerNodes;
this.deployToAllNodes = deployToAllNodes;
}

public MLModel(StreamInput input) throws IOException{
Expand Down Expand Up @@ -165,6 +169,7 @@ public MLModel(StreamInput input) throws IOException{
planningWorkerNodeCount = input.readOptionalInt();
currentWorkerNodeCount = input.readOptionalInt();
planningWorkerNodes = input.readOptionalStringArray();
deployToAllNodes = input.readBoolean();
}
}

Expand Down Expand Up @@ -211,6 +216,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalInt(planningWorkerNodeCount);
out.writeOptionalInt(currentWorkerNodeCount);
out.writeOptionalStringArray(planningWorkerNodes);
out.writeBoolean(deployToAllNodes);
}

@Override
Expand Down Expand Up @@ -282,6 +288,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (planningWorkerNodes != null && planningWorkerNodes.length > 0) {
builder.field(PLANNING_WORKER_NODES_FIELD, planningWorkerNodes);
}
if (deployToAllNodes) {
builder.field(DEPLOY_TO_ALL_NODES_FIELD, deployToAllNodes);
}
builder.endObject();
return builder;
}
Expand Down Expand Up @@ -312,6 +321,7 @@ public static MLModel parse(XContentParser parser) throws IOException {
Integer planningWorkerNodeCount = null;
Integer currentWorkerNodeCount = null;
List<String> planningWorkerNodes = new ArrayList<>();
boolean deployToAllNodes = false;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -379,6 +389,9 @@ public static MLModel parse(XContentParser parser) throws IOException {
planningWorkerNodes.add(parser.text());
}
break;
case DEPLOY_TO_ALL_NODES_FIELD:
deployToAllNodes = parser.booleanValue();
break;
case CREATED_TIME_FIELD:
createdTime = Instant.ofEpochMilli(parser.longValue());
break;
Expand Down Expand Up @@ -422,6 +435,7 @@ public static MLModel parse(XContentParser parser) throws IOException {
.planningWorkerNodeCount(planningWorkerNodeCount)
.currentWorkerNodeCount(currentWorkerNodeCount)
.planningWorkerNodes(planningWorkerNodes.toArray(new String[0]))
.deployToAllNodes(deployToAllNodes)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
import static org.opensearch.ml.common.MLTaskState.FAILED;
import static org.opensearch.ml.plugin.MachineLearningPlugin.LOAD_THREAD_POOL;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN;
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;

import java.time.Instant;
Expand All @@ -31,6 +32,7 @@
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
Expand Down Expand Up @@ -75,6 +77,8 @@ public class TransportLoadModelAction extends HandledTransportAction<ActionReque
MLModelManager mlModelManager;
MLStats mlStats;

private volatile boolean allowCustomDeploymentPlan;

@Inject
public TransportLoadModelAction(
TransportService transportService,
Expand All @@ -88,7 +92,8 @@ public TransportLoadModelAction(
DiscoveryNodeHelper nodeFilter,
MLTaskDispatcher mlTaskDispatcher,
MLModelManager mlModelManager,
MLStats mlStats
MLStats mlStats,
Settings settings
) {
super(MLLoadModelAction.NAME, transportService, actionFilters, MLLoadModelRequest::new);
this.transportService = transportService;
Expand All @@ -102,13 +107,22 @@ public TransportLoadModelAction(
this.mlTaskDispatcher = mlTaskDispatcher;
this.mlModelManager = mlModelManager;
this.mlStats = mlStats;
allowCustomDeploymentPlan = ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN.get(settings);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN, it -> allowCustomDeploymentPlan = it);
}

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<LoadModelResponse> listener) {
MLLoadModelRequest deployModelRequest = MLLoadModelRequest.fromActionRequest(request);
String modelId = deployModelRequest.getModelId();
String[] targetNodeIds = deployModelRequest.getModelNodeIds();
boolean deployToAllNodes = targetNodeIds == null || targetNodeIds.length == 0;
if (!allowCustomDeploymentPlan && !deployToAllNodes) {
throw new IllegalArgumentException("Don't allow custom deployment plan");
}

// mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
DiscoveryNode[] allEligibleNodes = nodeFilter.getEligibleNodes();
Expand All @@ -121,7 +135,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<LoadMo

List<DiscoveryNode> eligibleNodes = new ArrayList<>();
List<String> nodeIds = new ArrayList<>();
if (targetNodeIds != null && targetNodeIds.length > 0) {
if (!deployToAllNodes) {
for (String nodeId : targetNodeIds) {
if (allEligibleNodeIds.contains(nodeId)) {
eligibleNodes.add(nodeMapping.get(nodeId));
Expand Down Expand Up @@ -189,7 +203,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<LoadMo
localNodeId,
mlTask,
eligibleNodes,
algorithm
deployToAllNodes
)
);
} catch (Exception ex) {
Expand Down Expand Up @@ -226,7 +240,7 @@ void updateModelLoadStatusAndTriggerOnNodesAction(
String localNodeId,
MLTask mlTask,
List<DiscoveryNode> eligibleNodes,
FunctionName algorithm
boolean deployToAllNodes
) {
LoadModelInput loadModelInput = new LoadModelInput(
modelId,
Expand Down Expand Up @@ -264,7 +278,9 @@ void updateModelLoadStatusAndTriggerOnNodesAction(
MLModel.PLANNING_WORKER_NODE_COUNT_FIELD,
eligibleNodes.size(),
MLModel.PLANNING_WORKER_NODES_FIELD,
workerNodes
workerNodes,
MLModel.DEPLOY_TO_ALL_NODES_FIELD,
deployToAllNodes
),
ActionListener
.wrap(
Expand Down
50 changes: 43 additions & 7 deletions plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Semaphore;
import java.util.stream.Collectors;

import lombok.extern.log4j.Log4j2;

Expand Down Expand Up @@ -174,6 +176,8 @@ void refreshModelState(Map<String, Set<String>> modelWorkerNodes, Map<String, Se
.fetchSource(
new String[] {
MLModel.MODEL_STATE_FIELD,
MLModel.DEPLOY_TO_ALL_NODES_FIELD,
MLModel.PLANNING_WORKER_NODES_FIELD,
MLModel.PLANNING_WORKER_NODE_COUNT_FIELD,
MLModel.LAST_UPDATED_TIME_FIELD,
MLModel.CURRENT_WORKER_NODE_COUNT_FIELD },
Expand All @@ -183,6 +187,7 @@ void refreshModelState(Map<String, Set<String>> modelWorkerNodes, Map<String, Se
client.search(searchRequest, ActionListener.wrap(res -> {
SearchHit[] hits = res.getHits().getHits();
Map<String, MLModelState> newModelStates = new HashMap<>();
Map<String, List<String>> newPlanningWorkerNodes = new HashMap<>();
for (SearchHit hit : hits) {
String modelId = hit.getId();
Map<String, Object> sourceAsMap = hit.getSourceAsMap();
Expand All @@ -196,6 +201,24 @@ void refreshModelState(Map<String, Set<String>> modelWorkerNodes, Map<String, Se
int currentWorkerNodeCountInIndex = sourceAsMap.containsKey(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD)
? (int) sourceAsMap.get(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD)
: 0;
boolean deployToAllNodes = sourceAsMap.containsKey(MLModel.DEPLOY_TO_ALL_NODES_FIELD)
? (boolean) sourceAsMap.get(MLModel.DEPLOY_TO_ALL_NODES_FIELD)
: false;
List<String> planningWorkNodes = sourceAsMap.containsKey(MLModel.PLANNING_WORKER_NODES_FIELD)
? (List<String>) sourceAsMap.get(MLModel.PLANNING_WORKER_NODES_FIELD)
: new ArrayList<>();
if (deployToAllNodes) {
DiscoveryNode[] eligibleNodes = nodeHelper.getEligibleNodes();
planningWorkerNodeCount = eligibleNodes.length;
List<String> eligibleNodeIds = Arrays
.asList(eligibleNodes)
.stream()
.map(n -> n.getId())
.collect(Collectors.toList());
if (eligibleNodeIds.size() != planningWorkNodes.size() || !eligibleNodeIds.containsAll(planningWorkNodes)) {
newPlanningWorkerNodes.put(modelId, eligibleNodeIds);
}
}
MLModelState mlModelState = getNewModelState(
loadingModels,
modelWorkerNodes,
Expand All @@ -209,7 +232,7 @@ void refreshModelState(Map<String, Set<String>> modelWorkerNodes, Map<String, Se
newModelStates.put(modelId, mlModelState);
}
}
bulkUpdateModelState(modelWorkerNodes, newModelStates);
bulkUpdateModelState(modelWorkerNodes, newModelStates, newPlanningWorkerNodes);
}, e -> {
updateModelStateSemaphore.release();
log.error("Failed to search models", e);
Expand Down Expand Up @@ -270,16 +293,29 @@ private MLModelState getNewModelState(
return null;
}

private void bulkUpdateModelState(Map<String, Set<String>> modelWorkerNodes, Map<String, MLModelState> newModelStates) {
if (newModelStates.size() > 0) {
private void bulkUpdateModelState(
Map<String, Set<String>> modelWorkerNodes,
Map<String, MLModelState> newModelStates,
Map<String, List<String>> newPlanningWorkNodes
) {
Set<String> updatedModelIds = new HashSet<>();
updatedModelIds.addAll(newModelStates.keySet());
updatedModelIds.addAll(newPlanningWorkNodes.keySet());

if (updatedModelIds.size() > 0) {
BulkRequest bulkUpdateRequest = new BulkRequest();
for (String modelId : newModelStates.keySet()) {
for (String modelId : updatedModelIds) {
UpdateRequest updateRequest = new UpdateRequest();
Instant now = Instant.now();
ImmutableMap.Builder<String, Object> builder = ImmutableMap.builder();
builder
.put(MLModel.MODEL_STATE_FIELD, newModelStates.get(modelId).name())
.put(MLModel.LAST_UPDATED_TIME_FIELD, now.toEpochMilli());
if (newModelStates.containsKey(modelId)) {
builder.put(MLModel.MODEL_STATE_FIELD, newModelStates.get(modelId).name());
}
if (newPlanningWorkNodes.containsKey(modelId)) {
builder.put(MLModel.PLANNING_WORKER_NODES_FIELD, newPlanningWorkNodes.get(modelId));
builder.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, newPlanningWorkNodes.get(modelId).size());
}
builder.put(MLModel.LAST_UPDATED_TIME_FIELD, now.toEpochMilli());
Set<String> workerNodes = modelWorkerNodes.get(modelId);
int currentWorkNodeCount = workerNodes == null ? 0 : workerNodes.size();
builder.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, currentWorkNodeCount);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,8 @@ public List<Setting<?>> getSettings() {
MLCommonsSettings.ML_COMMONS_MAX_LOAD_MODEL_TASKS_PER_NODE,
MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX,
MLCommonsSettings.ML_COMMONS_NATIVE_MEM_THRESHOLD,
MLCommonsSettings.ML_COMMONS_EXCLUDE_NODE_NAMES
MLCommonsSettings.ML_COMMONS_EXCLUDE_NODE_NAMES,
MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN
);
return settings;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,6 @@ private MLCommonsSettings() {}

public static final Setting<String> ML_COMMONS_EXCLUDE_NODE_NAMES = Setting
.simpleString("plugins.ml_commons.exclude_nodes._name", Setting.Property.NodeScope, Setting.Property.Dynamic);
public static final Setting<Boolean> ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN = Setting
.boolSetting("plugins.ml_commons.allow_custom_deployment_plan", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
}
Loading

0 comments on commit 4df13e3

Please sign in to comment.