Skip to content

Commit

Permalink
expression: Fix incorrect result of logical operators (pingcap#11199)
Browse files Browse the repository at this point in the history
  • Loading branch information
sduzh authored and SunRunAway committed Oct 18, 2019
1 parent cf7f6f1 commit c79440f
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 20 deletions.
4 changes: 2 additions & 2 deletions expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,8 +540,8 @@ var funcs = map[string]functionClass{
ast.Xor: &bitXorFunctionClass{baseFunctionClass{ast.Xor, 2, 2}},
ast.UnaryMinus: &unaryMinusFunctionClass{baseFunctionClass{ast.UnaryMinus, 1, 1}},
ast.In: &inFunctionClass{baseFunctionClass{ast.In, 2, -1}},
ast.IsTruth: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth},
ast.IsFalsity: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity},
ast.IsTruth: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth, false},
ast.IsFalsity: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity, false},
ast.Like: &likeFunctionClass{baseFunctionClass{ast.Like, 3, 3}},
ast.Regexp: &regexpFunctionClass{baseFunctionClass{ast.Regexp, 2, 2}},
ast.Case: &caseWhenFunctionClass{baseFunctionClass{ast.Case, 1, -1}},
Expand Down
72 changes: 60 additions & 12 deletions expression/builtin_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ func (c *logicAndFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
if err != nil {
return nil, errors.Trace(err)
}
args[0], err = wrapWithIsTrue(ctx, true, args[0])
if err != nil {
return nil, errors.Trace(err)
}
args[1], err = wrapWithIsTrue(ctx, true, args[1])
if err != nil {
return nil, errors.Trace(err)
}

bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt)
sig := &builtinLogicAndSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_LogicalAnd)
Expand Down Expand Up @@ -106,6 +115,15 @@ func (c *logicOrFunctionClass) getFunction(ctx sessionctx.Context, args []Expres
if err != nil {
return nil, errors.Trace(err)
}
args[0], err = wrapWithIsTrue(ctx, true, args[0])
if err != nil {
return nil, errors.Trace(err)
}
args[1], err = wrapWithIsTrue(ctx, true, args[1])
if err != nil {
return nil, errors.Trace(err)
}

bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt)
bf.tp.Flen = 1
sig := &builtinLogicOrSig{bf}
Expand Down Expand Up @@ -153,6 +171,7 @@ func (c *logicXorFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
if err != nil {
return nil, errors.Trace(err)
}

bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt)
sig := &builtinLogicXorSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_LogicalXor)
Expand Down Expand Up @@ -376,6 +395,11 @@ func (b *builtinRightShiftSig) evalInt(row chunk.Row) (int64, bool, error) {
type isTrueOrFalseFunctionClass struct {
baseFunctionClass
op opcode.Op

// keepNull indicates how this function treats a null input parameter.
// If keepNull is true and the input parameter is null, the function will return null.
// If keepNull is false, the null input parameter will be cast to 0.
keepNull bool
}

func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
Expand All @@ -396,25 +420,25 @@ func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args []
case opcode.IsTruth:
switch argTp {
case types.ETReal:
sig = &builtinRealIsTrueSig{bf}
sig = &builtinRealIsTrueSig{bf, c.keepNull}
sig.setPbCode(tipb.ScalarFuncSig_RealIsTrue)
case types.ETDecimal:
sig = &builtinDecimalIsTrueSig{bf}
sig = &builtinDecimalIsTrueSig{bf, c.keepNull}
sig.setPbCode(tipb.ScalarFuncSig_DecimalIsTrue)
case types.ETInt:
sig = &builtinIntIsTrueSig{bf}
sig = &builtinIntIsTrueSig{bf, c.keepNull}
sig.setPbCode(tipb.ScalarFuncSig_IntIsTrue)
}
case opcode.IsFalsity:
switch argTp {
case types.ETReal:
sig = &builtinRealIsFalseSig{bf}
sig = &builtinRealIsFalseSig{bf, c.keepNull}
sig.setPbCode(tipb.ScalarFuncSig_RealIsFalse)
case types.ETDecimal:
sig = &builtinDecimalIsFalseSig{bf}
sig = &builtinDecimalIsFalseSig{bf, c.keepNull}
sig.setPbCode(tipb.ScalarFuncSig_DecimalIsFalse)
case types.ETInt:
sig = &builtinIntIsFalseSig{bf}
sig = &builtinIntIsFalseSig{bf, c.keepNull}
sig.setPbCode(tipb.ScalarFuncSig_IntIsFalse)
}
}
Expand All @@ -423,10 +447,11 @@ func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args []

type builtinRealIsTrueSig struct {
baseBuiltinFunc
keepNull bool
}

