Skip to content

Add query based input into interfaces #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.parameter.MLParameter;

/**
Expand All @@ -29,7 +32,7 @@ public interface MachineLearningClient {
/**
* Do prediction machine learning job
* @param algorithm algorithm name
* @param inputData input data set
* @param inputData input data frame
* @return the result future
*/
default ActionFuture<DataFrame> predict(String algorithm, DataFrame inputData) {
Expand All @@ -40,7 +43,7 @@ default ActionFuture<DataFrame> predict(String algorithm, DataFrame inputData) {
* Do prediction machine learning job
* @param algorithm algorithm name
* @param parameters parameters of ml algorithm
* @param inputData input data set
* @param inputData input data frame
* @return the result future
*/
default ActionFuture<DataFrame> predict(String algorithm, List<MLParameter> parameters, DataFrame inputData) {
Expand All @@ -51,7 +54,7 @@ default ActionFuture<DataFrame> predict(String algorithm, List<MLParameter> para
* Do prediction machine learning job
* @param algorithm algorithm name
* @param parameters parameters of ml algorithm
* @param inputData input data set
* @param inputData input data frame
* @param modelId the trained model id
* @return the result future
*/
Expand All @@ -64,7 +67,7 @@ default ActionFuture<DataFrame> predict(String algorithm, List<MLParameter> para
/**
* Do prediction machine learning job
* @param algorithm algorithm name
* @param inputData input data set
* @param inputData input data frame
* @param listener a listener to be notified of the result
*/
default void predict(String algorithm, DataFrame inputData, ActionListener<DataFrame> listener) {
Expand All @@ -75,13 +78,25 @@ default void predict(String algorithm, DataFrame inputData, ActionListener<DataF
* Do prediction machine learning job
* @param algorithm algorithm name
* @param parameters parameters of ml algorithm
* @param inputData input data set
* @param inputData input data frame
* @param listener a listener to be notified of the result
*/
default void predict(String algorithm, List<MLParameter> parameters, DataFrame inputData, ActionListener<DataFrame> listener){
predict(algorithm, parameters, inputData, null, listener);
}

/**
* Do prediction machine learning job
* @param algorithm algorithm name
* @param parameters parameters of ml algorithm
* @param inputData input data frame
* @param modelId the trained model id
* @param listener a listener to be notified of the result
*/
default void predict(String algorithm, List<MLParameter> parameters, DataFrame inputData, String modelId, ActionListener<DataFrame> listener) {
Copy link
Collaborator

@ylwu-amzn ylwu-amzn May 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a supported algorithm list? How about add the list or add the doc link in java doc , so the client can know the exact algorithm name they should use?
Another option is to add enum for support algorithm. So user can just use code like this SupportedMLAlgorithm.KMEANS easily without worrying about wrong algorithm name (kMeans, kmean, or k_means?), but maybe the list will be too long.

Copy link
Contributor Author

@weicongs-amazon weicongs-amazon May 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The built-in algorithms candidates are available in the first release For these, we plan to have official page to share the details of each algorithm. We can add that link into the java doc. One issue is created to track this change. #26

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the algorithm enum, we can have an enum for built-in algorithms, but for interface, String will be still used to support custom algorithms built by customers.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, it's great that we will support custom algorithms. Then it make sense to use string here. Maybe another interface with enum as input is more friendly for users who use built-in algorithms.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment. Created issue to track this change. #28 Personally I prefer not providing enum interface, since it will increase some cognitive load to customers. Adding this enum in the javadoc of the interfaces should be good enough. But still open to change this.

predict(algorithm, parameters, DataFrameInputDataset.builder().dataFrame(inputData).build(), modelId, listener);
}

/**
* Do prediction machine learning job
* @param algorithm algorithm name
Expand All @@ -90,13 +105,13 @@ default void predict(String algorithm, List<MLParameter> parameters, DataFrame i
* @param modelId the trained model id
* @param listener a listener to be notified of the result
*/
void predict(String algorithm, List<MLParameter> parameters, DataFrame inputData, String modelId, ActionListener<DataFrame> listener);
void predict(String algorithm, List<MLParameter> parameters, MLInputDataset inputData, String modelId, ActionListener<DataFrame> listener);

/**
* Do the training machine learning job. The training job will be always async process. The job id will be returned in this method.
* @param algorithm algorithm name
* @param parameters parameters of ml algorithm
* @param inputData input data set
* @param inputData input data frame
* @return the result future
*/
default ActionFuture<String> train(String algorithm, List<MLParameter> parameters, DataFrame inputData) {
Expand All @@ -105,13 +120,25 @@ default ActionFuture<String> train(String algorithm, List<MLParameter> parameter
return actionFuture;
}

/**
* Do the training machine learning job. The training job will be always async process. The job id will be returned in this method.
* @param algorithm algorithm name
* @param parameters parameters of ml algorithm
* @param inputData input data frame
* @param listener a listener to be notified of the result
*/
default void train(String algorithm, List<MLParameter> parameters, DataFrame inputData, ActionListener<String> listener) {
train(algorithm, parameters, DataFrameInputDataset.builder().dataFrame(inputData).build(), listener);
}


/**
* Do the training machine learning job. The training job will be always async process. The job id will be returned in this method.
* @param algorithm algorithm name
* @param parameters parameters of ml algorithm
* @param inputData input data set
* @param listener a listener to be notified of the result
*/
void train(String algorithm, List<MLParameter> parameters, DataFrame inputData, ActionListener<String> listener);
void train(String algorithm, List<MLParameter> parameters, MLInputDataset inputData, ActionListener<String> listener);

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.common.Strings;
import org.opensearch.ml.common.dataframe.DataFrame;

import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.parameter.MLParameter;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
Expand All @@ -38,20 +39,20 @@ public class MachineLearningNodeClient implements MachineLearningClient {
NodeClient client;

@Override
public void predict(String algorithm, List<MLParameter> parameters, DataFrame inputData, String modelId,
public void predict(String algorithm, List<MLParameter> parameters, MLInputDataset inputData, String modelId,
ActionListener<DataFrame> listener) {
if(Strings.isNullOrEmpty(algorithm)) {
throw new IllegalArgumentException("algorithm name can't be null or empty");
}
if(Objects.isNull(inputData) || inputData.size() <= 0) {
throw new IllegalArgumentException("input data frame can't be null or empty");
if(Objects.isNull(inputData)) {
throw new IllegalArgumentException("input data set can't be null");
}

MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest.builder()
.algorithm(algorithm)
.modelId(modelId)
.parameters(parameters)
.dataFrame(inputData)
.inputDataset(inputData)
.build();

client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, ActionListener.wrap(response -> {
Expand All @@ -64,17 +65,17 @@ public void predict(String algorithm, List<MLParameter> parameters, DataFrame in
}

@Override
public void train(String algorithm, List<MLParameter> parameters, DataFrame inputData, ActionListener<String> listener) {
public void train(String algorithm, List<MLParameter> parameters, MLInputDataset inputData, ActionListener<String> listener) {
if(Strings.isNullOrEmpty(algorithm)) {
throw new IllegalArgumentException("algorithm name can't be null or empty");
}
if(Objects.isNull(inputData) || inputData.size() <= 0) {
throw new IllegalArgumentException("input data frame can't be null or empty");
if(Objects.isNull(inputData)) {
throw new IllegalArgumentException("input data set can't be null");
}

MLTrainingTaskRequest trainingTaskRequest = MLTrainingTaskRequest.builder()
.algorithm(algorithm)
.dataFrame(inputData)
.inputDataset(inputData)
.parameters(parameters)
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.mockito.MockitoAnnotations;
import org.opensearch.action.ActionListener;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.parameter.MLParameter;

import static org.junit.Assert.assertEquals;
Expand Down Expand Up @@ -49,13 +50,13 @@ public void setUp() throws Exception {

machineLearningClient = new MachineLearningClient() {
@Override
public void predict(String algorithm, List<MLParameter> parameters, DataFrame inputData, String modelId,
public void predict(String algorithm, List<MLParameter> parameters, MLInputDataset inputData, String modelId,
ActionListener<DataFrame> listener) {
listener.onResponse(output);
}

@Override
public void train(String algorithm, List<MLParameter> parameters, DataFrame inputData,
public void train(String algorithm, List<MLParameter> parameters, MLInputDataset inputData,
ActionListener<String> listener) {
listener.onResponse("taskId");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.opensearch.action.ActionListener;
import org.opensearch.client.node.NodeClient;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskResponse;
Expand All @@ -44,7 +45,7 @@ public class MachineLearningNodeClientTest {
NodeClient client;

@Mock
DataFrame input;
MLInputDataset input;

@Mock
DataFrame output;
Expand All @@ -64,7 +65,6 @@ public class MachineLearningNodeClientTest {
@Before
public void setUp() throws Exception {
MockitoAnnotations.openMocks(this);
when(input.size()).thenReturn(1);
}

@Test
Expand Down Expand Up @@ -96,18 +96,10 @@ public void predict_Exception_WithNullAlgorithm() {
}

@Test
public void predict_Exception_WithEmptyDataFrame() {
public void predict_Exception_WithNullDataSet() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("input data frame can't be null or empty");
when(input.size()).thenReturn(0);
machineLearningNodeClient.predict("algo", null, input, null, dataFrameActionListener);
}

@Test
public void predict_Exception_WithNullDataFrame() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("input data frame can't be null or empty");
machineLearningNodeClient.predict("algo", null, null, null, dataFrameActionListener);
exceptionRule.expectMessage("input data set can't be null");
machineLearningNodeClient.predict("algo", null, (MLInputDataset) null, null, dataFrameActionListener);
}

@Test
Expand Down Expand Up @@ -139,17 +131,9 @@ public void train_Exception_WithNullAlgorithm() {
}

@Test
public void train_Exception_WithEmptyDataFrame() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("input data frame can't be null or empty");
when(input.size()).thenReturn(0);
machineLearningNodeClient.train("algo", null, input, trainingActionListener);
}

@Test
public void train_Exception_WithNullDataFrame() {
public void train_Exception_WithNullDataSet() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("input data frame can't be null or empty");
machineLearningNodeClient.train("algo", null, null, trainingActionListener);
exceptionRule.expectMessage("input data set can't be null");
machineLearningNodeClient.train("algo", null, (MLInputDataset)null, trainingActionListener);
}
}
5 changes: 3 additions & 2 deletions common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ plugins {
dependencies {
compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}-SNAPSHOT"
testCompile group: 'junit', name: 'junit', version: '4.12'

}

jacocoTestReport {
Expand All @@ -46,4 +45,6 @@ jacocoTestCoverageVerification {
}
dependsOn jacocoTestReport
}
check.dependsOn jacocoTestCoverageVerification
check.dependsOn jacocoTestCoverageVerification

lombok.config['lombok.nonNull.exceptionType'] = 'JDK'
1 change: 1 addition & 0 deletions common/lombok.config
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# This file is generated by the 'io.freefair.lombok' Gradle plugin
config.stopBubbling = true
lombok.addLombokGeneratedAnnotation = true
lombok.nonNull.exceptionType = JDK
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.ml.common.dataset;

import java.io.IOException;

import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.ml.common.dataframe.DataFrame;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import lombok.experimental.FieldDefaults;

/**
* DataFrame based input data. Client directly passes the data frame to ML plugin with this.
*/
@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class DataFrameInputDataset extends MLInputDataset {
DataFrame dataFrame;

@Builder
public DataFrameInputDataset(@NonNull DataFrame dataFrame) {
super(MLInputDataType.DATA_FRAME);
this.dataFrame = dataFrame;
}

@Override
public void writeTo(StreamOutput streamOutput) throws IOException {
super.writeTo(streamOutput);
dataFrame.writeTo(streamOutput);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.ml.common.dataset;

public enum MLInputDataType {
SEARCH_QUERY,
DATA_FRAME;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.ml.common.dataset;

import java.io.IOException;

import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.io.stream.Writeable;

import lombok.AccessLevel;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.experimental.FieldDefaults;

@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@RequiredArgsConstructor
public abstract class MLInputDataset implements Writeable {
MLInputDataType inputDataType;

@Override
public void writeTo(StreamOutput streamOutput) throws IOException {
streamOutput.writeEnum(this.inputDataType);
}
}
Loading