Skip to content

Commit

Permalink
refresh model load state in sync up cron job (#704)
Browse files Browse the repository at this point in the history
* refresh model load state in sync up cron job

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* refactor code

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

Signed-off-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
ylwu-amzn authored Jan 24, 2023
1 parent a788998 commit 1ba4de8
Show file tree
Hide file tree
Showing 17 changed files with 620 additions and 94 deletions.
11 changes: 10 additions & 1 deletion common/src/main/java/org/opensearch/ml/common/CommonValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public class CommonValue {

public static final String ML_MODEL_INDEX = ".plugins-ml-model";
public static final String ML_TASK_INDEX = ".plugins-ml-task";
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 2;
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 3;
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 1;
public static final String USER_FIELD_MAPPING = " \""
+ CommonValue.USER
Expand Down Expand Up @@ -85,6 +85,12 @@ public class CommonValue {
+ MLModel.MODEL_CONTENT_SIZE_IN_BYTES_FIELD
+ "\" : {\"type\": \"long\"},\n"
+ " \""
+ MLModel.PLANNING_WORKER_NODE_COUNT_FIELD
+ "\" : {\"type\": \"integer\"},\n"
+ " \""
+ MLModel.CURRENT_WORKER_NODE_COUNT_FIELD
+ "\" : {\"type\": \"integer\"},\n"
+ " \""
+ MLModel.MODEL_CONFIG_FIELD
+ "\" : {\"properties\":{\""
+ MODEL_TYPE_FIELD + "\":{\"type\":\"keyword\"},\""
Expand All @@ -101,6 +107,9 @@ public class CommonValue {
+ MLModel.CREATED_TIME_FIELD
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n"
+ " \""
+ MLModel.LAST_UPDATED_TIME_FIELD
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n"
+ " \""
+ MLModel.LAST_UPLOADED_TIME_FIELD
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n"
+ " \""
Expand Down
60 changes: 59 additions & 1 deletion common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,16 @@ public class MLModel implements ToXContentObject {
public static final String MODEL_CONTENT_HASH_VALUE_FIELD = "model_content_hash_value";
public static final String MODEL_CONFIG_FIELD = "model_config";
public static final String CREATED_TIME_FIELD = "created_time";
public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time";
public static final String LAST_UPLOADED_TIME_FIELD = "last_uploaded_time";
public static final String LAST_LOADED_TIME_FIELD = "last_loaded_time";
public static final String LAST_UNLOADED_TIME_FIELD = "last_unloaded_time";

public static final String MODEL_ID_FIELD = "model_id";
public static final String CHUNK_NUMBER_FIELD = "chunk_number";
public static final String TOTAL_CHUNKS_FIELD = "total_chunks";
public static final String PLANNING_WORKER_NODE_COUNT_FIELD = "planning_worker_node_count";
public static final String CURRENT_WORKER_NODE_COUNT_FIELD = "current_worker_node_count";

private String name;
private FunctionName algorithm;
Expand All @@ -66,6 +69,7 @@ public class MLModel implements ToXContentObject {
private String modelContentHash;
private MLModelConfig modelConfig;
private Instant createdTime;
private Instant lastUpdateTime;
private Instant lastUploadedTime;
private Instant lastLoadedTime;
private Instant lastUnloadedTime;
Expand All @@ -74,9 +78,30 @@ public class MLModel implements ToXContentObject {
private String modelId; // model chunk doc only
private Integer chunkNumber; // model chunk doc only
private Integer totalChunks; // model chunk doc only
private Integer planningWorkerNodeCount; // plan to deploy model to how many nodes
private Integer currentWorkerNodeCount; // model is deployed to how many nodes

@Builder(toBuilder = true)
public MLModel(String name, FunctionName algorithm, String version, String content, User user, String description, MLModelFormat modelFormat, MLModelState modelState, Long modelContentSizeInBytes, String modelContentHash, MLModelConfig modelConfig, Instant createdTime, Instant lastUploadedTime, Instant lastLoadedTime, Instant lastUnloadedTime, String modelId, Integer chunkNumber, Integer totalChunks) {
public MLModel(String name,
FunctionName algorithm,
String version,
String content,
User user,
String description,
MLModelFormat modelFormat,
MLModelState modelState,
Long modelContentSizeInBytes,
String modelContentHash,
MLModelConfig modelConfig,
Instant createdTime,
Instant lastUpdateTime,
Instant lastUploadedTime,
Instant lastLoadedTime,
Instant lastUnloadedTime,
String modelId, Integer chunkNumber,
Integer totalChunks,
Integer planningWorkerNodeCount,
Integer currentWorkerNodeCount) {
this.name = name;
this.algorithm = algorithm;
this.version = version;
Expand All @@ -89,12 +114,15 @@ public MLModel(String name, FunctionName algorithm, String version, String conte
this.modelContentHash = modelContentHash;
this.modelConfig = modelConfig;
this.createdTime = createdTime;
this.lastUpdateTime = lastUpdateTime;
this.lastUploadedTime = lastUploadedTime;
this.lastLoadedTime = lastLoadedTime;
this.lastUnloadedTime = lastUnloadedTime;
this.modelId = modelId;
this.chunkNumber = chunkNumber;
this.totalChunks = totalChunks;
this.planningWorkerNodeCount = planningWorkerNodeCount;
this.currentWorkerNodeCount = currentWorkerNodeCount;
}

public MLModel(StreamInput input) throws IOException{
Expand All @@ -121,12 +149,15 @@ public MLModel(StreamInput input) throws IOException{
modelConfig = new TextEmbeddingModelConfig(input);
}
createdTime = input.readOptionalInstant();
lastUpdateTime = input.readOptionalInstant();
lastUploadedTime = input.readOptionalInstant();
lastLoadedTime = input.readOptionalInstant();
lastUnloadedTime = input.readOptionalInstant();
modelId = input.readOptionalString();
chunkNumber = input.readOptionalInt();
totalChunks = input.readOptionalInt();
planningWorkerNodeCount = input.readOptionalInt();
currentWorkerNodeCount = input.readOptionalInt();
}
}

Expand Down Expand Up @@ -163,12 +194,15 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(false);
}
out.writeOptionalInstant(createdTime);
out.writeOptionalInstant(lastUpdateTime);
out.writeOptionalInstant(lastUploadedTime);
out.writeOptionalInstant(lastLoadedTime);
out.writeOptionalInstant(lastUnloadedTime);
out.writeOptionalString(modelId);
out.writeOptionalInt(chunkNumber);
out.writeOptionalInt(totalChunks);
out.writeOptionalInt(planningWorkerNodeCount);
out.writeOptionalInt(currentWorkerNodeCount);
}

@Override
Expand Down Expand Up @@ -210,6 +244,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (createdTime != null) {
builder.field(CREATED_TIME_FIELD, createdTime.toEpochMilli());
}
if (lastUpdateTime != null) {
builder.field(LAST_UPDATED_TIME_FIELD, lastUpdateTime.toEpochMilli());
}
if (lastUploadedTime != null) {
builder.field(LAST_UPLOADED_TIME_FIELD, lastUploadedTime.toEpochMilli());
}
Expand All @@ -228,6 +265,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (totalChunks != null) {
builder.field(TOTAL_CHUNKS_FIELD, totalChunks);
}
if (planningWorkerNodeCount != null) {
builder.field(PLANNING_WORKER_NODE_COUNT_FIELD, planningWorkerNodeCount);
}
if (currentWorkerNodeCount != null) {
builder.field(CURRENT_WORKER_NODE_COUNT_FIELD, currentWorkerNodeCount);
}
builder.endObject();
return builder;
}
Expand All @@ -248,12 +291,15 @@ public static MLModel parse(XContentParser parser) throws IOException {
String modelContentHash = null;
MLModelConfig modelConfig = null;
Instant createdTime = null;
Instant lastUpdateTime = null;
Instant lastUploadedTime = null;
Instant lastLoadedTime = null;
Instant lastUnloadedTime = null;
String modelId = null;
Integer chunkNumber = null;
Integer totalChunks = null;
Integer planningWorkerNodeCount = null;
Integer currentWorkerNodeCount = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -309,9 +355,18 @@ public static MLModel parse(XContentParser parser) throws IOException {
case MODEL_CONFIG_FIELD:
modelConfig = TextEmbeddingModelConfig.parse(parser);
break;
case PLANNING_WORKER_NODE_COUNT_FIELD:
planningWorkerNodeCount = parser.intValue();
break;
case CURRENT_WORKER_NODE_COUNT_FIELD:
currentWorkerNodeCount = parser.intValue();
break;
case CREATED_TIME_FIELD:
createdTime = Instant.ofEpochMilli(parser.longValue());
break;
case LAST_UPDATED_TIME_FIELD:
lastUpdateTime = Instant.ofEpochMilli(parser.longValue());
break;
case LAST_UPLOADED_TIME_FIELD:
lastUploadedTime = Instant.ofEpochMilli(parser.longValue());
break;
Expand Down Expand Up @@ -339,12 +394,15 @@ public static MLModel parse(XContentParser parser) throws IOException {
.modelContentHash(modelContentHash)
.modelConfig(modelConfig)
.createdTime(createdTime)
.lastUpdateTime(lastUpdateTime)
.lastUploadedTime(lastUploadedTime)
.lastLoadedTime(lastLoadedTime)
.lastUnloadedTime(lastUnloadedTime)
.modelId(modelId)
.chunkNumber(chunkNumber)
.totalChunks(totalChunks)
.planningWorkerNodeCount(planningWorkerNodeCount)
.currentWorkerNodeCount(currentWorkerNodeCount)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,31 @@
import org.opensearch.common.io.stream.StreamOutput;

import java.io.IOException;
import java.util.Map;

@Log4j2
@Getter
public class MLSyncUpNodeResponse extends BaseNodeResponse {

private String modelStatus;
private String[] loadedModelIds;
private String[] runningLoadModelTaskIds;
private String[] runningLoadModelIds; // model ids which have loading model task running
private String[] runningLoadModelTaskIds; // load model task ids which is running

public MLSyncUpNodeResponse(DiscoveryNode node, String modelStatus, String[] loadedModelIds, String[] runningLoadModelTaskIds) {
public MLSyncUpNodeResponse(DiscoveryNode node, String modelStatus, String[] loadedModelIds, String[] runningLoadModelIds,
String[] runningLoadModelTaskIds) {
super(node);
this.modelStatus = modelStatus;
this.loadedModelIds = loadedModelIds;
this.runningLoadModelIds = runningLoadModelIds;
this.runningLoadModelTaskIds = runningLoadModelTaskIds;
}

public MLSyncUpNodeResponse(StreamInput in) throws IOException {
super(in);
this.modelStatus = in.readOptionalString();
this.loadedModelIds = in.readOptionalStringArray();
this.runningLoadModelIds = in.readOptionalStringArray();
this.runningLoadModelTaskIds = in.readOptionalStringArray();
}

Expand All @@ -45,6 +50,7 @@ public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeOptionalString(modelStatus);
out.writeOptionalStringArray(loadedModelIds);
out.writeOptionalStringArray(runningLoadModelIds);
out.writeOptionalStringArray(runningLoadModelTaskIds);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import java.io.IOException;
import java.net.InetAddress;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import static org.junit.Assert.*;
import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE;
Expand All @@ -23,6 +25,7 @@ public class MLSyncUpNodeResponseTest {
private final String modelStatus = "modelStatus";
private final String[] loadedModelIds = {"loadedModelIds"};
private final String[] runningLoadModelTaskIds = {"runningLoadModelTaskIds"};
private final String[] runningLoadModelIds = {"modelid1"};
@Before
public void setUp() throws Exception {
localNode = new DiscoveryNode(
Expand All @@ -37,7 +40,7 @@ public void setUp() throws Exception {

@Test
public void testSerializationDeserialization() throws IOException {
MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelTaskIds);
MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds);
BytesStreamOutput output = new BytesStreamOutput();
response.writeTo(output);
MLSyncUpNodeResponse newResponse = new MLSyncUpNodeResponse(output.bytes().streamInput());
Expand All @@ -50,7 +53,7 @@ public void testSerializationDeserialization() throws IOException {

@Test
public void testReadProfile() throws IOException {
MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelTaskIds);
MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds);
BytesStreamOutput output = new BytesStreamOutput();
response.writeTo(output);
MLSyncUpNodeResponse newResponse = MLSyncUpNodeResponse.readStats(output.bytes().streamInput());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,18 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw

if (workNodes == null || workNodes.size() == 0) {
MLTaskCache mlTaskCache = mlTaskManager.getMLTaskCache(taskId);
int currentWorkerNodeCount = mlTaskCache.getWorkerNodeSize();
MLTaskState taskState = mlTaskCache.hasError() ? MLTaskState.COMPLETED_WITH_ERROR : MLTaskState.COMPLETED;
if (mlTaskCache.allNodeFailed()) {
taskState = MLTaskState.FAILED;
currentWorkerNodeCount = 0;
} else {
syncModelWorkerNodes(modelId);
}
ImmutableMap.Builder<String, Object> builder = ImmutableMap.builder();
builder.put(MLTask.STATE_FIELD, taskState);
if (mlTaskCache.hasError()) {
currentWorkerNodeCount = mlTaskCache.getWorkerNodeSize() - mlTaskCache.getErrors().size();
builder.put(MLTask.ERROR_FIELD, toJsonString(mlTaskCache.getErrors()));
}
mlTaskManager.updateMLTask(taskId, builder.build(), TASK_SEMAPHORE_TIMEOUT, true);
Expand All @@ -125,7 +128,14 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw
.updateModel(
modelId,
ImmutableMap
.of(MLModel.MODEL_STATE_FIELD, modelState, MLModel.LAST_LOADED_TIME_FIELD, Instant.now().toEpochMilli())
.of(
MLModel.MODEL_STATE_FIELD,
modelState,
MLModel.LAST_LOADED_TIME_FIELD,
Instant.now().toEpochMilli(),
MLModel.CURRENT_WORKER_NODE_COUNT_FIELD,
currentWorkerNodeCount
)
);
}
listener.onResponse(new MLForwardResponse("ok", null));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ void updateModelLoadStatusAndTriggerOnNodesAction(
mlModelManager
.updateModel(
modelId,
ImmutableMap.of(MLModel.MODEL_STATE_FIELD, MLModelState.LOADING),
ImmutableMap
.of(MLModel.MODEL_STATE_FIELD, MLModelState.LOADING, MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, eligibleNodes.size()),
ActionListener
.wrap(
r -> client.execute(MLLoadModelOnNodeAction.INSTANCE, loadModelRequest, actionListener),
Expand Down
Loading

0 comments on commit 1ba4de8

Please sign in to comment.