diff --git a/mlflow/java/client/src/main/java/org/mlflow/tracking/MlflowClient.java b/mlflow/java/client/src/main/java/org/mlflow/tracking/MlflowClient.java index da753951be005..d30445dd6d27f 100644 --- a/mlflow/java/client/src/main/java/org/mlflow/tracking/MlflowClient.java +++ b/mlflow/java/client/src/main/java/org/mlflow/tracking/MlflowClient.java @@ -151,7 +151,30 @@ public List searchRuns(List experimentIds, String searchFilter) public List searchRuns(List experimentIds, String searchFilter, ViewType runViewType) { - SearchRuns.Builder builder = SearchRuns.newBuilder().addAllExperimentIds(experimentIds); + return searchRuns(experimentIds, searchFilter, runViewType, new ArrayList<>()); + } + + + /** + * Return runs from provided list of experiments that satisfy the search query. + * + * @param experimentIds List of experiment IDs. + * @param searchFilter SQL compatible search query string. Format of this query string is + * similar to that specified on MLflow UI. + * Example : "params.model = 'LogisticRegression' and metrics.acc != 0.9" + * @param runViewType ViewType for expected runs. One of (ACTIVE_ONLY, DELETED_ONLY, ALL) + * Defaults to ACTIVE_ONLY. + * @param orderBy List of properties to order by. Example: "metrics.acc DESC". + * + * @return A list of all RunInfos that satisfy search filter. + */ + public List searchRuns(List experimentIds, + String searchFilter, + ViewType runViewType, + List orderBy) { + SearchRuns.Builder builder = SearchRuns.newBuilder() + .addAllExperimentIds(experimentIds) + .addAllOrderBy(orderBy); if (searchFilter != null) { builder.setFilter(searchFilter); diff --git a/mlflow/java/client/src/test/java/org/mlflow/tracking/MlflowClientTest.java b/mlflow/java/client/src/test/java/org/mlflow/tracking/MlflowClientTest.java index 7dbe099f6e4b2..3d73d704a94fd 100644 --- a/mlflow/java/client/src/test/java/org/mlflow/tracking/MlflowClientTest.java +++ b/mlflow/java/client/src/test/java/org/mlflow/tracking/MlflowClientTest.java @@ -13,6 +13,7 @@ import java.util.Vector; import java.util.LinkedList; +import com.google.common.collect.Lists; import org.apache.commons.io.FileUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -250,6 +251,16 @@ public void searchRuns() { searchResult = client.searchRuns(experimentIds, "tag.test = 'also works'"); Assert.assertEquals(searchResult.get(0).getRunUuid(), runId_2); + + searchResult = client.searchRuns(experimentIds, "", ViewType.ACTIVE_ONLY, + Lists.newArrayList("metrics.accuracy_score")); + Assert.assertEquals(searchResult.get(0).getRunUuid(), runId_1); + Assert.assertEquals(searchResult.get(1).getRunUuid(), runId_2); + + searchResult = client.searchRuns(experimentIds, "", ViewType.ACTIVE_ONLY, + Lists.newArrayList("params.min_samples_leaf", "metrics.accuracy_score DESC")); + Assert.assertEquals(searchResult.get(1).getRunUuid(), runId_1); + Assert.assertEquals(searchResult.get(0).getRunUuid(), runId_2); } @Test