func (b *builtinRealIsTrueSig) Clone() builtinFunc {
newSig := &builtinRealIsTrueSig{}
newSig := &builtinRealIsTrueSig{keepNull: b.keepNull}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}
Expand All @@ -436,6 +461,9 @@ func (b *builtinRealIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {
if err != nil {
return 0, true, errors.Trace(err)
}
if b.keepNull && isNull {
return 0, true, nil
}
if isNull || input == 0 {
return 0, false, nil
}
Expand All @@ -444,10 +472,11 @@ func (b *builtinRealIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {

type builtinDecimalIsTrueSig struct {
baseBuiltinFunc
keepNull bool
}

func (b *builtinDecimalIsTrueSig) Clone() builtinFunc {
newSig := &builtinDecimalIsTrueSig{}
newSig := &builtinDecimalIsTrueSig{keepNull: b.keepNull}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}
Expand All @@ -457,6 +486,9 @@ func (b *builtinDecimalIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {
if err != nil {
return 0, true, errors.Trace(err)
}
if b.keepNull && isNull {
return 0, true, nil
}
if isNull || input.IsZero() {
return 0, false, nil
}
Expand All @@ -465,10 +497,11 @@ func (b *builtinDecimalIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {

type builtinIntIsTrueSig struct {
baseBuiltinFunc
keepNull bool
}

func (b *builtinIntIsTrueSig) Clone() builtinFunc {
newSig := &builtinIntIsTrueSig{}
newSig := &builtinIntIsTrueSig{keepNull: b.keepNull}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}
Expand All @@ -478,6 +511,9 @@ func (b *builtinIntIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {
if err != nil {
return 0, true, errors.Trace(err)
}
if b.keepNull && isNull {
return 0, true, nil
}
if isNull || input == 0 {
return 0, false, nil
}
Expand All @@ -486,10 +522,11 @@ func (b *builtinIntIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {

type builtinRealIsFalseSig struct {
baseBuiltinFunc
keepNull bool
}

func (b *builtinRealIsFalseSig) Clone() builtinFunc {
newSig := &builtinRealIsFalseSig{}
newSig := &builtinRealIsFalseSig{keepNull: b.keepNull}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}
Expand All @@ -499,6 +536,9 @@ func (b *builtinRealIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) {
if err != nil {
return 0, true, errors.Trace(err)
}
if b.keepNull && isNull {
return 0, true, nil
}
if isNull || input != 0 {
return 0, false, nil
}
Expand All @@ -507,10 +547,11 @@ func (b *builtinRealIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) {

type builtinDecimalIsFalseSig struct {
baseBuiltinFunc
keepNull bool
}

func (b *builtinDecimalIsFalseSig) Clone() builtinFunc {
newSig := &builtinDecimalIsFalseSig{}
newSig := &builtinDecimalIsFalseSig{keepNull: b.keepNull}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}
Expand All @@ -520,6 +561,9 @@ func (b *builtinDecimalIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) {
if err != nil {
return 0, true, errors.Trace(err)
}
if b.keepNull && isNull {
return 0, true, nil
}
if isNull || !input.IsZero() {
return 0, false, nil
}
Expand All @@ -528,10 +572,11 @@ func (b *builtinDecimalIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) {

type builtinIntIsFalseSig struct {
baseBuiltinFunc
keepNull bool
}

func (b *builtinIntIsFalseSig) Clone() builtinFunc {
newSig := &builtinIntIsFalseSig{}
newSig := &builtinIntIsFalseSig{keepNull: b.keepNull}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}
Expand All @@ -541,6 +586,9 @@ func (b *builtinIntIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) {
if err != nil {
return 0, true, errors.Trace(err)
}
if b.keepNull && isNull {
return 0, true, nil
}
if isNull || input != 0 {
return 0, false, nil
}
Expand Down
89 changes: 89 additions & 0 deletions expression/builtin_op_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,21 @@ func (s *testEvaluatorSuite) TestLogicAnd(c *C) {
{[]interface{}{0, 1}, 0, false, false},
{[]interface{}{0, 0}, 0, false, false},
{[]interface{}{2, -1}, 1, false, false},
{[]interface{}{"a", "0"}, 0, false, false},
{[]interface{}{"a", "1"}, 0, false, false},
{[]interface{}{"1a", "0"}, 0, false, false},
{[]interface{}{"1a", "1"}, 1, false, false},
{[]interface{}{0, nil}, 0, false, false},
{[]interface{}{nil, 0}, 0, false, false},
{[]interface{}{nil, 1}, 0, true, false},
{[]interface{}{0.001, 0}, 0, false, false},
{[]interface{}{0.001, 1}, 1, false, false},
{[]interface{}{nil, 0.000}, 0, false, false},
{[]interface{}{nil, 0.001}, 0, true, false},
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 0}, 0, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 1}, 1, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000000"), nil}, 0, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000001"), nil}, 0, true, false},

{[]interface{}{errors.New("must error"), 1}, 0, false, true},
}
Expand Down Expand Up @@ -300,11 +310,25 @@ func (s *testEvaluatorSuite) TestLogicOr(c *C) {
{[]interface{}{0, 1}, 1, false, false},
{[]interface{}{0, 0}, 0, false, false},
{[]interface{}{2, -1}, 1, false, false},
{[]interface{}{"a", "0"}, 0, false, false},
{[]interface{}{"a", "1"}, 1, false, false},
{[]interface{}{"1a", "0"}, 1, false, false},
{[]interface{}{"1a", "1"}, 1, false, false},
{[]interface{}{"0.0a", 0}, 0, false, false},
{[]interface{}{"0.0001a", 0}, 1, false, false},
{[]interface{}{1, nil}, 1, false, false},
{[]interface{}{nil, 1}, 1, false, false},
{[]interface{}{nil, 0}, 0, true, false},
{[]interface{}{0.000, 0}, 0, false, false},
{[]interface{}{0.001, 0}, 1, false, false},
{[]interface{}{nil, 0.000}, 0, true, false},
{[]interface{}{nil, 0.001}, 1, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000000"), 0}, 0, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000000"), 1}, 1, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000000"), nil}, 0, true, false},
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 0}, 1, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 1}, 1, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000001"), nil}, 1, false, false},

{[]interface{}{errors.New("must error"), 1}, 0, false, true},
}
Expand Down Expand Up @@ -541,3 +565,68 @@ func (s *testEvaluatorSuite) TestIsTrueOrFalse(c *C) {
c.Assert(isFalse, testutil.DatumEquals, types.NewDatum(tc.isFalse))
}
}

func (s *testEvaluatorSuite) TestLogicXor(c *C) {
defer testleak.AfterTest(c)()

sc := s.ctx.GetSessionVars().StmtCtx
origin := sc.IgnoreTruncate
defer func() {
sc.IgnoreTruncate = origin
}()
sc.IgnoreTruncate = true

cases := []struct {
args []interface{}
expected int64
isNil bool
getErr bool
}{
{[]interface{}{1, 1}, 0, false, false},
{[]interface{}{1, 0}, 1, false, false},
{[]interface{}{0, 1}, 1, false, false},
{[]interface{}{0, 0}, 0, false, false},
{[]interface{}{2, -1}, 0, false, false},
{[]interface{}{"a", "0"}, 0, false, false},
{[]interface{}{"a", "1"}, 1, false, false},
{[]interface{}{"1a", "0"}, 1, false, false},
{[]interface{}{"1a", "1"}, 0, false, false},
{[]interface{}{0, nil}, 0, true, false},
{[]interface{}{nil, 0}, 0, true, false},
{[]interface{}{nil, 1}, 0, true, false},
{[]interface{}{0.5000, 0.4999}, 1, false, false},
{[]interface{}{0.5000, 1.0}, 0, false, false},
{[]interface{}{0.4999, 1.0}, 1, false, false},
{[]interface{}{nil, 0.000}, 0, true, false},
{[]interface{}{nil, 0.001}, 0, true, false},
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 0.00001}, 0, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 1}, 1, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000000"), nil}, 0, true, false},
{[]interface{}{types.NewDecFromStringForTest("0.000001"), nil}, 0, true, false},

{[]interface{}{errors.New("must error"), 1}, 0, false, true},
}

for _, t := range cases {
f, err := newFunctionForTest(s.ctx, ast.LogicXor, s.primitiveValsToConstants(t.args)...)
c.Assert(err, IsNil)
d, err := f.Eval(chunk.Row{})
if t.getErr {
c.Assert(err, NotNil)
} else {
c.Assert(err, IsNil)
if t.isNil {
c.Assert(d.Kind(), Equals, types.KindNull)
} else {
c.Assert(d.GetInt64(), Equals, t.expected)
}
}
}

// Test incorrect parameter count.
_, err := newFunctionForTest(s.ctx, ast.LogicXor, Zero)
c.Assert(err, NotNil)

_, err = funcs[ast.LogicXor].getFunction(s.ctx, []Expression{Zero, Zero})
c.Assert(err, IsNil)
}
12 changes: 6 additions & 6 deletions expression/distsql_builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,17 +371,17 @@ func getSignatureByPB(ctx sessionctx.Context, sigCode tipb.ScalarFuncSig, tp *ti
f = &builtinCaseWhenIntSig{base}

case tipb.ScalarFuncSig_IntIsFalse:
f = &builtinIntIsFalseSig{base}
f = &builtinIntIsFalseSig{base, false}
case tipb.ScalarFuncSig_RealIsFalse:
f = &builtinRealIsFalseSig{base}
f = &builtinRealIsFalseSig{base, false}
case tipb.ScalarFuncSig_DecimalIsFalse:
f = &builtinDecimalIsFalseSig{base}
f = &builtinDecimalIsFalseSig{base, false}
case tipb.ScalarFuncSig_IntIsTrue:
f = &builtinIntIsTrueSig{base}
f = &builtinIntIsTrueSig{base, false}
case tipb.ScalarFuncSig_RealIsTrue:
f = &builtinRealIsTrueSig{base}
f = &builtinRealIsTrueSig{base, false}
case tipb.ScalarFuncSig_DecimalIsTrue:
f = &builtinDecimalIsTrueSig{base}
f = &builtinDecimalIsTrueSig{base, false}

case tipb.ScalarFuncSig_IfNullReal:
f = &builtinIfNullRealSig{base}
Expand Down
Loading

0 comments on commit c79440f

Please sign in to comment.