Skip to content

Commit

Permalink
Implement json visitor for plan node (milvus-io#22897)
Browse files Browse the repository at this point in the history
Signed-off-by: longjiquan <jiquan.long@zilliz.com>
  • Loading branch information
longjiquan authored Mar 21, 2023
1 parent 348fba4 commit ce2c5d1
Show file tree
Hide file tree
Showing 15 changed files with 800 additions and 22 deletions.
3 changes: 2 additions & 1 deletion .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ run:
skip-files:
# The worst lint rule of golang.
- internal/mysqld/parser/antlrparser/ast_builder.go
- internal/mysqld/parser/antlrparser/node_ret.go
- internal/mysqld/parser/antlrparser/node_ret_wrong_var_naming.go
- internal/mysqld/planner/sql_statement.go
- internal/mysqld/planner/sql_statements.go
- internal/mysqld/planner/visitor_json_wrong_var_naming.go

linters:
disable-all: true
Expand Down
222 changes: 220 additions & 2 deletions internal/mysqld/executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func Test_defaultExecutor_Run(t *testing.T) {
planner.NewNodeSqlStatement("sql2"),
}
plan := &planner.PhysicalPlan{
Node: planner.NewNodeSqlStatements(stmts, "sql1; sql2"),
Node: planner.NewNodeSqlStatements("sql1; sql2", stmts),
}
_, err := e.Run(context.TODO(), plan)
assert.Error(t, err)
Expand Down Expand Up @@ -108,7 +108,7 @@ func Test_defaultExecutor_Run(t *testing.T) {
)),
}
plan := &planner.PhysicalPlan{
Node: planner.NewNodeSqlStatements(stmts, ""),
Node: planner.NewNodeSqlStatements("", stmts),
}
sqlRes, err := e.Run(context.TODO(), plan)
assert.NoError(t, err)
Expand Down Expand Up @@ -501,3 +501,221 @@ func Test_getOutputFieldsOrMatchCountRule(t *testing.T) {
assert.ElementsMatch(t, []string{"field1", "field2"}, outputFields)
})
}

func Test_defaultExecutor_execCountWithFilter(t *testing.T) {
t.Run("failed to query", func(t *testing.T) {
s := mocks.NewProxyComponent(t)
s.On("Query",
mock.Anything, // context.Context
mock.Anything, // milvuspb.QueryRequest
).Return(nil, errors.New("error mock Query"))
e := NewDefaultExecutor(s).(*defaultExecutor)
_, err := e.execCountWithFilter(context.TODO(), "t", "a > 2")
assert.Error(t, err)
})

t.Run("normal case", func(t *testing.T) {
s := mocks.NewProxyComponent(t)
res := &milvuspb.QueryResults{
Status: &commonpb.Status{},
FieldsData: []*schemapb.FieldData{
{
Type: schemapb.DataType_Int64,
FieldName: "field",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: []int64{1, 2, 3, 4},
},
},
},
},
},
},
CollectionName: "test",
}
s.On("Query",
mock.Anything, // context.Context
mock.Anything, // *milvuspb.QueryRequest
).Return(res, nil)
e := NewDefaultExecutor(s).(*defaultExecutor)
sqlRes, err := e.execCountWithFilter(context.TODO(), "t", "a > 2")
assert.NoError(t, err)
assert.Equal(t, 1, len(sqlRes.Fields))
assert.Equal(t, "count(*)", sqlRes.Fields[0].Name)
assert.Equal(t, querypb.Type_INT64, sqlRes.Fields[0].Type)
assert.Equal(t, 1, len(sqlRes.Rows))
assert.Equal(t, 1, len(sqlRes.Rows[0]))
assert.Equal(t, querypb.Type_INT64, sqlRes.Rows[0][0].Type())
})
}

func Test_defaultExecutor_execQuery(t *testing.T) {
t.Run("rpc failure", func(t *testing.T) {
s := mocks.NewProxyComponent(t)
s.On("Query",
mock.Anything, // context.Context
mock.Anything, // milvuspb.QueryRequest
).Return(nil, errors.New("error mock Query"))
e := NewDefaultExecutor(s).(*defaultExecutor)
_, err := e.execQuery(context.TODO(), "t", "a > 2", []string{"a"})
assert.Error(t, err)
})

t.Run("not success", func(t *testing.T) {
s := mocks.NewProxyComponent(t)
s.On("Query",
mock.Anything, // context.Context
mock.Anything, // milvuspb.QueryRequest
).Return(&milvuspb.QueryResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "error mock reason",
},
}, nil)
e := NewDefaultExecutor(s).(*defaultExecutor)
_, err := e.execQuery(context.TODO(), "t", "a > 2", []string{"a"})
assert.Error(t, err)
})

t.Run("success", func(t *testing.T) {
s := mocks.NewProxyComponent(t)
s.On("Query",
mock.Anything, // context.Context
mock.Anything, // milvuspb.QueryRequest
).Return(&milvuspb.QueryResults{
Status: &commonpb.Status{},
}, nil)
e := NewDefaultExecutor(s).(*defaultExecutor)
_, err := e.execQuery(context.TODO(), "t", "a > 2", []string{"a"})
assert.NoError(t, err)
})
}

func Test_wrapCountResult(t *testing.T) {
sqlRes := wrapCountResult(100, "count(*)")
assert.Equal(t, 1, len(sqlRes.Fields))
assert.Equal(t, "count(*)", sqlRes.Fields[0].Name)
assert.Equal(t, querypb.Type_INT64, sqlRes.Fields[0].Type)
assert.Equal(t, 1, len(sqlRes.Rows))
assert.Equal(t, 1, len(sqlRes.Rows[0]))
assert.Equal(t, querypb.Type_INT64, sqlRes.Rows[0][0].Type())
}

func Test_wrapQueryResults(t *testing.T) {
res := &milvuspb.QueryResults{
Status: &commonpb.Status{},
FieldsData: []*schemapb.FieldData{
{
Type: schemapb.DataType_Int64,
FieldName: "field",
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: []int64{1, 2, 3, 4},
},
},
},
},
},
},
CollectionName: "test",
}
sqlRes := wrapQueryResults(res)
assert.Equal(t, 1, len(sqlRes.Fields))
assert.Equal(t, 4, len(sqlRes.Rows))
assert.Equal(t, "field", sqlRes.Fields[0].Name)
assert.Equal(t, querypb.Type_INT64, sqlRes.Fields[0].Type)
assert.Equal(t, 1, len(sqlRes.Rows[0]))
assert.Equal(t, querypb.Type_INT64, sqlRes.Rows[0][0].Type())
}

func Test_getSQLField(t *testing.T) {
f := &schemapb.FieldData{
FieldName: "a",
Type: schemapb.DataType_Int64,
}
sf := getSQLField("t", f)
assert.Equal(t, "a", sf.Name)
assert.Equal(t, querypb.Type_INT64, sf.Type)
assert.Equal(t, "t", sf.Table)
}

func Test_toSQLType(t *testing.T) {
type args struct {
t schemapb.DataType
}
tests := []struct {
name string
args args
want querypb.Type
}{
{
args: args{
t: schemapb.DataType_Bool,
},
want: querypb.Type_UINT8,
},
{
args: args{
t: schemapb.DataType_Int8,
},
want: querypb.Type_INT8,
},
{
args: args{
t: schemapb.DataType_Int16,
},
want: querypb.Type_INT16,
},
{
args: args{
t: schemapb.DataType_Int32,
},
want: querypb.Type_INT32,
},
{
args: args{
t: schemapb.DataType_Int64,
},
want: querypb.Type_INT64,
},
{
args: args{
t: schemapb.DataType_Float,
},
want: querypb.Type_FLOAT32,
},
{
args: args{
t: schemapb.DataType_Double,
},
want: querypb.Type_FLOAT64,
},
{
args: args{
t: schemapb.DataType_VarChar,
},
want: querypb.Type_VARCHAR,
},
{
args: args{
t: schemapb.DataType_FloatVector,
},
want: querypb.Type_NULL_TYPE,
},
{
args: args{
t: schemapb.DataType_BinaryVector,
},
want: querypb.Type_NULL_TYPE,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equalf(t, tt.want, toSQLType(tt.args.t), "toSQLType(%v)", tt.args.t)
})
}
}
15 changes: 15 additions & 0 deletions internal/mysqld/optimizer/cbo_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package optimizer

