Skip to content

Commit

Permalink
feat(backend): sort by run metrics - step 2 (#4235)
Browse files Browse the repository at this point in the history
* enable pagination when expanding experiment in both the home page and the archive page

* Revert "enable pagination when expanding experiment in both the home page and the archive page"

This reverts commit 5b67273.

* sorting by run metrics is different from sorting by name, uuid, created at, etc. The lattre are direct field in listable object, the former is an element in an arrary-typed field in listable object. In other words, the latter are columns in table, the former is not.

* unit test: add sorting on metrics with both asc and desc order

* list is generic. model specific test is put to run_store_test.go
  • Loading branch information
jingzhang36 authored Jul 21, 2020
1 parent 9431009 commit d4d3616
Show file tree
Hide file tree
Showing 11 changed files with 438 additions and 37 deletions.
2 changes: 2 additions & 0 deletions backend/src/apiserver/list/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ go_library(
"//backend/api:go_default_library",
"//backend/src/apiserver/common:go_default_library",
"//backend/src/apiserver/filter:go_default_library",
"//backend/src/apiserver/model:go_default_library",
"//backend/src/common/util:go_default_library",
"@com_github_golang_glog//:go_default_library",
"@com_github_masterminds_squirrel//:go_default_library",
],
)
Expand Down
107 changes: 79 additions & 28 deletions backend/src/apiserver/list/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,22 @@ type token struct {
// SortByFieldValue is the value of the sorted field of the next row to be
// returned.
SortByFieldValue interface{}
// SortByFieldIsRunMetric indicates whether the SortByFieldName field is
// a run metric field or not.
SortByFieldIsRunMetric bool

// KeyFieldName is the name of the primary key for the model being queried.
KeyFieldName string
// KeyFieldValue is the value of the sorted field of the next row to be
// returned.
KeyFieldValue interface{}

// IsDesc is true if the sorting order should be descending.
IsDesc bool

// ModelName is the table where ***FieldName belongs to.
ModelName string

// Filter represents the filtering that should be applied in the query.
Filter *filter.Filter
}
Expand Down Expand Up @@ -94,7 +101,7 @@ type Options struct {
// Matches returns trues if the sorting and filtering criteria in o matches that
// of the one supplied in opts.
func (o *Options) Matches(opts *Options) bool {
return o.SortByFieldName == opts.SortByFieldName &&
return o.SortByFieldName == opts.SortByFieldName && o.SortByFieldIsRunMetric == opts.SortByFieldIsRunMetric &&
o.IsDesc == opts.IsDesc &&
reflect.DeepEqual(o.Filter, opts.Filter)
}
Expand Down Expand Up @@ -140,13 +147,23 @@ func NewOptions(listable Listable, pageSize int, sortBy string, filterProto *api
}

token.SortByFieldName = listable.DefaultSortField()
token.SortByFieldIsRunMetric = false
if len(queryList) > 0 {
var err error
n, ok := listable.APIToModelFieldMap()[queryList[0]]
if !ok {
if ok {
token.SortByFieldName = n
} else if strings.HasPrefix(queryList[0], "metric:") {
// Sorting on metrics is only available on certain runs.
model := reflect.ValueOf(listable).Elem().Type().Name()
if model != "Run" {
return nil, util.NewInvalidInputError("Invalid sorting field: %q on %q : %s", queryList[0], model, err)
}
token.SortByFieldName = queryList[0][7:]
token.SortByFieldIsRunMetric = true
} else {
return nil, util.NewInvalidInputError("Invalid sorting field: %q: %s", queryList[0], err)
}
token.SortByFieldName = n
}

if len(queryList) == 2 {
Expand Down Expand Up @@ -176,28 +193,34 @@ func (o *Options) AddPaginationToSelect(sqlBuilder sq.SelectBuilder) sq.SelectBu
return sqlBuilder
}

// AddPaginationToSelect adds WHERE clauses with the sorting and pagination criteria in the
// Options o to the supplied SelectBuilder, and returns the new SelectBuilder
// containing these.
// AddSortingToSelect adds Order By clause.
func (o *Options) AddSortingToSelect(sqlBuilder sq.SelectBuilder) sq.SelectBuilder {
// If next row's value is specified, set those values in the clause.
var modelNamePrefix string
// When sorting by a direct field in the listable model (i.e., name in Run or uuid in Pipeline), a sortByFieldPrefix can be specified; when sorting by a field in an array-typed dictionary (i.e., a run metric inside the metrics in Run), a sortByFieldPrefix is not needed.
var keyFieldPrefix string
var sortByFieldPrefix string
if len(o.ModelName) == 0 {
modelNamePrefix = ""
keyFieldPrefix = ""
sortByFieldPrefix = ""
} else if o.SortByFieldIsRunMetric {
keyFieldPrefix = o.ModelName + "."
sortByFieldPrefix = ""
} else {
modelNamePrefix = o.ModelName + "."
keyFieldPrefix = o.ModelName + "."
sortByFieldPrefix = o.ModelName + "."
}

// If next row's value is specified, set those values in the clause.
if o.SortByFieldValue != nil && o.KeyFieldValue != nil {
if o.IsDesc {
sqlBuilder = sqlBuilder.
Where(sq.Or{sq.Lt{modelNamePrefix + o.SortByFieldName: o.SortByFieldValue},
sq.And{sq.Eq{modelNamePrefix + o.SortByFieldName: o.SortByFieldValue},
sq.LtOrEq{modelNamePrefix + o.KeyFieldName: o.KeyFieldValue}}})
Where(sq.Or{sq.Lt{sortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue},
sq.And{sq.Eq{sortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue},
sq.LtOrEq{keyFieldPrefix + o.KeyFieldName: o.KeyFieldValue}}})
} else {
sqlBuilder = sqlBuilder.
Where(sq.Or{sq.Gt{modelNamePrefix + o.SortByFieldName: o.SortByFieldValue},
sq.And{sq.Eq{modelNamePrefix + o.SortByFieldName: o.SortByFieldValue},
sq.GtOrEq{modelNamePrefix + o.KeyFieldName: o.KeyFieldValue}}})
Where(sq.Or{sq.Gt{sortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue},
sq.And{sq.Eq{sortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue},
sq.GtOrEq{keyFieldPrefix + o.KeyFieldName: o.KeyFieldValue}}})
}
}

Expand All @@ -206,12 +229,25 @@ func (o *Options) AddSortingToSelect(sqlBuilder sq.SelectBuilder) sq.SelectBuild
order = "DESC"
}
sqlBuilder = sqlBuilder.
OrderBy(fmt.Sprintf("%v %v", modelNamePrefix+o.SortByFieldName, order)).
OrderBy(fmt.Sprintf("%v %v", modelNamePrefix+o.KeyFieldName, order))
OrderBy(fmt.Sprintf("%v %v", sortByFieldPrefix+o.SortByFieldName, order)).
OrderBy(fmt.Sprintf("%v %v", keyFieldPrefix+o.KeyFieldName, order))

return sqlBuilder
}

// Add a metric as a new field to the select clause by join the passed-in SQL query with run_metrics table.
// With the metric as a field in the select clause enable sorting on this metric afterwards.
func (o *Options) AddSortByRunMetricToSelect(sqlBuilder sq.SelectBuilder) sq.SelectBuilder {
if !o.SortByFieldIsRunMetric {
return sqlBuilder
}
// TODO(jingzhang36): address the case where runs doesn't have the specified metric.
return sq.
Select("selected_runs.*, run_metrics.numbervalue as "+o.SortByFieldName).
FromSelect(sqlBuilder, "selected_runs").
LeftJoin("run_metrics ON selected_runs.uuid=run_metrics.runuuid AND run_metrics.name='" + o.SortByFieldName + "'")
}

// AddFilterToSelect adds WHERE clauses with the filtering criteria in the
// Options o to the supplied SelectBuilder, and returns the new SelectBuilder
// containing these.
Expand Down Expand Up @@ -309,6 +345,8 @@ type Listable interface {
APIToModelFieldMap() map[string]string
// GetModelName returns table name used as sort field prefix.
GetModelName() string
// Find the value of a given field in a listable object.
GetFieldValue(name string) interface{}
}

// NextPageToken returns a string that can be used to fetch the subsequent set
Expand All @@ -326,9 +364,21 @@ func (o *Options) nextPageToken(listable Listable) (*token, error) {
elem := reflect.ValueOf(listable).Elem()
elemName := elem.Type().Name()

sortByField := elem.FieldByName(o.SortByFieldName)
if !sortByField.IsValid() {
return nil, util.NewInvalidInputError("cannot sort by field %q on type %q", o.SortByFieldName, elemName)
var sortByField interface{}
// TODO(jingzhang36): this if-else block can be simplified to one call to
// GetFieldValue after all the models (run, job, experiment, etc.) implement
// GetFieldValue method in listable interface.
if !o.SortByFieldIsRunMetric {
if value := elem.FieldByName(o.SortByFieldName); value.IsValid() {
sortByField = value.Interface()
} else {
return nil, util.NewInvalidInputError("cannot sort by field %q on type %q", o.SortByFieldName, elemName)
}
} else {
sortByField = listable.GetFieldValue(o.SortByFieldName)
if sortByField == nil {
return nil, util.NewInvalidInputError("Unable to find run metric %s", o.SortByFieldName)
}
}

keyField := elem.FieldByName(listable.PrimaryKeyColumnName())
Expand All @@ -337,13 +387,14 @@ func (o *Options) nextPageToken(listable Listable) (*token, error) {
}

return &token{
SortByFieldName: o.SortByFieldName,
SortByFieldValue: sortByField.Interface(),
KeyFieldName: listable.PrimaryKeyColumnName(),
KeyFieldValue: keyField.Interface(),
IsDesc: o.IsDesc,
Filter: o.Filter,
ModelName: o.ModelName,
SortByFieldName: o.SortByFieldName,
SortByFieldValue: sortByField,
SortByFieldIsRunMetric: o.SortByFieldIsRunMetric,
KeyFieldName: listable.PrimaryKeyColumnName(),
KeyFieldValue: keyField.Interface(),
IsDesc: o.IsDesc,
Filter: o.Filter,
ModelName: o.ModelName,
}, nil
}

Expand Down
4 changes: 4 additions & 0 deletions backend/src/apiserver/list/list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ func (f *fakeListable) GetModelName() string {
return ""
}

func (f *fakeListable) GetFieldValue(name string) interface{} {
return nil
}

func TestNextPageToken_ValidTokens(t *testing.T) {
l := &fakeListable{PrimaryKey: "uuid123", FakeName: "Fake", CreatedTimestamp: 1234}

Expand Down
5 changes: 5 additions & 0 deletions backend/src/apiserver/model/experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,8 @@ func (e *Experiment) APIToModelFieldMap() map[string]string {
func (e *Experiment) GetModelName() string {
return "experiments"
}

func (e *Experiment) GetFieldValue(name string) interface{} {
// TODO(jingzhang36): follow the example of GetFieldValue in run.go
return nil
}
5 changes: 5 additions & 0 deletions backend/src/apiserver/model/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,8 @@ func (k *Job) APIToModelFieldMap() map[string]string {
func (j *Job) GetModelName() string {
return "jobs"
}

func (j *Job) GetFieldValue(name string) interface{} {
// TODO(jingzhang36): follow the example of GetFieldValue in run.go
return nil
}
5 changes: 5 additions & 0 deletions backend/src/apiserver/model/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,8 @@ func (p *Pipeline) APIToModelFieldMap() map[string]string {
func (p *Pipeline) GetModelName() string {
return "pipelines"
}

func (p *Pipeline) GetFieldValue(name string) interface{} {
// TODO(jingzhang36): follow the example of GetFieldValue in run.go
return nil
}
5 changes: 5 additions & 0 deletions backend/src/apiserver/model/pipeline_version.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,8 @@ func (p *PipelineVersion) APIToModelFieldMap() map[string]string {
func (p *PipelineVersion) GetModelName() string {
return "pipeline_versions"
}

func (p *PipelineVersion) GetFieldValue(name string) interface{} {
// TODO(jingzhang36): follow the example of GetFieldValue in run.go
return nil
}
29 changes: 29 additions & 0 deletions backend/src/apiserver/model/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,32 @@ func (r *Run) GetModelName() string {
// and thus as prefix in sorting fields.
return ""
}

func (r *Run) GetFieldValue(name string) interface{} {
// "name" could be a field in Run type or a name inside an array typed field
// in Run type
// First, try to find the value if "name" is a field in Run type
switch name {
case "UUID":
return r.UUID
case "DisplayName":
return r.DisplayName
case "CreatedAtInSec":
return r.CreatedAtInSec
case "Description":
return r.Description
case "ScheduledAtInSec":
return r.ScheduledAtInSec
case "StorageState":
return r.StorageState
case "Conditions":
return r.Conditions
}
// Second, try to find the match of "name" inside an array typed field
for _, metric := range r.Metrics {
if metric.Name == name {
return metric.NumberValue
}
}
return nil
}
4 changes: 4 additions & 0 deletions backend/src/apiserver/server/list_request_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,10 @@ func (f *fakeListable) GetModelName() string {
return ""
}

func (f *fakeListable) GetFieldValue(name string) interface{} {
return nil
}

func TestValidatedListOptions_Errors(t *testing.T) {
opts, err := list.NewOptions(&fakeListable{}, 10, "name asc", nil)
if err != nil {
Expand Down
31 changes: 26 additions & 5 deletions backend/src/apiserver/storage/run_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,9 @@ func (s *RunStore) buildSelectRunsQuery(selectCount bool, opts *list.Options,
// If we're not just counting, then also add select columns and perform a left join
// to get resource reference information. Also add pagination.
if !selectCount {
sqlBuilder = opts.AddSortByRunMetricToSelect(sqlBuilder)
sqlBuilder = opts.AddPaginationToSelect(sqlBuilder)
sqlBuilder = s.addMetricsAndResourceReferences(sqlBuilder)
sqlBuilder = s.addMetricsAndResourceReferences(sqlBuilder, opts)
sqlBuilder = opts.AddSortingToSelect(sqlBuilder)
}
sql, args, err := sqlBuilder.ToSql()
Expand All @@ -187,7 +188,7 @@ func (s *RunStore) GetRun(runId string) (*model.RunDetail, error) {
sq.Select(runColumns...).
From("run_details").
Where(sq.Eq{"UUID": runId}).
Limit(1)).
Limit(1), nil).
ToSql()

if err != nil {
Expand All @@ -213,17 +214,37 @@ func (s *RunStore) GetRun(runId string) (*model.RunDetail, error) {
return runs[0], nil
}

func (s *RunStore) addMetricsAndResourceReferences(filteredSelectBuilder sq.SelectBuilder) sq.SelectBuilder {
// Apply func f to every string in a given string slice.
func Map(vs []string, f func(string) string) []string {
vsm := make([]string, len(vs))
for i, v := range vs {
vsm[i] = f(v)
}
return vsm
}

func (s *RunStore) addMetricsAndResourceReferences(filteredSelectBuilder sq.SelectBuilder, opts *list.Options) sq.SelectBuilder {
resourceRefConcatQuery := s.db.Concat([]string{`"["`, s.db.GroupConcat("rr.Payload", ","), `"]"`}, "")
columnsAfterJoiningResourceReferences := append(
Map(runColumns, func(column string) string { return "rd." + column }), // Add prefix "rd." to runColumns
resourceRefConcatQuery+" AS refs")
if opts != nil && opts.SortByFieldIsRunMetric {
columnsAfterJoiningResourceReferences = append(columnsAfterJoiningResourceReferences, "rd."+opts.SortByFieldName)
}
subQ := sq.
Select("rd.*", resourceRefConcatQuery+" AS refs").
Select(columnsAfterJoiningResourceReferences...).
FromSelect(filteredSelectBuilder, "rd").
LeftJoin("resource_references AS rr ON rr.ResourceType='Run' AND rd.UUID=rr.ResourceUUID").
GroupBy("rd.UUID")

// TODO(jingzhang36): address the case where some runs don't have the metric used in order by.
metricConcatQuery := s.db.Concat([]string{`"["`, s.db.GroupConcat("rm.Payload", ","), `"]"`}, "")
columnsAfterJoiningRunMetrics := append(
Map(runColumns, func(column string) string { return "subq." + column }), // Add prefix "subq." to runColumns
"subq.refs",
metricConcatQuery+" AS metrics")
return sq.
Select("subq.*", metricConcatQuery+" AS metrics").
Select(columnsAfterJoiningRunMetrics...).
FromSelect(subQ, "subq").
LeftJoin("run_metrics AS rm ON subq.UUID=rm.RunUUID").
GroupBy("subq.UUID")
Expand Down
Loading

0 comments on commit d4d3616

Please sign in to comment.