Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: fix incorrect result of logical operators (#12173) #12811

4 changes: 2 additions & 2 deletions expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,8 +540,8 @@ var funcs = map[string]functionClass{
ast.Xor: &bitXorFunctionClass{baseFunctionClass{ast.Xor, 2, 2}},
ast.UnaryMinus: &unaryMinusFunctionClass{baseFunctionClass{ast.UnaryMinus, 1, 1}},
ast.In: &inFunctionClass{baseFunctionClass{ast.In, 2, -1}},
ast.IsTruth: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth},
ast.IsFalsity: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity},
ast.IsTruth: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth, false},
ast.IsFalsity: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity, false},
ast.Like: &likeFunctionClass{baseFunctionClass{ast.Like, 3, 3}},
ast.Regexp: &regexpFunctionClass{baseFunctionClass{ast.Regexp, 2, 2}},
ast.Case: &caseWhenFunctionClass{baseFunctionClass{ast.Case, 1, -1}},
Expand Down
73 changes: 61 additions & 12 deletions expression/builtin_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"fmt"
"math"

"github.com/pingcap/errors"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/opcode"
"github.com/pingcap/tidb/sessionctx"
Expand Down Expand Up @@ -64,6 +65,15 @@ func (c *logicAndFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
if err != nil {
return nil, err
}
args[0], err = wrapWithIsTrue(ctx, true, args[0])
if err != nil {
return nil, errors.Trace(err)
}
args[1], err = wrapWithIsTrue(ctx, true, args[1])
if err != nil {
return nil, errors.Trace(err)
}

bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt)
sig := &builtinLogicAndSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_LogicalAnd)
Expand Down Expand Up @@ -105,6 +115,15 @@ func (c *logicOrFunctionClass) getFunction(ctx sessionctx.Context, args []Expres
if err != nil {
return nil, err
}
args[0], err = wrapWithIsTrue(ctx, true, args[0])
if err != nil {
return nil, errors.Trace(err)
}
args[1], err = wrapWithIsTrue(ctx, true, args[1])
if err != nil {
return nil, errors.Trace(err)
}

bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt)
bf.tp.Flen = 1
sig := &builtinLogicOrSig{bf}
Expand Down Expand Up @@ -152,6 +171,7 @@ func (c *logicXorFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
if err != nil {
return nil, err
}

bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt)
sig := &builtinLogicXorSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_LogicalXor)
Expand Down Expand Up @@ -375,6 +395,11 @@ func (b *builtinRightShiftSig) evalInt(row chunk.Row) (int64, bool, error) {
type isTrueOrFalseFunctionClass struct {
baseFunctionClass
op opcode.Op

// keepNull indicates how this function treats a null input parameter.
// If keepNull is true and the input parameter is null, the function will return null.
// If keepNull is false, the null input parameter will be cast to 0.
keepNull bool
}

func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
Expand All @@ -395,25 +420,25 @@ func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args []
case opcode.IsTruth:
switch argTp {
case types.ETReal:
sig = &builtinRealIsTrueSig{bf}
sig = &builtinRealIsTrueSig{bf, c.keepNull}
sig.setPbCode(tipb.ScalarFuncSig_RealIsTrue)
case types.ETDecimal:
sig = &builtinDecimalIsTrueSig{bf}
sig = &builtinDecimalIsTrueSig{bf, c.keepNull}
sig.setPbCode(tipb.ScalarFuncSig_DecimalIsTrue)
case types.ETInt:
sig = &builtinIntIsTrueSig{bf}
sig = &builtinIntIsTrueSig{bf, c.keepNull}
sig.setPbCode(tipb.ScalarFuncSig_IntIsTrue)
}
case opcode.IsFalsity:
switch argTp {
case types.ETReal:
sig = &builtinRealIsFalseSig{bf}
sig = &builtinRealIsFalseSig{bf, c.keepNull}
sig.setPbCode(tipb.ScalarFuncSig_RealIsFalse)
case types.ETDecimal:
sig = &builtinDecimalIsFalseSig{bf}
sig = &builtinDecimalIsFalseSig{bf, c.keepNull}
sig.setPbCode(tipb.ScalarFuncSig_DecimalIsFalse)
case types.ETInt:
sig = &builtinIntIsFalseSig{bf}
sig = &builtinIntIsFalseSig{bf, c.keepNull}
sig.setPbCode(tipb.ScalarFuncSig_IntIsFalse)
}
}
Expand All @@ -422,10 +447,11 @@ func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args []

type builtinRealIsTrueSig struct {
baseBuiltinFunc
keepNull bool
}

