Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner: reimplement DEFAULT function to make the behavior consistent with MySQL when looking up the corresponding column #19709

Merged
merged 22 commits into from
Sep 18, 2020
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion expression/simple_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func FindFieldName(names types.NameSlice, astCol *ast.ColumnName) (int, error) {
if idx == -1 {
idx = i
} else {
return -1, errNonUniq.GenWithStackByArgs(name.String(), "field list")
return -1, errNonUniq.GenWithStackByArgs(astCol.String(), "field list")
}
}
}
Expand Down
32 changes: 22 additions & 10 deletions planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1653,25 +1653,38 @@ func (er *expressionRewriter) toColumn(v *ast.ColumnName) {
}

func (er *expressionRewriter) evalDefaultExpr(v *ast.DefaultExpr) {
stkLen := len(er.ctxStack)
name := er.ctxNameStk[stkLen-1]
switch er.ctxStack[stkLen-1].(type) {
case *expression.Column:
case *expression.CorrelatedColumn:
default:
idx, err := expression.FindFieldName(er.names, v.Name)
var name *types.FieldName
// Here we will find the corresponding column for default function. At the same time, we need to consider the issue
// of subquery and name space.
// For example, we have two tables t1(a int default 1, b int) and t2(a int default -1, c int). Consider the following SQL:
// select a from t1 where a > (select default(a) from t2)
// Refer to the behavior of MySQL, we need to find column a in table t2. If table t2 does not have column a, then find it
// in table t1. If there are none, return an error message.
// Based on the above description, we need to look in er.b.allNames from back to front.
for i := len(er.b.allNames) - 1; i >= 0; i-- {
idx, err := expression.FindFieldName(er.b.allNames[i], v.Name)
if err != nil {
er.err = err
return
}
if er.err != nil {
if idx >= 0 {
name = er.b.allNames[i][idx]
break
}
}
if name == nil {
idx, err := expression.FindFieldName(er.names, v.Name)
if err != nil {
er.err = err
return
}
if idx < 0 {
er.err = ErrUnknownColumn.GenWithStackByArgs(v.Name.OrigColName(), "field_list")
er.err = ErrUnknownColumn.GenWithStackByArgs(v.Name.OrigColName(), "field list")
return
}
name = er.names[idx]
}

dbName := name.DBName
if dbName.O == "" {
// if database name is not specified, use current database name
Expand Down Expand Up @@ -1719,7 +1732,6 @@ func (er *expressionRewriter) evalDefaultExpr(v *ast.DefaultExpr) {
if er.err != nil {
return
}
er.ctxStackPop(1)
er.ctxStackAppend(val, types.EmptyName)
}

Expand Down
17 changes: 16 additions & 1 deletion planner/core/expression_rewriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func (s *testExpressionRewriterSuite) TestDefaultFunction(c *C) {
tk.MustExec("create table t2(a varchar(10), b varchar(10))")
tk.MustExec("insert into t2 values ('1', '1')")
err = tk.ExecToErr("select default(a) from t1, t2")
c.Assert(err.Error(), Equals, "[planner:1052]Column 'a' in field list is ambiguous")
c.Assert(err.Error(), Equals, "[expression:1052]Column 'a' in field list is ambiguous")
tk.MustQuery("select default(t1.a) from t1, t2").Check(testkit.Rows("def"))

tk.MustExec(`create table t3(
Expand Down Expand Up @@ -130,6 +130,21 @@ func (s *testExpressionRewriterSuite) TestDefaultFunction(c *C) {

tk.MustExec("update t1 set c = c + default(c)")
tk.MustQuery("select c from t1").Check(testkit.Rows("11"))

tk.MustExec("create table t6(a int default -1, b int)")
tk.MustExec(`insert into t6 values (0, 0), (1, 1), (2, 2)`)
tk.MustExec("create table t7(a int default 1, b int)")
tk.MustExec(`insert into t7 values (0, 0), (1, 1), (2, 2)`)

tk.MustQuery(`select a from t6 where a > (select default(a) from t7 where t6.a = t7.a)`).Check(testkit.Rows("2"))
tk.MustQuery(`select a, default(a) from t6 where a > (select default(a) from t7 where t6.a = t7.a)`).Check(testkit.Rows("2 -1"))

tk.MustExec("create table t8(a int default 1, b int default -1)")
tk.MustExec(`insert into t8 values (0, 0), (1, 1)`)

tk.MustQuery(`select a, a from t8 order by default(a)`).Check(testkit.Rows("0 0", "1 1"))
tk.MustQuery(`select a from t8 order by default(b)`).Check(testkit.Rows("0", "1"))
tk.MustQuery(`select a from t8 order by default(b) * a`).Check(testkit.Rows("1", "0"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you please add some test that is invalid in resolving column names?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}

func (s *testExpressionRewriterSuite) TestCompareSubquery(c *C) {
Expand Down
9 changes: 9 additions & 0 deletions planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2665,6 +2665,15 @@ func (b *PlanBuilder) buildSelect(ctx context.Context, sel *ast.SelectStmt) (p L
return nil, err
}

// b.allNames will be used in evalDefaultExpr(). Default function is special because it needs to find the
// corresponding column name, but does not need the value in the column.
// For example, select a from t order by default(b), the column b will not be in select fields. Also because
// buildSort is after buildProjection, so we need get OutputNames before BuildProjection and store in allNames.
// Otherwise, we will get select fields instead of all OutputNames, so that we can't find the column b in the
// above example.
b.allNames = append(b.allNames, p.OutputNames())
defer func() { b.allNames = b.allNames[:len(b.allNames)-1] }()

if sel.Where != nil {
p, err = b.buildSelection(ctx, p, sel.Where, nil)
if err != nil {
Expand Down
4 changes: 4 additions & 0 deletions planner/core/planbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,10 @@ type PlanBuilder struct {
partitionedTable []table.PartitionedTable
// CreateView needs this information to check whether exists nested view.
underlyingViewNames set.StringSet

// evalDefaultExpr needs this information to find the corresponding column.
// It stores the OutputNames before buildProjection.
allNames [][]*types.FieldName
}

type handleColHelper struct {
Expand Down