Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor persisting ML model #109

Merged
merged 1 commit into from
Dec 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
@RequiredArgsConstructor
@Log4j2
public class MLIndicesHandler {
public static final String ML_MODEL = ".plugins-ml-model";
public static final String ML_MODEL_INDEX = ".plugins-ml-model";
private static final String ML_MODEL_INDEX_MAPPING = "{\n"
+ " \"properties\": {\n"
+ " \"task_id\": { \"type\": \"keyword\" },\n"
Expand All @@ -40,11 +40,11 @@ public class MLIndicesHandler {
Client client;

public void initModelIndexIfAbsent() {
initMLIndexIfAbsent(ML_MODEL, ML_MODEL_INDEX_MAPPING);
initMLIndexIfAbsent(ML_MODEL_INDEX, ML_MODEL_INDEX_MAPPING);
}

public boolean doesModelIndexExist() {
return clusterService.state().metadata().hasIndex(ML_MODEL);
return clusterService.state().metadata().hasIndex(ML_MODEL_INDEX);
}

private void initMLIndexIfAbsent(String indexName, String mapping) {
Expand Down
75 changes: 75 additions & 0 deletions plugin/src/main/java/org/opensearch/ml/model/MLModel.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.ml.model;

import java.io.IOException;
import java.util.Base64;

import lombok.Builder;
import lombok.Getter;

import org.opensearch.common.xcontent.ToXContentObject;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.commons.authuser.User;
import org.opensearch.ml.common.parameter.FunctionName;
import org.opensearch.ml.engine.Model;

@Getter
public class MLModel implements ToXContentObject {
public static final String ALGORITHM = "algorithm";
public static final String MODEL_NAME = "name";
public static final String MODEL_VERSION = "version";
public static final String MODEL_CONTENT = "content";
public static final String USER = "user";

private String name;
private FunctionName algorithm;
private Integer version;
private String content;
private User user;

@Builder
public MLModel(String name, FunctionName algorithm, Integer version, String content, User user) {
this.name = name;
this.algorithm = algorithm;
this.version = version;
this.content = content;
this.user = user;
}

public MLModel(FunctionName algorithm, Model model) {
this(model.getName(), algorithm, model.getVersion(), Base64.getEncoder().encodeToString(model.getContent()), null);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (name != null) {
builder.field(MODEL_NAME, name);
}
if (algorithm != null) {
builder.field(ALGORITHM, algorithm);
}
if (version != null) {
builder.field(MODEL_VERSION, version);
}
if (content != null) {
builder.field(MODEL_CONTENT, content);
}
if (user != null) {
builder.field(USER, user);
}
builder.endObject();
return builder;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

package org.opensearch.ml.task;

import static org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL;
import static org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL_INDEX;
import static org.opensearch.ml.permission.AccessController.checkUserPermissions;
import static org.opensearch.ml.permission.AccessController.getUserContext;
import static org.opensearch.ml.plugin.MachineLearningPlugin.TASK_THREAD_POOL;
Expand Down Expand Up @@ -172,7 +172,7 @@ private void predict(
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
QueryBuilder queryBuilder = QueryBuilders.termQuery(TASK_ID, request.getModelId());
searchSourceBuilder.query(queryBuilder);
SearchRequest searchRequest = new SearchRequest(new String[] { ML_MODEL }, searchSourceBuilder);
SearchRequest searchRequest = new SearchRequest(new String[] { ML_MODEL_INDEX }, searchSourceBuilder);

// Search model.
client.search(searchRequest, ActionListener.wrap(searchResponse -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,25 @@

package org.opensearch.ml.task;

import static org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL;
import static org.opensearch.ml.permission.AccessController.getUserStr;
import static org.opensearch.ml.indices.MLIndicesHandler.ML_MODEL_INDEX;
import static org.opensearch.ml.plugin.MachineLearningPlugin.TASK_THREAD_POOL;
import static org.opensearch.ml.stats.StatNames.ML_EXECUTING_TASK_COUNT;

import java.time.Instant;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;

import lombok.extern.log4j.Log4j2;

import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.support.ThreadedActionListener;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.ml.action.training.MLTrainingTaskExecutionAction;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
Expand All @@ -43,6 +43,7 @@
import org.opensearch.ml.engine.Model;
import org.opensearch.ml.indices.MLIndicesHandler;
import org.opensearch.ml.indices.MLInputDatasetHandler;
import org.opensearch.ml.model.MLModel;
import org.opensearch.ml.model.MLTask;
import org.opensearch.ml.model.MLTaskState;
import org.opensearch.ml.model.MLTaskType;
Expand Down Expand Up @@ -113,20 +114,19 @@ public void startTrainingTask(MLTrainingTaskRequest request, ActionListener<MLTr
.createTime(Instant.now())
.state(MLTaskState.CREATED)
.build();
// TODO: move this listener onResponse later to catch the following cases:
// 1). search data failure, 2) train model failure, 3) persist model failure.
MLTrainingOutput output = MLTrainingOutput.builder().modelId(mlTask.getTaskId()).status(MLTaskState.CREATED.name()).build();
listener.onResponse(MLTrainingTaskResponse.builder().output(output).build());
MLInput mlInput = request.getMlInput();
if (mlInput.getInputDataset().getInputDataType().equals(MLInputDataType.SEARCH_QUERY)) {
ActionListener<DataFrame> dataFrameActionListener = ActionListener
.wrap(
dataFrame -> { train(mlTask, mlInput.toBuilder().inputDataset(new DataFrameInputDataset(dataFrame)).build()); },
dataFrame -> {
train(mlTask, mlInput.toBuilder().inputDataset(new DataFrameInputDataset(dataFrame)).build(), listener);
},
e -> {
log.error("Failed to generate DataFrame from search query", e);
mlTaskManager.addIfAbsent(mlTask);
mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.FAILED);
mlTaskManager.updateTaskError(mlTask.getTaskId(), e.getMessage());
listener.onFailure(e);
}
);
mlInputDatasetHandler
Expand All @@ -135,37 +135,40 @@ public void startTrainingTask(MLTrainingTaskRequest request, ActionListener<MLTr
new ThreadedActionListener<>(log, threadPool, TASK_THREAD_POOL, dataFrameActionListener, false)
);
} else {
threadPool.executor(TASK_THREAD_POOL).execute(() -> { train(mlTask, mlInput); });
threadPool.executor(TASK_THREAD_POOL).execute(() -> { train(mlTask, mlInput, listener); });
}
}

private void train(MLTask mlTask, MLInput mlInput) {
// track ML task count and add ML task into cache
mlStats.getStat(ML_EXECUTING_TASK_COUNT.getName()).increment();
mlTaskManager.add(mlTask);
// run training
private void train(MLTask mlTask, MLInput mlInput, ActionListener<MLTrainingTaskResponse> listener) {
try {
// track ML task count and add ML task into cache
mlStats.getStat(ML_EXECUTING_TASK_COUNT.getName()).increment();
mlTaskManager.add(mlTask);
// run training
mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING);
Model model = MLEngine.train(mlInput);
String encodedModelContent = Base64.getEncoder().encodeToString(model.getContent());
mlIndicesHandler.initModelIndexIfAbsent();
Map<String, Object> source = new HashMap<>();
source.put(TASK_ID, mlTask.getTaskId());
source.put(ALGORITHM, mlInput.getAlgorithm());
source.put(MODEL_NAME, model.getName());
source.put(MODEL_VERSION, model.getVersion());
source.put(MODEL_CONTENT, encodedModelContent);

// put the user into model for backend role based access control.
source.put(USER, getUserStr(client));
// TODO: put the user into model for backend role based access control.
MLModel mlModel = new MLModel(mlInput.getAlgorithm(), model);

IndexResponse response = client.prepareIndex(ML_MODEL, "_doc").setSource(source).get();
log.info("mode data indexing done, result:{}", response.getResult());
handleMLTaskComplete(mlTask);
IndexRequest indexRequest = new IndexRequest(ML_MODEL_INDEX)
.source(mlModel.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS))
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
client.index(indexRequest, ActionListener.wrap(r -> {
log.info("mode data indexing done, result:{}", r.getResult());
handleMLTaskComplete(mlTask);
MLTrainingOutput output = MLTrainingOutput.builder().modelId(r.getId()).status(MLTaskState.CREATED.name()).build();
listener.onResponse(MLTrainingTaskResponse.builder().output(output).build());
}, e -> {
handleMLTaskFailure(mlTask, e);
listener.onFailure(e);
}));
} catch (Exception e) {
// todo need to specify what exception
log.error("Failed to train " + mlInput.getAlgorithm(), e);
handleMLTaskFailure(mlTask, e);
listener.onFailure(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
package org.opensearch.ml.action.training;

import static org.opensearch.ml.utils.IntegTestUtils.DATA_FRAME_INPUT_DATASET;
import static org.opensearch.ml.utils.IntegTestUtils.ML_MODEL;
import static org.opensearch.ml.utils.IntegTestUtils.TESTING_DATA;
import static org.opensearch.ml.utils.IntegTestUtils.TESTING_INDEX_NAME;
import static org.opensearch.ml.utils.IntegTestUtils.generateMLTestingData;
Expand All @@ -31,16 +30,11 @@
import org.junit.Ignore;
import org.opensearch.action.ActionFuture;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.action.search.SearchAction;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
import org.opensearch.ml.common.parameter.FunctionName;
import org.opensearch.ml.common.parameter.MLInput;
import org.opensearch.ml.common.parameter.MLTrainingOutput;
import org.opensearch.ml.common.transport.training.MLTrainingTaskAction;
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest;
import org.opensearch.ml.common.transport.training.MLTrainingTaskResponse;
Expand Down Expand Up @@ -107,35 +101,6 @@ public void testTrainingWithEmptyDataset() throws InterruptedException {
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(inputDataset).build();
MLTrainingTaskRequest trainingRequest = new MLTrainingTaskRequest(mlInput);

ActionFuture<MLTrainingTaskResponse> trainingFuture = client().execute(MLTrainingTaskAction.INSTANCE, trainingRequest);
MLTrainingTaskResponse trainingResponse = trainingFuture.actionGet();

// The training taskId and status will be response to the client.
assertNotNull(trainingResponse);
MLTrainingOutput modelTrainingOutput = (MLTrainingOutput) trainingResponse.getOutput();
String modelId = modelTrainingOutput.getModelId();
String status = modelTrainingOutput.getStatus();
assertNotNull(modelId);
assertFalse(modelId.isEmpty());
assertEquals("CREATED", status);

SearchSourceBuilder modelSearchSourceBuilder = new SearchSourceBuilder();
QueryBuilder queryBuilder = QueryBuilders.termQuery("taskId", modelId);
modelSearchSourceBuilder.query(queryBuilder);
SearchRequest modelSearchRequest = new SearchRequest(new String[] { ML_MODEL }, modelSearchSourceBuilder);
SearchResponse modelSearchResponse = null;
int i = 0;
while ((modelSearchResponse == null || modelSearchResponse.getHits().getTotalHits().value == 0) && i < 100) {
try {
ActionFuture<SearchResponse> searchFuture = client().execute(SearchAction.INSTANCE, modelSearchRequest);
modelSearchResponse = searchFuture.actionGet();
} catch (Exception e) {} finally {
// Wait 100 ms until get valid search response or timeout.
Thread.sleep(100);
}
i++;
}
// No model would be trained successfully with empty dataset.
assertNull(modelSearchResponse);
expectThrows(IllegalArgumentException.class, () -> client().execute(MLTrainingTaskAction.INSTANCE, trainingRequest).actionGet());
}
}
44 changes: 44 additions & 0 deletions plugin/src/test/java/org/opensearch/ml/model/MLModelTests.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.ml.model;

import static org.junit.Assert.assertEquals;
import static org.opensearch.common.xcontent.ToXContent.EMPTY_PARAMS;

import java.io.IOException;

import org.junit.Test;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.ml.common.parameter.FunctionName;
import org.opensearch.ml.utils.TestHelper;

public class MLModelTests {

@Test
public void toXContent() throws IOException {
MLModel mlModel = MLModel.builder().algorithm(FunctionName.KMEANS).name("model_name").version(1).content("test_content").build();
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
mlModel.toXContent(builder, EMPTY_PARAMS);
String mlModelContent = TestHelper.xContentBuilderToString(builder);
assertEquals("{\"name\":\"model_name\",\"algorithm\":\"KMEANS\",\"version\":1,\"content\":\"test_content\"}", mlModelContent);
}

@Test
public void toXContent_NullValue() throws IOException {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe have another case for empty value?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add in next pr

MLModel mlModel = MLModel.builder().build();
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
mlModel.toXContent(builder, EMPTY_PARAMS);
String mlModelContent = TestHelper.xContentBuilderToString(builder);
assertEquals("{}", mlModelContent);
}
}
14 changes: 14 additions & 0 deletions plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,14 @@
import java.io.IOException;
import java.util.Collections;

import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.xcontent.ToXContentObject;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.search.SearchModule;
Expand All @@ -39,4 +44,13 @@ public static NamedXContentRegistry xContentRegistry() {
SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList());
return new NamedXContentRegistry(searchModule.getNamedXContents());
}

public static String toJsonString(ToXContentObject object) throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder();
return xContentBuilderToString(object.toXContent(builder, ToXContent.EMPTY_PARAMS));
}

public static String xContentBuilderToString(XContentBuilder builder) {
return BytesReference.bytes(builder).utf8ToString();
}
}