func (b *builtinRealIsTrueSig) Clone() builtinFunc {
newSig := &builtinRealIsTrueSig{}
newSig := &builtinRealIsTrueSig{keepNull: b.keepNull}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}
Expand All @@ -435,6 +461,9 @@ func (b *builtinRealIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {
if err != nil {
return 0, true, err
}
if b.keepNull && isNull {
return 0, true, nil
}
if isNull || input == 0 {
return 0, false, nil
}
Expand All @@ -443,10 +472,11 @@ func (b *builtinRealIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {

type builtinDecimalIsTrueSig struct {
baseBuiltinFunc
keepNull bool
}

func (b *builtinDecimalIsTrueSig) Clone() builtinFunc {
newSig := &builtinDecimalIsTrueSig{}
newSig := &builtinDecimalIsTrueSig{keepNull: b.keepNull}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}
Expand All @@ -456,6 +486,9 @@ func (b *builtinDecimalIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {
if err != nil {
return 0, true, err
}
if b.keepNull && isNull {
return 0, true, nil
}
if isNull || input.IsZero() {
return 0, false, nil
}
Expand All @@ -464,10 +497,11 @@ func (b *builtinDecimalIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {

type builtinIntIsTrueSig struct {
baseBuiltinFunc
keepNull bool
}

func (b *builtinIntIsTrueSig) Clone() builtinFunc {
newSig := &builtinIntIsTrueSig{}
newSig := &builtinIntIsTrueSig{keepNull: b.keepNull}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}
Expand All @@ -477,6 +511,9 @@ func (b *builtinIntIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {
if err != nil {
return 0, true, err
}
if b.keepNull && isNull {
return 0, true, nil
}
if isNull || input == 0 {
return 0, false, nil
}
Expand All @@ -485,10 +522,11 @@ func (b *builtinIntIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {

type builtinRealIsFalseSig struct {
baseBuiltinFunc
keepNull bool
}

func (b *builtinRealIsFalseSig) Clone() builtinFunc {
newSig := &builtinRealIsFalseSig{}
newSig := &builtinRealIsFalseSig{keepNull: b.keepNull}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}
Expand All @@ -498,6 +536,9 @@ func (b *builtinRealIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) {
if err != nil {
return 0, true, err
}
if b.keepNull && isNull {
return 0, true, nil
}
if isNull || input != 0 {
return 0, false, nil
}
Expand All @@ -506,10 +547,11 @@ func (b *builtinRealIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) {

type builtinDecimalIsFalseSig struct {
baseBuiltinFunc
keepNull bool
}

func (b *builtinDecimalIsFalseSig) Clone() builtinFunc {
newSig := &builtinDecimalIsFalseSig{}
newSig := &builtinDecimalIsFalseSig{keepNull: b.keepNull}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}
Expand All @@ -519,6 +561,9 @@ func (b *builtinDecimalIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) {
if err != nil {
return 0, true, err
}
if b.keepNull && isNull {
return 0, true, nil
}
if isNull || !input.IsZero() {
return 0, false, nil
}
Expand All @@ -527,10 +572,11 @@ func (b *builtinDecimalIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) {

type builtinIntIsFalseSig struct {
baseBuiltinFunc
keepNull bool
}

func (b *builtinIntIsFalseSig) Clone() builtinFunc {
newSig := &builtinIntIsFalseSig{}
newSig := &builtinIntIsFalseSig{keepNull: b.keepNull}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}
Expand All @@ -540,6 +586,9 @@ func (b *builtinIntIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) {
if err != nil {
return 0, true, err
}
if b.keepNull && isNull {
return 0, true, nil
}
if isNull || input != 0 {
return 0, false, nil
}
Expand Down
90 changes: 90 additions & 0 deletions expression/builtin_op_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,21 @@ func (s *testEvaluatorSuite) TestLogicAnd(c *C) {
{[]interface{}{0, 1}, 0, false, false},
{[]interface{}{0, 0}, 0, false, false},
{[]interface{}{2, -1}, 1, false, false},
{[]interface{}{"a", "0"}, 0, false, false},
{[]interface{}{"a", "1"}, 0, false, false},
{[]interface{}{"1a", "0"}, 0, false, false},
{[]interface{}{"1a", "1"}, 1, false, false},
{[]interface{}{0, nil}, 0, false, false},
{[]interface{}{nil, 0}, 0, false, false},
{[]interface{}{nil, 1}, 0, true, false},
{[]interface{}{0.001, 0}, 0, false, false},
{[]interface{}{0.001, 1}, 1, false, false},
{[]interface{}{nil, 0.000}, 0, false, false},
{[]interface{}{nil, 0.001}, 0, true, false},
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 0}, 0, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 1}, 1, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000000"), nil}, 0, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000001"), nil}, 0, true, false},

{[]interface{}{errors.New("must error"), 1}, 0, false, true},
}
Expand Down Expand Up @@ -300,11 +310,26 @@ func (s *testEvaluatorSuite) TestLogicOr(c *C) {
{[]interface{}{0, 1}, 1, false, false},
{[]interface{}{0, 0}, 0, false, false},
{[]interface{}{2, -1}, 1, false, false},
{[]interface{}{"a", "0"}, 0, false, false},
{[]interface{}{"a", "1"}, 1, false, false},
{[]interface{}{"1a", "0"}, 1, false, false},
{[]interface{}{"1a", "1"}, 1, false, false},
// casting string to real depends on #10498, which will not be cherry-picked.
// {[]interface{}{"0.0a", 0}, 0, false, false},
// {[]interface{}{"0.0001a", 0}, 1, false, false},
{[]interface{}{1, nil}, 1, false, false},
{[]interface{}{nil, 1}, 1, false, false},
{[]interface{}{nil, 0}, 0, true, false},
{[]interface{}{0.000, 0}, 0, false, false},
{[]interface{}{0.001, 0}, 1, false, false},
{[]interface{}{nil, 0.000}, 0, true, false},
{[]interface{}{nil, 0.001}, 1, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000000"), 0}, 0, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000000"), 1}, 1, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000000"), nil}, 0, true, false},
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 0}, 1, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 1}, 1, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000001"), nil}, 1, false, false},

{[]interface{}{errors.New("must error"), 1}, 0, false, true},
}
Expand Down Expand Up @@ -541,3 +566,68 @@ func (s *testEvaluatorSuite) TestIsTrueOrFalse(c *C) {
c.Assert(isFalse, testutil.DatumEquals, types.NewDatum(tc.isFalse))
}
}

func (s *testEvaluatorSuite) TestLogicXor(c *C) {
defer testleak.AfterTest(c)()

sc := s.ctx.GetSessionVars().StmtCtx
origin := sc.IgnoreTruncate
defer func() {
sc.IgnoreTruncate = origin
}()
sc.IgnoreTruncate = true

cases := []struct {
args []interface{}
expected int64
isNil bool
getErr bool
}{
{[]interface{}{1, 1}, 0, false, false},
{[]interface{}{1, 0}, 1, false, false},
{[]interface{}{0, 1}, 1, false, false},
{[]interface{}{0, 0}, 0, false, false},
{[]interface{}{2, -1}, 0, false, false},
{[]interface{}{"a", "0"}, 0, false, false},
{[]interface{}{"a", "1"}, 1, false, false},
{[]interface{}{"1a", "0"}, 1, false, false},
{[]interface{}{"1a", "1"}, 0, false, false},
{[]interface{}{0, nil}, 0, true, false},
{[]interface{}{nil, 0}, 0, true, false},
{[]interface{}{nil, 1}, 0, true, false},
{[]interface{}{0.5000, 0.4999}, 1, false, false},
{[]interface{}{0.5000, 1.0}, 0, false, false},
{[]interface{}{0.4999, 1.0}, 1, false, false},
{[]interface{}{nil, 0.000}, 0, true, false},
{[]interface{}{nil, 0.001}, 0, true, false},
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 0.00001}, 0, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 1}, 1, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000000"), nil}, 0, true, false},
{[]interface{}{types.NewDecFromStringForTest("0.000001"), nil}, 0, true, false},

{[]interface{}{errors.New("must error"), 1}, 0, false, true},
}

for _, t := range cases {
f, err := newFunctionForTest(s.ctx, ast.LogicXor, s.primitiveValsToConstants(t.args)...)
c.Assert(err, IsNil)
d, err := f.Eval(chunk.Row{})
if t.getErr {
c.Assert(err, NotNil)
} else {
c.Assert(err, IsNil)
if t.isNil {
c.Assert(d.Kind(), Equals, types.KindNull)
} else {
c.Assert(d.GetInt64(), Equals, t.expected)
}
}
}

// Test incorrect parameter count.
_, err := newFunctionForTest(s.ctx, ast.LogicXor, Zero)
c.Assert(err, NotNil)

_, err = funcs[ast.LogicXor].getFunction(s.ctx, []Expression{Zero, Zero})
c.Assert(err, IsNil)
}
12 changes: 6 additions & 6 deletions expression/distsql_builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,17 +371,17 @@ func getSignatureByPB(ctx sessionctx.Context, sigCode tipb.ScalarFuncSig, tp *ti
f = &builtinCaseWhenIntSig{base}

case tipb.ScalarFuncSig_IntIsFalse:
f = &builtinIntIsFalseSig{base}
f = &builtinIntIsFalseSig{base, false}
case tipb.ScalarFuncSig_RealIsFalse:
f = &builtinRealIsFalseSig{base}
f = &builtinRealIsFalseSig{base, false}
case tipb.ScalarFuncSig_DecimalIsFalse:
f = &builtinDecimalIsFalseSig{base}
f = &builtinDecimalIsFalseSig{base, false}
case tipb.ScalarFuncSig_IntIsTrue:
f = &builtinIntIsTrueSig{base}
f = &builtinIntIsTrueSig{base, false}
case tipb.ScalarFuncSig_RealIsTrue:
f = &builtinRealIsTrueSig{base}
f = &builtinRealIsTrueSig{base, false}
case tipb.ScalarFuncSig_DecimalIsTrue:
f = &builtinDecimalIsTrueSig{base}
f = &builtinDecimalIsTrueSig{base, false}

case tipb.ScalarFuncSig_IfNullReal:
f = &builtinIfNullRealSig{base}
Expand Down
Loading