diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java index 8038655e90..91d8a651af 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java @@ -19,6 +19,9 @@ import org.opensearch.action.support.PlainActionFuture; import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataset.DataFrameInputDataset; +import org.opensearch.ml.common.dataset.MLInputDataType; +import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.parameter.MLParameter; /** @@ -29,7 +32,7 @@ public interface MachineLearningClient { /** * Do prediction machine learning job * @param algorithm algorithm name - * @param inputData input data set + * @param inputData input data frame * @return the result future */ default ActionFuture predict(String algorithm, DataFrame inputData) { @@ -40,7 +43,7 @@ default ActionFuture predict(String algorithm, DataFrame inputData) { * Do prediction machine learning job * @param algorithm algorithm name * @param parameters parameters of ml algorithm - * @param inputData input data set + * @param inputData input data frame * @return the result future */ default ActionFuture predict(String algorithm, List parameters, DataFrame inputData) { @@ -51,7 +54,7 @@ default ActionFuture predict(String algorithm, List para * Do prediction machine learning job * @param algorithm algorithm name * @param parameters parameters of ml algorithm - * @param inputData input data set + * @param inputData input data frame * @param modelId the trained model id * @return the result future */ @@ -64,7 +67,7 @@ default ActionFuture predict(String algorithm, List para /** * Do prediction machine learning job * @param algorithm algorithm name - * @param inputData input data set + * @param inputData input data frame * @param listener a listener to be notified of the result */ default void predict(String algorithm, DataFrame inputData, ActionListener listener) { @@ -75,13 +78,25 @@ default void predict(String algorithm, DataFrame inputData, ActionListener parameters, DataFrame inputData, ActionListener listener){ predict(algorithm, parameters, inputData, null, listener); } + /** + * Do prediction machine learning job + * @param algorithm algorithm name + * @param parameters parameters of ml algorithm + * @param inputData input data frame + * @param modelId the trained model id + * @param listener a listener to be notified of the result + */ + default void predict(String algorithm, List parameters, DataFrame inputData, String modelId, ActionListener listener) { + predict(algorithm, parameters, DataFrameInputDataset.builder().dataFrame(inputData).build(), modelId, listener); + } + /** * Do prediction machine learning job * @param algorithm algorithm name @@ -90,13 +105,13 @@ default void predict(String algorithm, List parameters, DataFrame i * @param modelId the trained model id * @param listener a listener to be notified of the result */ - void predict(String algorithm, List parameters, DataFrame inputData, String modelId, ActionListener listener); + void predict(String algorithm, List parameters, MLInputDataset inputData, String modelId, ActionListener listener); /** * Do the training machine learning job. The training job will be always async process. The job id will be returned in this method. * @param algorithm algorithm name * @param parameters parameters of ml algorithm - * @param inputData input data set + * @param inputData input data frame * @return the result future */ default ActionFuture train(String algorithm, List parameters, DataFrame inputData) { @@ -105,6 +120,18 @@ default ActionFuture train(String algorithm, List parameter return actionFuture; } + /** + * Do the training machine learning job. The training job will be always async process. The job id will be returned in this method. + * @param algorithm algorithm name + * @param parameters parameters of ml algorithm + * @param inputData input data frame + * @param listener a listener to be notified of the result + */ + default void train(String algorithm, List parameters, DataFrame inputData, ActionListener listener) { + train(algorithm, parameters, DataFrameInputDataset.builder().dataFrame(inputData).build(), listener); + } + + /** * Do the training machine learning job. The training job will be always async process. The job id will be returned in this method. * @param algorithm algorithm name @@ -112,6 +139,6 @@ default ActionFuture train(String algorithm, List parameter * @param inputData input data set * @param listener a listener to be notified of the result */ - void train(String algorithm, List parameters, DataFrame inputData, ActionListener listener); + void train(String algorithm, List parameters, MLInputDataset inputData, ActionListener listener); } diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java index 84325b7cae..37bfd4b462 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java @@ -20,6 +20,7 @@ import org.opensearch.common.Strings; import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.parameter.MLParameter; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; @@ -38,20 +39,20 @@ public class MachineLearningNodeClient implements MachineLearningClient { NodeClient client; @Override - public void predict(String algorithm, List parameters, DataFrame inputData, String modelId, + public void predict(String algorithm, List parameters, MLInputDataset inputData, String modelId, ActionListener listener) { if(Strings.isNullOrEmpty(algorithm)) { throw new IllegalArgumentException("algorithm name can't be null or empty"); } - if(Objects.isNull(inputData) || inputData.size() <= 0) { - throw new IllegalArgumentException("input data frame can't be null or empty"); + if(Objects.isNull(inputData)) { + throw new IllegalArgumentException("input data set can't be null"); } MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest.builder() .algorithm(algorithm) .modelId(modelId) .parameters(parameters) - .dataFrame(inputData) + .inputDataset(inputData) .build(); client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, ActionListener.wrap(response -> { @@ -64,17 +65,17 @@ public void predict(String algorithm, List parameters, DataFrame in } @Override - public void train(String algorithm, List parameters, DataFrame inputData, ActionListener listener) { + public void train(String algorithm, List parameters, MLInputDataset inputData, ActionListener listener) { if(Strings.isNullOrEmpty(algorithm)) { throw new IllegalArgumentException("algorithm name can't be null or empty"); } - if(Objects.isNull(inputData) || inputData.size() <= 0) { - throw new IllegalArgumentException("input data frame can't be null or empty"); + if(Objects.isNull(inputData)) { + throw new IllegalArgumentException("input data set can't be null"); } MLTrainingTaskRequest trainingTaskRequest = MLTrainingTaskRequest.builder() .algorithm(algorithm) - .dataFrame(inputData) + .inputDataset(inputData) .parameters(parameters) .build(); diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java index 88689ed2c2..fb93f5032e 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java @@ -21,6 +21,7 @@ import org.mockito.MockitoAnnotations; import org.opensearch.action.ActionListener; import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.parameter.MLParameter; import static org.junit.Assert.assertEquals; @@ -49,13 +50,13 @@ public void setUp() throws Exception { machineLearningClient = new MachineLearningClient() { @Override - public void predict(String algorithm, List parameters, DataFrame inputData, String modelId, + public void predict(String algorithm, List parameters, MLInputDataset inputData, String modelId, ActionListener listener) { listener.onResponse(output); } @Override - public void train(String algorithm, List parameters, DataFrame inputData, + public void train(String algorithm, List parameters, MLInputDataset inputData, ActionListener listener) { listener.onResponse("taskId"); } diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java index 112978034f..da68021a33 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java @@ -22,6 +22,7 @@ import org.opensearch.action.ActionListener; import org.opensearch.client.node.NodeClient; import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskResponse; @@ -44,7 +45,7 @@ public class MachineLearningNodeClientTest { NodeClient client; @Mock - DataFrame input; + MLInputDataset input; @Mock DataFrame output; @@ -64,7 +65,6 @@ public class MachineLearningNodeClientTest { @Before public void setUp() throws Exception { MockitoAnnotations.openMocks(this); - when(input.size()).thenReturn(1); } @Test @@ -96,18 +96,10 @@ public void predict_Exception_WithNullAlgorithm() { } @Test - public void predict_Exception_WithEmptyDataFrame() { + public void predict_Exception_WithNullDataSet() { exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("input data frame can't be null or empty"); - when(input.size()).thenReturn(0); - machineLearningNodeClient.predict("algo", null, input, null, dataFrameActionListener); - } - - @Test - public void predict_Exception_WithNullDataFrame() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("input data frame can't be null or empty"); - machineLearningNodeClient.predict("algo", null, null, null, dataFrameActionListener); + exceptionRule.expectMessage("input data set can't be null"); + machineLearningNodeClient.predict("algo", null, (MLInputDataset) null, null, dataFrameActionListener); } @Test @@ -139,17 +131,9 @@ public void train_Exception_WithNullAlgorithm() { } @Test - public void train_Exception_WithEmptyDataFrame() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("input data frame can't be null or empty"); - when(input.size()).thenReturn(0); - machineLearningNodeClient.train("algo", null, input, trainingActionListener); - } - - @Test - public void train_Exception_WithNullDataFrame() { + public void train_Exception_WithNullDataSet() { exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("input data frame can't be null or empty"); - machineLearningNodeClient.train("algo", null, null, trainingActionListener); + exceptionRule.expectMessage("input data set can't be null"); + machineLearningNodeClient.train("algo", null, (MLInputDataset)null, trainingActionListener); } } \ No newline at end of file diff --git a/common/build.gradle b/common/build.gradle index 0f88fee859..47707ad02c 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -19,7 +19,6 @@ plugins { dependencies { compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}-SNAPSHOT" testCompile group: 'junit', name: 'junit', version: '4.12' - } jacocoTestReport { @@ -46,4 +45,6 @@ jacocoTestCoverageVerification { } dependsOn jacocoTestReport } -check.dependsOn jacocoTestCoverageVerification \ No newline at end of file +check.dependsOn jacocoTestCoverageVerification + +lombok.config['lombok.nonNull.exceptionType'] = 'JDK' \ No newline at end of file diff --git a/common/lombok.config b/common/lombok.config index 189c0bef98..3837e1974e 100644 --- a/common/lombok.config +++ b/common/lombok.config @@ -1,3 +1,4 @@ # This file is generated by the 'io.freefair.lombok' Gradle plugin config.stopBubbling = true lombok.addLombokGeneratedAnnotation = true +lombok.nonNull.exceptionType = JDK diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/DataFrameInputDataset.java b/common/src/main/java/org/opensearch/ml/common/dataset/DataFrameInputDataset.java new file mode 100644 index 0000000000..64c8579dfc --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/dataset/DataFrameInputDataset.java @@ -0,0 +1,44 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ml.common.dataset; + +import java.io.IOException; + +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.ml.common.dataframe.DataFrame; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; +import lombok.experimental.FieldDefaults; + +/** + * DataFrame based input data. Client directly passes the data frame to ML plugin with this. + */ +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +public class DataFrameInputDataset extends MLInputDataset { + DataFrame dataFrame; + + @Builder + public DataFrameInputDataset(@NonNull DataFrame dataFrame) { + super(MLInputDataType.DATA_FRAME); + this.dataFrame = dataFrame; + } + + @Override + public void writeTo(StreamOutput streamOutput) throws IOException { + super.writeTo(streamOutput); + dataFrame.writeTo(streamOutput); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataType.java b/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataType.java new file mode 100644 index 0000000000..1ab387a2cc --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataType.java @@ -0,0 +1,17 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ml.common.dataset; + +public enum MLInputDataType { + SEARCH_QUERY, + DATA_FRAME; +} diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataset.java b/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataset.java new file mode 100644 index 0000000000..766a9f0171 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataset.java @@ -0,0 +1,34 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ml.common.dataset; + +import java.io.IOException; + +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.io.stream.Writeable; + +import lombok.AccessLevel; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.experimental.FieldDefaults; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@RequiredArgsConstructor +public abstract class MLInputDataset implements Writeable { + MLInputDataType inputDataType; + + @Override + public void writeTo(StreamOutput streamOutput) throws IOException { + streamOutput.writeEnum(this.inputDataType); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDatasetReader.java b/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDatasetReader.java new file mode 100644 index 0000000000..314faa17df --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDatasetReader.java @@ -0,0 +1,34 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ml.common.dataset; + +import java.io.IOException; + +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.Writeable; +import org.opensearch.ml.common.dataframe.DataFrameBuilder; + +public class MLInputDatasetReader implements Writeable.Reader { + @Override + public MLInputDataset read(StreamInput streamInput) throws IOException { + MLInputDataType inputDataType = streamInput.readEnum(MLInputDataType.class); + switch (inputDataType) { + case DATA_FRAME: + return new DataFrameInputDataset(DataFrameBuilder.load(streamInput)); + case SEARCH_QUERY: + return new SearchQueryInputDataset(streamInput); + default: + throw new IllegalArgumentException("unknown input data type:" + inputDataType); + } + + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/SearchQueryInputDataset.java b/common/src/main/java/org/opensearch/ml/common/dataset/SearchQueryInputDataset.java new file mode 100644 index 0000000000..e88c909f64 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/dataset/SearchQueryInputDataset.java @@ -0,0 +1,62 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ml.common.dataset; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.search.builder.SearchSourceBuilder; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; +import lombok.experimental.FieldDefaults; + +/** + * Search query based input data. The client just need give the search query, and ML plugin will read the data based on it, + * and build the data frame for algorithm execution. + */ +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +public class SearchQueryInputDataset extends MLInputDataset { + + SearchSourceBuilder searchSourceBuilder; + + List indices; + + @Builder + public SearchQueryInputDataset(@NonNull List indices, @NonNull SearchSourceBuilder searchSourceBuilder) { + super(MLInputDataType.SEARCH_QUERY); + if (indices.isEmpty()) { + throw new IllegalArgumentException("indices can't be empty"); + } + + this.indices = indices; + this.searchSourceBuilder = searchSourceBuilder; + } + + public SearchQueryInputDataset(StreamInput streaminput) throws IOException { + super(MLInputDataType.SEARCH_QUERY); + this.searchSourceBuilder = new SearchSourceBuilder(streaminput); + this.indices = streaminput.readStringList(); + } + + @Override + public void writeTo(StreamOutput streamOutput) throws IOException { + super.writeTo(streamOutput); + searchSourceBuilder.writeTo(streamOutput); + streamOutput.writeStringCollection(indices); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java index 1a3f754e57..5e945eefd3 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java @@ -26,8 +26,8 @@ import org.opensearch.common.io.stream.OutputStreamStreamOutput; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.ml.common.dataframe.DataFrame; -import org.opensearch.ml.common.dataframe.DataFrameBuilder; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.MLInputDatasetReader; import org.opensearch.ml.common.parameter.MLParameter; import lombok.AccessLevel; @@ -57,7 +57,7 @@ public class MLPredictionTaskRequest extends ActionRequest { * input data set */ @ToString.Exclude - DataFrame dataFrame; + MLInputDataset inputDataset; /** * Trained model id @@ -71,11 +71,11 @@ public class MLPredictionTaskRequest extends ActionRequest { @Builder public MLPredictionTaskRequest(String algorithm, List parameters, - String modelId, DataFrame dataFrame) { + String modelId, MLInputDataset inputDataset) { this.algorithm = algorithm; this.parameters = parameters; this.modelId = modelId; - this.dataFrame = dataFrame; + this.inputDataset = inputDataset; this.version = 1; } @@ -85,7 +85,7 @@ public MLPredictionTaskRequest(StreamInput in) throws IOException { this.algorithm = in.readString(); this.parameters = in.readList(MLParameter::new); this.modelId = in.readOptionalString(); - this.dataFrame = DataFrameBuilder.load(in); + this.inputDataset = new MLInputDatasetReader().read(in); } @Override @@ -95,7 +95,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(this.algorithm); out.writeList(this.parameters); out.writeOptionalString(this.modelId); - this.dataFrame.writeTo(out); + this.inputDataset.writeTo(out); } @Override @@ -104,8 +104,8 @@ public ActionRequestValidationException validate() { if(Strings.isNullOrEmpty(this.algorithm)) { exception = addValidationError("algorithm name can't be null or empty", exception); } - if(Objects.isNull(this.dataFrame) || this.dataFrame.size() < 1) { - exception = addValidationError("input data can't be null or empty", exception); + if(Objects.isNull(this.inputDataset)) { + exception = addValidationError("input data can't be null", exception); } return exception; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequest.java index 55050e506f..d1d9aff4d2 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequest.java @@ -26,8 +26,8 @@ import org.opensearch.common.io.stream.OutputStreamStreamOutput; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.ml.common.dataframe.DataFrame; -import org.opensearch.ml.common.dataframe.DataFrameBuilder; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.MLInputDatasetReader; import org.opensearch.ml.common.parameter.MLParameter; import lombok.AccessLevel; @@ -57,7 +57,7 @@ public class MLTrainingTaskRequest extends ActionRequest { * input data set */ @ToString.Exclude - DataFrame dataFrame; + MLInputDataset inputDataset; /** * version id, in case there is future schema change. This can be used to detect which version the client is using. @@ -65,10 +65,10 @@ public class MLTrainingTaskRequest extends ActionRequest { int version; @Builder - public MLTrainingTaskRequest(String algorithm, List parameters, DataFrame dataFrame) { + public MLTrainingTaskRequest(String algorithm, List parameters, MLInputDataset inputDataset) { this.algorithm = algorithm; this.parameters = parameters; - this.dataFrame = dataFrame; + this.inputDataset = inputDataset; this.version = 1; } @@ -78,7 +78,7 @@ public MLTrainingTaskRequest(StreamInput in) throws IOException { this.algorithm = in.readString(); this.parameters = in.readList(MLParameter::new); - this.dataFrame = DataFrameBuilder.load(in); + this.inputDataset = new MLInputDatasetReader().read(in); } @Override @@ -87,8 +87,8 @@ public ActionRequestValidationException validate() { if(Strings.isNullOrEmpty(this.algorithm)) { exception = addValidationError("algorithm name can't be null or empty", exception); } - if(Objects.isNull(this.dataFrame) || this.dataFrame.size() < 1) { - exception = addValidationError("input data can't be null or empty", exception); + if(Objects.isNull(this.inputDataset)) { + exception = addValidationError("input data can't be null", exception); } return exception; @@ -101,7 +101,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(this.algorithm); out.writeList(this.parameters); - this.dataFrame.writeTo(out); + this.inputDataset.writeTo(out); } public static MLTrainingTaskRequest fromActionRequest(ActionRequest actionRequest) { diff --git a/common/src/test/java/org/opensearch/ml/common/dataset/DataFrameInputDatasetTest.java b/common/src/test/java/org/opensearch/ml/common/dataset/DataFrameInputDatasetTest.java new file mode 100644 index 0000000000..522771ce51 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/dataset/DataFrameInputDatasetTest.java @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ml.common.dataset; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; + +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.ml.common.dataframe.DataFrameBuilder; + +import static org.junit.Assert.assertEquals; + +public class DataFrameInputDatasetTest { + + @Test + public void writeTo_Success() throws IOException { + DataFrameInputDataset dataFrameInputDataset = DataFrameInputDataset.builder() + .dataFrame(DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ + put("key1", 2.0D); + }}))) + .build(); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + dataFrameInputDataset.writeTo(bytesStreamOutput); + assertEquals(21, bytesStreamOutput.size()); + } +} \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/dataset/MLInputDatasetReaderTest.java b/common/src/test/java/org/opensearch/ml/common/dataset/MLInputDatasetReaderTest.java new file mode 100644 index 0000000000..c8c4adbb59 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/dataset/MLInputDatasetReaderTest.java @@ -0,0 +1,57 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ml.common.dataset; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; + +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.ml.common.dataframe.DataFrameBuilder; +import org.opensearch.search.builder.SearchSourceBuilder; + +import static org.junit.Assert.assertEquals; + +public class MLInputDatasetReaderTest { + + MLInputDatasetReader mlInputDatasetReader = new MLInputDatasetReader(); + + @Test + public void read_Success_DataFrameInputDataset() throws IOException { + DataFrameInputDataset dataFrameInputDataset = DataFrameInputDataset.builder() + .dataFrame(DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ + put("key1", 2.0D); + }}))) + .build(); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + dataFrameInputDataset.writeTo(bytesStreamOutput); + MLInputDataset inputDataset = mlInputDatasetReader.read(bytesStreamOutput.bytes().streamInput()); + assertEquals(MLInputDataType.DATA_FRAME, inputDataset.getInputDataType()); + assertEquals(1, ((DataFrameInputDataset) inputDataset).getDataFrame().size()); + } + + @Test + public void read_Success_SearchQueryInputDataset() throws IOException { + SearchQueryInputDataset searchQueryInputDataset = SearchQueryInputDataset.builder() + .indices(Arrays.asList("index1")) + .searchSourceBuilder(new SearchSourceBuilder().size(1)) + .build(); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + searchQueryInputDataset.writeTo(bytesStreamOutput); + MLInputDataset inputDataset = mlInputDatasetReader.read(bytesStreamOutput.bytes().streamInput()); + assertEquals(MLInputDataType.SEARCH_QUERY, inputDataset.getInputDataType()); + searchQueryInputDataset = (SearchQueryInputDataset) inputDataset; + assertEquals(Arrays.asList("index1"), searchQueryInputDataset.getIndices()); + } +} \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/dataset/SearchQueryInputDatasetTest.java b/common/src/test/java/org/opensearch/ml/common/dataset/SearchQueryInputDatasetTest.java new file mode 100644 index 0000000000..bd21f8fcdf --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/dataset/SearchQueryInputDatasetTest.java @@ -0,0 +1,57 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ml.common.dataset; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.search.builder.SearchSourceBuilder; + +import static org.junit.Assert.assertEquals; + +public class SearchQueryInputDatasetTest { + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Test + public void writeTo_Success() throws IOException { + SearchQueryInputDataset searchQueryInputDataset = SearchQueryInputDataset.builder() + .indices(Arrays.asList("index1")) + .searchSourceBuilder(new SearchSourceBuilder().size(1)) + .build(); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + searchQueryInputDataset.writeTo(bytesStreamOutput); + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLInputDataType inputDataType = streamInput.readEnum(MLInputDataType.class); + assertEquals(MLInputDataType.SEARCH_QUERY, inputDataType); + searchQueryInputDataset = new SearchQueryInputDataset(streamInput); + assertEquals(1, searchQueryInputDataset.getIndices().size()); + assertEquals(1, searchQueryInputDataset.getSearchSourceBuilder().size()); + } + + @Test + public void init_EmptyIndices() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("indices can't be empty"); + SearchQueryInputDataset.builder() + .indices(new ArrayList<>()) + .searchSourceBuilder(new SearchSourceBuilder().size(1)) + .build(); + } +} \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequestTest.java index 4d992fb865..996f1e1c35 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequestTest.java @@ -17,15 +17,13 @@ import java.util.Collections; import java.util.HashMap; -import org.junit.Before; import org.junit.Test; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.ml.common.dataframe.ColumnMeta; -import org.opensearch.ml.common.dataframe.ColumnType; import org.opensearch.ml.common.dataframe.DataFrameBuilder; +import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.parameter.MLParameterBuilder; import static org.junit.Assert.assertEquals; @@ -37,16 +35,19 @@ public class MLPredictionTaskRequestTest { @Test public void writeTo_Success() throws IOException { + MLPredictionTaskRequest request = MLPredictionTaskRequest.builder() - .algorithm("algo") - .parameters(Collections.singletonList(MLParameterBuilder.parameter("k1", 1))) + .algorithm("algo") + .parameters(Collections.singletonList(MLParameterBuilder.parameter("k1", 1))) + .inputDataset(DataFrameInputDataset.builder() .dataFrame(DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ put("key1", 2.0D); }}))) - .build(); + .build()) + .build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); request.writeTo(bytesStreamOutput); - assertEquals(40, bytesStreamOutput.size()); + assertEquals(41, bytesStreamOutput.size()); request = new MLPredictionTaskRequest(bytesStreamOutput.bytes().streamInput()); assertEquals("algo", request.getAlgorithm()); assertEquals(1, request.getParameters().size()); @@ -56,12 +57,14 @@ public void writeTo_Success() throws IOException { @Test public void validate_Success() { MLPredictionTaskRequest request = MLPredictionTaskRequest.builder() - .algorithm("algo") - .parameters(Collections.singletonList(MLParameterBuilder.parameter("k1", 1))) + .algorithm("algo") + .parameters(Collections.singletonList(MLParameterBuilder.parameter("k1", 1))) + .inputDataset(DataFrameInputDataset.builder() .dataFrame(DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ put("key1", 2.0D); }}))) - .build(); + .build()) + .build(); assertNull(request.validate()); } @@ -69,70 +72,59 @@ public void validate_Success() { @Test public void validate_Exception_NullAlgorithmName() { MLPredictionTaskRequest request = MLPredictionTaskRequest.builder() - .algorithm(null) - .parameters(Collections.singletonList(MLParameterBuilder.parameter("k1", 1))) + .algorithm(null) + .parameters(Collections.singletonList(MLParameterBuilder.parameter("k1", 1))) + .inputDataset(DataFrameInputDataset.builder() .dataFrame(DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ put("key1", 2.0D); }}))) - .build(); + .build()) + .build(); ActionRequestValidationException exception = request.validate(); assertEquals("Validation Failed: 1: algorithm name can't be null or empty;", exception.getMessage()); } - @Test - public void validate_Exception_EmptyDataFrame() { - MLPredictionTaskRequest request = MLPredictionTaskRequest.builder() - .algorithm("algo") - .parameters(Collections.singletonList(MLParameterBuilder.parameter("k1", 1))) - .dataFrame(DataFrameBuilder.emptyDataFrame(new ColumnMeta[]{ - ColumnMeta.builder() - .name("name") - .columnType(ColumnType.DOUBLE) - .build() - })) - .build(); - - ActionRequestValidationException exception = request.validate(); - - assertEquals("Validation Failed: 1: input data can't be null or empty;", exception.getMessage()); - } - @Test public void validate_Exception_NullDataFrame() { MLPredictionTaskRequest request = MLPredictionTaskRequest.builder() - .algorithm("algo") - .parameters(Collections.singletonList(MLParameterBuilder.parameter("k1", 1))) - .dataFrame(null) - .build(); + .algorithm("algo") + .parameters(Collections.singletonList(MLParameterBuilder.parameter("k1", 1))) + .inputDataset(null) + .build(); ActionRequestValidationException exception = request.validate(); - assertEquals("Validation Failed: 1: input data can't be null or empty;", exception.getMessage()); + assertEquals("Validation Failed: 1: input data can't be null;", exception.getMessage()); } @Test public void fromActionRequest_Success_WithMLPredictionTaskRequest() { MLPredictionTaskRequest request = MLPredictionTaskRequest.builder() - .algorithm("algo") - .parameters(Collections.singletonList(MLParameterBuilder.parameter("k1", 1))) + .algorithm("algo") + .parameters(Collections.singletonList(MLParameterBuilder.parameter("k1", 1))) + .inputDataset(DataFrameInputDataset.builder() .dataFrame(DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ put("key1", 2.0D); }}))) - .build(); + .build()) + + .build(); assertSame(MLPredictionTaskRequest.fromActionRequest(request), request); } @Test public void fromActionRequest_Success_WithNonMLPredictionTaskRequest() { MLPredictionTaskRequest request = MLPredictionTaskRequest.builder() - .algorithm("algo") - .parameters(Collections.singletonList(MLParameterBuilder.parameter("k1", 1))) + .algorithm("algo") + .parameters(Collections.singletonList(MLParameterBuilder.parameter("k1", 1))) + .inputDataset(DataFrameInputDataset.builder() .dataFrame(DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ put("key1", 2.0D); }}))) - .build(); + .build()) + .build(); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { @@ -147,7 +139,7 @@ public void writeTo(StreamOutput out) throws IOException { MLPredictionTaskRequest result = MLPredictionTaskRequest.fromActionRequest(actionRequest); assertNotSame(result, request); assertEquals(request.getAlgorithm(), result.getAlgorithm()); - assertEquals(request.getDataFrame().size(), result.getDataFrame().size()); + assertEquals(request.getInputDataset().getInputDataType(), result.getInputDataset().getInputDataType()); } @Test(expected = UncheckedIOException.class) diff --git a/common/src/test/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequestTest.java index 435e6708d9..fea4a1ca0f 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequestTest.java @@ -22,9 +22,9 @@ import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.ml.common.dataframe.ColumnMeta; -import org.opensearch.ml.common.dataframe.ColumnType; import org.opensearch.ml.common.dataframe.DataFrameBuilder; +import org.opensearch.ml.common.dataset.DataFrameInputDataset; +import org.opensearch.ml.common.dataset.MLInputDataType; import org.opensearch.ml.common.parameter.MLParameterBuilder; import static org.junit.Assert.assertEquals; @@ -37,95 +37,88 @@ public class MLTrainingTaskRequestTest { @Test public void validate_Success() { MLTrainingTaskRequest request = MLTrainingTaskRequest.builder() - .algorithm("algo") - .parameters(Collections.singletonList(MLParameterBuilder.parameter("k1", 1))) + .algorithm("algo") + .parameters(Collections.singletonList(MLParameterBuilder.parameter("k1", 1))) + .inputDataset(DataFrameInputDataset.builder() .dataFrame(DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ put("key1", 2.0D); }}))) - .build(); + .build()) + .build(); assertNull(request.validate()); } @Test public void validate_Exception_NullAlgoName() { MLTrainingTaskRequest request = MLTrainingTaskRequest.builder() - .algorithm(null) - .parameters(Collections.singletonList(MLParameterBuilder.parameter("k1", 1))) + .algorithm(null) + .parameters(Collections.singletonList(MLParameterBuilder.parameter("k1", 1))) + .inputDataset(DataFrameInputDataset.builder() .dataFrame(DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ put("key1", 2.0D); }}))) - .build(); + .build()) + .build(); ActionRequestValidationException exception = request.validate(); assertEquals("Validation Failed: 1: algorithm name can't be null or empty;", exception.getMessage()); } - - @Test - public void validate_Exception_EmptyDataFrame() { - MLTrainingTaskRequest request = MLTrainingTaskRequest.builder() - .algorithm("algo") - .parameters(Collections.singletonList(MLParameterBuilder.parameter("k1", 1))) - .dataFrame(DataFrameBuilder.emptyDataFrame(new ColumnMeta[]{ - ColumnMeta.builder() - .name("name") - .columnType(ColumnType.DOUBLE) - .build() - })) - .build(); - ActionRequestValidationException exception = request.validate(); - assertEquals("Validation Failed: 1: input data can't be null or empty;", exception.getMessage()); - } - @Test public void validate_Exception_NullDataFrame() { MLTrainingTaskRequest request = MLTrainingTaskRequest.builder() - .algorithm("algo") - .parameters(Collections.singletonList(MLParameterBuilder.parameter("k1", 1))) - .dataFrame(null) - .build(); + .algorithm("algo") + .parameters(Collections.singletonList(MLParameterBuilder.parameter("k1", 1))) + .inputDataset(null) + .build(); ActionRequestValidationException exception = request.validate(); - assertEquals("Validation Failed: 1: input data can't be null or empty;", exception.getMessage()); + assertEquals("Validation Failed: 1: input data can't be null;", exception.getMessage()); } @Test public void writeTo() throws IOException { MLTrainingTaskRequest request = MLTrainingTaskRequest.builder() - .algorithm("algo") - .parameters(Collections.singletonList(MLParameterBuilder.parameter("k1", 1))) + .algorithm("algo") + .parameters(Collections.singletonList(MLParameterBuilder.parameter("k1", 1))) + .inputDataset(DataFrameInputDataset.builder() .dataFrame(DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ put("key1", 2.0D); }}))) - .build(); + .build()) + .build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); request.writeTo(bytesStreamOutput); - assertEquals(39, bytesStreamOutput.size()); + assertEquals(40, bytesStreamOutput.size()); request = new MLTrainingTaskRequest(bytesStreamOutput.bytes().streamInput()); assertEquals("algo", request.getAlgorithm()); assertEquals(1, request.getParameters().size()); - assertEquals(1, request.getDataFrame().size()); + assertEquals(MLInputDataType.DATA_FRAME, request.getInputDataset().getInputDataType()); } @Test public void fromActionRequest_WithMLTrainingTaskRequest() { MLTrainingTaskRequest request = MLTrainingTaskRequest.builder() - .algorithm("algo") - .parameters(Collections.singletonList(MLParameterBuilder.parameter("k1", 1))) + .algorithm("algo") + .parameters(Collections.singletonList(MLParameterBuilder.parameter("k1", 1))) + .inputDataset(DataFrameInputDataset.builder() .dataFrame(DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ put("key1", 2.0D); }}))) - .build(); + .build()) + .build(); assertSame(request, MLTrainingTaskRequest.fromActionRequest(request)); } @Test public void fromActionRequest_WithNonMLTrainingTaskRequest() { MLTrainingTaskRequest request = MLTrainingTaskRequest.builder() - .algorithm("algo") - .parameters(Collections.singletonList(MLParameterBuilder.parameter("k1", 1))) + .algorithm("algo") + .parameters(Collections.singletonList(MLParameterBuilder.parameter("k1", 1))) + .inputDataset(DataFrameInputDataset.builder() .dataFrame(DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ put("key1", 2.0D); }}))) - .build(); + .build()) + .build(); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { @@ -140,8 +133,8 @@ public void writeTo(StreamOutput out) throws IOException { MLTrainingTaskRequest result = MLTrainingTaskRequest.fromActionRequest(actionRequest); assertNotSame(request, result); assertEquals(request.getAlgorithm(), result.getAlgorithm()); - assertEquals(request.getParameters().size(), request.getParameters().size()); - assertEquals(request.getDataFrame().size(), request.getDataFrame().size()); + assertEquals(request.getParameters().size(), result.getParameters().size()); + assertEquals(request.getInputDataset().getInputDataType(), result.getInputDataset().getInputDataType()); } @Test(expected = UncheckedIOException.class)