-
Notifications
You must be signed in to change notification settings - Fork 143
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add integration test for ml plugin. (#59)
Signed-off-by: Alex Sun <pengsun@dev-dsk-pengsun-2c-c6fbcf50.us-west-2.amazon.com> Co-authored-by: Alex Sun <pengsun@dev-dsk-pengsun-2c-c6fbcf50.us-west-2.amazon.com>
- Loading branch information
Showing
4 changed files
with
498 additions
and
0 deletions.
There are no files selected for viewing
104 changes: 104 additions & 0 deletions
104
plugin/src/test/java/org/opensearch/ml/action/prediction/PredictionIT.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
package org.opensearch.ml.action.prediction; | ||
|
||
import static org.opensearch.ml.utils.IntegTestUtils.DATA_FRAME_INPUT_DATASET; | ||
import static org.opensearch.ml.utils.IntegTestUtils.TESTING_DATA; | ||
import static org.opensearch.ml.utils.IntegTestUtils.TESTING_INDEX_NAME; | ||
import static org.opensearch.ml.utils.IntegTestUtils.generateEmptyDataset; | ||
import static org.opensearch.ml.utils.IntegTestUtils.generateMLTestingData; | ||
import static org.opensearch.ml.utils.IntegTestUtils.generateSearchSourceBuilder; | ||
import static org.opensearch.ml.utils.IntegTestUtils.predictAndVerifyResult; | ||
import static org.opensearch.ml.utils.IntegTestUtils.trainModel; | ||
import static org.opensearch.ml.utils.IntegTestUtils.verifyGeneratedTestingData; | ||
import static org.opensearch.ml.utils.IntegTestUtils.waitModelAvailable; | ||
|
||
import java.io.IOException; | ||
import java.util.ArrayList; | ||
import java.util.Collection; | ||
import java.util.Collections; | ||
import java.util.concurrent.ExecutionException; | ||
|
||
import org.junit.Before; | ||
import org.opensearch.ResourceNotFoundException; | ||
import org.opensearch.action.ActionFuture; | ||
import org.opensearch.action.ActionRequestValidationException; | ||
import org.opensearch.common.io.stream.NotSerializableExceptionWrapper; | ||
import org.opensearch.ml.common.dataset.MLInputDataset; | ||
import org.opensearch.ml.common.dataset.SearchQueryInputDataset; | ||
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; | ||
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; | ||
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskResponse; | ||
import org.opensearch.ml.plugin.MachineLearningPlugin; | ||
import org.opensearch.plugins.Plugin; | ||
import org.opensearch.search.builder.SearchSourceBuilder; | ||
import org.opensearch.test.OpenSearchIntegTestCase; | ||
|
||
@OpenSearchIntegTestCase.ClusterScope(transportClientRatio = 0.9) | ||
public class PredictionIT extends OpenSearchIntegTestCase { | ||
private String taskId; | ||
|
||
@Before | ||
public void initTestingData() throws ExecutionException, InterruptedException { | ||
generateMLTestingData(); | ||
|
||
SearchSourceBuilder searchSourceBuilder = generateSearchSourceBuilder(); | ||
MLInputDataset inputDataset = new SearchQueryInputDataset(Collections.singletonList(TESTING_INDEX_NAME), searchSourceBuilder); | ||
taskId = trainModel(inputDataset); | ||
waitModelAvailable(taskId); | ||
} | ||
|
||
@Override | ||
protected Collection<Class<? extends Plugin>> nodePlugins() { | ||
return Collections.singletonList(MachineLearningPlugin.class); | ||
} | ||
|
||
@Override | ||
protected Collection<Class<? extends Plugin>> transportClientPlugins() { | ||
return Collections.singletonList(MachineLearningPlugin.class); | ||
} | ||
|
||
public void testTestingData() throws ExecutionException, InterruptedException { | ||
verifyGeneratedTestingData(TESTING_DATA); | ||
waitModelAvailable(taskId); | ||
} | ||
|
||
public void testPredictionWithSearchInput() throws IOException { | ||
SearchSourceBuilder searchSourceBuilder = generateSearchSourceBuilder(); | ||
MLInputDataset inputDataset = new SearchQueryInputDataset(Collections.singletonList(TESTING_INDEX_NAME), searchSourceBuilder); | ||
|
||
predictAndVerifyResult(taskId, inputDataset); | ||
} | ||
|
||
public void testPredictionWithDataInput() throws IOException { | ||
predictAndVerifyResult(taskId, DATA_FRAME_INPUT_DATASET); | ||
} | ||
|
||
public void testPredictionWithoutAlgorithm() throws IOException { | ||
MLPredictionTaskRequest predictionRequest = new MLPredictionTaskRequest("", new ArrayList<>(), taskId, DATA_FRAME_INPUT_DATASET); | ||
ActionFuture<MLPredictionTaskResponse> predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest); | ||
expectThrows(ActionRequestValidationException.class, () -> predictionFuture.actionGet()); | ||
} | ||
|
||
public void testPredictionWithoutModelId() throws IOException { | ||
MLPredictionTaskRequest predictionRequest = new MLPredictionTaskRequest("kmeans", new ArrayList<>(), "", DATA_FRAME_INPUT_DATASET); | ||
ActionFuture<MLPredictionTaskResponse> predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest); | ||
expectThrows(ResourceNotFoundException.class, () -> predictionFuture.actionGet()); | ||
} | ||
|
||
public void testPredictionWithoutDataset() throws IOException { | ||
MLPredictionTaskRequest predictionRequest = new MLPredictionTaskRequest("kmeans", new ArrayList<>(), taskId, null); | ||
ActionFuture<MLPredictionTaskResponse> predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest); | ||
expectThrows(ActionRequestValidationException.class, () -> predictionFuture.actionGet()); | ||
} | ||
|
||
public void testPredictionWithEmptyDataset() throws IOException { | ||
MLInputDataset emptySearchInputDataset = generateEmptyDataset(); | ||
MLPredictionTaskRequest predictionRequest = new MLPredictionTaskRequest( | ||
"kmeans", | ||
new ArrayList<>(), | ||
taskId, | ||
emptySearchInputDataset | ||
); | ||
ActionFuture<MLPredictionTaskResponse> predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest); | ||
expectThrows(NotSerializableExceptionWrapper.class, () -> predictionFuture.actionGet()); | ||
} | ||
} |
71 changes: 71 additions & 0 deletions
71
plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeIT.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
* | ||
*/ | ||
|
||
package org.opensearch.ml.action.stats; | ||
|
||
import static org.opensearch.ml.action.stats.MLStatsNodesRequest.ALL_STATS_KEY; | ||
import static org.opensearch.ml.utils.IntegTestUtils.TESTING_DATA; | ||
import static org.opensearch.ml.utils.IntegTestUtils.generateMLTestingData; | ||
import static org.opensearch.ml.utils.IntegTestUtils.verifyGeneratedTestingData; | ||
|
||
import java.util.Collection; | ||
import java.util.Collections; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.concurrent.ExecutionException; | ||
|
||
import org.junit.Before; | ||
import org.opensearch.action.ActionFuture; | ||
import org.opensearch.ml.plugin.MachineLearningPlugin; | ||
import org.opensearch.plugins.Plugin; | ||
import org.opensearch.test.OpenSearchIntegTestCase; | ||
|
||
@OpenSearchIntegTestCase.ClusterScope(transportClientRatio = 0.9) | ||
public class MLStatsNodeIT extends OpenSearchIntegTestCase { | ||
@Before | ||
public void initTestingData() throws ExecutionException, InterruptedException { | ||
generateMLTestingData(); | ||
} | ||
|
||
@Override | ||
protected Collection<Class<? extends Plugin>> nodePlugins() { | ||
return Collections.singletonList(MachineLearningPlugin.class); | ||
} | ||
|
||
@Override | ||
protected Collection<Class<? extends Plugin>> transportClientPlugins() { | ||
return Collections.singletonList(MachineLearningPlugin.class); | ||
} | ||
|
||
public void testGeneratedTestingData() throws ExecutionException, InterruptedException { | ||
verifyGeneratedTestingData(TESTING_DATA); | ||
} | ||
|
||
public void testNormalCase() throws ExecutionException, InterruptedException { | ||
MLStatsNodesRequest request = new MLStatsNodesRequest(new String[0]); | ||
request.addStat(ALL_STATS_KEY); | ||
|
||
ActionFuture<MLStatsNodesResponse> future = client().execute(MLStatsNodesAction.INSTANCE, request); | ||
MLStatsNodesResponse response = future.get(); | ||
assertNotNull(response); | ||
|
||
List<MLStatsNodeResponse> responseList = response.getNodes(); | ||
assertNotNull(responseList); | ||
assertEquals(1, responseList.size()); | ||
|
||
MLStatsNodeResponse nodeResponse = responseList.get(0); | ||
Map<String, Object> statsMap = nodeResponse.getStatsMap(); | ||
|
||
assertNotNull(statsMap); | ||
assertEquals(0, statsMap.size()); | ||
} | ||
} |
142 changes: 142 additions & 0 deletions
142
plugin/src/test/java/org/opensearch/ml/action/training/TrainingIT.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
* | ||
* Modifications Copyright OpenSearch Contributors. See | ||
* GitHub history for details. | ||
* | ||
*/ | ||
|
||
package org.opensearch.ml.action.training; | ||
|
||
import static org.opensearch.ml.indices.MLIndicesHandler.OS_ML_MODEL_RESULT; | ||
import static org.opensearch.ml.utils.IntegTestUtils.DATA_FRAME_INPUT_DATASET; | ||
import static org.opensearch.ml.utils.IntegTestUtils.TESTING_DATA; | ||
import static org.opensearch.ml.utils.IntegTestUtils.TESTING_INDEX_NAME; | ||
import static org.opensearch.ml.utils.IntegTestUtils.generateMLTestingData; | ||
import static org.opensearch.ml.utils.IntegTestUtils.generateSearchSourceBuilder; | ||
import static org.opensearch.ml.utils.IntegTestUtils.trainModel; | ||
import static org.opensearch.ml.utils.IntegTestUtils.verifyGeneratedTestingData; | ||
import static org.opensearch.ml.utils.IntegTestUtils.waitModelAvailable; | ||
|
||
import java.io.IOException; | ||
import java.util.ArrayList; | ||
import java.util.Collection; | ||
import java.util.Collections; | ||
import java.util.concurrent.ExecutionException; | ||
|
||
import org.junit.Before; | ||
import org.opensearch.action.ActionFuture; | ||
import org.opensearch.action.ActionRequestValidationException; | ||
import org.opensearch.action.search.SearchAction; | ||
import org.opensearch.action.search.SearchRequest; | ||
import org.opensearch.action.search.SearchResponse; | ||
import org.opensearch.index.query.QueryBuilder; | ||
import org.opensearch.index.query.QueryBuilders; | ||
import org.opensearch.ml.common.dataset.MLInputDataset; | ||
import org.opensearch.ml.common.dataset.SearchQueryInputDataset; | ||
import org.opensearch.ml.common.transport.training.MLTrainingTaskAction; | ||
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest; | ||
import org.opensearch.ml.common.transport.training.MLTrainingTaskResponse; | ||
import org.opensearch.ml.plugin.MachineLearningPlugin; | ||
import org.opensearch.plugins.Plugin; | ||
import org.opensearch.search.builder.SearchSourceBuilder; | ||
import org.opensearch.test.OpenSearchIntegTestCase; | ||
|
||
@OpenSearchIntegTestCase.ClusterScope(transportClientRatio = 0.9) | ||
public class TrainingIT extends OpenSearchIntegTestCase { | ||
@Before | ||
public void initTestingData() throws ExecutionException, InterruptedException { | ||
generateMLTestingData(); | ||
} | ||
|
||
@Override | ||
protected Collection<Class<? extends Plugin>> nodePlugins() { | ||
return Collections.singletonList(MachineLearningPlugin.class); | ||
} | ||
|
||
@Override | ||
protected Collection<Class<? extends Plugin>> transportClientPlugins() { | ||
return Collections.singletonList(MachineLearningPlugin.class); | ||
} | ||
|
||
public void testGeneratedTestingData() throws ExecutionException, InterruptedException { | ||
verifyGeneratedTestingData(TESTING_DATA); | ||
} | ||
|
||
public void testTrainingWithSearchInput() throws ExecutionException, InterruptedException, IOException { | ||
SearchSourceBuilder searchSourceBuilder = generateSearchSourceBuilder(); | ||
MLInputDataset inputDataset = new SearchQueryInputDataset(Collections.singletonList(TESTING_INDEX_NAME), searchSourceBuilder); | ||
|
||
String taskId = trainModel(inputDataset); | ||
|
||
waitModelAvailable(taskId); | ||
} | ||
|
||
public void testTrainingWithDataInput() throws ExecutionException, InterruptedException, IOException { | ||
String taskId = trainModel(DATA_FRAME_INPUT_DATASET); | ||
|
||
waitModelAvailable(taskId); | ||
} | ||
|
||
// Train a model without algorithm. | ||
public void testTrainingWithoutAlgorithm() { | ||
SearchSourceBuilder searchSourceBuilder = generateSearchSourceBuilder(); | ||
MLInputDataset inputDataset = new SearchQueryInputDataset(Collections.singletonList(TESTING_INDEX_NAME), searchSourceBuilder); | ||
MLTrainingTaskRequest trainingRequest = new MLTrainingTaskRequest("", new ArrayList<>(), inputDataset); | ||
expectThrows(ActionRequestValidationException.class, () -> { | ||
ActionFuture<MLTrainingTaskResponse> trainingFuture = client().execute(MLTrainingTaskAction.INSTANCE, trainingRequest); | ||
trainingFuture.actionGet(); | ||
}); | ||
} | ||
|
||
// Train a model without dataset. | ||
public void testTrainingWithoutDataset() { | ||
MLTrainingTaskRequest trainingRequest = new MLTrainingTaskRequest("kmeans", new ArrayList<>(), null); | ||
expectThrows(ActionRequestValidationException.class, () -> { | ||
ActionFuture<MLTrainingTaskResponse> trainingFuture = client().execute(MLTrainingTaskAction.INSTANCE, trainingRequest); | ||
trainingFuture.actionGet(); | ||
}); | ||
} | ||
|
||
// Train a model with empty dataset. | ||
public void testTrainingWithEmptyDataset() throws InterruptedException { | ||
SearchSourceBuilder searchSourceBuilder = generateSearchSourceBuilder(); | ||
searchSourceBuilder.query(QueryBuilders.matchQuery("noSuchName", "")); | ||
MLInputDataset inputDataset = new SearchQueryInputDataset(Collections.singletonList(TESTING_INDEX_NAME), searchSourceBuilder); | ||
MLTrainingTaskRequest trainingRequest = new MLTrainingTaskRequest("kmeans", new ArrayList<>(), inputDataset); | ||
|
||
ActionFuture<MLTrainingTaskResponse> trainingFuture = client().execute(MLTrainingTaskAction.INSTANCE, trainingRequest); | ||
MLTrainingTaskResponse trainingResponse = trainingFuture.actionGet(); | ||
|
||
// The training taskId and status will be response to the client. | ||
assertNotNull(trainingResponse); | ||
String taskId = trainingResponse.getTaskId(); | ||
String status = trainingResponse.getStatus(); | ||
assertNotNull(taskId); | ||
assertFalse(taskId.isEmpty()); | ||
assertEquals("CREATED", status); | ||
|
||
SearchSourceBuilder modelSearchSourceBuilder = new SearchSourceBuilder(); | ||
QueryBuilder queryBuilder = QueryBuilders.termQuery("taskId", taskId); | ||
modelSearchSourceBuilder.query(queryBuilder); | ||
SearchRequest modelSearchRequest = new SearchRequest(new String[] { OS_ML_MODEL_RESULT }, modelSearchSourceBuilder); | ||
SearchResponse modelSearchResponse = null; | ||
int i = 0; | ||
while ((modelSearchResponse == null || modelSearchResponse.getHits().getTotalHits().value == 0) && i < 100) { | ||
try { | ||
ActionFuture<SearchResponse> searchFuture = client().execute(SearchAction.INSTANCE, modelSearchRequest); | ||
modelSearchResponse = searchFuture.actionGet(); | ||
} catch (Exception e) {} finally { | ||
// Wait 100 ms until get valid search response or timeout. | ||
Thread.sleep(100); | ||
} | ||
i++; | ||
} | ||
// No model would be trained successfully with empty dataset. | ||
assertNull(modelSearchResponse); | ||
} | ||
} |
Oops, something went wrong.