Skip to content

Commit

Permalink
[Java] Add order_by support to Java tracking client (mlflow#1482)
Browse files Browse the repository at this point in the history
  • Loading branch information
aarondav authored Jun 20, 2019
1 parent 2a7bcb7 commit dbd4944
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,30 @@ public List<RunInfo> searchRuns(List<String> experimentIds, String searchFilter)
public List<RunInfo> searchRuns(List<String> experimentIds,
String searchFilter,
ViewType runViewType) {
SearchRuns.Builder builder = SearchRuns.newBuilder().addAllExperimentIds(experimentIds);
return searchRuns(experimentIds, searchFilter, runViewType, new ArrayList<>());
}


/**
* Return runs from provided list of experiments that satisfy the search query.
*
* @param experimentIds List of experiment IDs.
* @param searchFilter SQL compatible search query string. Format of this query string is
* similar to that specified on MLflow UI.
* Example : "params.model = 'LogisticRegression' and metrics.acc != 0.9"
* @param runViewType ViewType for expected runs. One of (ACTIVE_ONLY, DELETED_ONLY, ALL)
* Defaults to ACTIVE_ONLY.
* @param orderBy List of properties to order by. Example: "metrics.acc DESC".
*
* @return A list of all RunInfos that satisfy search filter.
*/
public List<RunInfo> searchRuns(List<String> experimentIds,
String searchFilter,
ViewType runViewType,
List<String> orderBy) {
SearchRuns.Builder builder = SearchRuns.newBuilder()
.addAllExperimentIds(experimentIds)
.addAllOrderBy(orderBy);

if (searchFilter != null) {
builder.setFilter(searchFilter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.util.Vector;
import java.util.LinkedList;

import com.google.common.collect.Lists;
import org.apache.commons.io.FileUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -250,6 +251,16 @@ public void searchRuns() {

searchResult = client.searchRuns(experimentIds, "tag.test = 'also works'");
Assert.assertEquals(searchResult.get(0).getRunUuid(), runId_2);

searchResult = client.searchRuns(experimentIds, "", ViewType.ACTIVE_ONLY,
Lists.newArrayList("metrics.accuracy_score"));
Assert.assertEquals(searchResult.get(0).getRunUuid(), runId_1);
Assert.assertEquals(searchResult.get(1).getRunUuid(), runId_2);

searchResult = client.searchRuns(experimentIds, "", ViewType.ACTIVE_ONLY,
Lists.newArrayList("params.min_samples_leaf", "metrics.accuracy_score DESC"));
Assert.assertEquals(searchResult.get(1).getRunUuid(), runId_1);
Assert.assertEquals(searchResult.get(0).getRunUuid(), runId_2);
}

@Test
Expand Down

0 comments on commit dbd4944

Please sign in to comment.