From 93f81512f84bad176dd767f753709cf9d75e7c92 Mon Sep 17 00:00:00 2001 From: Feng Liyuan Date: Fri, 18 Oct 2019 16:49:21 +0800 Subject: [PATCH] expression: fix incorrect result of logical operators (#12173) (#12811) --- expression/builtin.go | 4 +- expression/builtin_op.go | 73 +++++++++++++++++++++++----- expression/builtin_op_test.go | 90 +++++++++++++++++++++++++++++++++++ expression/distsql_builtin.go | 12 ++--- expression/expression.go | 22 +++++++++ 5 files changed, 181 insertions(+), 20 deletions(-) diff --git a/expression/builtin.go b/expression/builtin.go index e86ec51d85614..da5f0a6509a0b 100644 --- a/expression/builtin.go +++ b/expression/builtin.go @@ -540,8 +540,8 @@ var funcs = map[string]functionClass{ ast.Xor: &bitXorFunctionClass{baseFunctionClass{ast.Xor, 2, 2}}, ast.UnaryMinus: &unaryMinusFunctionClass{baseFunctionClass{ast.UnaryMinus, 1, 1}}, ast.In: &inFunctionClass{baseFunctionClass{ast.In, 2, -1}}, - ast.IsTruth: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth}, - ast.IsFalsity: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity}, + ast.IsTruth: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth, false}, + ast.IsFalsity: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity, false}, ast.Like: &likeFunctionClass{baseFunctionClass{ast.Like, 3, 3}}, ast.Regexp: ®expFunctionClass{baseFunctionClass{ast.Regexp, 2, 2}}, ast.Case: &caseWhenFunctionClass{baseFunctionClass{ast.Case, 1, -1}}, diff --git a/expression/builtin_op.go b/expression/builtin_op.go index 3c30642d6b3e0..cf6dc0a71a857 100644 --- a/expression/builtin_op.go +++ b/expression/builtin_op.go @@ -17,6 +17,7 @@ import ( "fmt" "math" + "github.com/pingcap/errors" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/opcode" "github.com/pingcap/tidb/sessionctx" @@ -64,6 +65,15 @@ func (c *logicAndFunctionClass) getFunction(ctx sessionctx.Context, args []Expre if err != nil { return nil, err } + args[0], err = wrapWithIsTrue(ctx, true, args[0]) + if err != nil { + return nil, errors.Trace(err) + } + args[1], err = wrapWithIsTrue(ctx, true, args[1]) + if err != nil { + return nil, errors.Trace(err) + } + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt) sig := &builtinLogicAndSig{bf} sig.setPbCode(tipb.ScalarFuncSig_LogicalAnd) @@ -105,6 +115,15 @@ func (c *logicOrFunctionClass) getFunction(ctx sessionctx.Context, args []Expres if err != nil { return nil, err } + args[0], err = wrapWithIsTrue(ctx, true, args[0]) + if err != nil { + return nil, errors.Trace(err) + } + args[1], err = wrapWithIsTrue(ctx, true, args[1]) + if err != nil { + return nil, errors.Trace(err) + } + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt) bf.tp.Flen = 1 sig := &builtinLogicOrSig{bf} @@ -152,6 +171,7 @@ func (c *logicXorFunctionClass) getFunction(ctx sessionctx.Context, args []Expre if err != nil { return nil, err } + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt) sig := &builtinLogicXorSig{bf} sig.setPbCode(tipb.ScalarFuncSig_LogicalXor) @@ -375,6 +395,11 @@ func (b *builtinRightShiftSig) evalInt(row chunk.Row) (int64, bool, error) { type isTrueOrFalseFunctionClass struct { baseFunctionClass op opcode.Op + + // keepNull indicates how this function treats a null input parameter. + // If keepNull is true and the input parameter is null, the function will return null. + // If keepNull is false, the null input parameter will be cast to 0. + keepNull bool } func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { @@ -395,25 +420,25 @@ func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args [] case opcode.IsTruth: switch argTp { case types.ETReal: - sig = &builtinRealIsTrueSig{bf} + sig = &builtinRealIsTrueSig{bf, c.keepNull} sig.setPbCode(tipb.ScalarFuncSig_RealIsTrue) case types.ETDecimal: - sig = &builtinDecimalIsTrueSig{bf} + sig = &builtinDecimalIsTrueSig{bf, c.keepNull} sig.setPbCode(tipb.ScalarFuncSig_DecimalIsTrue) case types.ETInt: - sig = &builtinIntIsTrueSig{bf} + sig = &builtinIntIsTrueSig{bf, c.keepNull} sig.setPbCode(tipb.ScalarFuncSig_IntIsTrue) } case opcode.IsFalsity: switch argTp { case types.ETReal: - sig = &builtinRealIsFalseSig{bf} + sig = &builtinRealIsFalseSig{bf, c.keepNull} sig.setPbCode(tipb.ScalarFuncSig_RealIsFalse) case types.ETDecimal: - sig = &builtinDecimalIsFalseSig{bf} + sig = &builtinDecimalIsFalseSig{bf, c.keepNull} sig.setPbCode(tipb.ScalarFuncSig_DecimalIsFalse) case types.ETInt: - sig = &builtinIntIsFalseSig{bf} + sig = &builtinIntIsFalseSig{bf, c.keepNull} sig.setPbCode(tipb.ScalarFuncSig_IntIsFalse) } } @@ -422,10 +447,11 @@ func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args [] type builtinRealIsTrueSig struct { baseBuiltinFunc + keepNull bool } func (b *builtinRealIsTrueSig) Clone() builtinFunc { - newSig := &builtinRealIsTrueSig{} + newSig := &builtinRealIsTrueSig{keepNull: b.keepNull} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -435,6 +461,9 @@ func (b *builtinRealIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) { if err != nil { return 0, true, err } + if b.keepNull && isNull { + return 0, true, nil + } if isNull || input == 0 { return 0, false, nil } @@ -443,10 +472,11 @@ func (b *builtinRealIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) { type builtinDecimalIsTrueSig struct { baseBuiltinFunc + keepNull bool } func (b *builtinDecimalIsTrueSig) Clone() builtinFunc { - newSig := &builtinDecimalIsTrueSig{} + newSig := &builtinDecimalIsTrueSig{keepNull: b.keepNull} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -456,6 +486,9 @@ func (b *builtinDecimalIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) { if err != nil { return 0, true, err } + if b.keepNull && isNull { + return 0, true, nil + } if isNull || input.IsZero() { return 0, false, nil } @@ -464,10 +497,11 @@ func (b *builtinDecimalIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) { type builtinIntIsTrueSig struct { baseBuiltinFunc + keepNull bool } func (b *builtinIntIsTrueSig) Clone() builtinFunc { - newSig := &builtinIntIsTrueSig{} + newSig := &builtinIntIsTrueSig{keepNull: b.keepNull} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -477,6 +511,9 @@ func (b *builtinIntIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) { if err != nil { return 0, true, err } + if b.keepNull && isNull { + return 0, true, nil + } if isNull || input == 0 { return 0, false, nil } @@ -485,10 +522,11 @@ func (b *builtinIntIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) { type builtinRealIsFalseSig struct { baseBuiltinFunc + keepNull bool } func (b *builtinRealIsFalseSig) Clone() builtinFunc { - newSig := &builtinRealIsFalseSig{} + newSig := &builtinRealIsFalseSig{keepNull: b.keepNull} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -498,6 +536,9 @@ func (b *builtinRealIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) { if err != nil { return 0, true, err } + if b.keepNull && isNull { + return 0, true, nil + } if isNull || input != 0 { return 0, false, nil } @@ -506,10 +547,11 @@ func (b *builtinRealIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) { type builtinDecimalIsFalseSig struct { baseBuiltinFunc + keepNull bool } func (b *builtinDecimalIsFalseSig) Clone() builtinFunc { - newSig := &builtinDecimalIsFalseSig{} + newSig := &builtinDecimalIsFalseSig{keepNull: b.keepNull} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -519,6 +561,9 @@ func (b *builtinDecimalIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) { if err != nil { return 0, true, err } + if b.keepNull && isNull { + return 0, true, nil + } if isNull || !input.IsZero() { return 0, false, nil } @@ -527,10 +572,11 @@ func (b *builtinDecimalIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) { type builtinIntIsFalseSig struct { baseBuiltinFunc + keepNull bool } func (b *builtinIntIsFalseSig) Clone() builtinFunc { - newSig := &builtinIntIsFalseSig{} + newSig := &builtinIntIsFalseSig{keepNull: b.keepNull} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -540,6 +586,9 @@ func (b *builtinIntIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) { if err != nil { return 0, true, err } + if b.keepNull && isNull { + return 0, true, nil + } if isNull || input != 0 { return 0, false, nil } diff --git a/expression/builtin_op_test.go b/expression/builtin_op_test.go index b2f700cb7ca57..a45d488dfff95 100644 --- a/expression/builtin_op_test.go +++ b/expression/builtin_op_test.go @@ -86,11 +86,21 @@ func (s *testEvaluatorSuite) TestLogicAnd(c *C) { {[]interface{}{0, 1}, 0, false, false}, {[]interface{}{0, 0}, 0, false, false}, {[]interface{}{2, -1}, 1, false, false}, + {[]interface{}{"a", "0"}, 0, false, false}, {[]interface{}{"a", "1"}, 0, false, false}, + {[]interface{}{"1a", "0"}, 0, false, false}, {[]interface{}{"1a", "1"}, 1, false, false}, {[]interface{}{0, nil}, 0, false, false}, {[]interface{}{nil, 0}, 0, false, false}, {[]interface{}{nil, 1}, 0, true, false}, + {[]interface{}{0.001, 0}, 0, false, false}, + {[]interface{}{0.001, 1}, 1, false, false}, + {[]interface{}{nil, 0.000}, 0, false, false}, + {[]interface{}{nil, 0.001}, 0, true, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000001"), 0}, 0, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000001"), 1}, 1, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000000"), nil}, 0, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000001"), nil}, 0, true, false}, {[]interface{}{errors.New("must error"), 1}, 0, false, true}, } @@ -300,11 +310,26 @@ func (s *testEvaluatorSuite) TestLogicOr(c *C) { {[]interface{}{0, 1}, 1, false, false}, {[]interface{}{0, 0}, 0, false, false}, {[]interface{}{2, -1}, 1, false, false}, + {[]interface{}{"a", "0"}, 0, false, false}, {[]interface{}{"a", "1"}, 1, false, false}, + {[]interface{}{"1a", "0"}, 1, false, false}, {[]interface{}{"1a", "1"}, 1, false, false}, + // casting string to real depends on #10498, which will not be cherry-picked. + // {[]interface{}{"0.0a", 0}, 0, false, false}, + // {[]interface{}{"0.0001a", 0}, 1, false, false}, {[]interface{}{1, nil}, 1, false, false}, {[]interface{}{nil, 1}, 1, false, false}, {[]interface{}{nil, 0}, 0, true, false}, + {[]interface{}{0.000, 0}, 0, false, false}, + {[]interface{}{0.001, 0}, 1, false, false}, + {[]interface{}{nil, 0.000}, 0, true, false}, + {[]interface{}{nil, 0.001}, 1, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000000"), 0}, 0, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000000"), 1}, 1, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000000"), nil}, 0, true, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000001"), 0}, 1, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000001"), 1}, 1, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000001"), nil}, 1, false, false}, {[]interface{}{errors.New("must error"), 1}, 0, false, true}, } @@ -541,3 +566,68 @@ func (s *testEvaluatorSuite) TestIsTrueOrFalse(c *C) { c.Assert(isFalse, testutil.DatumEquals, types.NewDatum(tc.isFalse)) } } + +func (s *testEvaluatorSuite) TestLogicXor(c *C) { + defer testleak.AfterTest(c)() + + sc := s.ctx.GetSessionVars().StmtCtx + origin := sc.IgnoreTruncate + defer func() { + sc.IgnoreTruncate = origin + }() + sc.IgnoreTruncate = true + + cases := []struct { + args []interface{} + expected int64 + isNil bool + getErr bool + }{ + {[]interface{}{1, 1}, 0, false, false}, + {[]interface{}{1, 0}, 1, false, false}, + {[]interface{}{0, 1}, 1, false, false}, + {[]interface{}{0, 0}, 0, false, false}, + {[]interface{}{2, -1}, 0, false, false}, + {[]interface{}{"a", "0"}, 0, false, false}, + {[]interface{}{"a", "1"}, 1, false, false}, + {[]interface{}{"1a", "0"}, 1, false, false}, + {[]interface{}{"1a", "1"}, 0, false, false}, + {[]interface{}{0, nil}, 0, true, false}, + {[]interface{}{nil, 0}, 0, true, false}, + {[]interface{}{nil, 1}, 0, true, false}, + {[]interface{}{0.5000, 0.4999}, 1, false, false}, + {[]interface{}{0.5000, 1.0}, 0, false, false}, + {[]interface{}{0.4999, 1.0}, 1, false, false}, + {[]interface{}{nil, 0.000}, 0, true, false}, + {[]interface{}{nil, 0.001}, 0, true, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000001"), 0.00001}, 0, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000001"), 1}, 1, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000000"), nil}, 0, true, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000001"), nil}, 0, true, false}, + + {[]interface{}{errors.New("must error"), 1}, 0, false, true}, + } + + for _, t := range cases { + f, err := newFunctionForTest(s.ctx, ast.LogicXor, s.primitiveValsToConstants(t.args)...) + c.Assert(err, IsNil) + d, err := f.Eval(chunk.Row{}) + if t.getErr { + c.Assert(err, NotNil) + } else { + c.Assert(err, IsNil) + if t.isNil { + c.Assert(d.Kind(), Equals, types.KindNull) + } else { + c.Assert(d.GetInt64(), Equals, t.expected) + } + } + } + + // Test incorrect parameter count. + _, err := newFunctionForTest(s.ctx, ast.LogicXor, Zero) + c.Assert(err, NotNil) + + _, err = funcs[ast.LogicXor].getFunction(s.ctx, []Expression{Zero, Zero}) + c.Assert(err, IsNil) +} diff --git a/expression/distsql_builtin.go b/expression/distsql_builtin.go index 127e6baa95c78..749b46b467dc5 100644 --- a/expression/distsql_builtin.go +++ b/expression/distsql_builtin.go @@ -371,17 +371,17 @@ func getSignatureByPB(ctx sessionctx.Context, sigCode tipb.ScalarFuncSig, tp *ti f = &builtinCaseWhenIntSig{base} case tipb.ScalarFuncSig_IntIsFalse: - f = &builtinIntIsFalseSig{base} + f = &builtinIntIsFalseSig{base, false} case tipb.ScalarFuncSig_RealIsFalse: - f = &builtinRealIsFalseSig{base} + f = &builtinRealIsFalseSig{base, false} case tipb.ScalarFuncSig_DecimalIsFalse: - f = &builtinDecimalIsFalseSig{base} + f = &builtinDecimalIsFalseSig{base, false} case tipb.ScalarFuncSig_IntIsTrue: - f = &builtinIntIsTrueSig{base} + f = &builtinIntIsTrueSig{base, false} case tipb.ScalarFuncSig_RealIsTrue: - f = &builtinRealIsTrueSig{base} + f = &builtinRealIsTrueSig{base, false} case tipb.ScalarFuncSig_DecimalIsTrue: - f = &builtinDecimalIsTrueSig{base} + f = &builtinDecimalIsTrueSig{base, false} case tipb.ScalarFuncSig_IfNullReal: f = &builtinIfNullRealSig{base} diff --git a/expression/expression.go b/expression/expression.go index 30520557d74a9..7f5f65f3f6f58 100644 --- a/expression/expression.go +++ b/expression/expression.go @@ -20,6 +20,7 @@ import ( "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" + "github.com/pingcap/parser/opcode" "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/stmtctx" @@ -376,3 +377,24 @@ func IsBinaryLiteral(expr Expression) bool { con, ok := expr.(*Constant) return ok && con.Value.Kind() == types.KindBinaryLiteral } + +// wrapWithIsTrue wraps `arg` with istrue function if the return type of expr is not +// type int, otherwise, returns `arg` directly. +// The `keepNull` controls what the istrue function will return when `arg` is null: +// 1. keepNull is true and arg is null, the istrue function returns null. +// 2. keepNull is false and arg is null, the istrue function returns 0. +func wrapWithIsTrue(ctx sessionctx.Context, keepNull bool, arg Expression) (Expression, error) { + if arg.GetType().EvalType() == types.ETInt { + return arg, nil + } + fc := &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth, keepNull} + f, err := fc.getFunction(ctx, []Expression{arg}) + if err != nil { + return nil, err + } + return &ScalarFunction{ + FuncName: model.NewCIStr(fmt.Sprintf("sig_%T", f)), + Function: f, + RetType: f.getRetTp(), + }, nil +}