Skip to content

Commit

Permalink
Implement filtering in ListPipelines
Browse files Browse the repository at this point in the history
  • Loading branch information
jiezhang committed Feb 11, 2020
1 parent b96ce12 commit 3606718
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 8 deletions.
18 changes: 12 additions & 6 deletions backend/src/apiserver/filter/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,18 +128,24 @@ func New(filterProto *api.Filter) (*Filter, error) {
return f, nil
}

// NewWithKeyMap is like New, but takes an additional map for mapping key names
// NewWithKeyMap is like New, but takes an additional map and model name for mapping key names
// in the protocol buffer to an appropriate name for use when querying the
// model. For example, if the API name of a field is "foo" and the equivalent
// model name is "ModelFoo", then filterProto with predicates against key "foo"
// will be parsed as if the key value was "ModelFoo".
func NewWithKeyMap(filterProto *api.Filter, keyMap map[string]string) (*Filter, error) {
// model. For example, if the API name of a field is "name", the model name is "pipelines", and
// the equivalent column name is "Name", then filterProto with predicates against key "name"
// will be parsed as if the key value was "pipelines.Name".
func NewWithKeyMap(filterProto *api.Filter, keyMap map[string]string, modelName string) (*Filter, error) {
// Fully qualify column name to avoid "ambiguous column name" error.
var modelNamePrefix string
if modelName != "" {
modelNamePrefix = modelName + "."
}

for _, pred := range filterProto.Predicates {
k, ok := keyMap[pred.Key]
if !ok {
return nil, util.NewInvalidInputError("no support for filtering on unrecognized field %q", pred.Key)
}
pred.Key = k
pred.Key = modelNamePrefix + k
}
return New(filterProto)
}
Expand Down
50 changes: 50 additions & 0 deletions backend/src/apiserver/filter/filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
api "github.com/kubeflow/pipelines/backend/api/go_client"
"github.com/kubeflow/pipelines/backend/src/apiserver/model"
)

func TestValidNewFilters(t *testing.T) {
Expand Down Expand Up @@ -87,6 +88,55 @@ func TestValidNewFilters(t *testing.T) {
}
}

func TestValidNewFiltersWithKeyMap(t *testing.T) {
opts := []cmp.Option{
cmp.AllowUnexported(Filter{}),
cmp.FilterPath(func(p cmp.Path) bool {
return p.String() == "filterProto"
}, cmp.Ignore()),
cmpopts.EquateEmpty(),
}

tests := []struct {
protoStr string
want *Filter
}{
{
`predicates { key: "name" op: EQUALS string_value: "pipeline" }`,
&Filter{eq: map[string]interface{}{"pipelines.Name": "pipeline"}},
},
{
`predicates { key: "name" op: NOT_EQUALS string_value: "pipeline" }`,
&Filter{neq: map[string]interface{}{"pipelines.Name": "pipeline"}},
},
{
`predicates {
key: "name" op: IN
string_values { values: 'pipeline_1' values: 'pipeline_2' } }`,
&Filter{in: map[string]interface{}{"pipelines.Name": []string{"pipeline_1", "pipeline_2"}}},
},
{
`predicates {
key: "name" op: IS_SUBSTRING string_value: "pipeline" }`,
&Filter{substring: map[string]interface{}{"pipelines.Name": "pipeline"}},
},
}

for _, test := range tests {
filterProto := &api.Filter{}
if err := proto.UnmarshalText(test.protoStr, filterProto); err != nil {
t.Errorf("Failed to unmarshal Filter text proto\n%q\nError: %v", test.protoStr, err)
continue
}

listable := &model.Pipeline{}
got, err := NewWithKeyMap(filterProto, listable.APIToModelFieldMap(), listable.GetModelName())
if !cmp.Equal(got, test.want, opts...) || err != nil {
t.Errorf("New(%+v) = %+v, %v\nWant %+v, nil", *filterProto, got, err, test.want)
}
}
}

