Skip to content

Commit

Permalink
the dev of [FEATURE]Auto reload model when cluster rebooted/node rejo…
Browse files Browse the repository at this point in the history
…in (#711)

* [wjunshen] #N/A feat: fix after the latest rebase

Signed-off-by: wujunshen <frank_wjs@hotmail.com>

* [wjunshen] #N/A feat: fix after rebase

Signed-off-by: wujunshen <frank_wjs@hotmail.com>

* [wjunshen] #N/A feat: fix after rebase

Signed-off-by: wujunshen <frank_wjs@hotmail.com>

* [wjunshen] #N/A feat: fix after rebase

Signed-off-by: wujunshen <frank_wjs@hotmail.com>

* [wjunshen] #N/A feat: fix after the latest rebase

Signed-off-by: wujunshen <frank_wjs@hotmail.com>

* Increment version to 2.6.0-SNAPSHOT (#671)

Signed-off-by: opensearch-ci-bot <opensearch-infra@amazon.com>

Signed-off-by: opensearch-ci-bot <opensearch-infra@amazon.com>
Co-authored-by: opensearch-ci-bot <opensearch-infra@amazon.com>
Signed-off-by: wujunshen <frank_wjs@hotmail.com>

* fix profile API in example doc (#712)

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

Signed-off-by: Yaliang Wu <ylwu@amazon.com>
Signed-off-by: wujunshen <frank_wjs@hotmail.com>

* change model url to public repo in text embedding model example doc (#713)

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

Signed-off-by: Yaliang Wu <ylwu@amazon.com>
Signed-off-by: wujunshen <frank_wjs@hotmail.com>

* Enhance profile API to add model centric result controlled by view parameter (#714)

* Enhance profile API to add model centric result controled by view paramter

Signed-off-by: Zan Niu <zaniu@amazon.com>

* Enhance profile API to add model centric result controled by view parameter

Signed-off-by: Zan Niu <zaniu@amazon.com>

* Enhance profile API to add model centric result controled by view parameter

Signed-off-by: Zan Niu <zaniu@amazon.com>

---------

Signed-off-by: Zan Niu <zaniu@amazon.com>
Signed-off-by: wujunshen <frank_wjs@hotmail.com>

* add planning work nodes to model (#715)

* add planning work nodes to model

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* add test

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

---------

Signed-off-by: Yaliang Wu <ylwu@amazon.com>
Signed-off-by: wujunshen <frank_wjs@hotmail.com>

* skip running syncup job if no model index (#717)

Signed-off-by: Yaliang Wu <ylwu@amazon.com>
Signed-off-by: wujunshen <frank_wjs@hotmail.com>

* refactor: add DL model class (#722)

* refactor: add DL model class

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* fix model url in example doc

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* address comments

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* fix failed ut

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

---------

Signed-off-by: Yaliang Wu <ylwu@amazon.com>
Signed-off-by: wujunshen <frank_wjs@hotmail.com>

* tune model config: change pooling mode to optional (#724)

Signed-off-by: Yaliang Wu <ylwu@amazon.com>
Signed-off-by: wujunshen <frank_wjs@hotmail.com>

* [wjunshen] #N/A feat: make the log readable

Signed-off-by: wujunshen <frank_wjs@hotmail.com>

* [wjunshen] #N/A feat: add error log

Signed-off-by: wujunshen <frank_wjs@hotmail.com>

* [wjunshen] #N/A feat: Refer to PR #717,just checking if index exists in cluster metadata

Signed-off-by: wujunshen <frank_wjs@hotmail.com>

* [wjunshen] #N/A feat: change RunTimeException to MLException

Signed-off-by: wujunshen <frank_wjs@hotmail.com>

* [wjunshen] #N/A feat: also consider COMPLETED_WITH_ERROR

Signed-off-by: wujunshen <frank_wjs@hotmail.com>

* [wjunshen] #N/A feat: remove ML_MODEL_RELOAD_MAX_RETRY_TIMES in CommonValue.java

Signed-off-by: wujunshen <frank_wjs@hotmail.com>

* [wjunshen] #N/A feat: remove Result class

Signed-off-by: wujunshen <frank_wjs@hotmail.com>

* [wjunshen] #N/A feat: change "reload" and "retry" to a full word

Signed-off-by: wujunshen <frank_wjs@hotmail.com>

* [wjunshen] #N/A feat: change log info sentence

Signed-off-by: wujunshen <frank_wjs@hotmail.com>

* [wjunshen] #N/A feat: code format

Signed-off-by: wujunshen <frank_wjs@hotmail.com>

* [Signed-off-by: wjunshen<wjunshen@amazon.com>] #N/A feat:
1. Let the collection just have all ids of ml node
2. print out full exception stack trace

Signed-off-by: wujunshen <frank_wjs@hotmail.com>

* [Signed-off-by: wjunshen<wjunshen@amazon.com>] #N/A feat:
use LOAD_THREAD_POOL to replace generic

Signed-off-by: wujunshen <frank_wjs@hotmail.com>

* [Signed-off-by: wjunshen<wjunshen@amazon.com>] #N/A feat:
print the whole exception stack

Signed-off-by: wujunshen <frank_wjs@hotmail.com>

* [Signed-off-by: wjunshen<wjunshen@amazon.com>] #N/A feat:
change the IndexNotFoundException to MLException

Signed-off-by: wujunshen <frank_wjs@hotmail.com>

* [Signed-off-by: wjunshen<wjunshen@amazon.com>] #N/A feat:
add logs when receiving indexRequestBuilder.execute exception

Signed-off-by: wujunshen <frank_wjs@hotmail.com>

* [Signed-off-by: wjunshen<wjunshen@amazon.com>] #N/A feat:
change the test code after code review

Signed-off-by: wujunshen <frank_wjs@hotmail.com>

---------

Signed-off-by: wujunshen <frank_wjs@hotmail.com>
Signed-off-by: opensearch-ci-bot <opensearch-infra@amazon.com>
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
Signed-off-by: Zan Niu <zaniu@amazon.com>
Co-authored-by: opensearch-trigger-bot[bot] <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com>
Co-authored-by: opensearch-ci-bot <opensearch-infra@amazon.com>
Co-authored-by: Yaliang Wu <ylwu@amazon.com>
Co-authored-by: zane-neo <zaniu@amazon.com>
  • Loading branch information
5 people authored Feb 20, 2023
1 parent 2ce521c commit 7a51dcc
Show file tree
Hide file tree
Showing 6 changed files with 1,500 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ public class CommonValue {
public static final String ML_TASK_INDEX = ".plugins-ml-task";
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 3;
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 1;

public static final String ML_MODEL_RELOAD_INDEX = ".plugins-ml-model-reload";
public static final String NODE_ID_FIELD = "node_id";
public static final String MODEL_LOAD_RETRY_TIMES_FIELD = "retry_times";
public static final String USER_FIELD_MAPPING = " \""
+ CommonValue.USER
+ "\": {\n"
Expand Down
331 changes: 331 additions & 0 deletions plugin/src/main/java/org/opensearch/ml/model/MLModelAutoReloader.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,331 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.ml.model;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_RELOAD_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX;
import static org.opensearch.ml.common.CommonValue.MODEL_LOAD_RETRY_TIMES_FIELD;
import static org.opensearch.ml.common.CommonValue.NODE_ID_FIELD;
import static org.opensearch.ml.plugin.MachineLearningPlugin.LOAD_THREAD_POOL;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_RELOAD_ENABLE;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_MODEL_RELOAD_MAX_RETRY_TIMES;
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;

import lombok.extern.log4j.Log4j2;

import org.opensearch.action.ActionListener;
import org.opensearch.action.StepListener;
import org.opensearch.action.index.IndexAction;
import org.opensearch.action.index.IndexRequestBuilder;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.search.SearchAction;
import org.opensearch.action.search.SearchRequestBuilder;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.CollectionUtils;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.transport.load.MLLoadModelAction;
import org.opensearch.ml.common.transport.load.MLLoadModelRequest;
import org.opensearch.ml.utils.MLNodeUtils;
import org.opensearch.rest.RestStatus;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.sort.FieldSortBuilder;
import org.opensearch.search.sort.SortBuilder;
import org.opensearch.search.sort.SortOrder;
import org.opensearch.threadpool.ThreadPool;

import com.google.common.annotations.VisibleForTesting;

/**
* Manager class for ML models and nodes. It contains ML model auto reload operations etc.
*/
@Log4j2
public class MLModelAutoReloader {

private final Client client;
private final ClusterService clusterService;
private final NamedXContentRegistry xContentRegistry;
private final DiscoveryNodeHelper nodeHelper;
private final ThreadPool threadPool;
private volatile Boolean enableAutoReloadModel;
private volatile Integer autoReloadMaxRetryTimes;

/**
* constructor method, init all the params necessary for model auto reloading
*
* @param clusterService clusterService
* @param threadPool threadPool
* @param client client
* @param xContentRegistry xContentRegistry
* @param nodeHelper nodeHelper
* @param settings settings
*/
public MLModelAutoReloader(
ClusterService clusterService,
ThreadPool threadPool,
Client client,
NamedXContentRegistry xContentRegistry,
DiscoveryNodeHelper nodeHelper,
Settings settings
) {
this.clusterService = clusterService;
this.client = client;
this.xContentRegistry = xContentRegistry;
this.nodeHelper = nodeHelper;
this.threadPool = threadPool;

enableAutoReloadModel = ML_COMMONS_MODEL_AUTO_RELOAD_ENABLE.get(settings);
autoReloadMaxRetryTimes = ML_MODEL_RELOAD_MAX_RETRY_TIMES.get(settings);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_MODEL_AUTO_RELOAD_ENABLE, it -> enableAutoReloadModel = it);

clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_MODEL_RELOAD_MAX_RETRY_TIMES, it -> autoReloadMaxRetryTimes = it);
}

/**
* the main method: model auto reloading
*/
public void autoReloadModel() {
log.info("auto reload model enabled: {} ", enableAutoReloadModel);

// if we don't need to reload automatically, just return without doing anything
if (!enableAutoReloadModel) {
return;
}

// At opensearch startup, get local node id, if not ml node,we ignored, just return without doing anything
if (!MLNodeUtils.isMLNode(clusterService.localNode())) {
return;
}

String localNodeId = clusterService.localNode().getId();
// auto reload all models of this local ml node
threadPool.executor(LOAD_THREAD_POOL).execute(() -> {
try {
autoReloadModelByNodeId(localNodeId);
} catch (ExecutionException | InterruptedException e) {
log.error("the model auto-reloading has exception,and the root cause message is: {}", e);
throw new MLException(e);
}
});
}

/**
* auto reload all the models under the node id<br> the node must be a ml node<br>
*
* @param localNodeId node id
*/
@VisibleForTesting
void autoReloadModelByNodeId(String localNodeId) throws ExecutionException, InterruptedException {
StepListener<SearchResponse> queryTaskStep = new StepListener<>();
StepListener<SearchResponse> getRetryTimesStep = new StepListener<>();
StepListener<IndexResponse> saveLatestRetryTimesStep = new StepListener<>();

if (!clusterService.state().metadata().indices().containsKey(ML_TASK_INDEX)) {
// ML_TASK_INDEX did not exist,do nothing
return;
}

queryTask(localNodeId, ActionListener.wrap(queryTaskStep::onResponse, queryTaskStep::onFailure));

getRetryTimes(localNodeId, ActionListener.wrap(getRetryTimesStep::onResponse, getRetryTimesStep::onFailure));

queryTaskStep.whenComplete(searchResponse -> {
SearchHit[] hits = searchResponse.getHits().getHits();
if (CollectionUtils.isEmpty(hits)) {
return;
}

getRetryTimesStep.whenComplete(getReTryTimesResponse -> {
int retryTimes = 0;
// if getReTryTimesResponse is null,it means we get retryTimes at the first time,and the index
// .plugins-ml-model-reload doesn't exist,so we should let retryTimes be zero(init value)
// we don't do anything
// if getReTryTimesResponse is not null,it means we have saved the value of retryTimes into the index
// .plugins-ml-model-reload,so we get the value of the field MODEL_LOAD_RETRY_TIMES_FIELD
if (getReTryTimesResponse != null) {
Map<String, Object> sourceAsMap = getReTryTimesResponse.getHits().getHits()[0].getSourceAsMap();
retryTimes = (Integer) sourceAsMap.get(MODEL_LOAD_RETRY_TIMES_FIELD);
}

// According to the node id to get retry times, if more than the max retry times, don't need to retry
// that the number of unsuccessful reload has reached the maximum number of times, do not need to reload
if (retryTimes > autoReloadMaxRetryTimes) {
log.info("Node: {} has reached to the max retry limit, failed to load models", localNodeId);
return;
}

try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, hits[0].getSourceRef())) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLTask mlTask = MLTask.parse(parser);

autoReloadModelByNodeAndModelId(localNodeId, mlTask.getModelId());

// if reload the model successfully,the number of unsuccessful reload should be reset to zero.
retryTimes = 0;
} catch (MLException e) {
retryTimes++;
log.error("Can't auto reload model in node id {} ,has tried {} times\nThe reason is:{}", localNodeId, retryTimes, e);
}

// Store the latest value of the retryTimes and node id under the index ".plugins-ml-model-reload"
saveLatestRetryTimes(
localNodeId,
retryTimes,
ActionListener.wrap(saveLatestRetryTimesStep::onResponse, saveLatestRetryTimesStep::onFailure)
);
}, getRetryTimesStep::onFailure);
}, queryTaskStep::onFailure);

saveLatestRetryTimesStep.whenComplete(response -> log.info("successfully complete all steps"), saveLatestRetryTimesStep::onFailure);
}

/**
* auto reload 1 model under the node id
*
* @param localNodeId node id
* @param modelId model id
*/
@VisibleForTesting
void autoReloadModelByNodeAndModelId(String localNodeId, String modelId) throws MLException {
List<String> allMLNodeIdList = Arrays
.stream(nodeHelper.getAllNodes())
.filter(MLNodeUtils::isMLNode)
.map(DiscoveryNode::getId)
.collect(Collectors.toList());

if (!allMLNodeIdList.contains(localNodeId)) {
allMLNodeIdList.add(localNodeId);
}
MLLoadModelRequest mlLoadModelRequest = new MLLoadModelRequest(modelId, allMLNodeIdList.toArray(new String[] {}), false, false);

client
.execute(
MLLoadModelAction.INSTANCE,
mlLoadModelRequest,
ActionListener
.wrap(response -> log.info("the model {} is auto reloading under the node {} ", modelId, localNodeId), exception -> {
log.error("fail to reload model " + modelId + " under the node " + localNodeId + "\nthe reason is: " + exception);
throw new MLException(
"fail to reload model " + modelId + " under the node " + localNodeId + "\nthe reason is: " + exception
);
})
);
}

/**
* query task index, and get the result of "task_type"="LOAD_MODEL" and "state"="COMPLETED" and
* "worker_node" match nodeId
*
* @param localNodeId one of query condition
*/
@VisibleForTesting
void queryTask(String localNodeId, ActionListener<SearchResponse> searchResponseActionListener) {
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().from(0).size(1);

QueryBuilder queryBuilder = QueryBuilders
.boolQuery()
.must(QueryBuilders.matchPhraseQuery("task_type", "LOAD_MODEL"))
.must(QueryBuilders.matchPhraseQuery("worker_node", localNodeId))
.must(
QueryBuilders
.boolQuery()
.should(QueryBuilders.matchPhraseQuery("state", "COMPLETED"))
.should(QueryBuilders.matchPhraseQuery("state", "COMPLETED_WITH_ERROR"))
);
searchSourceBuilder.query(queryBuilder);

SortBuilder<FieldSortBuilder> sortBuilderOrder = new FieldSortBuilder("create_time").order(SortOrder.DESC);
searchSourceBuilder.sort(sortBuilderOrder);

SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client, SearchAction.INSTANCE)
.setIndices(ML_TASK_INDEX)
.setSource(searchSourceBuilder);

searchRequestBuilder.execute(ActionListener.wrap(searchResponseActionListener::onResponse, exception -> {
log.error("index {} not found, the reason is {}", ML_TASK_INDEX, exception);
throw new MLException("index " + ML_TASK_INDEX + " not found");
}));
}

/**
* get retry times from the index ".plugins-ml-model-reload" by 1 ml node
*
* @param localNodeId the filter condition to query
*/
@VisibleForTesting
void getRetryTimes(String localNodeId, ActionListener<SearchResponse> searchResponseActionListener) {
if (!clusterService.state().metadata().indices().containsKey(ML_MODEL_RELOAD_INDEX)) {
// ML_MODEL_RELOAD_INDEX did not exist, it means it is our first time to do model auto-reloading operation
searchResponseActionListener.onResponse(null);
}

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.fetchSource(new String[] { MODEL_LOAD_RETRY_TIMES_FIELD }, null);
QueryBuilder queryBuilder = QueryBuilders.idsQuery().addIds(localNodeId);
searchSourceBuilder.query(queryBuilder);

SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client, SearchAction.INSTANCE)
.setIndices(ML_MODEL_RELOAD_INDEX)
.setSource(searchSourceBuilder);

searchRequestBuilder.execute(ActionListener.wrap(searchResponse -> {
SearchHit[] hits = searchResponse.getHits().getHits();
if (CollectionUtils.isEmpty(hits)) {
searchResponseActionListener.onResponse(null);
return;
}

searchResponseActionListener.onResponse(searchResponse);
}, searchResponseActionListener::onFailure));
}

/**
* save retry times
* @param localNodeId node id
* @param retryTimes actual retry times
*/
@VisibleForTesting
void saveLatestRetryTimes(String localNodeId, int retryTimes, ActionListener<IndexResponse> indexResponseActionListener) {
Map<String, Object> content = new HashMap<>(2);
content.put(NODE_ID_FIELD, localNodeId);
content.put(MODEL_LOAD_RETRY_TIMES_FIELD, retryTimes);

IndexRequestBuilder indexRequestBuilder = new IndexRequestBuilder(client, IndexAction.INSTANCE, ML_MODEL_RELOAD_INDEX)
.setId(localNodeId)
.setSource(content)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

indexRequestBuilder.execute(ActionListener.wrap(indexResponse -> {
if (indexResponse.status() == RestStatus.CREATED || indexResponse.status() == RestStatus.OK) {
log.info("node id:{} insert retry times successfully", localNodeId);
indexResponseActionListener.onResponse(indexResponse);
}
}, e -> {
log.error("node id:" + localNodeId + " insert retry times unsuccessfully", e);
indexResponseActionListener.onFailure(new MLException(e));
}));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
import org.opensearch.ml.engine.algorithms.sample.LocalSampleCalculator;
import org.opensearch.ml.indices.MLIndicesHandler;
import org.opensearch.ml.indices.MLInputDatasetHandler;
import org.opensearch.ml.model.MLModelAutoReloader;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.rest.RestMLCreateModelMetaAction;
Expand Down Expand Up @@ -175,6 +176,8 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin {
private MLModelChunkUploader mlModelChunkUploader;
private MLEngine mlEngine;

private MLModelAutoReloader mlModelAutoReloader;

private Client client;
private ClusterService clusterService;
private ThreadPool threadPool;
Expand Down Expand Up @@ -352,6 +355,9 @@ public Collection<Object> createComponents(
mlIndicesHandler
);

mlModelAutoReloader = new MLModelAutoReloader(clusterService, threadPool, client, xContentRegistry, nodeHelper, settings);
mlModelAutoReloader.autoReloadModel();

return ImmutableList
.of(
mlEngine,
Expand All @@ -373,7 +379,8 @@ public Collection<Object> createComponents(
modelHelper,
mlCommonsClusterEventListener,
clusterManagerEventListener,
mlCircuitBreakerService
mlCircuitBreakerService,
mlModelAutoReloader
);
}

Expand Down Expand Up @@ -513,7 +520,9 @@ public List<Setting<?>> getSettings() {
MLCommonsSettings.ML_COMMONS_MAX_ML_TASK_PER_NODE,
MLCommonsSettings.ML_COMMONS_MAX_LOAD_MODEL_TASKS_PER_NODE,
MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX,
MLCommonsSettings.ML_COMMONS_NATIVE_MEM_THRESHOLD
MLCommonsSettings.ML_COMMONS_NATIVE_MEM_THRESHOLD,
MLCommonsSettings.ML_COMMONS_MODEL_AUTO_RELOAD_ENABLE,
MLCommonsSettings.ML_MODEL_RELOAD_MAX_RETRY_TIMES
);
return settings;
}
Expand Down
Loading

0 comments on commit 7a51dcc

Please sign in to comment.