diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index f5fb4ee486..40c9b89e0b 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -30,7 +30,7 @@ public class CommonValue { public static final String ML_MODEL_INDEX = ".plugins-ml-model"; public static final String ML_TASK_INDEX = ".plugins-ml-task"; - public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 2; + public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 3; public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 1; public static final String USER_FIELD_MAPPING = " \"" + CommonValue.USER @@ -85,6 +85,12 @@ public class CommonValue { + MLModel.MODEL_CONTENT_SIZE_IN_BYTES_FIELD + "\" : {\"type\": \"long\"},\n" + " \"" + + MLModel.PLANNING_WORKER_NODE_COUNT_FIELD + + "\" : {\"type\": \"integer\"},\n" + + " \"" + + MLModel.CURRENT_WORKER_NODE_COUNT_FIELD + + "\" : {\"type\": \"integer\"},\n" + + " \"" + MLModel.MODEL_CONFIG_FIELD + "\" : {\"properties\":{\"" + MODEL_TYPE_FIELD + "\":{\"type\":\"keyword\"},\"" @@ -101,6 +107,9 @@ public class CommonValue { + MLModel.CREATED_TIME_FIELD + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + " \"" + + MLModel.LAST_UPDATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + MLModel.LAST_UPLOADED_TIME_FIELD + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + " \"" diff --git a/common/src/main/java/org/opensearch/ml/common/MLModel.java b/common/src/main/java/org/opensearch/ml/common/MLModel.java index e5e6b6f83d..c506ceb42c 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModel.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModel.java @@ -44,6 +44,7 @@ public class MLModel implements ToXContentObject { public static final String MODEL_CONTENT_HASH_VALUE_FIELD = "model_content_hash_value"; public static final String MODEL_CONFIG_FIELD = "model_config"; public static final String CREATED_TIME_FIELD = "created_time"; + public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; public static final String LAST_UPLOADED_TIME_FIELD = "last_uploaded_time"; public static final String LAST_LOADED_TIME_FIELD = "last_loaded_time"; public static final String LAST_UNLOADED_TIME_FIELD = "last_unloaded_time"; @@ -51,6 +52,8 @@ public class MLModel implements ToXContentObject { public static final String MODEL_ID_FIELD = "model_id"; public static final String CHUNK_NUMBER_FIELD = "chunk_number"; public static final String TOTAL_CHUNKS_FIELD = "total_chunks"; + public static final String PLANNING_WORKER_NODE_COUNT_FIELD = "planning_worker_node_count"; + public static final String CURRENT_WORKER_NODE_COUNT_FIELD = "current_worker_node_count"; private String name; private FunctionName algorithm; @@ -66,6 +69,7 @@ public class MLModel implements ToXContentObject { private String modelContentHash; private MLModelConfig modelConfig; private Instant createdTime; + private Instant lastUpdateTime; private Instant lastUploadedTime; private Instant lastLoadedTime; private Instant lastUnloadedTime; @@ -74,9 +78,30 @@ public class MLModel implements ToXContentObject { private String modelId; // model chunk doc only private Integer chunkNumber; // model chunk doc only private Integer totalChunks; // model chunk doc only + private Integer planningWorkerNodeCount; // plan to deploy model to how many nodes + private Integer currentWorkerNodeCount; // model is deployed to how many nodes @Builder(toBuilder = true) - public MLModel(String name, FunctionName algorithm, String version, String content, User user, String description, MLModelFormat modelFormat, MLModelState modelState, Long modelContentSizeInBytes, String modelContentHash, MLModelConfig modelConfig, Instant createdTime, Instant lastUploadedTime, Instant lastLoadedTime, Instant lastUnloadedTime, String modelId, Integer chunkNumber, Integer totalChunks) { + public MLModel(String name, + FunctionName algorithm, + String version, + String content, + User user, + String description, + MLModelFormat modelFormat, + MLModelState modelState, + Long modelContentSizeInBytes, + String modelContentHash, + MLModelConfig modelConfig, + Instant createdTime, + Instant lastUpdateTime, + Instant lastUploadedTime, + Instant lastLoadedTime, + Instant lastUnloadedTime, + String modelId, Integer chunkNumber, + Integer totalChunks, + Integer planningWorkerNodeCount, + Integer currentWorkerNodeCount) { this.name = name; this.algorithm = algorithm; this.version = version; @@ -89,12 +114,15 @@ public MLModel(String name, FunctionName algorithm, String version, String conte this.modelContentHash = modelContentHash; this.modelConfig = modelConfig; this.createdTime = createdTime; + this.lastUpdateTime = lastUpdateTime; this.lastUploadedTime = lastUploadedTime; this.lastLoadedTime = lastLoadedTime; this.lastUnloadedTime = lastUnloadedTime; this.modelId = modelId; this.chunkNumber = chunkNumber; this.totalChunks = totalChunks; + this.planningWorkerNodeCount = planningWorkerNodeCount; + this.currentWorkerNodeCount = currentWorkerNodeCount; } public MLModel(StreamInput input) throws IOException{ @@ -121,12 +149,15 @@ public MLModel(StreamInput input) throws IOException{ modelConfig = new TextEmbeddingModelConfig(input); } createdTime = input.readOptionalInstant(); + lastUpdateTime = input.readOptionalInstant(); lastUploadedTime = input.readOptionalInstant(); lastLoadedTime = input.readOptionalInstant(); lastUnloadedTime = input.readOptionalInstant(); modelId = input.readOptionalString(); chunkNumber = input.readOptionalInt(); totalChunks = input.readOptionalInt(); + planningWorkerNodeCount = input.readOptionalInt(); + currentWorkerNodeCount = input.readOptionalInt(); } } @@ -163,12 +194,15 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } out.writeOptionalInstant(createdTime); + out.writeOptionalInstant(lastUpdateTime); out.writeOptionalInstant(lastUploadedTime); out.writeOptionalInstant(lastLoadedTime); out.writeOptionalInstant(lastUnloadedTime); out.writeOptionalString(modelId); out.writeOptionalInt(chunkNumber); out.writeOptionalInt(totalChunks); + out.writeOptionalInt(planningWorkerNodeCount); + out.writeOptionalInt(currentWorkerNodeCount); } @Override @@ -210,6 +244,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (createdTime != null) { builder.field(CREATED_TIME_FIELD, createdTime.toEpochMilli()); } + if (lastUpdateTime != null) { + builder.field(LAST_UPDATED_TIME_FIELD, lastUpdateTime.toEpochMilli()); + } if (lastUploadedTime != null) { builder.field(LAST_UPLOADED_TIME_FIELD, lastUploadedTime.toEpochMilli()); } @@ -228,6 +265,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (totalChunks != null) { builder.field(TOTAL_CHUNKS_FIELD, totalChunks); } + if (planningWorkerNodeCount != null) { + builder.field(PLANNING_WORKER_NODE_COUNT_FIELD, planningWorkerNodeCount); + } + if (currentWorkerNodeCount != null) { + builder.field(CURRENT_WORKER_NODE_COUNT_FIELD, currentWorkerNodeCount); + } builder.endObject(); return builder; } @@ -248,12 +291,15 @@ public static MLModel parse(XContentParser parser) throws IOException { String modelContentHash = null; MLModelConfig modelConfig = null; Instant createdTime = null; + Instant lastUpdateTime = null; Instant lastUploadedTime = null; Instant lastLoadedTime = null; Instant lastUnloadedTime = null; String modelId = null; Integer chunkNumber = null; Integer totalChunks = null; + Integer planningWorkerNodeCount = null; + Integer currentWorkerNodeCount = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -309,9 +355,18 @@ public static MLModel parse(XContentParser parser) throws IOException { case MODEL_CONFIG_FIELD: modelConfig = TextEmbeddingModelConfig.parse(parser); break; + case PLANNING_WORKER_NODE_COUNT_FIELD: + planningWorkerNodeCount = parser.intValue(); + break; + case CURRENT_WORKER_NODE_COUNT_FIELD: + currentWorkerNodeCount = parser.intValue(); + break; case CREATED_TIME_FIELD: createdTime = Instant.ofEpochMilli(parser.longValue()); break; + case LAST_UPDATED_TIME_FIELD: + lastUpdateTime = Instant.ofEpochMilli(parser.longValue()); + break; case LAST_UPLOADED_TIME_FIELD: lastUploadedTime = Instant.ofEpochMilli(parser.longValue()); break; @@ -339,12 +394,15 @@ public static MLModel parse(XContentParser parser) throws IOException { .modelContentHash(modelContentHash) .modelConfig(modelConfig) .createdTime(createdTime) + .lastUpdateTime(lastUpdateTime) .lastUploadedTime(lastUploadedTime) .lastLoadedTime(lastLoadedTime) .lastUnloadedTime(lastUnloadedTime) .modelId(modelId) .chunkNumber(chunkNumber) .totalChunks(totalChunks) + .planningWorkerNodeCount(planningWorkerNodeCount) + .currentWorkerNodeCount(currentWorkerNodeCount) .build(); } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponse.java index c4abaeab87..c7fd6edcf9 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponse.java @@ -13,6 +13,7 @@ import org.opensearch.common.io.stream.StreamOutput; import java.io.IOException; +import java.util.Map; @Log4j2 @Getter @@ -20,12 +21,15 @@ public class MLSyncUpNodeResponse extends BaseNodeResponse { private String modelStatus; private String[] loadedModelIds; - private String[] runningLoadModelTaskIds; + private String[] runningLoadModelIds; // model ids which have loading model task running + private String[] runningLoadModelTaskIds; // load model task ids which is running - public MLSyncUpNodeResponse(DiscoveryNode node, String modelStatus, String[] loadedModelIds, String[] runningLoadModelTaskIds) { + public MLSyncUpNodeResponse(DiscoveryNode node, String modelStatus, String[] loadedModelIds, String[] runningLoadModelIds, + String[] runningLoadModelTaskIds) { super(node); this.modelStatus = modelStatus; this.loadedModelIds = loadedModelIds; + this.runningLoadModelIds = runningLoadModelIds; this.runningLoadModelTaskIds = runningLoadModelTaskIds; } @@ -33,6 +37,7 @@ public MLSyncUpNodeResponse(StreamInput in) throws IOException { super(in); this.modelStatus = in.readOptionalString(); this.loadedModelIds = in.readOptionalStringArray(); + this.runningLoadModelIds = in.readOptionalStringArray(); this.runningLoadModelTaskIds = in.readOptionalStringArray(); } @@ -45,6 +50,7 @@ public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeOptionalString(modelStatus); out.writeOptionalStringArray(loadedModelIds); + out.writeOptionalStringArray(runningLoadModelIds); out.writeOptionalStringArray(runningLoadModelTaskIds); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponseTest.java index 6bbd8ad4b8..ff153b952a 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponseTest.java @@ -12,6 +12,8 @@ import java.io.IOException; import java.net.InetAddress; import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import static org.junit.Assert.*; import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; @@ -23,6 +25,7 @@ public class MLSyncUpNodeResponseTest { private final String modelStatus = "modelStatus"; private final String[] loadedModelIds = {"loadedModelIds"}; private final String[] runningLoadModelTaskIds = {"runningLoadModelTaskIds"}; + private final String[] runningLoadModelIds = {"modelid1"}; @Before public void setUp() throws Exception { localNode = new DiscoveryNode( @@ -37,7 +40,7 @@ public void setUp() throws Exception { @Test public void testSerializationDeserialization() throws IOException { - MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelTaskIds); + MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); MLSyncUpNodeResponse newResponse = new MLSyncUpNodeResponse(output.bytes().streamInput()); @@ -50,7 +53,7 @@ public void testSerializationDeserialization() throws IOException { @Test public void testReadProfile() throws IOException { - MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelTaskIds); + MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); MLSyncUpNodeResponse newResponse = MLSyncUpNodeResponse.readStats(output.bytes().streamInput()); diff --git a/plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java b/plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java index 571e3bb0a4..7e2c46c82c 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java @@ -100,15 +100,18 @@ protected void doExecute(Task task, ActionRequest request, ActionListener builder = ImmutableMap.builder(); builder.put(MLTask.STATE_FIELD, taskState); if (mlTaskCache.hasError()) { + currentWorkerNodeCount = mlTaskCache.getWorkerNodeSize() - mlTaskCache.getErrors().size(); builder.put(MLTask.ERROR_FIELD, toJsonString(mlTaskCache.getErrors())); } mlTaskManager.updateMLTask(taskId, builder.build(), TASK_SEMAPHORE_TIMEOUT, true); @@ -125,7 +128,14 @@ protected void doExecute(Task task, ActionRequest request, ActionListener client.execute(MLLoadModelOnNodeAction.INSTANCE, loadModelRequest, actionListener), diff --git a/plugin/src/main/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeAction.java b/plugin/src/main/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeAction.java index bffac18b9b..d2e56184bf 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeAction.java @@ -27,11 +27,9 @@ import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.NamedXContentRegistry; -import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; -import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.transport.sync.MLSyncUpAction; import org.opensearch.ml.common.transport.sync.MLSyncUpInput; import org.opensearch.ml.common.transport.sync.MLSyncUpNodeRequest; @@ -148,9 +146,12 @@ private MLSyncUpNodeResponse createSyncUpNodeResponse(MLSyncUpNodesRequest loadM String[] loadedModelIds = null; String[] runningLoadModelTaskIds = null; + String[] runningLoadModelIds = null; if (syncUpInput.isGetLoadedModels()) { loadedModelIds = mlModelManager.getLocalLoadedModels(); - runningLoadModelTaskIds = mlTaskManager.getLocalRunningLoadModelTasks(); + List localRunningLoadModel = mlTaskManager.getLocalRunningLoadModelTasks(); + runningLoadModelTaskIds = localRunningLoadModel.get(0); + runningLoadModelIds = localRunningLoadModel.get(1); } if (syncUpInput.isClearRoutingTable()) { @@ -162,18 +163,14 @@ private MLSyncUpNodeResponse createSyncUpNodeResponse(MLSyncUpNodesRequest loadM mlModelManager.syncModelWorkerNodes(modelRoutingTable); } - if (syncUpInput.isSyncRunningLoadModelTasks()) { - mlTaskManager.syncRunningLoadModelTasks(runningLoadModelTasks); - } - - cleanUpLocalCache(); + cleanUpLocalCache(runningLoadModelTasks); cleanUpLocalCacheFiles(); - return new MLSyncUpNodeResponse(clusterService.localNode(), "ok", loadedModelIds, runningLoadModelTaskIds); + return new MLSyncUpNodeResponse(clusterService.localNode(), "ok", loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds); } @VisibleForTesting - void cleanUpLocalCache() { + void cleanUpLocalCache(Map> runningLoadModelTasks) { String[] allTaskIds = mlTaskManager.getAllTaskIds(); if (allTaskIds == null) { return; @@ -185,6 +182,12 @@ void cleanUpLocalCache() { Instant now = Instant.now(); if (now.isAfter(lastUpdateTime.plusSeconds(mlTaskTimeout))) { log.info("ML task timeout. task id: {}, task type: {}", taskId, mlTask.getTaskType()); + if (mlTask.getTaskType() == MLTaskType.LOAD_MODEL + && mlTask.getState() == MLTaskState.CREATED + && runningLoadModelTasks != null + && runningLoadModelTasks.containsKey(taskId)) { + continue; + } mlTaskManager .updateMLTask( taskId, @@ -193,31 +196,6 @@ void cleanUpLocalCache() { 10_000, true ); - - if (mlTask.getTaskType() == MLTaskType.LOAD_MODEL) { - String modelId = mlTask.getModelId(); - String[] workerNodes = mlModelManager.getWorkerNodes(modelId); - MLModelState modelState; - if (workerNodes == null || workerNodes.length == 0) { - modelState = MLModelState.LOAD_FAILED; - } else if (mlTask.getWorkerNodes().size() > workerNodes.length) { - modelState = MLModelState.PARTIALLY_LOADED; - } else { - modelState = MLModelState.LOADED; - if (mlTask.getWorkerNodes().size() < workerNodes.length) { - log - .warn( - "Model loaded on more nodes than target worker nodes. taskId:{}, modelId: {}, workerNodes: {}, targetWorkerNodes: {}", - taskId, - modelId, - Arrays.toString(workerNodes), - Arrays.toString(mlTask.getWorkerNodes().toArray(new String[0])) - ); - } - } - log.info("Reset model state as {} for model {}", modelState, modelId); - mlModelManager.updateModel(modelId, ImmutableMap.of(MLModel.MODEL_STATE_FIELD, modelState)); - } } } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelMetaCreate.java b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelMetaCreate.java index 702be4188a..2c114c5272 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelMetaCreate.java +++ b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelMetaCreate.java @@ -48,6 +48,7 @@ public void createModelMeta(MLCreateModelMetaInput mlCreateModelMetaInput, Actio FunctionName functionName = mlCreateModelMetaInput.getFunctionName(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(res -> { + Instant now = Instant.now(); MLModel mlModelMeta = MLModel .builder() .name(modelName) @@ -60,7 +61,8 @@ public void createModelMeta(MLCreateModelMetaInput mlCreateModelMetaInput, Actio .totalChunks(mlCreateModelMetaInput.getTotalChunks()) .modelContentHash(mlCreateModelMetaInput.getModelContentHashValue()) .modelContentSizeInBytes(mlCreateModelMetaInput.getModelContentSizeInBytes()) - .createdTime(Instant.now()) + .createdTime(now) + .lastUpdateTime(now) .build(); IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX); indexRequest diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java index 75dafd94a5..be2f42195f 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java @@ -16,6 +16,7 @@ import org.opensearch.common.component.LifecycleListener; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.threadpool.Scheduler; import org.opensearch.threadpool.ThreadPool; @@ -28,6 +29,7 @@ public class MLCommonsClusterManagerEventListener implements LocalNodeClusterMan private ThreadPool threadPool; private Scheduler.Cancellable syncModelRoutingCron; private DiscoveryNodeHelper nodeHelper; + private final MLIndicesHandler mlIndicesHandler; private volatile Integer jobInterval; @@ -36,13 +38,15 @@ public MLCommonsClusterManagerEventListener( Client client, Settings settings, ThreadPool threadPool, - DiscoveryNodeHelper nodeHelper + DiscoveryNodeHelper nodeHelper, + MLIndicesHandler mlIndicesHandler ) { this.clusterService = clusterService; this.client = client; this.threadPool = threadPool; this.clusterService.addListener(this); this.nodeHelper = nodeHelper; + this.mlIndicesHandler = mlIndicesHandler; this.jobInterval = ML_COMMONS_SYNC_UP_JOB_INTERVAL_IN_SECONDS.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_SYNC_UP_JOB_INTERVAL_IN_SECONDS, it -> { @@ -62,7 +66,11 @@ public void onClusterManager() { private void startSyncModelRoutingCron() { if (jobInterval > 0) { syncModelRoutingCron = threadPool - .scheduleWithFixedDelay(new MLSyncUpCron(client, nodeHelper), TimeValue.timeValueSeconds(jobInterval), GENERAL_THREAD_POOL); + .scheduleWithFixedDelay( + new MLSyncUpCron(client, clusterService, nodeHelper, mlIndicesHandler), + TimeValue.timeValueSeconds(jobInterval), + GENERAL_THREAD_POOL + ); } else { log.debug("Stop ML syncup job as its interval is: {}", jobInterval); } diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java index 8767ab0db3..ab362b5291 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java @@ -5,31 +5,58 @@ package org.opensearch.ml.cluster; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; + +import java.time.Instant; +import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.Semaphore; import lombok.extern.log4j.Log4j2; import org.opensearch.action.ActionListener; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.update.UpdateRequest; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.TermsQueryBuilder; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.transport.sync.MLSyncUpAction; import org.opensearch.ml.common.transport.sync.MLSyncUpInput; import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse; import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest; +import org.opensearch.ml.indices.MLIndicesHandler; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; @Log4j2 public class MLSyncUpCron implements Runnable { + public static final int LOAD_MODEL_TASK_GRACE_TIME_IN_MS = 20_000; private Client client; + private ClusterService clusterService; private DiscoveryNodeHelper nodeHelper; + private MLIndicesHandler mlIndicesHandler; + @VisibleForTesting + Semaphore updateModelStateSemaphore; - public MLSyncUpCron(Client client, DiscoveryNodeHelper nodeHelper) { + public MLSyncUpCron(Client client, ClusterService clusterService, DiscoveryNodeHelper nodeHelper, MLIndicesHandler mlIndicesHandler) { this.client = client; + this.clusterService = clusterService; this.nodeHelper = nodeHelper; + this.mlIndicesHandler = mlIndicesHandler; + this.updateModelStateSemaphore = new Semaphore(1); } @Override @@ -45,6 +72,8 @@ public void run() { Map> modelWorkerNodes = new HashMap<>(); // key is task id, value is set of worker node ids Map> runningLoadModelTasks = new HashMap<>(); + // key is model id, value is set of worker node ids + Map> loadingModels = new HashMap<>(); for (MLSyncUpNodeResponse response : responses) { String nodeId = response.getNode().getId(); String[] loadedModelIds = response.getLoadedModelIds(); @@ -54,6 +83,14 @@ public void run() { workerNodes.add(nodeId); } } + String[] runningModelIds = response.getRunningLoadModelIds(); + if (runningModelIds != null && runningModelIds.length > 0) { + for (String modelId : runningModelIds) { + Set workerNodes = loadingModels.computeIfAbsent(modelId, it -> new HashSet<>()); + workerNodes.add(nodeId); + } + } + String[] runningLoadModelTaskIds = response.getRunningLoadModelTaskIds(); if (runningLoadModelTaskIds != null && runningLoadModelTaskIds.length > 0) { for (String taskId : runningLoadModelTaskIds) { @@ -63,7 +100,8 @@ public void run() { } } for (Map.Entry> entry : modelWorkerNodes.entrySet()) { - log.debug("will sync model worker nodes for model: {}: {}", entry.getKey(), entry.getValue().toArray(new String[0])); + String modelId = entry.getKey(); + log.debug("will sync model worker nodes for model: {}: {}", modelId, entry.getValue().toArray(new String[0])); } for (Map.Entry> entry : runningLoadModelTasks.entrySet()) { log.debug("will sync running task: {}: {}", entry.getKey(), entry.getValue().toArray(new String[0])); @@ -91,6 +129,171 @@ public void run() { ex -> { log.error("Failed to sync model routing", ex); } ) ); + + // refresh model status + if (clusterService.state().getRoutingTable().hasIndex(ML_MODEL_INDEX)) { + mlIndicesHandler + .initModelIndexIfAbsent( + ActionListener + .wrap( + res -> { refreshModelState(modelWorkerNodes, loadingModels); }, + e -> { log.error("Failed to init model index", e); } + ) + ); + } }, e -> { log.error("Failed to sync model routing", e); })); } + + @VisibleForTesting + void refreshModelState(Map> modelWorkerNodes, Map> loadingModels) { + if (!updateModelStateSemaphore.tryAcquire()) { + return; + } + try { + SearchRequest searchRequest = new SearchRequest(ML_MODEL_INDEX); + BoolQueryBuilder queryBuilder = new BoolQueryBuilder(); + queryBuilder + .filter( + new TermsQueryBuilder( + MLModel.MODEL_STATE_FIELD, + Arrays + .asList( + MLModelState.LOADING.name(), + MLModelState.PARTIALLY_LOADED.name(), + MLModelState.LOADED.name(), + MLModelState.LOAD_FAILED.name() + ) + ) + ); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(queryBuilder); + sourceBuilder.size(10_000); + sourceBuilder + .fetchSource( + new String[] { + MLModel.MODEL_STATE_FIELD, + MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, + MLModel.LAST_UPDATED_TIME_FIELD, + MLModel.CURRENT_WORKER_NODE_COUNT_FIELD }, + null + ); + searchRequest.source(sourceBuilder); + client.search(searchRequest, ActionListener.wrap(res -> { + SearchHit[] hits = res.getHits().getHits(); + Map newModelStates = new HashMap<>(); + for (SearchHit hit : hits) { + String modelId = hit.getId(); + Map sourceAsMap = hit.getSourceAsMap(); + MLModelState state = MLModelState.from((String) sourceAsMap.get(MLModel.MODEL_STATE_FIELD)); + Long lastUpdateTime = sourceAsMap.containsKey(MLModel.LAST_UPDATED_TIME_FIELD) + ? (Long) sourceAsMap.get(MLModel.LAST_UPDATED_TIME_FIELD) + : null; + int planningWorkerNodeCount = sourceAsMap.containsKey(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD) + ? (int) sourceAsMap.get(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD) + : 0; + int currentWorkerNodeCountInIndex = sourceAsMap.containsKey(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD) + ? (int) sourceAsMap.get(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD) + : 0; + MLModelState mlModelState = getNewModelState( + loadingModels, + modelWorkerNodes, + modelId, + state, + lastUpdateTime, + planningWorkerNodeCount, + currentWorkerNodeCountInIndex + ); + if (mlModelState != null) { + newModelStates.put(modelId, mlModelState); + } + } + bulkUpdateModelState(modelWorkerNodes, newModelStates); + }, e -> { + updateModelStateSemaphore.release(); + log.error("Failed to search models", e); + })); + } catch (Exception e) { + updateModelStateSemaphore.release(); + log.error("Failed to refresh model state", e); + } + } + + private MLModelState getNewModelState( + Map> loadingModels, + Map> modelWorkerNodes, + + String modelId, + MLModelState state, + Long lastUpdateTime, + int planningWorkerNodeCount, + int currentWorkerNodeCountInIndex + ) { + Set loadTaskNodes = loadingModels.get(modelId); + if (loadTaskNodes != null && loadTaskNodes.size() > 0 && state != MLModelState.LOADING) { + // no + return MLModelState.LOADING; + } + int currentWorkerNodeCount = modelWorkerNodes.containsKey(modelId) ? modelWorkerNodes.get(modelId).size() : 0; + if (currentWorkerNodeCount == 0 + && state != MLModelState.LOAD_FAILED + && !(state == MLModelState.LOADING + && lastUpdateTime != null + && lastUpdateTime + LOAD_MODEL_TASK_GRACE_TIME_IN_MS > Instant.now().toEpochMilli())) { + // If model not deployed to any node and no node is loading the model, then set model state as LOAD_FAILED + return MLModelState.LOAD_FAILED; + } + if (currentWorkerNodeCount > 0) { + if (currentWorkerNodeCount < planningWorkerNodeCount + && (state != MLModelState.PARTIALLY_LOADED || currentWorkerNodeCountInIndex != currentWorkerNodeCount)) { + // If model deployed to some node/nodes, but not deployed to all nodes planned by user, + // then set model state as PARTIALLY_LOADED. + return MLModelState.PARTIALLY_LOADED; + } else if (planningWorkerNodeCount > 0 && currentWorkerNodeCount >= planningWorkerNodeCount && state != MLModelState.LOADED) { + if (currentWorkerNodeCount > planningWorkerNodeCount) { + // This case should not happen that model loaded to more nodes than planned. So log this as warning if + // it happens. + log + .warn( + "Model {} loaded on more nodes [{}] than planing worker node[{}]", + modelId, + currentWorkerNodeCount, + planningWorkerNodeCount + ); + } + + // If model deployed to all nodes planned by user, then set model state as LOADED. + return MLModelState.LOADED; + } + } + return null; + } + + private void bulkUpdateModelState(Map> modelWorkerNodes, Map newModelStates) { + if (newModelStates.size() > 0) { + BulkRequest bulkUpdateRequest = new BulkRequest(); + for (String modelId : newModelStates.keySet()) { + UpdateRequest updateRequest = new UpdateRequest(); + Instant now = Instant.now(); + ImmutableMap.Builder builder = ImmutableMap.builder(); + builder + .put(MLModel.MODEL_STATE_FIELD, newModelStates.get(modelId).name()) + .put(MLModel.LAST_UPDATED_TIME_FIELD, now.toEpochMilli()); + Set workerNodes = modelWorkerNodes.get(modelId); + int currentWorkNodeCount = workerNodes == null ? 0 : workerNodes.size(); + builder.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, currentWorkNodeCount); + updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(builder.build()); + bulkUpdateRequest.add(updateRequest); + } + log.info("Refresh model state: {}", newModelStates); + client.bulk(bulkUpdateRequest, ActionListener.wrap(br -> { + updateModelStateSemaphore.release(); + log.debug("Refresh model state successfully"); + }, e -> { + updateModelStateSemaphore.release(); + log.error("Failed to bulk update model state", e); + })); + } else { + updateModelStateSemaphore.release(); + } + } } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index a988d91c74..88cdb94755 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -217,6 +217,7 @@ private void uploadModelFromUrl(MLUploadInput uploadInput, MLTask mlTask) { mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); String modelName = uploadInput.getModelName(); String version = uploadInput.getVersion(); + Instant now = Instant.now(); mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(res -> { MLModel mlModelMeta = MLModel .builder() @@ -227,7 +228,8 @@ private void uploadModelFromUrl(MLUploadInput uploadInput, MLTask mlTask) { .modelFormat(uploadInput.getModelFormat()) .modelState(MLModelState.UPLOADING) .modelConfig(uploadInput.getModelConfig()) - .createdTime(Instant.now()) + .createdTime(now) + .lastUpdateTime(now) .build(); IndexRequest indexModelMetaRequest = new IndexRequest(ML_MODEL_INDEX); indexModelMetaRequest.source(mlModelMeta.toXContent(XContentBuilder.builder(JSON.xContent()), EMPTY_PARAMS)); @@ -284,6 +286,7 @@ private void uploadModel( File file = new File(name); byte[] bytes = Files.toByteArray(file); int chunkNum = Integer.parseInt(file.getName()); + Instant now = Instant.now(); MLModel mlModel = MLModel .builder() .modelId(modelId) @@ -294,7 +297,8 @@ private void uploadModel( .chunkNumber(chunkNum) .totalChunks(chunkFiles.size()) .content(Base64.getEncoder().encodeToString(bytes)) - .createdTime(Instant.now()) + .createdTime(now) + .lastUpdateTime(now) .build(); IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX); String chunkId = getModelChunkId(modelId, chunkNum); @@ -611,11 +615,14 @@ public void updateModel(String modelId, Map updatedFields, Actio listener.onFailure(new IllegalArgumentException("Updated fields is null or empty")); return; } + Map newUpdatedFields = new HashMap<>(); + newUpdatedFields.putAll(updatedFields); + newUpdatedFields.put(MLModel.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli()); UpdateRequest updateRequest = new UpdateRequest(ML_MODEL_INDEX, modelId); - updateRequest.doc(updatedFields); + updateRequest.doc(newUpdatedFields); updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - if (updatedFields.containsKey(MLModel.MODEL_STATE_FIELD) - && MODEL_DONE_STATES.contains(updatedFields.get(MLModel.MODEL_STATE_FIELD))) { + if (newUpdatedFields.containsKey(MLModel.MODEL_STATE_FIELD) + && MODEL_DONE_STATES.contains(newUpdatedFields.get(MLModel.MODEL_STATE_FIELD))) { updateRequest.retryOnConflict(3); } try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index ed8b1b4c6b..2c44a7aa0d 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -348,7 +348,8 @@ public Collection createComponents( client, settings, threadPool, - nodeHelper + nodeHelper, + mlIndicesHandler ); return ImmutableList diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 26cf36a2bc..3bae6ea92a 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -27,7 +27,7 @@ private MLCommonsSettings() {} public static final Setting ML_COMMONS_SYNC_UP_JOB_INTERVAL_IN_SECONDS = Setting .intSetting( "plugins.ml_commons.sync_up_job_interval_in_seconds", - 10, + 3, 0, 86400, Setting.Property.NodeScope, diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java index 65c2f7a08a..d61d1aa677 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java @@ -12,11 +12,9 @@ import static org.opensearch.ml.utils.MLExceptionUtils.logException; import java.time.Instant; -import java.time.temporal.ChronoUnit; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -390,38 +388,17 @@ public boolean containsModel(String modelId) { return false; } - public String[] getLocalRunningLoadModelTasks() { + public List getLocalRunningLoadModelTasks() { List runningLoadModelTaskIds = new ArrayList<>(); + List runningLoadModelIds = new ArrayList<>(); for (Map.Entry entry : taskCaches.entrySet()) { MLTask mlTask = entry.getValue().getMlTask(); if (mlTask.getTaskType() == MLTaskType.LOAD_MODEL && mlTask.getState() != MLTaskState.CREATED) { runningLoadModelTaskIds.add(entry.getKey()); + runningLoadModelIds.add(mlTask.getModelId()); } } - return runningLoadModelTaskIds.toArray(new String[0]); + return Arrays.asList(runningLoadModelTaskIds.toArray(new String[0]), runningLoadModelIds.toArray(new String[0])); } - public void syncRunningLoadModelTasks(Map> runningLoadModelTasks) { - Instant ttlEndTime = Instant.now().minus(10, ChronoUnit.MINUTES); - Set staleTasks = new HashSet<>(); - - boolean noRunningTask = runningLoadModelTasks == null || runningLoadModelTasks.size() == 0; - for (Map.Entry entry : taskCaches.entrySet()) { - String taskId = entry.getKey(); - MLTask mlTask = entry.getValue().getMlTask(); - boolean exceedTTL = mlTask.getLastUpdateTime().isBefore(ttlEndTime); - if (exceedTTL - && mlTask.getTaskType() == MLTaskType.LOAD_MODEL - && mlTask.getState() == MLTaskState.CREATED - && (noRunningTask || !runningLoadModelTasks.containsKey(taskId))) { - staleTasks.add(entry.getKey()); - } - } - if (staleTasks.size() > 0) { - log.debug("remove stale load tasks : {}", Arrays.toString(staleTasks.toArray(new String[0]))); - for (String taskId : staleTasks) { - remove(taskId); - } - } - } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/forward/TransportForwardActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/forward/TransportForwardActionTests.java index 1b99273abb..4b66941a27 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/forward/TransportForwardActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/forward/TransportForwardActionTests.java @@ -24,6 +24,7 @@ import java.util.Arrays; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; @@ -154,7 +155,8 @@ public void testDoExecute_LoadModelDone_NoError() { public void testDoExecute_LoadModelDone_Error_NullTaskWorkerNodes() { when(mlTaskManager.getWorkNodes(anyString())).thenReturn(null); - MLTaskCache mlTaskCache = MLTaskCache.builder().mlTask(createMlTask(MLTaskType.UPLOAD_MODEL)).build(); + List workerNodes = Arrays.asList(nodeId1, nodeId2); + MLTaskCache mlTaskCache = MLTaskCache.builder().mlTask(createMlTask(MLTaskType.UPLOAD_MODEL)).workerNodes(workerNodes).build(); mlTaskCache.addError(nodeId1, error); doReturn(mlTaskCache).when(mlTaskManager).getMLTaskCache(anyString()); when(mlModelManager.getWorkerNodes(anyString())).thenReturn(new String[] { nodeId1, nodeId2 }); diff --git a/plugin/src/test/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeActionTests.java index 60e72c88cf..217eb5c049 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeActionTests.java @@ -53,7 +53,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.transport.TransportAddress; import org.opensearch.common.xcontent.NamedXContentRegistry; -import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskType; import org.opensearch.ml.common.model.MLModelState; @@ -112,6 +111,8 @@ public class TransportSyncUpOnNodeActionTests extends OpenSearchTestCase { private TransportSyncUpOnNodeAction action; + private Map> runningLoadModelTasks; + @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); @@ -130,6 +131,10 @@ public void setup() throws IOException { xContentRegistry, mlEngine ); + runningLoadModelTasks = new HashMap<>(); + runningLoadModelTasks.put("model1", ImmutableSet.of("node1")); + when(mlTaskManager.getLocalRunningLoadModelTasks()) + .thenReturn(Arrays.asList(new String[] { "load_task_id1" }, new String[] { "model_id1" })); } public void testConstructor() { @@ -158,8 +163,15 @@ public void testNewNodeResponse() throws IOException { Version.CURRENT ); String[] loadedModelIds = new String[] { "123" }; + String[] runningLoadModelIds = new String[] { "model1" }; String[] runningLoadModelTaskIds = new String[] { "1" }; - MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(mlNode1, "LOADED", loadedModelIds, runningLoadModelTaskIds); + MLSyncUpNodeResponse response = new MLSyncUpNodeResponse( + mlNode1, + "LOADED", + loadedModelIds, + runningLoadModelIds, + runningLoadModelTaskIds + ); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); final MLSyncUpNodeResponse response1 = action.newNodeResponse(output.bytes().streamInput()); @@ -239,13 +251,13 @@ public void testNodeOperation_RemovedWorkerNodes() throws IOException { public void testCleanUpLocalCache_NoTasks() { when(mlTaskManager.getAllTaskIds()).thenReturn(null); - action.cleanUpLocalCache(); + action.cleanUpLocalCache(runningLoadModelTasks); verify(mlTaskManager, never()).updateMLTask(anyString(), any(), anyLong(), anyBoolean()); } public void testCleanUpLocalCache_EmptyTasks() { when(mlTaskManager.getAllTaskIds()).thenReturn(new String[] {}); - action.cleanUpLocalCache(); + action.cleanUpLocalCache(runningLoadModelTasks); verify(mlTaskManager, never()).updateMLTask(anyString(), any(), anyLong(), anyBoolean()); } @@ -255,7 +267,7 @@ public void testCleanUpLocalCache_NotExpiredMLTask() { MLTask mlTask = MLTask.builder().lastUpdateTime(Instant.now()).build(); MLTaskCache taskCache = MLTaskCache.builder().mlTask(mlTask).build(); when(mlTaskManager.getMLTaskCache(taskId)).thenReturn(taskCache); - action.cleanUpLocalCache(); + action.cleanUpLocalCache(runningLoadModelTasks); verify(mlTaskManager, never()).updateMLTask(anyString(), any(), anyLong(), anyBoolean()); } @@ -265,7 +277,7 @@ public void testCleanUpLocalCache_ExpiredMLTask_Upload() { MLTask mlTask = MLTask.builder().taskType(MLTaskType.UPLOAD_MODEL).lastUpdateTime(Instant.now().minusSeconds(86400)).build(); MLTaskCache taskCache = MLTaskCache.builder().mlTask(mlTask).build(); when(mlTaskManager.getMLTaskCache(taskId)).thenReturn(taskCache); - action.cleanUpLocalCache(); + action.cleanUpLocalCache(runningLoadModelTasks); verify(mlTaskManager, times(1)).updateMLTask(anyString(), any(), anyLong(), anyBoolean()); verify(mlModelManager, never()).updateModel(anyString(), any()); } @@ -303,11 +315,10 @@ private void testCleanUpLocalCache_ExpiredMLTask_LoadStatus(MLModelState modelSt when(mlModelManager.getWorkerNodes(modelId)).thenReturn(new String[] { "node1" }); } when(mlTaskManager.getMLTaskCache(taskId)).thenReturn(taskCache); - action.cleanUpLocalCache(); + action.cleanUpLocalCache(runningLoadModelTasks); verify(mlTaskManager, times(1)).updateMLTask(anyString(), any(), anyLong(), anyBoolean()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); - verify(mlModelManager, times(1)).updateModel(eq(modelId), argumentCaptor.capture()); - assertEquals(modelState, argumentCaptor.getValue().get(MLModel.MODEL_STATE_FIELD)); + verify(mlModelManager, never()).updateModel(eq(modelId), argumentCaptor.capture()); } private MLSyncUpInput prepareRequest() { diff --git a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java index 95363577b7..a057ba045a 100644 --- a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java +++ b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java @@ -9,27 +9,54 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.utils.TestHelper.ML_ROLE; import java.io.IOException; +import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.Set; +import org.apache.lucene.search.TotalHits; import org.junit.Before; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.Version; import org.opensearch.action.ActionListener; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.action.update.UpdateRequest; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.transport.sync.MLSyncUpAction; import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse; import org.opensearch.ml.common.transport.sync.MLSyncUpNodesResponse; +import org.opensearch.ml.utils.TestHelper; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.internal.InternalSearchResponse; +import org.opensearch.search.profile.SearchProfileShardResults; +import org.opensearch.search.suggest.Suggest; import org.opensearch.test.OpenSearchTestCase; import com.google.common.collect.ImmutableSet; @@ -39,6 +66,8 @@ public class MLSyncUpCronTests extends OpenSearchTestCase { @Mock private Client client; @Mock + private ClusterService clusterService; + @Mock private DiscoveryNodeHelper nodeHelper; private DiscoveryNode mlNode1; @@ -53,7 +82,7 @@ public void setup() throws IOException { MockitoAnnotations.openMocks(this); mlNode1 = new DiscoveryNode(mlNode1Id, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT); mlNode2 = new DiscoveryNode(mlNode2Id, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT); - syncUpCron = new MLSyncUpCron(client, nodeHelper); + syncUpCron = new MLSyncUpCron(client, clusterService, nodeHelper, null); } public void testRun() { @@ -83,13 +112,194 @@ public void testRun_Failure() { verify(client, times(1)).execute(eq(MLSyncUpAction.INSTANCE), any(), any()); } + public void testRefreshModelState_NoSemaphore() throws InterruptedException { + syncUpCron.updateModelStateSemaphore.acquire(); + syncUpCron.refreshModelState(null, null); + verify(client, never()).search(any(), any()); + syncUpCron.updateModelStateSemaphore.release(); + } + + public void testRefreshModelState_SearchException() { + doThrow(new RuntimeException("test exception")).when(client).search(any(), any()); + syncUpCron.refreshModelState(null, null); + verify(client, times(1)).search(any(), any()); + assertTrue(syncUpCron.updateModelStateSemaphore.tryAcquire()); + syncUpCron.updateModelStateSemaphore.release(); + } + + public void testRefreshModelState_SearchFailed() { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new RuntimeException("search error")); + return null; + }).when(client).search(any(), any()); + syncUpCron.refreshModelState(null, null); + verify(client, times(1)).search(any(), any()); + assertTrue(syncUpCron.updateModelStateSemaphore.tryAcquire()); + syncUpCron.updateModelStateSemaphore.release(); + } + + public void testRefreshModelState_EmptySearchResponse() { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + SearchHits hits = new SearchHits(new SearchHit[0], null, Float.NaN); + SearchResponseSections searchSections = new SearchResponseSections( + hits, + InternalAggregations.EMPTY, + null, + true, + false, + null, + 1 + ); + SearchResponse searchResponse = new SearchResponse( + searchSections, + null, + 1, + 1, + 0, + 11, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + actionListener.onResponse(searchResponse); + return null; + }).when(client).search(any(), any()); + syncUpCron.refreshModelState(new HashMap<>(), new HashMap<>()); + verify(client, times(1)).search(any(), any()); + verify(client, never()).bulk(any(), any()); + assertTrue(syncUpCron.updateModelStateSemaphore.tryAcquire()); + syncUpCron.updateModelStateSemaphore.release(); + } + + public void testRefreshModelState_ResetAsLoadFailed() { + Map> modelWorkerNodes = new HashMap<>(); + Map> loadingModels = new HashMap<>(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(createSearchModelResponse("modelId", MLModelState.LOADED, 2, null, Instant.now().toEpochMilli())); + return null; + }).when(client).search(any(), any()); + syncUpCron.refreshModelState(modelWorkerNodes, loadingModels); + verify(client, times(1)).search(any(), any()); + ArgumentCaptor bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class); + verify(client, times(1)).bulk(bulkRequestCaptor.capture(), any()); + BulkRequest bulkRequest = bulkRequestCaptor.getValue(); + assertEquals(1, bulkRequest.numberOfActions()); + assertEquals(1, bulkRequest.requests().size()); + UpdateRequest updateRequest = (UpdateRequest) bulkRequest.requests().get(0); + String updateContent = updateRequest.toString(); + assertTrue(updateContent.contains("\"model_state\":\"LOAD_FAILED\"")); + assertTrue(updateContent.contains("\"current_worker_node_count\":0")); + assertEquals(ML_MODEL_INDEX, updateRequest.index()); + } + + public void testRefreshModelState_ResetAsPartiallyLoaded() { + Map> modelWorkerNodes = new HashMap<>(); + modelWorkerNodes.put("modelId", ImmutableSet.of("node1")); + Map> loadingModels = new HashMap<>(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(createSearchModelResponse("modelId", MLModelState.LOADED, 2, 0, Instant.now().toEpochMilli())); + return null; + }).when(client).search(any(), any()); + syncUpCron.refreshModelState(modelWorkerNodes, loadingModels); + verify(client, times(1)).search(any(), any()); + ArgumentCaptor bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class); + verify(client, times(1)).bulk(bulkRequestCaptor.capture(), any()); + BulkRequest bulkRequest = bulkRequestCaptor.getValue(); + assertEquals(1, bulkRequest.numberOfActions()); + assertEquals(1, bulkRequest.requests().size()); + UpdateRequest updateRequest = (UpdateRequest) bulkRequest.requests().get(0); + String updateContent = updateRequest.toString(); + assertTrue(updateContent.contains("\"model_state\":\"PARTIALLY_LOADED\"")); + assertTrue(updateContent.contains("\"current_worker_node_count\":1")); + assertEquals(ML_MODEL_INDEX, updateRequest.index()); + } + + public void testRefreshModelState_ResetCurrentWorkerNodeCountForPartiallyLoaded() { + Map> modelWorkerNodes = new HashMap<>(); + modelWorkerNodes.put("modelId", ImmutableSet.of("node1")); + Map> loadingModels = new HashMap<>(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener + .onResponse(createSearchModelResponse("modelId", MLModelState.PARTIALLY_LOADED, 3, 2, Instant.now().toEpochMilli())); + return null; + }).when(client).search(any(), any()); + syncUpCron.refreshModelState(modelWorkerNodes, loadingModels); + verify(client, times(1)).search(any(), any()); + ArgumentCaptor bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class); + verify(client, times(1)).bulk(bulkRequestCaptor.capture(), any()); + BulkRequest bulkRequest = bulkRequestCaptor.getValue(); + assertEquals(1, bulkRequest.numberOfActions()); + assertEquals(1, bulkRequest.requests().size()); + UpdateRequest updateRequest = (UpdateRequest) bulkRequest.requests().get(0); + String updateContent = updateRequest.toString(); + assertTrue(updateContent.contains("\"model_state\":\"PARTIALLY_LOADED\"")); + assertTrue(updateContent.contains("\"current_worker_node_count\":1")); + assertEquals(ML_MODEL_INDEX, updateRequest.index()); + } + + public void testRefreshModelState_ResetAsLoading() { + Map> modelWorkerNodes = new HashMap<>(); + modelWorkerNodes.put("modelId", ImmutableSet.of("node1")); + Map> loadingModels = new HashMap<>(); + loadingModels.put("modelId", ImmutableSet.of("node2")); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(createSearchModelResponse("modelId", MLModelState.LOAD_FAILED, 2, 0, Instant.now().toEpochMilli())); + return null; + }).when(client).search(any(), any()); + syncUpCron.refreshModelState(modelWorkerNodes, loadingModels); + verify(client, times(1)).search(any(), any()); + ArgumentCaptor bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class); + verify(client, times(1)).bulk(bulkRequestCaptor.capture(), any()); + BulkRequest bulkRequest = bulkRequestCaptor.getValue(); + assertEquals(1, bulkRequest.numberOfActions()); + assertEquals(1, bulkRequest.requests().size()); + UpdateRequest updateRequest = (UpdateRequest) bulkRequest.requests().get(0); + String updateContent = updateRequest.toString(); + assertTrue(updateContent.contains("\"model_state\":\"LOADING\"")); + assertTrue(updateContent.contains("\"current_worker_node_count\":1")); + assertEquals(ML_MODEL_INDEX, updateRequest.index()); + } + + public void testRefreshModelState_NotResetState_LoadingModelTaskRunning() { + Map> modelWorkerNodes = new HashMap<>(); + Map> loadingModels = new HashMap<>(); + loadingModels.put("modelId", ImmutableSet.of("node2")); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(createSearchModelResponse("modelId", MLModelState.LOADING, 2, null, Instant.now().toEpochMilli())); + return null; + }).when(client).search(any(), any()); + syncUpCron.refreshModelState(modelWorkerNodes, loadingModels); + verify(client, times(1)).search(any(), any()); + verify(client, never()).bulk(any(), any()); + } + + public void testRefreshModelState_NotResetState_LoadingInGraceTime() { + Map> modelWorkerNodes = new HashMap<>(); + Map> loadingModels = new HashMap<>(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(createSearchModelResponse("modelId", MLModelState.LOADING, 2, null, Instant.now().toEpochMilli())); + return null; + }).when(client).search(any(), any()); + syncUpCron.refreshModelState(modelWorkerNodes, loadingModels); + verify(client, times(1)).search(any(), any()); + verify(client, never()).bulk(any(), any()); + } + private void mockSyncUp_GatherRunningTasks() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); List nodeResponses = new ArrayList<>(); String[] loadedModelIds = new String[] { randomAlphaOfLength(10) }; + String[] runningLoadModelIds = new String[] { randomAlphaOfLength(10) }; String[] runningLoadModelTaskIds = new String[] { randomAlphaOfLength(10) }; - nodeResponses.add(new MLSyncUpNodeResponse(mlNode1, "ok", loadedModelIds, runningLoadModelTaskIds)); + nodeResponses.add(new MLSyncUpNodeResponse(mlNode1, "ok", loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds)); MLSyncUpNodesResponse response = new MLSyncUpNodesResponse(ClusterName.DEFAULT, nodeResponses, Arrays.asList()); listener.onResponse(response); return null; @@ -103,4 +313,44 @@ private void mockSyncUp_GatherRunningTasks_Failure() { return null; }).when(client).execute(eq(MLSyncUpAction.INSTANCE), any(), any()); } + + private SearchResponse createSearchModelResponse( + String modelId, + MLModelState state, + Integer planningWorkerNodeCount, + Integer currentWorkerNodeCount, + Long lastUpdateTime + ) throws IOException { + XContentBuilder content = TestHelper.builder(); + content.startObject(); + content.field(MLModel.MODEL_STATE_FIELD, state); + content.field(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, planningWorkerNodeCount); + if (currentWorkerNodeCount != null) { + content.field(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, currentWorkerNodeCount); + } + content.field(MLModel.LAST_UPDATED_TIME_FIELD, lastUpdateTime); + content.endObject(); + + SearchHit[] hits = new SearchHit[1]; + hits[0] = new SearchHit(0, modelId, null, null).sourceRef(BytesReference.bytes(content)); + + return new SearchResponse( + new InternalSearchResponse( + new SearchHits(hits, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0f), + InternalAggregations.EMPTY, + new Suggest(Collections.emptyList()), + new SearchProfileShardResults(Collections.emptyMap()), + false, + false, + 1 + ), + "", + 5, + 5, + 0, + 100, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + } }