From ce2c5d15d903d5b8aa4c28fa7490b4c9ed0f87c0 Mon Sep 17 00:00:00 2001 From: Jiquan Long Date: Tue, 21 Mar 2023 20:13:57 +0800 Subject: [PATCH] Implement json visitor for plan node (#22897) Signed-off-by: longjiquan --- .golangci.yml | 3 +- internal/mysqld/executor/executor_test.go | 222 ++++++++++- internal/mysqld/optimizer/cbo_test.go | 15 + internal/mysqld/optimizer/rbo_test.go | 15 + .../mysqld/parser/antlrparser/ast_builder.go | 2 +- .../mysqld/parser/antlrparser/node_ret.go | 16 - .../antlrparser/node_ret_wrong_var_naming.go | 19 + internal/mysqld/parser/config_test.go | 14 + internal/mysqld/planner/node_equal.go | 9 + internal/mysqld/planner/node_equal_test.go | 96 +++++ internal/mysqld/planner/sql_statements.go | 2 +- internal/mysqld/planner/test_value.go | 13 +- ...go => visitor_expression_text_restorer.go} | 0 internal/mysqld/planner/visitor_json.go | 373 ++++++++++++++++++ .../planner/visitor_json_wrong_var_naming.go | 23 ++ 15 files changed, 800 insertions(+), 22 deletions(-) create mode 100644 internal/mysqld/optimizer/cbo_test.go create mode 100644 internal/mysqld/optimizer/rbo_test.go create mode 100644 internal/mysqld/parser/antlrparser/node_ret_wrong_var_naming.go create mode 100644 internal/mysqld/parser/config_test.go create mode 100644 internal/mysqld/planner/node_equal.go create mode 100644 internal/mysqld/planner/node_equal_test.go rename internal/mysqld/planner/{expression_text_restorer.go => visitor_expression_text_restorer.go} (100%) create mode 100644 internal/mysqld/planner/visitor_json.go create mode 100644 internal/mysqld/planner/visitor_json_wrong_var_naming.go diff --git a/.golangci.yml b/.golangci.yml index ed0cded330ed7..1a284dc391566 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -12,9 +12,10 @@ run: skip-files: # The worst lint rule of golang. - internal/mysqld/parser/antlrparser/ast_builder.go - - internal/mysqld/parser/antlrparser/node_ret.go + - internal/mysqld/parser/antlrparser/node_ret_wrong_var_naming.go - internal/mysqld/planner/sql_statement.go - internal/mysqld/planner/sql_statements.go + - internal/mysqld/planner/visitor_json_wrong_var_naming.go linters: disable-all: true diff --git a/internal/mysqld/executor/executor_test.go b/internal/mysqld/executor/executor_test.go index 27e0f65dbf27e..480adc2fd4d2a 100644 --- a/internal/mysqld/executor/executor_test.go +++ b/internal/mysqld/executor/executor_test.go @@ -37,7 +37,7 @@ func Test_defaultExecutor_Run(t *testing.T) { planner.NewNodeSqlStatement("sql2"), } plan := &planner.PhysicalPlan{ - Node: planner.NewNodeSqlStatements(stmts, "sql1; sql2"), + Node: planner.NewNodeSqlStatements("sql1; sql2", stmts), } _, err := e.Run(context.TODO(), plan) assert.Error(t, err) @@ -108,7 +108,7 @@ func Test_defaultExecutor_Run(t *testing.T) { )), } plan := &planner.PhysicalPlan{ - Node: planner.NewNodeSqlStatements(stmts, ""), + Node: planner.NewNodeSqlStatements("", stmts), } sqlRes, err := e.Run(context.TODO(), plan) assert.NoError(t, err) @@ -501,3 +501,221 @@ func Test_getOutputFieldsOrMatchCountRule(t *testing.T) { assert.ElementsMatch(t, []string{"field1", "field2"}, outputFields) }) } + +func Test_defaultExecutor_execCountWithFilter(t *testing.T) { + t.Run("failed to query", func(t *testing.T) { + s := mocks.NewProxyComponent(t) + s.On("Query", + mock.Anything, // context.Context + mock.Anything, // milvuspb.QueryRequest + ).Return(nil, errors.New("error mock Query")) + e := NewDefaultExecutor(s).(*defaultExecutor) + _, err := e.execCountWithFilter(context.TODO(), "t", "a > 2") + assert.Error(t, err) + }) + + t.Run("normal case", func(t *testing.T) { + s := mocks.NewProxyComponent(t) + res := &milvuspb.QueryResults{ + Status: &commonpb.Status{}, + FieldsData: []*schemapb.FieldData{ + { + Type: schemapb.DataType_Int64, + FieldName: "field", + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 4}, + }, + }, + }, + }, + }, + }, + CollectionName: "test", + } + s.On("Query", + mock.Anything, // context.Context + mock.Anything, // *milvuspb.QueryRequest + ).Return(res, nil) + e := NewDefaultExecutor(s).(*defaultExecutor) + sqlRes, err := e.execCountWithFilter(context.TODO(), "t", "a > 2") + assert.NoError(t, err) + assert.Equal(t, 1, len(sqlRes.Fields)) + assert.Equal(t, "count(*)", sqlRes.Fields[0].Name) + assert.Equal(t, querypb.Type_INT64, sqlRes.Fields[0].Type) + assert.Equal(t, 1, len(sqlRes.Rows)) + assert.Equal(t, 1, len(sqlRes.Rows[0])) + assert.Equal(t, querypb.Type_INT64, sqlRes.Rows[0][0].Type()) + }) +} + +func Test_defaultExecutor_execQuery(t *testing.T) { + t.Run("rpc failure", func(t *testing.T) { + s := mocks.NewProxyComponent(t) + s.On("Query", + mock.Anything, // context.Context + mock.Anything, // milvuspb.QueryRequest + ).Return(nil, errors.New("error mock Query")) + e := NewDefaultExecutor(s).(*defaultExecutor) + _, err := e.execQuery(context.TODO(), "t", "a > 2", []string{"a"}) + assert.Error(t, err) + }) + + t.Run("not success", func(t *testing.T) { + s := mocks.NewProxyComponent(t) + s.On("Query", + mock.Anything, // context.Context + mock.Anything, // milvuspb.QueryRequest + ).Return(&milvuspb.QueryResults{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "error mock reason", + }, + }, nil) + e := NewDefaultExecutor(s).(*defaultExecutor) + _, err := e.execQuery(context.TODO(), "t", "a > 2", []string{"a"}) + assert.Error(t, err) + }) + + t.Run("success", func(t *testing.T) { + s := mocks.NewProxyComponent(t) + s.On("Query", + mock.Anything, // context.Context + mock.Anything, // milvuspb.QueryRequest + ).Return(&milvuspb.QueryResults{ + Status: &commonpb.Status{}, + }, nil) + e := NewDefaultExecutor(s).(*defaultExecutor) + _, err := e.execQuery(context.TODO(), "t", "a > 2", []string{"a"}) + assert.NoError(t, err) + }) +} + +func Test_wrapCountResult(t *testing.T) { + sqlRes := wrapCountResult(100, "count(*)") + assert.Equal(t, 1, len(sqlRes.Fields)) + assert.Equal(t, "count(*)", sqlRes.Fields[0].Name) + assert.Equal(t, querypb.Type_INT64, sqlRes.Fields[0].Type) + assert.Equal(t, 1, len(sqlRes.Rows)) + assert.Equal(t, 1, len(sqlRes.Rows[0])) + assert.Equal(t, querypb.Type_INT64, sqlRes.Rows[0][0].Type()) +} + +func Test_wrapQueryResults(t *testing.T) { + res := &milvuspb.QueryResults{ + Status: &commonpb.Status{}, + FieldsData: []*schemapb.FieldData{ + { + Type: schemapb.DataType_Int64, + FieldName: "field", + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 4}, + }, + }, + }, + }, + }, + }, + CollectionName: "test", + } + sqlRes := wrapQueryResults(res) + assert.Equal(t, 1, len(sqlRes.Fields)) + assert.Equal(t, 4, len(sqlRes.Rows)) + assert.Equal(t, "field", sqlRes.Fields[0].Name) + assert.Equal(t, querypb.Type_INT64, sqlRes.Fields[0].Type) + assert.Equal(t, 1, len(sqlRes.Rows[0])) + assert.Equal(t, querypb.Type_INT64, sqlRes.Rows[0][0].Type()) +} + +func Test_getSQLField(t *testing.T) { + f := &schemapb.FieldData{ + FieldName: "a", + Type: schemapb.DataType_Int64, + } + sf := getSQLField("t", f) + assert.Equal(t, "a", sf.Name) + assert.Equal(t, querypb.Type_INT64, sf.Type) + assert.Equal(t, "t", sf.Table) +} + +func Test_toSQLType(t *testing.T) { + type args struct { + t schemapb.DataType + } + tests := []struct { + name string + args args + want querypb.Type + }{ + { + args: args{ + t: schemapb.DataType_Bool, + }, + want: querypb.Type_UINT8, + }, + { + args: args{ + t: schemapb.DataType_Int8, + }, + want: querypb.Type_INT8, + }, + { + args: args{ + t: schemapb.DataType_Int16, + }, + want: querypb.Type_INT16, + }, + { + args: args{ + t: schemapb.DataType_Int32, + }, + want: querypb.Type_INT32, + }, + { + args: args{ + t: schemapb.DataType_Int64, + }, + want: querypb.Type_INT64, + }, + { + args: args{ + t: schemapb.DataType_Float, + }, + want: querypb.Type_FLOAT32, + }, + { + args: args{ + t: schemapb.DataType_Double, + }, + want: querypb.Type_FLOAT64, + }, + { + args: args{ + t: schemapb.DataType_VarChar, + }, + want: querypb.Type_VARCHAR, + }, + { + args: args{ + t: schemapb.DataType_FloatVector, + }, + want: querypb.Type_NULL_TYPE, + }, + { + args: args{ + t: schemapb.DataType_BinaryVector, + }, + want: querypb.Type_NULL_TYPE, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, toSQLType(tt.args.t), "toSQLType(%v)", tt.args.t) + }) + } +} diff --git a/internal/mysqld/optimizer/cbo_test.go b/internal/mysqld/optimizer/cbo_test.go new file mode 100644 index 0000000000000..348f2ce3ed6fe --- /dev/null +++ b/internal/mysqld/optimizer/cbo_test.go @@ -0,0 +1,15 @@ +package optimizer + +import ( + "testing" + + "github.com/milvus-io/milvus/internal/mysqld/planner" + "github.com/stretchr/testify/assert" +) + +func Test_defaultCBO_Optimize(t *testing.T) { + cbo := NewDefaultCBO() + plan := &planner.PhysicalPlan{} + optimized := cbo.Optimize(plan) + assert.Same(t, plan, optimized) +} diff --git a/internal/mysqld/optimizer/rbo_test.go b/internal/mysqld/optimizer/rbo_test.go new file mode 100644 index 0000000000000..63d88c6a0c483 --- /dev/null +++ b/internal/mysqld/optimizer/rbo_test.go @@ -0,0 +1,15 @@ +package optimizer + +import ( + "testing" + + "github.com/milvus-io/milvus/internal/mysqld/planner" + "github.com/stretchr/testify/assert" +) + +func Test_defaultRBO_Optimize(t *testing.T) { + rbo := NewDefaultRBO() + plan := &planner.LogicalPlan{} + optimized := rbo.Optimize(plan) + assert.Same(t, plan, optimized) +} diff --git a/internal/mysqld/parser/antlrparser/ast_builder.go b/internal/mysqld/parser/antlrparser/ast_builder.go index 7974fc621bf7c..53aabe2572fc5 100644 --- a/internal/mysqld/parser/antlrparser/ast_builder.go +++ b/internal/mysqld/parser/antlrparser/ast_builder.go @@ -40,7 +40,7 @@ func (v *AstBuilder) VisitSqlStatements(ctx *parsergen.SqlStatementsContext) int } sqlStatements = append(sqlStatements, n) } - return planner.NewNodeSqlStatements(sqlStatements, GetOriginalText(ctx)) + return planner.NewNodeSqlStatements(GetOriginalText(ctx), sqlStatements) } func (v *AstBuilder) VisitSqlStatement(ctx *parsergen.SqlStatementContext) interface{} { diff --git a/internal/mysqld/parser/antlrparser/node_ret.go b/internal/mysqld/parser/antlrparser/node_ret.go index cf5dda4d8e327..d79650457ab83 100644 --- a/internal/mysqld/parser/antlrparser/node_ret.go +++ b/internal/mysqld/parser/antlrparser/node_ret.go @@ -21,22 +21,6 @@ func GetNode(obj interface{}) planner.Node { return n } -func GetSqlStatements(obj interface{}) *planner.NodeSqlStatements { - n, ok := obj.(*planner.NodeSqlStatements) - if !ok { - return nil - } - return n -} - -func GetSqlStatement(obj interface{}) *planner.NodeSqlStatement { - n, ok := obj.(*planner.NodeSqlStatement) - if !ok { - return nil - } - return n -} - func GetDmlStatement(obj interface{}) *planner.NodeDmlStatement { n, ok := obj.(*planner.NodeDmlStatement) if !ok { diff --git a/internal/mysqld/parser/antlrparser/node_ret_wrong_var_naming.go b/internal/mysqld/parser/antlrparser/node_ret_wrong_var_naming.go new file mode 100644 index 0000000000000..96b27f5705905 --- /dev/null +++ b/internal/mysqld/parser/antlrparser/node_ret_wrong_var_naming.go @@ -0,0 +1,19 @@ +package antlrparser + +import "github.com/milvus-io/milvus/internal/mysqld/planner" + +func GetSqlStatements(obj interface{}) *planner.NodeSqlStatements { + n, ok := obj.(*planner.NodeSqlStatements) + if !ok { + return nil + } + return n +} + +func GetSqlStatement(obj interface{}) *planner.NodeSqlStatement { + n, ok := obj.(*planner.NodeSqlStatement) + if !ok { + return nil + } + return n +} diff --git a/internal/mysqld/parser/config_test.go b/internal/mysqld/parser/config_test.go new file mode 100644 index 0000000000000..c426c38c8e925 --- /dev/null +++ b/internal/mysqld/parser/config_test.go @@ -0,0 +1,14 @@ +package parser + +import "testing" + +func TestConfig_Apply(t *testing.T) { + opts := []Option{ + func(c *Config) { + }, + func(c *Config) { + }, + } + c := defaultParserConfig() + c.Apply(opts...) +} diff --git a/internal/mysqld/planner/node_equal.go b/internal/mysqld/planner/node_equal.go new file mode 100644 index 0000000000000..a615603375709 --- /dev/null +++ b/internal/mysqld/planner/node_equal.go @@ -0,0 +1,9 @@ +package planner + +import "reflect" + +func Equal(n1, n2 Node) bool { + v := NewJSONVisitor() + j1, j2 := n1.Accept(v), n2.Accept(v) + return reflect.DeepEqual(j1, j2) +} diff --git a/internal/mysqld/planner/node_equal_test.go b/internal/mysqld/planner/node_equal_test.go new file mode 100644 index 0000000000000..819b0b9a80237 --- /dev/null +++ b/internal/mysqld/planner/node_equal_test.go @@ -0,0 +1,96 @@ +package planner + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEqual(t *testing.T) { + // a in ["100", false, 666.666] + expr1 := NewNodeExpression("", WithPredicate( + NewNodePredicate("", WithInPredicate( + NewNodeInPredicate("", NewNodePredicate("", WithNodeExpressionAtomPredicate( + NewNodeExpressionAtomPredicate("", + NewNodeExpressionAtom("", ExpressionAtomWithFullColumnName( + NewNodeFullColumnName("", "a")))))), NewNodeExpressions("", []*NodeExpression{ + NewNodeExpression("", WithPredicate( + NewNodePredicate("", WithNodeExpressionAtomPredicate( + NewNodeExpressionAtomPredicate("", + NewNodeExpressionAtom("", ExpressionAtomWithConstant( + NewNodeConstant("", WithStringLiteral("100"))))))))), + NewNodeExpression("", WithPredicate( + NewNodePredicate("", WithNodeExpressionAtomPredicate( + NewNodeExpressionAtomPredicate("", + NewNodeExpressionAtom("", ExpressionAtomWithConstant( + NewNodeConstant("", WithBooleanLiteral(false))))))))), + NewNodeExpression("", WithPredicate( + NewNodePredicate("", WithNodeExpressionAtomPredicate( + NewNodeExpressionAtomPredicate("", + NewNodeExpressionAtom("", ExpressionAtomWithConstant( + NewNodeConstant("", WithRealLiteral(666.666))))))))), + }), InOperatorIn))))) + + // b >= (+100) + expr2 := NewNodeExpression("", WithPredicate( + NewNodePredicate("", WithNodeBinaryComparisonPredicate( + NewNodeBinaryComparisonPredicate("", NewNodePredicate("", WithNodeExpressionAtomPredicate( + NewNodeExpressionAtomPredicate("", + NewNodeExpressionAtom("", ExpressionAtomWithFullColumnName( + NewNodeFullColumnName("", "b")))))), NewNodePredicate("", WithNodeExpressionAtomPredicate( + NewNodeExpressionAtomPredicate("", + NewNodeExpressionAtom("", ExpressionAtomWithNestedExpr( + NewNodeNestedExpressionAtom("", []*NodeExpression{ + NewNodeExpression("", WithPredicate( + NewNodePredicate("", WithNodeExpressionAtomPredicate( + NewNodeExpressionAtomPredicate("", + NewNodeExpressionAtom("", ExpressionAtomWithUnaryExpr( + NewNodeUnaryExpressionAtom("", NewNodeExpressionAtom("", ExpressionAtomWithConstant( + NewNodeConstant("", WithDecimalLiteral(100)))), UnaryOperatorPositive)))))))), + })))))), + ComparisonOperatorGreaterEqual))))) + + // c is true + expr3 := NewNodeExpression("", WithIsExpr( + NewNodeIsExpression("", NewNodePredicate("", WithNodeExpressionAtomPredicate( + NewNodeExpressionAtomPredicate("", + NewNodeExpressionAtom("", ExpressionAtomWithFullColumnName( + NewNodeFullColumnName("", "c")))))), TestValueTrue, IsOperatorIs))) + + // ((not expr1) & (expr2)) | expr3 + expr := NewNodeExpression("", WithLogicalExpr( + NewNodeLogicalExpression("", + NewNodeExpression("", WithLogicalExpr( + NewNodeLogicalExpression("", + NewNodeExpression("", WithNotExpr( + NewNodeNotExpression("", expr1))), + expr2, + LogicalOperatorAnd))), + expr3, + LogicalOperatorOr))) + + n := NewNodeSqlStatements("", []*NodeSqlStatement{ + NewNodeSqlStatement("", WithDmlStatement( + NewNodeDmlStatement("", WithSelectStatement( + NewNodeSelectStatement("", WithSimpleSelect( + NewNodeSimpleSelect("", WithLockClause( + NewNodeLockClause("", LockClauseOptionForUpdate)), WithQuery( + NewNodeQuerySpecification("", []*NodeSelectSpec{ + NewNodeSelectSpec(""), + }, []*NodeSelectElement{ + NewNodeSelectElement("", WithStar()), + NewNodeSelectElement("", WithFullColumnName( + NewNodeFullColumnName("", "a", FullColumnNameWithAlias("alias1")))), + NewNodeSelectElement("", WithFunctionCall( + NewNodeFunctionCall("", FunctionCallWithAlias("alias2"), WithAgg( + NewNodeAggregateWindowedFunction("", WithAggCount( + NewNodeCount(""))))))), + }, WithLimit( + NewNodeLimitClause("", 100, 0)), WithFrom( + NewNodeFromClause("", []*NodeTableSource{ + NewNodeTableSource("", WithTableName("t")), + }, WithWhere(expr)))))))))))), + }) + + assert.True(t, Equal(n, n)) +} diff --git a/internal/mysqld/planner/sql_statements.go b/internal/mysqld/planner/sql_statements.go index bd9a9ecfb1a5a..ed184a53a6a45 100644 --- a/internal/mysqld/planner/sql_statements.go +++ b/internal/mysqld/planner/sql_statements.go @@ -21,7 +21,7 @@ func (n *NodeSqlStatements) Accept(v Visitor) interface{} { return v.VisitSqlStatements(n) } -func NewNodeSqlStatements(statements []*NodeSqlStatement, text string) *NodeSqlStatements { +func NewNodeSqlStatements(text string, statements []*NodeSqlStatement) *NodeSqlStatements { return &NodeSqlStatements{ baseNode: newBaseNode(text), Statements: statements, diff --git a/internal/mysqld/planner/test_value.go b/internal/mysqld/planner/test_value.go index 83336d46c9e23..282ab5d5cad07 100644 --- a/internal/mysqld/planner/test_value.go +++ b/internal/mysqld/planner/test_value.go @@ -1,9 +1,20 @@ package planner -type TestValue = int +type TestValue int const ( TestValueUnknown TestValue = iota TestValueTrue TestValueFalse ) + +func (t TestValue) String() string { + switch t { + case TestValueTrue: + return "true" + case TestValueFalse: + return "false" + default: + return "unknown" + } +} diff --git a/internal/mysqld/planner/expression_text_restorer.go b/internal/mysqld/planner/visitor_expression_text_restorer.go similarity index 100% rename from internal/mysqld/planner/expression_text_restorer.go rename to internal/mysqld/planner/visitor_expression_text_restorer.go diff --git a/internal/mysqld/planner/visitor_json.go b/internal/mysqld/planner/visitor_json.go new file mode 100644 index 0000000000000..f45f1813f14ad --- /dev/null +++ b/internal/mysqld/planner/visitor_json.go @@ -0,0 +1,373 @@ +package planner + +import "strconv" + +type jsonVisitor struct { +} + +func (v jsonVisitor) VisitDmlStatement(n *NodeDmlStatement) interface{} { + j := map[string]interface{}{} + if n.SelectStatement.IsSome() { + j["dml_statement"] = n.SelectStatement.Unwrap().Accept(v) + } + return j +} + +func (v jsonVisitor) VisitSelectStatement(n *NodeSelectStatement) interface{} { + j := map[string]interface{}{} + if n.SimpleSelect.IsSome() { + j["simple_select"] = n.SimpleSelect.Unwrap().Accept(v) + } + return j +} + +func (v jsonVisitor) VisitSimpleSelect(n *NodeSimpleSelect) interface{} { + j := map[string]interface{}{} + + if n.Query.IsSome() { + j["query"] = n.Query.Unwrap().Accept(v) + } + + if n.LockClause.IsSome() { + j["lock_clause"] = n.LockClause.Unwrap().Accept(v) + } + + return j +} + +func (v jsonVisitor) VisitQuerySpecification(n *NodeQuerySpecification) interface{} { + var selectSpecs []interface{} + var selectElements []interface{} + + for _, selectSpec := range n.SelectSpecs { + selectSpecs = append(selectSpecs, selectSpec.Accept(v)) + } + + for _, selectElement := range n.SelectElements { + selectElements = append(selectElements, selectElement.Accept(v)) + } + + j := map[string]interface{}{ + "select_specs": selectSpecs, + "select_elements": selectElements, + } + + if n.From.IsSome() { + j["from_clause"] = n.From.Unwrap().Accept(v) + } + + if n.Limit.IsSome() { + j["limit_clause"] = n.Limit.Unwrap().Accept(v) + } + + return j +} + +func (v jsonVisitor) VisitLockClause(n *NodeLockClause) interface{} { + // leaf node. + return n.Option.String() +} + +func (v jsonVisitor) VisitSelectSpec(n *NodeSelectSpec) interface{} { + // leaf node. + return "" +} + +func (v jsonVisitor) VisitSelectElement(n *NodeSelectElement) interface{} { + j := map[string]interface{}{} + + if n.Star.IsSome() { + j["star"] = n.Star.Unwrap().Accept(v) + } + + if n.FullColumnName.IsSome() { + j["full_column_name"] = n.FullColumnName.Unwrap().Accept(v) + } + + if n.FunctionCall.IsSome() { + j["function_call"] = n.FunctionCall.Unwrap().Accept(v) + } + + return j +} + +func (v jsonVisitor) VisitFromClause(n *NodeFromClause) interface{} { + var tableSources []interface{} + + for _, tableSource := range n.TableSources { + tableSources = append(tableSources, tableSource.Accept(v)) + } + + j := map[string]interface{}{ + "table_sources": tableSources, + } + + if n.Where.IsSome() { + j["where"] = n.Where.Unwrap().Accept(v) + } + + return j +} + +func (v jsonVisitor) VisitLimitClause(n *NodeLimitClause) interface{} { + // leaf node. + j := map[string]interface{}{ + "limit": n.Limit, + "offset": n.Offset, + } + return j +} + +func (v jsonVisitor) VisitSelectElementStar(n *NodeSelectElementStar) interface{} { + // leaf node. + return "*" +} + +func (v jsonVisitor) VisitFullColumnName(n *NodeFullColumnName) interface{} { + // leaf node. + + j := map[string]interface{}{ + "name": n.Name, + } + + if n.Alias.IsSome() { + j["alias"] = n.Alias.Unwrap() + } + + return j +} + +func (v jsonVisitor) VisitFunctionCall(n *NodeFunctionCall) interface{} { + j := map[string]interface{}{} + + if n.Agg.IsSome() { + j["agg"] = n.Agg.Unwrap().Accept(v) + } + + if n.Alias.IsSome() { + j["alias"] = n.Alias.Unwrap() + } + + return j +} + +func (v jsonVisitor) VisitAggregateWindowedFunction(n *NodeAggregateWindowedFunction) interface{} { + j := map[string]interface{}{} + + if n.AggCount.IsSome() { + j["agg_count"] = n.AggCount.Unwrap().Accept(v) + } + + return j +} + +func (v jsonVisitor) VisitCount(n *NodeCount) interface{} { + // leaf node. + return "count" +} + +func (v jsonVisitor) VisitTableSource(n *NodeTableSource) interface{} { + j := map[string]interface{}{} + + if n.TableName.IsSome() { + j["name"] = n.TableName.Unwrap() + } + + return j +} + +func (v jsonVisitor) VisitExpression(n *NodeExpression) interface{} { + + j := map[string]interface{}{} + + if n.NotExpr.IsSome() { + j["not_expr"] = n.NotExpr.Unwrap().Accept(v) + } + + if n.LogicalExpr.IsSome() { + j["logical_expr"] = n.LogicalExpr.Unwrap().Accept(v) + } + + if n.IsExpr.IsSome() { + j["is_expr"] = n.IsExpr.Unwrap().Accept(v) + } + + if n.Predicate.IsSome() { + j["predicate"] = n.Predicate.Unwrap().Accept(v) + } + + return j +} + +func (v jsonVisitor) VisitExpressions(n *NodeExpressions) interface{} { + j := map[string]interface{}{} + + var expressions []interface{} + + for _, expression := range n.Expressions { + expressions = append(expressions, expression.Accept(v)) + } + + j["expressions"] = expressions + + return j +} + +func (v jsonVisitor) VisitNotExpression(n *NodeNotExpression) interface{} { + j := map[string]interface{}{} + + j["expression"] = n.Expression.Accept(v) + + return j +} + +func (v jsonVisitor) VisitLogicalExpression(n *NodeLogicalExpression) interface{} { + j := map[string]interface{}{} + + j["op"] = n.Op.String() + + j["left"] = n.Left.Accept(v) + + j["right"] = n.Right.Accept(v) + + return j +} + +func (v jsonVisitor) VisitIsExpression(n *NodeIsExpression) interface{} { + j := map[string]interface{}{} + + j["predicate"] = n.Predicate.Accept(v) + + j["op"] = n.Op.String() + + j["test_value"] = n.TV.String() + + return j +} + +func (v jsonVisitor) VisitPredicate(n *NodePredicate) interface{} { + j := map[string]interface{}{} + + if n.InPredicate.IsSome() { + j["in_predicate"] = n.InPredicate.Unwrap().Accept(v) + } + + if n.BinaryComparisonPredicate.IsSome() { + j["binary_comparison_predicate"] = n.BinaryComparisonPredicate.Unwrap().Accept(v) + } + + if n.ExpressionAtomPredicate.IsSome() { + j["expression_atom_predicate"] = n.ExpressionAtomPredicate.Unwrap().Accept(v) + } + + return j +} + +func (v jsonVisitor) VisitInPredicate(n *NodeInPredicate) interface{} { + j := map[string]interface{}{} + + j["predicate"] = n.Predicate.Accept(v) + + j["op"] = n.Op.String() + + j["expressions"] = n.Expressions.Accept(v) + + return j +} + +func (v jsonVisitor) VisitBinaryComparisonPredicate(n *NodeBinaryComparisonPredicate) interface{} { + j := map[string]interface{}{} + + j["left"] = n.Left.Accept(v) + + j["op"] = n.Op.String() + + j["right"] = n.Right.Accept(v) + + return j +} + +func (v jsonVisitor) VisitExpressionAtomPredicate(n *NodeExpressionAtomPredicate) interface{} { + j := map[string]interface{}{} + + j["expression_atom"] = n.ExpressionAtom.Accept(v) + + return j +} + +func (v jsonVisitor) VisitExpressionAtom(n *NodeExpressionAtom) interface{} { + + j := map[string]interface{}{} + + if n.Constant.IsSome() { + j["constant"] = n.Constant.Unwrap().Accept(v) + } + + if n.FullColumnName.IsSome() { + j["full_column_name"] = n.FullColumnName.Unwrap().Accept(v) + } + + if n.UnaryExpr.IsSome() { + j["unary_expr"] = n.UnaryExpr.Unwrap().Accept(v) + } + + if n.NestedExpr.IsSome() { + j["nested_expr"] = n.NestedExpr.Unwrap().Accept(v) + } + + return j +} + +func (v jsonVisitor) VisitUnaryExpressionAtom(n *NodeUnaryExpressionAtom) interface{} { + + j := map[string]interface{}{} + + j["op"] = n.Op.String() + + j["expr"] = n.Expr.Accept(v) + + return j +} + +func (v jsonVisitor) VisitNestedExpressionAtom(n *NodeNestedExpressionAtom) interface{} { + + j := map[string]interface{}{} + + var expressions []interface{} + + for _, expression := range n.Expressions { + expressions = append(expressions, expression.Accept(v)) + } + + j["expressions"] = expressions + + return j +} + +func (v jsonVisitor) VisitConstant(n *NodeConstant) interface{} { + // leaf node. + + j := map[string]interface{}{} + + if n.StringLiteral.IsSome() { + j["string_literal"] = n.StringLiteral.Unwrap() + } + + if n.DecimalLiteral.IsSome() { + j["decimal_literal"] = strconv.FormatInt(n.DecimalLiteral.Unwrap(), 10) + } + + if n.BooleanLiteral.IsSome() { + j["boolean_literal"] = strconv.FormatBool(n.BooleanLiteral.Unwrap()) + } + + if n.RealLiteral.IsSome() { + j["real_literal"] = strconv.FormatFloat(n.RealLiteral.Unwrap(), 'f', -1, 64) + } + + return j +} + +func NewJSONVisitor() Visitor { + return &jsonVisitor{} +} diff --git a/internal/mysqld/planner/visitor_json_wrong_var_naming.go b/internal/mysqld/planner/visitor_json_wrong_var_naming.go new file mode 100644 index 0000000000000..c67b09d072a62 --- /dev/null +++ b/internal/mysqld/planner/visitor_json_wrong_var_naming.go @@ -0,0 +1,23 @@ +package planner + +func (v jsonVisitor) VisitSqlStatements(n *NodeSqlStatements) interface{} { + var stmts []interface{} + for _, stmt := range n.Statements { + stmts = append(stmts, stmt.Accept(v)) + } + j := map[string]interface{}{ + "sql_statements": stmts, + } + return j +} + +func (v jsonVisitor) VisitSqlStatement(n *NodeSqlStatement) interface{} { + var r interface{} + if n.DmlStatement.IsSome() { + r = n.DmlStatement.Unwrap().Accept(v) + } + j := map[string]interface{}{ + "sql_statement": r, + } + return j +}