func TestInvalidFilters(t *testing.T) {
tests := []struct {
protoStr string
Expand Down
2 changes: 1 addition & 1 deletion backend/src/apiserver/list/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ func NewOptions(listable Listable, pageSize int, sortBy string, filterProto *api

// Filtering.
if filterProto != nil {
f, err := filter.NewWithKeyMap(filterProto, listable.APIToModelFieldMap())
f, err := filter.NewWithKeyMap(filterProto, listable.APIToModelFieldMap(), listable.GetModelName())
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion backend/src/apiserver/storage/pipeline_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (s *PipelineStore) ListPipelines(opts *list.Options) ([]*model.Pipeline, in
}

buildQuery := func(sqlBuilder sq.SelectBuilder) sq.SelectBuilder {
return sqlBuilder.
return opts.AddFilterToSelect(sqlBuilder).
From("pipelines").
LeftJoin("pipeline_versions ON pipelines.DefaultVersionId = pipeline_versions.UUID").
Where(sq.Eq{"pipelines.Status": model.PipelineReady})
Expand Down
47 changes: 47 additions & 0 deletions backend/src/apiserver/storage/pipeline_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package storage
import (
"testing"

api "github.com/kubeflow/pipelines/backend/api/go_client"
"github.com/kubeflow/pipelines/backend/src/apiserver/list"
"github.com/kubeflow/pipelines/backend/src/apiserver/model"
"github.com/kubeflow/pipelines/backend/src/common/util"
Expand Down Expand Up @@ -102,6 +103,52 @@ func TestListPipelines_FilterOutNotReady(t *testing.T) {
assert.Equal(t, pipelinesExpected, pipelines)
}

func TestListPipelines_WithFilter(t *testing.T) {
db := NewFakeDbOrFatal()
defer db.Close()
pipelineStore := NewPipelineStore(db, util.NewFakeTimeForEpoch(), util.NewFakeUUIDGeneratorOrFatal(fakeUUID, nil))
pipelineStore.CreatePipeline(createPipeline("pipeline_foo"))
pipelineStore.uuid = util.NewFakeUUIDGeneratorOrFatal(fakeUUIDTwo, nil)
pipelineStore.CreatePipeline(createPipeline("pipeline_bar"))
pipelineStore.uuid = util.NewFakeUUIDGeneratorOrFatal(fakeUUIDThree, nil)

expectedPipeline1 := &model.Pipeline{
UUID: fakeUUID,
CreatedAtInSec: 1,
Name: "pipeline_foo",
Parameters: `[{"Name": "param1"}]`,
Status: model.PipelineReady,
DefaultVersionId: fakeUUID,
DefaultVersion: &model.PipelineVersion{
UUID: fakeUUID,
CreatedAtInSec: 1,
Name: "pipeline_foo",
Parameters: `[{"Name": "param1"}]`,
PipelineId: fakeUUID,
Status: model.PipelineVersionReady,
}}
pipelinesExpected := []*model.Pipeline{expectedPipeline1}

filterProto := &api.Filter{
Predicates: []*api.Predicate{
&api.Predicate{
Key: "name",
Op: api.Predicate_IS_SUBSTRING,
Value: &api.Predicate_StringValue{StringValue: "pipeline_f"},
},
},
}
opts, err := list.NewOptions(&model.Pipeline{}, 10, "id", filterProto)
assert.Nil(t, err)

pipelines, totalSize, nextPageToken, err := pipelineStore.ListPipelines(opts)

assert.Nil(t, err)
assert.Equal(t, "", nextPageToken)
assert.Equal(t, 1, totalSize)
assert.Equal(t, pipelinesExpected, pipelines)
}

func TestListPipelines_Pagination(t *testing.T) {
db := NewFakeDbOrFatal()
defer db.Close()
Expand Down

0 comments on commit 3606718

Please sign in to comment.