Skip to content

Commit

Permalink
refactor persisting ML model (#109)
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
ylwu-amzn authored Dec 23, 2021
1 parent 2fed70a commit e3fc904
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
@RequiredArgsConstructor
@Log4j2
public class MLIndicesHandler {
public static final String ML_MODEL = ".plugins-ml-model";
public static final String ML_MODEL_INDEX = ".plugins-ml-model";
private static final String ML_MODEL_INDEX_MAPPING = "{\n"
+ " \"properties\": {\n"
+ " \"task_id\": { \"type\": \"keyword\" },\n"
Expand All @@ -40,11 +40,11 @@ public class MLIndicesHandler {
Client client;

public void initModelIndexIfAbsent() {
initMLIndexIfAbsent(ML_MODEL, ML_MODEL_INDEX_MAPPING);
initMLIndexIfAbsent(ML_MODEL_INDEX, ML_MODEL_INDEX_MAPPING);
}

public boolean doesModelIndexExist() {
return clusterService.state().metadata().hasIndex(ML_MODEL);
return clusterService.state().metadata().hasIndex(ML_MODEL_INDEX);
}

private void initMLIndexIfAbsent(String indexName, String mapping) {
Expand Down
75 changes: 75 additions & 0 deletions plugin/src/main/java/org/opensearch/ml/model/MLModel.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.ml.model;

import java.io.IOException;
import java.util.Base64;

import lombok.Builder;
import lombok.Getter;

import org.opensearch.common.xcontent.ToXContentObject;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.commons.authuser.User;
import org.opensearch.ml.common.parameter.FunctionName;
import org.opensearch.ml.engine.Model;

@Getter
public class MLModel implements ToXContentObject {
public static final String ALGORITHM = "algorithm";
public static final String MODEL_NAME = "name";
public static final String MODEL_VERSION = "version";
public static final String MODEL_CONTENT = "content";
public static final String USER = "user";

private String name;
private FunctionName algorithm;
private Integer version;
private String content;
private User user;

@Builder
public MLModel(String name, FunctionName algorithm, Integer version, String content, User user) {
this.name = name;
this.algorithm = algorithm;
this.version = version;
this.content = content;
this.user = user;
}

public MLModel(FunctionName algorithm, Model model) {
this(model.getName(), algorithm, model.getVersion(), Base64.getEncoder().encodeToString(model.getContent()), null);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (name != null) {
builder.field(MODEL_NAME, name);
}
if (algorithm != null) {
builder.field(ALGORITHM, algorithm);
}
if (version != null) {
builder.field(MODEL_VERSION, version);
}
if (content != null) {
builder.field(MODEL_CONTENT, content);
}
if (user != null) {
builder.field(USER, user);
}
builder.endObject();
return builder;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

package org.opensearch.ml.task;

import static org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL;
import static org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL_INDEX;
import static org.opensearch.ml.permission.AccessController.checkUserPermissions;
import static org.opensearch.ml.permission.AccessController.getUserContext;
import static org.opensearch.ml.plugin.MachineLearningPlugin.TASK_THREAD_POOL;
Expand Down Expand Up @@ -172,7 +172,7 @@ private void predict(
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
QueryBuilder queryBuilder = QueryBuilders.termQuery(TASK_ID, request.getModelId());
searchSourceBuilder.query(queryBuilder);
SearchRequest searchRequest = new SearchRequest(new String[] { ML_MODEL }, searchSourceBuilder);
SearchRequest searchRequest = new SearchRequest(new String[] { ML_MODEL_INDEX }, searchSourceBuilder);

// Search model.
client.search(searchRequest, ActionListener.wrap(searchResponse -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,25 @@

package org.opensearch.ml.task;

import static org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL;
import static org.opensearch.ml.permission.AccessController.getUserStr;
import static org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL_INDEX;
import static org.opensearch.ml.plugin.MachineLearningPlugin.TASK_THREAD_POOL;
import static org.opensearch.ml.stats.StatNames.ML_EXECUTING_TASK_COUNT;

import java.time.Instant;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;

import lombok.extern.log4j.Log4j2;

import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.support.ThreadedActionListener;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.ml.action.training.MLTrainingTaskExecutionAction;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
Expand All @@ -43,6 +43,7 @@
import org.opensearch.ml.engine.Model;
import org.opensearch.ml.indices.MLIndicesHandler;
import org.opensearch.ml.indices.MLInputDatasetHandler;
import org.opensearch.ml.model.MLModel;
import org.opensearch.ml.model.MLTask;
import org.opensearch.ml.model.MLTaskState;
import org.opensearch.ml.model.MLTaskType;
Expand Down Expand Up @@ -113,20 +114,19 @@ public void startTrainingTask(MLTrainingTaskRequest request, ActionListener<MLTr
.createTime(Instant.now())
.state(MLTaskState.CREATED)
.build();
// TODO: move this listener onResponse later to catch the following cases:
// 1). search data failure, 2) train model failure, 3) persist model failure.
MLTrainingOutput output = MLTrainingOutput.builder().modelId(mlTask.getTaskId()).status(MLTaskState.CREATED.name()).build();
listener.onResponse(MLTrainingTaskResponse.builder().output(output).build());
MLInput mlInput = request.getMlInput();
if (mlInput.getInputDataset().getInputDataType().equals(MLInputDataType.SEARCH_QUERY)) {
ActionListener<DataFrame> dataFrameActionListener = ActionListener
.wrap(
dataFrame -> { train(mlTask, mlInput.toBuilder().inputDataset(new DataFrameInputDataset(dataFrame)).build()); },
dataFrame -> {
train(mlTask, mlInput.toBuilder().inputDataset(new DataFrameInputDataset(dataFrame)).build(), listener);
},
e -> {
log.error("Failed to generate DataFrame from search query", e);
mlTaskManager.addIfAbsent(mlTask);
mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.FAILED);
mlTaskManager.updateTaskError(mlTask.getTaskId(), e.getMessage());
listener.onFailure(e);
}
);
mlInputDatasetHandler
Expand All @@ -135,37 +135,40 @@ public void startTrainingTask(MLTrainingTaskRequest request, ActionListener<MLTr
new ThreadedActionListener<>(log, threadPool, TASK_THREAD_POOL, dataFrameActionListener, false)
);
} else {
threadPool.executor(TASK_THREAD_POOL).execute(() -> { train(mlTask, mlInput); });
threadPool.executor(TASK_THREAD_POOL).execute(() -> { train(mlTask, mlInput, listener); });
}
}

private void train(MLTask mlTask, MLInput mlInput) {
// track ML task count and add ML task into cache
mlStats.getStat(ML_EXECUTING_TASK_COUNT.getName()).increment();
mlTaskManager.add(mlTask);
// run training
private void train(MLTask mlTask, MLInput mlInput, ActionListener<MLTrainingTaskResponse> listener) {
try {
// track ML task count and add ML task into cache
mlStats.getStat(ML_EXECUTING_TASK_COUNT.getName()).increment();
mlTaskManager.add(mlTask);
// run training
mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING);
Model model = MLEngine.train(mlInput);
String encodedModelContent = Base64.getEncoder().encodeToString(model.getContent());
mlIndicesHandler.initModelIndexIfAbsent();
Map<String, Object> source = new HashMap<>();
source.put(TASK_ID, mlTask.getTaskId());
source.put(ALGORITHM, mlInput.getAlgorithm());
source.put(MODEL_NAME, model.getName());
source.put(MODEL_VERSION, model.getVersion());
source.put(MODEL_CONTENT, encodedModelContent);

// put the user into model for backend role based access control.
source.put(USER, getUserStr(client));
// TODO: put the user into model for backend role based access control.
MLModel mlModel = new MLModel(mlInput.getAlgorithm(), model);

IndexResponse response = client.prepareIndex(ML_MODEL, "_doc").setSource(source).get();
log.info("mode data indexing done, result:{}", response.getResult());
handleMLTaskComplete(mlTask);
IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX)
.source(mlModel.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS))
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
client.index(indexRequest, ActionListener.wrap(r -> {
log.info("mode data indexing done, result:{}", r.getResult());
handleMLTaskComplete(mlTask);
MLTrainingOutput output = MLTrainingOutput.builder().modelId(r.getId()).status(MLTaskState.CREATED.name()).build();
listener.onResponse(MLTrainingTaskResponse.builder().output(output).build());
}, e -> {
handleMLTaskFailure(mlTask, e);
listener.onFailure(e);
}));
} catch (Exception e) {
// todo need to specify what exception
log.error("Failed to train " + mlInput.getAlgorithm(), e);
handleMLTaskFailure(mlTask, e);
listener.onFailure(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
package org.opensearch.ml.action.training;

import static org.opensearch.ml.utils.IntegTestUtils.DATA_FRAME_INPUT_DATASET;
import static org.opensearch.ml.utils.IntegTestUtils.ML_MODEL;
import static org.opensearch.ml.utils.IntegTestUtils.TESTING_DATA;
import static org.opensearch.ml.utils.IntegTestUtils.TESTING_INDEX_NAME;
import static org.opensearch.ml.utils.IntegTestUtils.generateMLTestingData;
Expand All @@ -31,16 +30,11 @@
import org.junit.Ignore;
import org.opensearch.action.ActionFuture;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.action.search.SearchAction;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
import org.opensearch.ml.common.parameter.FunctionName;
import org.opensearch.ml.common.parameter.MLInput;
import org.opensearch.ml.common.parameter.MLTrainingOutput;
import org.opensearch.ml.common.transport.training.MLTrainingTaskAction;
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest;
import org.opensearch.ml.common.transport.training.MLTrainingTaskResponse;
Expand Down Expand Up @@ -107,35 +101,6 @@ public void testTrainingWithEmptyDataset() throws InterruptedException {
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(inputDataset).build();
MLTrainingTaskRequest trainingRequest = new MLTrainingTaskRequest(mlInput);

ActionFuture<MLTrainingTaskResponse> trainingFuture = client().execute(MLTrainingTaskAction.INSTANCE, trainingRequest);
MLTrainingTaskResponse trainingResponse = trainingFuture.actionGet();

// The training taskId and status will be response to the client.
assertNotNull(trainingResponse);
MLTrainingOutput modelTrainingOutput = (MLTrainingOutput) trainingResponse.getOutput();
String modelId = modelTrainingOutput.getModelId();
String status = modelTrainingOutput.getStatus();
assertNotNull(modelId);
assertFalse(modelId.isEmpty());
assertEquals("CREATED", status);

SearchSourceBuilder modelSearchSourceBuilder = new SearchSourceBuilder();
QueryBuilder queryBuilder = QueryBuilders.termQuery("taskId", modelId);
modelSearchSourceBuilder.query(queryBuilder);
SearchRequest modelSearchRequest = new SearchRequest(new String[] { ML_MODEL }, modelSearchSourceBuilder);
SearchResponse modelSearchResponse = null;
int i = 0;
while ((modelSearchResponse == null || modelSearchResponse.getHits().getTotalHits().value == 0) && i < 100) {
try {
ActionFuture<SearchResponse> searchFuture = client().execute(SearchAction.INSTANCE, modelSearchRequest);
modelSearchResponse = searchFuture.actionGet();
} catch (Exception e) {} finally {
// Wait 100 ms until get valid search response or timeout.
Thread.sleep(100);
}
i++;
}
// No model would be trained successfully with empty dataset.
assertNull(modelSearchResponse);
expectThrows(IllegalArgumentException.class, () -> client().execute(MLTrainingTaskAction.INSTANCE, trainingRequest).actionGet());
}
}
44 changes: 44 additions & 0 deletions plugin/src/test/java/org/opensearch/ml/model/MLModelTests.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.ml.model;

import static org.junit.Assert.assertEquals;
import static org.opensearch.common.xcontent.ToXContent.EMPTY_PARAMS;

import java.io.IOException;

import org.junit.Test;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.ml.common.parameter.FunctionName;
import org.opensearch.ml.utils.TestHelper;

public class MLModelTests {

@Test
public void toXContent() throws IOException {
MLModel mlModel = MLModel.builder().algorithm(FunctionName.KMEANS).name("model_name").version(1).content("test_content").build();
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
mlModel.toXContent(builder, EMPTY_PARAMS);
String mlModelContent = TestHelper.xContentBuilderToString(builder);
assertEquals("{\"name\":\"model_name\",\"algorithm\":\"KMEANS\",\"version\":1,\"content\":\"test_content\"}", mlModelContent);
}

@Test
public void toXContent_NullValue() throws IOException {
MLModel mlModel = MLModel.builder().build();
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
mlModel.toXContent(builder, EMPTY_PARAMS);
String mlModelContent = TestHelper.xContentBuilderToString(builder);
assertEquals("{}", mlModelContent);
}
}
14 changes: 14 additions & 0 deletions plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,14 @@
import java.io.IOException;
import java.util.Collections;

import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.xcontent.ToXContentObject;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.search.SearchModule;
Expand All @@ -39,4 +44,13 @@ public static NamedXContentRegistry xContentRegistry() {
SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList());
return new NamedXContentRegistry(searchModule.getNamedXContents());
}

public static String toJsonString(ToXContentObject object) throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder();
return xContentBuilderToString(object.toXContent(builder, ToXContent.EMPTY_PARAMS));
}

public static String xContentBuilderToString(XContentBuilder builder) {
return BytesReference.bytes(builder).utf8ToString();
}
}

0 comments on commit e3fc904

Please sign in to comment.