import (
"testing"

"github.com/milvus-io/milvus/internal/mysqld/planner"
"github.com/stretchr/testify/assert"
)

func Test_defaultCBO_Optimize(t *testing.T) {
cbo := NewDefaultCBO()
plan := &planner.PhysicalPlan{}
optimized := cbo.Optimize(plan)
assert.Same(t, plan, optimized)
}
15 changes: 15 additions & 0 deletions internal/mysqld/optimizer/rbo_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package optimizer

import (
"testing"

"github.com/milvus-io/milvus/internal/mysqld/planner"
"github.com/stretchr/testify/assert"
)

func Test_defaultRBO_Optimize(t *testing.T) {
rbo := NewDefaultRBO()
plan := &planner.LogicalPlan{}
optimized := rbo.Optimize(plan)
assert.Same(t, plan, optimized)
}
2 changes: 1 addition & 1 deletion internal/mysqld/parser/antlrparser/ast_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (v *AstBuilder) VisitSqlStatements(ctx *parsergen.SqlStatementsContext) int
}
sqlStatements = append(sqlStatements, n)
}
return planner.NewNodeSqlStatements(sqlStatements, GetOriginalText(ctx))
return planner.NewNodeSqlStatements(GetOriginalText(ctx), sqlStatements)
}

func (v *AstBuilder) VisitSqlStatement(ctx *parsergen.SqlStatementContext) interface{} {
Expand Down
16 changes: 0 additions & 16 deletions internal/mysqld/parser/antlrparser/node_ret.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,6 @@ func GetNode(obj interface{}) planner.Node {
return n
}

func GetSqlStatements(obj interface{}) *planner.NodeSqlStatements {
n, ok := obj.(*planner.NodeSqlStatements)
if !ok {
return nil
}
return n
}

func GetSqlStatement(obj interface{}) *planner.NodeSqlStatement {
n, ok := obj.(*planner.NodeSqlStatement)
if !ok {
return nil
}
return n
}

func GetDmlStatement(obj interface{}) *planner.NodeDmlStatement {
n, ok := obj.(*planner.NodeDmlStatement)
if !ok {
Expand Down
19 changes: 19 additions & 0 deletions internal/mysqld/parser/antlrparser/node_ret_wrong_var_naming.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package antlrparser

import "github.com/milvus-io/milvus/internal/mysqld/planner"

func GetSqlStatements(obj interface{}) *planner.NodeSqlStatements {
n, ok := obj.(*planner.NodeSqlStatements)
if !ok {
return nil
}
return n
}

func GetSqlStatement(obj interface{}) *planner.NodeSqlStatement {
n, ok := obj.(*planner.NodeSqlStatement)
if !ok {
return nil
}
return n
}
14 changes: 14 additions & 0 deletions internal/mysqld/parser/config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package parser

import "testing"

func TestConfig_Apply(t *testing.T) {
opts := []Option{
func(c *Config) {
},
func(c *Config) {
},
}
c := defaultParserConfig()
c.Apply(opts...)
}
9 changes: 9 additions & 0 deletions internal/mysqld/planner/node_equal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package planner

import "reflect"

func Equal(n1, n2 Node) bool {
v := NewJSONVisitor()
j1, j2 := n1.Accept(v), n2.Accept(v)
return reflect.DeepEqual(j1, j2)
}
Loading

0 comments on commit ce2c5d1

Please sign in to comment.