Skip to content

Commit

Permalink
Add query based input into interfaces (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
weicongs-amazon authored May 7, 2021
1 parent 78f3792 commit e682a4f
Show file tree
Hide file tree
Showing 18 changed files with 490 additions and 148 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.parameter.MLParameter;

/**
Expand All @@ -29,7 +32,7 @@ public interface MachineLearningClient {
/**
* Do prediction machine learning job
* @param algorithm algorithm name
* @param inputData input data set
* @param inputData input data frame
* @return the result future
*/
default ActionFuture<DataFrame> predict(String algorithm, DataFrame inputData) {
Expand All @@ -40,7 +43,7 @@ default ActionFuture<DataFrame> predict(String algorithm, DataFrame inputData) {
* Do prediction machine learning job
* @param algorithm algorithm name
* @param parameters parameters of ml algorithm
* @param inputData input data set
* @param inputData input data frame
* @return the result future
*/
default ActionFuture<DataFrame> predict(String algorithm, List<MLParameter> parameters, DataFrame inputData) {
Expand All @@ -51,7 +54,7 @@ default ActionFuture<DataFrame> predict(String algorithm, List<MLParameter> para
* Do prediction machine learning job
* @param algorithm algorithm name
* @param parameters parameters of ml algorithm
* @param inputData input data set
* @param inputData input data frame
* @param modelId the trained model id
* @return the result future
*/
Expand All @@ -64,7 +67,7 @@ default ActionFuture<DataFrame> predict(String algorithm, List<MLParameter> para
/**
* Do prediction machine learning job
* @param algorithm algorithm name
* @param inputData input data set
* @param inputData input data frame
* @param listener a listener to be notified of the result
*/
default void predict(String algorithm, DataFrame inputData, ActionListener<DataFrame> listener) {
Expand All @@ -75,13 +78,25 @@ default void predict(String algorithm, DataFrame inputData, ActionListener<DataF
* Do prediction machine learning job
* @param algorithm algorithm name
* @param parameters parameters of ml algorithm
* @param inputData input data set
* @param inputData input data frame
* @param listener a listener to be notified of the result
*/
default void predict(String algorithm, List<MLParameter> parameters, DataFrame inputData, ActionListener<DataFrame> listener){
predict(algorithm, parameters, inputData, null, listener);
}

/**
* Do prediction machine learning job
* @param algorithm algorithm name
* @param parameters parameters of ml algorithm
* @param inputData input data frame
* @param modelId the trained model id
* @param listener a listener to be notified of the result
*/
default void predict(String algorithm, List<MLParameter> parameters, DataFrame inputData, String modelId, ActionListener<DataFrame> listener) {
predict(algorithm, parameters, DataFrameInputDataset.builder().dataFrame(inputData).build(), modelId, listener);
}

/**
* Do prediction machine learning job
* @param algorithm algorithm name
Expand All @@ -90,13 +105,13 @@ default void predict(String algorithm, List<MLParameter> parameters, DataFrame i
* @param modelId the trained model id
* @param listener a listener to be notified of the result
*/
void predict(String algorithm, List<MLParameter> parameters, DataFrame inputData, String modelId, ActionListener<DataFrame> listener);
void predict(String algorithm, List<MLParameter> parameters, MLInputDataset inputData, String modelId, ActionListener<DataFrame> listener);

/**
* Do the training machine learning job. The training job will be always async process. The job id will be returned in this method.
* @param algorithm algorithm name
* @param parameters parameters of ml algorithm
* @param inputData input data set
* @param inputData input data frame
* @return the result future
*/
default ActionFuture<String> train(String algorithm, List<MLParameter> parameters, DataFrame inputData) {
Expand All @@ -105,13 +120,25 @@ default ActionFuture<String> train(String algorithm, List<MLParameter> parameter
return actionFuture;
}

/**
* Do the training machine learning job. The training job will be always async process. The job id will be returned in this method.
* @param algorithm algorithm name
* @param parameters parameters of ml algorithm
* @param inputData input data frame
* @param listener a listener to be notified of the result
*/
default void train(String algorithm, List<MLParameter> parameters, DataFrame inputData, ActionListener<String> listener) {
train(algorithm, parameters, DataFrameInputDataset.builder().dataFrame(inputData).build(), listener);
}


/**
* Do the training machine learning job. The training job will be always async process. The job id will be returned in this method.
* @param algorithm algorithm name
* @param parameters parameters of ml algorithm
* @param inputData input data set
* @param listener a listener to be notified of the result
*/
void train(String algorithm, List<MLParameter> parameters, DataFrame inputData, ActionListener<String> listener);
void train(String algorithm, List<MLParameter> parameters, MLInputDataset inputData, ActionListener<String> listener);

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.common.Strings;
import org.opensearch.ml.common.dataframe.DataFrame;

import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.parameter.MLParameter;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
Expand All @@ -38,20 +39,20 @@ public class MachineLearningNodeClient implements MachineLearningClient {
NodeClient client;

@Override
public void predict(String algorithm, List<MLParameter> parameters, DataFrame inputData, String modelId,
public void predict(String algorithm, List<MLParameter> parameters, MLInputDataset inputData, String modelId,
ActionListener<DataFrame> listener) {
if(Strings.isNullOrEmpty(algorithm)) {
throw new IllegalArgumentException("algorithm name can't be null or empty");
}
if(Objects.isNull(inputData) || inputData.size() <= 0) {
throw new IllegalArgumentException("input data frame can't be null or empty");
if(Objects.isNull(inputData)) {
throw new IllegalArgumentException("input data set can't be null");
}

MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest.builder()
.algorithm(algorithm)
.modelId(modelId)
.parameters(parameters)
.dataFrame(inputData)
.inputDataset(inputData)
.build();

client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, ActionListener.wrap(response -> {
Expand All @@ -64,17 +65,17 @@ public void predict(String algorithm, List<MLParameter> parameters, DataFrame in
}

@Override
public void train(String algorithm, List<MLParameter> parameters, DataFrame inputData, ActionListener<String> listener) {
public void train(String algorithm, List<MLParameter> parameters, MLInputDataset inputData, ActionListener<String> listener) {
if(Strings.isNullOrEmpty(algorithm)) {
throw new IllegalArgumentException("algorithm name can't be null or empty");
}
if(Objects.isNull(inputData) || inputData.size() <= 0) {
throw new IllegalArgumentException("input data frame can't be null or empty");
if(Objects.isNull(inputData)) {
throw new IllegalArgumentException("input data set can't be null");
}

MLTrainingTaskRequest trainingTaskRequest = MLTrainingTaskRequest.builder()
.algorithm(algorithm)
.dataFrame(inputData)
.inputDataset(inputData)
.parameters(parameters)
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.mockito.MockitoAnnotations;
import org.opensearch.action.ActionListener;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.parameter.MLParameter;

import static org.junit.Assert.assertEquals;
Expand Down Expand Up @@ -49,13 +50,13 @@ public void setUp() throws Exception {

machineLearningClient = new MachineLearningClient() {
@Override
public void predict(String algorithm, List<MLParameter> parameters, DataFrame inputData, String modelId,
public void predict(String algorithm, List<MLParameter> parameters, MLInputDataset inputData, String modelId,
ActionListener<DataFrame> listener) {
listener.onResponse(output);
}

@Override
public void train(String algorithm, List<MLParameter> parameters, DataFrame inputData,
public void train(String algorithm, List<MLParameter> parameters, MLInputDataset inputData,
ActionListener<String> listener) {
listener.onResponse("taskId");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.opensearch.action.ActionListener;
import org.opensearch.client.node.NodeClient;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskResponse;
Expand All @@ -44,7 +45,7 @@ public class MachineLearningNodeClientTest {
NodeClient client;

@Mock
DataFrame input;
MLInputDataset input;

@Mock
DataFrame output;
Expand All @@ -64,7 +65,6 @@ public class MachineLearningNodeClientTest {
@Before
public void setUp() throws Exception {
MockitoAnnotations.openMocks(this);
when(input.size()).thenReturn(1);
}

@Test
Expand Down Expand Up @@ -96,18 +96,10 @@ public void predict_Exception_WithNullAlgorithm() {
}

@Test
public void predict_Exception_WithEmptyDataFrame() {
public void predict_Exception_WithNullDataSet() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("input data frame can't be null or empty");
when(input.size()).thenReturn(0);
machineLearningNodeClient.predict("algo", null, input, null, dataFrameActionListener);
}

@Test
public void predict_Exception_WithNullDataFrame() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("input data frame can't be null or empty");
machineLearningNodeClient.predict("algo", null, null, null, dataFrameActionListener);
exceptionRule.expectMessage("input data set can't be null");
machineLearningNodeClient.predict("algo", null, (MLInputDataset) null, null, dataFrameActionListener);
}

@Test
Expand Down Expand Up @@ -139,17 +131,9 @@ public void train_Exception_WithNullAlgorithm() {
}

@Test
public void train_Exception_WithEmptyDataFrame() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("input data frame can't be null or empty");
when(input.size()).thenReturn(0);
machineLearningNodeClient.train("algo", null, input, trainingActionListener);
}

@Test
public void train_Exception_WithNullDataFrame() {
public void train_Exception_WithNullDataSet() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("input data frame can't be null or empty");
machineLearningNodeClient.train("algo", null, null, trainingActionListener);
exceptionRule.expectMessage("input data set can't be null");
machineLearningNodeClient.train("algo", null, (MLInputDataset)null, trainingActionListener);
}
}
5 changes: 3 additions & 2 deletions common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ plugins {
dependencies {
compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}-SNAPSHOT"
testCompile group: 'junit', name: 'junit', version: '4.12'

}

jacocoTestReport {
Expand All @@ -46,4 +45,6 @@ jacocoTestCoverageVerification {
}
dependsOn jacocoTestReport
}
check.dependsOn jacocoTestCoverageVerification
check.dependsOn jacocoTestCoverageVerification

lombok.config['lombok.nonNull.exceptionType'] = 'JDK'
1 change: 1 addition & 0 deletions common/lombok.config
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# This file is generated by the 'io.freefair.lombok' Gradle plugin
config.stopBubbling = true
lombok.addLombokGeneratedAnnotation = true
lombok.nonNull.exceptionType = JDK
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.ml.common.dataset;

import java.io.IOException;

import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.ml.common.dataframe.DataFrame;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import lombok.experimental.FieldDefaults;

/**
* DataFrame based input data. Client directly passes the data frame to ML plugin with this.
*/
@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class DataFrameInputDataset extends MLInputDataset {
DataFrame dataFrame;

@Builder
public DataFrameInputDataset(@NonNull DataFrame dataFrame) {
super(MLInputDataType.DATA_FRAME);
this.dataFrame = dataFrame;
}

@Override
public void writeTo(StreamOutput streamOutput) throws IOException {
super.writeTo(streamOutput);
dataFrame.writeTo(streamOutput);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.ml.common.dataset;

public enum MLInputDataType {
SEARCH_QUERY,
DATA_FRAME;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.ml.common.dataset;

import java.io.IOException;

import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.io.stream.Writeable;

import lombok.AccessLevel;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.experimental.FieldDefaults;

@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@RequiredArgsConstructor
public abstract class MLInputDataset implements Writeable {
MLInputDataType inputDataType;

@Override
public void writeTo(StreamOutput streamOutput) throws IOException {
streamOutput.writeEnum(this.inputDataType);
}
}
Loading

0 comments on commit e682a4f

Please sign in to comment.