diff --git a/backend/src/apiserver/list/BUILD.bazel b/backend/src/apiserver/list/BUILD.bazel index 3a26548af88..b46aee86303 100644 --- a/backend/src/apiserver/list/BUILD.bazel +++ b/backend/src/apiserver/list/BUILD.bazel @@ -9,7 +9,9 @@ go_library( "//backend/api:go_default_library", "//backend/src/apiserver/common:go_default_library", "//backend/src/apiserver/filter:go_default_library", + "//backend/src/apiserver/model:go_default_library", "//backend/src/common/util:go_default_library", + "@com_github_golang_glog//:go_default_library", "@com_github_masterminds_squirrel//:go_default_library", ], ) diff --git a/backend/src/apiserver/list/list.go b/backend/src/apiserver/list/list.go index 2a44b39b4ec..54698e48142 100644 --- a/backend/src/apiserver/list/list.go +++ b/backend/src/apiserver/list/list.go @@ -44,15 +44,22 @@ type token struct { // SortByFieldValue is the value of the sorted field of the next row to be // returned. SortByFieldValue interface{} + // SortByFieldIsRunMetric indicates whether the SortByFieldName field is + // a run metric field or not. + SortByFieldIsRunMetric bool + // KeyFieldName is the name of the primary key for the model being queried. KeyFieldName string // KeyFieldValue is the value of the sorted field of the next row to be // returned. KeyFieldValue interface{} + // IsDesc is true if the sorting order should be descending. IsDesc bool + // ModelName is the table where ***FieldName belongs to. ModelName string + // Filter represents the filtering that should be applied in the query. Filter *filter.Filter } @@ -94,7 +101,7 @@ type Options struct { // Matches returns trues if the sorting and filtering criteria in o matches that // of the one supplied in opts. func (o *Options) Matches(opts *Options) bool { - return o.SortByFieldName == opts.SortByFieldName && + return o.SortByFieldName == opts.SortByFieldName && o.SortByFieldIsRunMetric == opts.SortByFieldIsRunMetric && o.IsDesc == opts.IsDesc && reflect.DeepEqual(o.Filter, opts.Filter) } @@ -140,13 +147,23 @@ func NewOptions(listable Listable, pageSize int, sortBy string, filterProto *api } token.SortByFieldName = listable.DefaultSortField() + token.SortByFieldIsRunMetric = false if len(queryList) > 0 { var err error n, ok := listable.APIToModelFieldMap()[queryList[0]] - if !ok { + if ok { + token.SortByFieldName = n + } else if strings.HasPrefix(queryList[0], "metric:") { + // Sorting on metrics is only available on certain runs. + model := reflect.ValueOf(listable).Elem().Type().Name() + if model != "Run" { + return nil, util.NewInvalidInputError("Invalid sorting field: %q on %q : %s", queryList[0], model, err) + } + token.SortByFieldName = queryList[0][7:] + token.SortByFieldIsRunMetric = true + } else { return nil, util.NewInvalidInputError("Invalid sorting field: %q: %s", queryList[0], err) } - token.SortByFieldName = n } if len(queryList) == 2 { @@ -176,28 +193,34 @@ func (o *Options) AddPaginationToSelect(sqlBuilder sq.SelectBuilder) sq.SelectBu return sqlBuilder } -// AddPaginationToSelect adds WHERE clauses with the sorting and pagination criteria in the -// Options o to the supplied SelectBuilder, and returns the new SelectBuilder -// containing these. +// AddSortingToSelect adds Order By clause. func (o *Options) AddSortingToSelect(sqlBuilder sq.SelectBuilder) sq.SelectBuilder { - // If next row's value is specified, set those values in the clause. - var modelNamePrefix string + // When sorting by a direct field in the listable model (i.e., name in Run or uuid in Pipeline), a sortByFieldPrefix can be specified; when sorting by a field in an array-typed dictionary (i.e., a run metric inside the metrics in Run), a sortByFieldPrefix is not needed. + var keyFieldPrefix string + var sortByFieldPrefix string if len(o.ModelName) == 0 { - modelNamePrefix = "" + keyFieldPrefix = "" + sortByFieldPrefix = "" + } else if o.SortByFieldIsRunMetric { + keyFieldPrefix = o.ModelName + "." + sortByFieldPrefix = "" } else { - modelNamePrefix = o.ModelName + "." + keyFieldPrefix = o.ModelName + "." + sortByFieldPrefix = o.ModelName + "." } + + // If next row's value is specified, set those values in the clause. if o.SortByFieldValue != nil && o.KeyFieldValue != nil { if o.IsDesc { sqlBuilder = sqlBuilder. - Where(sq.Or{sq.Lt{modelNamePrefix + o.SortByFieldName: o.SortByFieldValue}, - sq.And{sq.Eq{modelNamePrefix + o.SortByFieldName: o.SortByFieldValue}, - sq.LtOrEq{modelNamePrefix + o.KeyFieldName: o.KeyFieldValue}}}) + Where(sq.Or{sq.Lt{sortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, + sq.And{sq.Eq{sortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, + sq.LtOrEq{keyFieldPrefix + o.KeyFieldName: o.KeyFieldValue}}}) } else { sqlBuilder = sqlBuilder. - Where(sq.Or{sq.Gt{modelNamePrefix + o.SortByFieldName: o.SortByFieldValue}, - sq.And{sq.Eq{modelNamePrefix + o.SortByFieldName: o.SortByFieldValue}, - sq.GtOrEq{modelNamePrefix + o.KeyFieldName: o.KeyFieldValue}}}) + Where(sq.Or{sq.Gt{sortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, + sq.And{sq.Eq{sortByFieldPrefix + o.SortByFieldName: o.SortByFieldValue}, + sq.GtOrEq{keyFieldPrefix + o.KeyFieldName: o.KeyFieldValue}}}) } } @@ -206,12 +229,25 @@ func (o *Options) AddSortingToSelect(sqlBuilder sq.SelectBuilder) sq.SelectBuild order = "DESC" } sqlBuilder = sqlBuilder. - OrderBy(fmt.Sprintf("%v %v", modelNamePrefix+o.SortByFieldName, order)). - OrderBy(fmt.Sprintf("%v %v", modelNamePrefix+o.KeyFieldName, order)) + OrderBy(fmt.Sprintf("%v %v", sortByFieldPrefix+o.SortByFieldName, order)). + OrderBy(fmt.Sprintf("%v %v", keyFieldPrefix+o.KeyFieldName, order)) return sqlBuilder } +// Add a metric as a new field to the select clause by join the passed-in SQL query with run_metrics table. +// With the metric as a field in the select clause enable sorting on this metric afterwards. +func (o *Options) AddSortByRunMetricToSelect(sqlBuilder sq.SelectBuilder) sq.SelectBuilder { + if !o.SortByFieldIsRunMetric { + return sqlBuilder + } + // TODO(jingzhang36): address the case where runs doesn't have the specified metric. + return sq. + Select("selected_runs.*, run_metrics.numbervalue as "+o.SortByFieldName). + FromSelect(sqlBuilder, "selected_runs"). + LeftJoin("run_metrics ON selected_runs.uuid=run_metrics.runuuid AND run_metrics.name='" + o.SortByFieldName + "'") +} + // AddFilterToSelect adds WHERE clauses with the filtering criteria in the // Options o to the supplied SelectBuilder, and returns the new SelectBuilder // containing these. @@ -309,6 +345,8 @@ type Listable interface { APIToModelFieldMap() map[string]string // GetModelName returns table name used as sort field prefix. GetModelName() string + // Find the value of a given field in a listable object. + GetFieldValue(name string) interface{} } // NextPageToken returns a string that can be used to fetch the subsequent set @@ -326,9 +364,21 @@ func (o *Options) nextPageToken(listable Listable) (*token, error) { elem := reflect.ValueOf(listable).Elem() elemName := elem.Type().Name() - sortByField := elem.FieldByName(o.SortByFieldName) - if !sortByField.IsValid() { - return nil, util.NewInvalidInputError("cannot sort by field %q on type %q", o.SortByFieldName, elemName) + var sortByField interface{} + // TODO(jingzhang36): this if-else block can be simplified to one call to + // GetFieldValue after all the models (run, job, experiment, etc.) implement + // GetFieldValue method in listable interface. + if !o.SortByFieldIsRunMetric { + if value := elem.FieldByName(o.SortByFieldName); value.IsValid() { + sortByField = value.Interface() + } else { + return nil, util.NewInvalidInputError("cannot sort by field %q on type %q", o.SortByFieldName, elemName) + } + } else { + sortByField = listable.GetFieldValue(o.SortByFieldName) + if sortByField == nil { + return nil, util.NewInvalidInputError("Unable to find run metric %s", o.SortByFieldName) + } } keyField := elem.FieldByName(listable.PrimaryKeyColumnName()) @@ -337,13 +387,14 @@ func (o *Options) nextPageToken(listable Listable) (*token, error) { } return &token{ - SortByFieldName: o.SortByFieldName, - SortByFieldValue: sortByField.Interface(), - KeyFieldName: listable.PrimaryKeyColumnName(), - KeyFieldValue: keyField.Interface(), - IsDesc: o.IsDesc, - Filter: o.Filter, - ModelName: o.ModelName, + SortByFieldName: o.SortByFieldName, + SortByFieldValue: sortByField, + SortByFieldIsRunMetric: o.SortByFieldIsRunMetric, + KeyFieldName: listable.PrimaryKeyColumnName(), + KeyFieldValue: keyField.Interface(), + IsDesc: o.IsDesc, + Filter: o.Filter, + ModelName: o.ModelName, }, nil } diff --git a/backend/src/apiserver/list/list_test.go b/backend/src/apiserver/list/list_test.go index 332acf5e37e..215961f3c8c 100644 --- a/backend/src/apiserver/list/list_test.go +++ b/backend/src/apiserver/list/list_test.go @@ -43,6 +43,10 @@ func (f *fakeListable) GetModelName() string { return "" } +func (f *fakeListable) GetFieldValue(name string) interface{} { + return nil +} + func TestNextPageToken_ValidTokens(t *testing.T) { l := &fakeListable{PrimaryKey: "uuid123", FakeName: "Fake", CreatedTimestamp: 1234} diff --git a/backend/src/apiserver/model/experiment.go b/backend/src/apiserver/model/experiment.go index 40bdc27c6b8..255ceca3aa6 100644 --- a/backend/src/apiserver/model/experiment.go +++ b/backend/src/apiserver/model/experiment.go @@ -46,3 +46,8 @@ func (e *Experiment) APIToModelFieldMap() map[string]string { func (e *Experiment) GetModelName() string { return "experiments" } + +func (e *Experiment) GetFieldValue(name string) interface{} { + // TODO(jingzhang36): follow the example of GetFieldValue in run.go + return nil +} diff --git a/backend/src/apiserver/model/job.go b/backend/src/apiserver/model/job.go index 2ea467e83ef..b9ecf9a1433 100644 --- a/backend/src/apiserver/model/job.go +++ b/backend/src/apiserver/model/job.go @@ -104,3 +104,8 @@ func (k *Job) APIToModelFieldMap() map[string]string { func (j *Job) GetModelName() string { return "jobs" } + +func (j *Job) GetFieldValue(name string) interface{} { + // TODO(jingzhang36): follow the example of GetFieldValue in run.go + return nil +} diff --git a/backend/src/apiserver/model/pipeline.go b/backend/src/apiserver/model/pipeline.go index 08100e49ce1..5cd120df33c 100644 --- a/backend/src/apiserver/model/pipeline.go +++ b/backend/src/apiserver/model/pipeline.go @@ -78,3 +78,8 @@ func (p *Pipeline) APIToModelFieldMap() map[string]string { func (p *Pipeline) GetModelName() string { return "pipelines" } + +func (p *Pipeline) GetFieldValue(name string) interface{} { + // TODO(jingzhang36): follow the example of GetFieldValue in run.go + return nil +} diff --git a/backend/src/apiserver/model/pipeline_version.go b/backend/src/apiserver/model/pipeline_version.go index 1cdb55a9197..7ba2983844a 100644 --- a/backend/src/apiserver/model/pipeline_version.go +++ b/backend/src/apiserver/model/pipeline_version.go @@ -73,3 +73,8 @@ func (p *PipelineVersion) APIToModelFieldMap() map[string]string { func (p *PipelineVersion) GetModelName() string { return "pipeline_versions" } + +func (p *PipelineVersion) GetFieldValue(name string) interface{} { + // TODO(jingzhang36): follow the example of GetFieldValue in run.go + return nil +} diff --git a/backend/src/apiserver/model/run.go b/backend/src/apiserver/model/run.go index 079e69602ce..288085f5e16 100644 --- a/backend/src/apiserver/model/run.go +++ b/backend/src/apiserver/model/run.go @@ -91,3 +91,32 @@ func (r *Run) GetModelName() string { // and thus as prefix in sorting fields. return "" } + +func (r *Run) GetFieldValue(name string) interface{} { + // "name" could be a field in Run type or a name inside an array typed field + // in Run type + // First, try to find the value if "name" is a field in Run type + switch name { + case "UUID": + return r.UUID + case "DisplayName": + return r.DisplayName + case "CreatedAtInSec": + return r.CreatedAtInSec + case "Description": + return r.Description + case "ScheduledAtInSec": + return r.ScheduledAtInSec + case "StorageState": + return r.StorageState + case "Conditions": + return r.Conditions + } + // Second, try to find the match of "name" inside an array typed field + for _, metric := range r.Metrics { + if metric.Name == name { + return metric.NumberValue + } + } + return nil +} diff --git a/backend/src/apiserver/server/list_request_util_test.go b/backend/src/apiserver/server/list_request_util_test.go index 76607317cf7..17991b039f3 100644 --- a/backend/src/apiserver/server/list_request_util_test.go +++ b/backend/src/apiserver/server/list_request_util_test.go @@ -245,6 +245,10 @@ func (f *fakeListable) GetModelName() string { return "" } +func (f *fakeListable) GetFieldValue(name string) interface{} { + return nil +} + func TestValidatedListOptions_Errors(t *testing.T) { opts, err := list.NewOptions(&fakeListable{}, 10, "name asc", nil) if err != nil { diff --git a/backend/src/apiserver/storage/run_store.go b/backend/src/apiserver/storage/run_store.go index 3cadb51f7ec..f283390fd7b 100644 --- a/backend/src/apiserver/storage/run_store.go +++ b/backend/src/apiserver/storage/run_store.go @@ -170,8 +170,9 @@ func (s *RunStore) buildSelectRunsQuery(selectCount bool, opts *list.Options, // If we're not just counting, then also add select columns and perform a left join // to get resource reference information. Also add pagination. if !selectCount { + sqlBuilder = opts.AddSortByRunMetricToSelect(sqlBuilder) sqlBuilder = opts.AddPaginationToSelect(sqlBuilder) - sqlBuilder = s.addMetricsAndResourceReferences(sqlBuilder) + sqlBuilder = s.addMetricsAndResourceReferences(sqlBuilder, opts) sqlBuilder = opts.AddSortingToSelect(sqlBuilder) } sql, args, err := sqlBuilder.ToSql() @@ -187,7 +188,7 @@ func (s *RunStore) GetRun(runId string) (*model.RunDetail, error) { sq.Select(runColumns...). From("run_details"). Where(sq.Eq{"UUID": runId}). - Limit(1)). + Limit(1), nil). ToSql() if err != nil { @@ -213,17 +214,37 @@ func (s *RunStore) GetRun(runId string) (*model.RunDetail, error) { return runs[0], nil } -func (s *RunStore) addMetricsAndResourceReferences(filteredSelectBuilder sq.SelectBuilder) sq.SelectBuilder { +// Apply func f to every string in a given string slice. +func Map(vs []string, f func(string) string) []string { + vsm := make([]string, len(vs)) + for i, v := range vs { + vsm[i] = f(v) + } + return vsm +} + +func (s *RunStore) addMetricsAndResourceReferences(filteredSelectBuilder sq.SelectBuilder, opts *list.Options) sq.SelectBuilder { resourceRefConcatQuery := s.db.Concat([]string{`"["`, s.db.GroupConcat("rr.Payload", ","), `"]"`}, "") + columnsAfterJoiningResourceReferences := append( + Map(runColumns, func(column string) string { return "rd." + column }), // Add prefix "rd." to runColumns + resourceRefConcatQuery+" AS refs") + if opts != nil && opts.SortByFieldIsRunMetric { + columnsAfterJoiningResourceReferences = append(columnsAfterJoiningResourceReferences, "rd."+opts.SortByFieldName) + } subQ := sq. - Select("rd.*", resourceRefConcatQuery+" AS refs"). + Select(columnsAfterJoiningResourceReferences...). FromSelect(filteredSelectBuilder, "rd"). LeftJoin("resource_references AS rr ON rr.ResourceType='Run' AND rd.UUID=rr.ResourceUUID"). GroupBy("rd.UUID") + // TODO(jingzhang36): address the case where some runs don't have the metric used in order by. metricConcatQuery := s.db.Concat([]string{`"["`, s.db.GroupConcat("rm.Payload", ","), `"]"`}, "") + columnsAfterJoiningRunMetrics := append( + Map(runColumns, func(column string) string { return "subq." + column }), // Add prefix "subq." to runColumns + "subq.refs", + metricConcatQuery+" AS metrics") return sq. - Select("subq.*", metricConcatQuery+" AS metrics"). + Select(columnsAfterJoiningRunMetrics...). FromSelect(subQ, "subq"). LeftJoin("run_metrics AS rm ON subq.UUID=rm.RunUUID"). GroupBy("subq.UUID") diff --git a/backend/src/apiserver/storage/run_store_test.go b/backend/src/apiserver/storage/run_store_test.go index e9e49b578ea..d7375713cae 100644 --- a/backend/src/apiserver/storage/run_store_test.go +++ b/backend/src/apiserver/storage/run_store_test.go @@ -15,6 +15,8 @@ package storage import ( + "fmt" + "sort" "testing" sq "github.com/Masterminds/squirrel" @@ -27,6 +29,12 @@ import ( "google.golang.org/grpc/codes" ) +type RunMetricSorter []*model.RunMetric + +func (r RunMetricSorter) Len() int { return len(r) } +func (r RunMetricSorter) Less(i, j int) bool { return r[i].Name < r[j].Name } +func (r RunMetricSorter) Swap(i, j int) { r[i], r[j] = r[j], r[i] } + func initializeRunStore() (*DB, *RunStore) { db := NewFakeDbOrFatal() expStore := NewExperimentStore(db, util.NewFakeTimeForEpoch(), util.NewFakeUUIDGeneratorOrFatal(defaultFakeExpId, nil)) @@ -104,6 +112,24 @@ func initializeRunStore() (*DB, *RunStore) { runStore.CreateRun(run1) runStore.CreateRun(run2) runStore.CreateRun(run3) + + metric1 := &model.RunMetric{ + RunUUID: "1", + NodeID: "node1", + Name: "dummymetric", + NumberValue: 1.0, + Format: "PERCENTAGE", + } + metric2 := &model.RunMetric{ + RunUUID: "2", + NodeID: "node2", + Name: "dummymetric", + NumberValue: 2.0, + Format: "PERCENTAGE", + } + runStore.ReportMetric(metric1) + runStore.ReportMetric(metric2) + return db, runStore } @@ -121,6 +147,15 @@ func TestListRuns_Pagination(t *testing.T) { ScheduledAtInSec: 1, StorageState: api.Run_STORAGESTATE_AVAILABLE.String(), Conditions: "Running", + Metrics: []*model.RunMetric{ + { + RunUUID: "1", + NodeID: "node1", + Name: "dummymetric", + NumberValue: 1.0, + Format: "PERCENTAGE", + }, + }, ResourceReferences: []*model.ResourceReference{ { ResourceUUID: "1", ResourceType: common.Run, @@ -139,6 +174,15 @@ func TestListRuns_Pagination(t *testing.T) { ScheduledAtInSec: 2, StorageState: api.Run_STORAGESTATE_AVAILABLE.String(), Conditions: "done", + Metrics: []*model.RunMetric{ + { + RunUUID: "2", + NodeID: "node2", + Name: "dummymetric", + NumberValue: 2.0, + Format: "PERCENTAGE", + }, + }, ResourceReferences: []*model.ResourceReference{ { ResourceUUID: "2", ResourceType: common.Run, @@ -168,6 +212,106 @@ func TestListRuns_Pagination(t *testing.T) { assert.Empty(t, nextPageToken) } +func TestListRuns_Pagination_WithSortingOnMetrics(t *testing.T) { + db, runStore := initializeRunStore() + defer db.Close() + + expectedFirstPageRuns := []*model.Run{ + { + UUID: "1", + Name: "run1", + DisplayName: "run1", + Namespace: "n1", + CreatedAtInSec: 1, + ScheduledAtInSec: 1, + StorageState: api.Run_STORAGESTATE_AVAILABLE.String(), + Conditions: "Running", + Metrics: []*model.RunMetric{ + { + RunUUID: "1", + NodeID: "node1", + Name: "dummymetric", + NumberValue: 1.0, + Format: "PERCENTAGE", + }, + }, + ResourceReferences: []*model.ResourceReference{ + { + ResourceUUID: "1", ResourceType: common.Run, + ReferenceUUID: defaultFakeExpId, ReferenceName: "e1", + ReferenceType: common.Experiment, Relationship: common.Creator, + }, + }, + }} + expectedSecondPageRuns := []*model.Run{ + { + UUID: "2", + Name: "run2", + DisplayName: "run2", + Namespace: "n2", + CreatedAtInSec: 2, + ScheduledAtInSec: 2, + StorageState: api.Run_STORAGESTATE_AVAILABLE.String(), + Conditions: "done", + Metrics: []*model.RunMetric{ + { + RunUUID: "2", + NodeID: "node2", + Name: "dummymetric", + NumberValue: 2.0, + Format: "PERCENTAGE", + }, + }, + ResourceReferences: []*model.ResourceReference{ + { + ResourceUUID: "2", ResourceType: common.Run, + ReferenceUUID: defaultFakeExpId, ReferenceName: "e1", + ReferenceType: common.Experiment, Relationship: common.Creator, + }, + }, + }} + + // Sort in asc order + opts, err := list.NewOptions(&model.Run{}, 1, "metric:dummymetric", nil) + assert.Nil(t, err) + + runs, total_size, nextPageToken, err := runStore.ListRuns( + &common.FilterContext{ReferenceKey: &common.ReferenceKey{Type: common.Experiment, ID: defaultFakeExpId}}, opts) + assert.Nil(t, err) + assert.Equal(t, 2, total_size) + assert.Equal(t, expectedFirstPageRuns, runs, "Unexpected Run listed.") + assert.NotEmpty(t, nextPageToken) + + opts, err = list.NewOptionsFromToken(nextPageToken, 1) + assert.Nil(t, err) + runs, total_size, nextPageToken, err = runStore.ListRuns( + &common.FilterContext{ReferenceKey: &common.ReferenceKey{Type: common.Experiment, ID: defaultFakeExpId}}, opts) + assert.Nil(t, err) + assert.Equal(t, 2, total_size) + assert.Equal(t, expectedSecondPageRuns, runs, "Unexpected Run listed.") + assert.Empty(t, nextPageToken) + + // Sort in desc order + opts, err = list.NewOptions(&model.Run{}, 1, "metric:dummymetric desc", nil) + assert.Nil(t, err) + + runs, total_size, nextPageToken, err = runStore.ListRuns( + &common.FilterContext{ReferenceKey: &common.ReferenceKey{Type: common.Experiment, ID: defaultFakeExpId}}, opts) + assert.Nil(t, err) + assert.Equal(t, 2, total_size) + assert.Equal(t, expectedSecondPageRuns, runs, "Unexpected Run listed.") + assert.NotEmpty(t, nextPageToken) + + opts, err = list.NewOptionsFromToken(nextPageToken, 1) + assert.Nil(t, err) + runs, total_size, nextPageToken, err = runStore.ListRuns( + &common.FilterContext{ReferenceKey: &common.ReferenceKey{Type: common.Experiment, ID: defaultFakeExpId}}, opts) + assert.Nil(t, err) + assert.Equal(t, 2, total_size) + assert.Equal(t, expectedFirstPageRuns, runs, "Unexpected Run listed.") + assert.Empty(t, nextPageToken) +} + func TestListRuns_TotalSizeWithNoFilter(t *testing.T) { db, runStore := initializeRunStore() defer db.Close() @@ -219,6 +363,15 @@ func TestListRuns_Pagination_Descend(t *testing.T) { ScheduledAtInSec: 2, StorageState: api.Run_STORAGESTATE_AVAILABLE.String(), Conditions: "done", + Metrics: []*model.RunMetric{ + { + RunUUID: "2", + NodeID: "node2", + Name: "dummymetric", + NumberValue: 2.0, + Format: "PERCENTAGE", + }, + }, ResourceReferences: []*model.ResourceReference{ { ResourceUUID: "2", ResourceType: common.Run, @@ -237,6 +390,15 @@ func TestListRuns_Pagination_Descend(t *testing.T) { ScheduledAtInSec: 1, StorageState: api.Run_STORAGESTATE_AVAILABLE.String(), Conditions: "Running", + Metrics: []*model.RunMetric{ + { + RunUUID: "1", + NodeID: "node1", + Name: "dummymetric", + NumberValue: 1.0, + Format: "PERCENTAGE", + }, + }, ResourceReferences: []*model.ResourceReference{ { ResourceUUID: "1", ResourceType: common.Run, @@ -251,6 +413,10 @@ func TestListRuns_Pagination_Descend(t *testing.T) { runs, total_size, nextPageToken, err := runStore.ListRuns( &common.FilterContext{ReferenceKey: &common.ReferenceKey{Type: common.Experiment, ID: defaultFakeExpId}}, opts) + for _, run := range runs { + fmt.Printf("%+v\n", run) + } + assert.Nil(t, err) assert.Equal(t, 2, total_size) assert.Equal(t, expectedFirstPageRuns, runs, "Unexpected Run listed.") @@ -280,6 +446,15 @@ func TestListRuns_Pagination_LessThanPageSize(t *testing.T) { ScheduledAtInSec: 1, StorageState: api.Run_STORAGESTATE_AVAILABLE.String(), Conditions: "Running", + Metrics: []*model.RunMetric{ + { + RunUUID: "1", + NodeID: "node1", + Name: "dummymetric", + NumberValue: 1.0, + Format: "PERCENTAGE", + }, + }, ResourceReferences: []*model.ResourceReference{ { ResourceUUID: "1", ResourceType: common.Run, @@ -297,6 +472,15 @@ func TestListRuns_Pagination_LessThanPageSize(t *testing.T) { ScheduledAtInSec: 2, StorageState: api.Run_STORAGESTATE_AVAILABLE.String(), Conditions: "done", + Metrics: []*model.RunMetric{ + { + RunUUID: "2", + NodeID: "node2", + Name: "dummymetric", + NumberValue: 2.0, + Format: "PERCENTAGE", + }, + }, ResourceReferences: []*model.ResourceReference{ { ResourceUUID: "2", ResourceType: common.Run, @@ -341,6 +525,15 @@ func TestGetRun(t *testing.T) { ScheduledAtInSec: 1, StorageState: api.Run_STORAGESTATE_AVAILABLE.String(), Conditions: "Running", + Metrics: []*model.RunMetric{ + { + RunUUID: "1", + NodeID: "node1", + Name: "dummymetric", + NumberValue: 1.0, + Format: "PERCENTAGE", + }, + }, ResourceReferences: []*model.ResourceReference{ { ResourceUUID: "1", ResourceType: common.Run, @@ -389,6 +582,15 @@ func TestCreateOrUpdateRun_UpdateSuccess(t *testing.T) { ScheduledAtInSec: 1, StorageState: api.Run_STORAGESTATE_AVAILABLE.String(), Conditions: "Running", + Metrics: []*model.RunMetric{ + { + RunUUID: "1", + NodeID: "node1", + Name: "dummymetric", + NumberValue: 1.0, + Format: "PERCENTAGE", + }, + }, ResourceReferences: []*model.ResourceReference{ { ResourceUUID: "1", ResourceType: common.Run, @@ -426,6 +628,15 @@ func TestCreateOrUpdateRun_UpdateSuccess(t *testing.T) { ScheduledAtInSec: 1, StorageState: api.Run_STORAGESTATE_AVAILABLE.String(), Conditions: "done", + Metrics: []*model.RunMetric{ + { + RunUUID: "1", + NodeID: "node1", + Name: "dummymetric", + NumberValue: 1.0, + Format: "PERCENTAGE", + }, + }, ResourceReferences: []*model.ResourceReference{ { ResourceUUID: "1", ResourceType: common.Run, @@ -559,6 +770,15 @@ func TestCreateOrUpdateRun_BadStorageStateValue(t *testing.T) { CreatedAtInSec: 1, ScheduledAtInSec: 1, Conditions: "Running", + Metrics: []*model.RunMetric{ + { + RunUUID: "1", + NodeID: "node1", + Name: "dummymetric", + NumberValue: 1.0, + Format: "PERCENTAGE", + }, + }, ResourceReferences: []*model.ResourceReference{ { ResourceUUID: "1", ResourceType: common.Run, @@ -603,6 +823,15 @@ func TestTerminateRun(t *testing.T) { ScheduledAtInSec: 1, StorageState: api.Run_STORAGESTATE_AVAILABLE.String(), Conditions: "Terminating", + Metrics: []*model.RunMetric{ + { + RunUUID: "1", + NodeID: "node1", + Name: "dummymetric", + NumberValue: 1.0, + Format: "PERCENTAGE", + }, + }, ResourceReferences: []*model.ResourceReference{ { ResourceUUID: "1", ResourceType: common.Run, @@ -652,7 +881,16 @@ func TestReportMetric_Success(t *testing.T) { runDetail, err := runStore.GetRun("1") assert.Nil(t, err, "Got error: %+v", err) - assert.Equal(t, []*model.RunMetric{metric}, runDetail.Run.Metrics) + sort.Sort(RunMetricSorter(runDetail.Run.Metrics)) + assert.Equal(t, []*model.RunMetric{ + metric, + { + RunUUID: "1", + NodeID: "node1", + Name: "dummymetric", + NumberValue: 1.0, + Format: "PERCENTAGE", + }}, runDetail.Run.Metrics) } func TestReportMetric_DupReports_Fail(t *testing.T) { @@ -744,7 +982,16 @@ func TestListRuns_WithMetrics(t *testing.T) { ReferenceType: common.Experiment, Relationship: common.Creator, }, }, - Metrics: []*model.RunMetric{metric1, metric2}, + Metrics: []*model.RunMetric{ + { + RunUUID: "1", + NodeID: "node1", + Name: "dummymetric", + NumberValue: 1.0, + Format: "PERCENTAGE", + }, + metric1, + metric2}, }, { UUID: "2", @@ -762,15 +1009,29 @@ func TestListRuns_WithMetrics(t *testing.T) { ReferenceType: common.Experiment, Relationship: common.Creator, }, }, - Metrics: []*model.RunMetric{metric3}, + Metrics: []*model.RunMetric{ + { + RunUUID: "2", + NodeID: "node2", + Name: "dummymetric", + NumberValue: 2.0, + Format: "PERCENTAGE", + }, + metric3}, }, } - opts, err := list.NewOptions(&model.Run{}, 2, "", nil) + opts, err := list.NewOptions(&model.Run{}, 2, "id", nil) assert.Nil(t, err) runs, total_size, _, err := runStore.ListRuns(&common.FilterContext{}, opts) assert.Equal(t, 3, total_size) assert.Nil(t, err) + for _, run := range expectedRuns { + sort.Sort(RunMetricSorter(run.Metrics)) + } + for _, run := range runs { + sort.Sort(RunMetricSorter(run.Metrics)) + } assert.Equal(t, expectedRuns, runs, "Unexpected Run listed.") } @@ -866,6 +1127,15 @@ func TestArchiveRun_IncludedInRunList(t *testing.T) { ScheduledAtInSec: 1, StorageState: api.Run_STORAGESTATE_ARCHIVED.String(), Conditions: "Running", + Metrics: []*model.RunMetric{ + { + RunUUID: "1", + NodeID: "node1", + Name: "dummymetric", + NumberValue: 1.0, + Format: "PERCENTAGE", + }, + }, ResourceReferences: []*model.ResourceReference{ { ResourceUUID: "1", ResourceType: common.Run,