Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: fix incorrect result of logical operators #12173

Merged
merged 3 commits into from
Oct 14, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -660,8 +660,8 @@ var funcs = map[string]functionClass{
ast.Xor: &bitXorFunctionClass{baseFunctionClass{ast.Xor, 2, 2}},
ast.UnaryMinus: &unaryMinusFunctionClass{baseFunctionClass{ast.UnaryMinus, 1, 1}},
ast.In: &inFunctionClass{baseFunctionClass{ast.In, 2, -1}},
ast.IsTruth: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth},
ast.IsFalsity: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity},
ast.IsTruth: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth, false},
ast.IsFalsity: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity, false},
ast.Like: &likeFunctionClass{baseFunctionClass{ast.Like, 3, 3}},
ast.Regexp: &regexpFunctionClass{baseFunctionClass{ast.Regexp, 2, 2}},
ast.Case: &caseWhenFunctionClass{baseFunctionClass{ast.Case, 1, -1}},
Expand Down
72 changes: 60 additions & 12 deletions expression/builtin_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ func (c *logicAndFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
if err != nil {
return nil, err
}
args[0], err = wrapWithIsTrue(ctx, true, args[0])
if err != nil {
return nil, errors.Trace(err)
}
args[1], err = wrapWithIsTrue(ctx, true, args[1])
if err != nil {
return nil, errors.Trace(err)
}

bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt)
sig := &builtinLogicAndSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_LogicalAnd)
Expand Down Expand Up @@ -108,6 +117,15 @@ func (c *logicOrFunctionClass) getFunction(ctx sessionctx.Context, args []Expres
if err != nil {
return nil, err
}
args[0], err = wrapWithIsTrue(ctx, true, args[0])
if err != nil {
return nil, errors.Trace(err)
}
args[1], err = wrapWithIsTrue(ctx, true, args[1])
if err != nil {
return nil, errors.Trace(err)
}

bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt)
bf.tp.Flen = 1
sig := &builtinLogicOrSig{bf}
Expand Down Expand Up @@ -155,6 +173,7 @@ func (c *logicXorFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
if err != nil {
return nil, err
}

bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt)
sig := &builtinLogicXorSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_LogicalXor)
Expand Down Expand Up @@ -378,6 +397,11 @@ func (b *builtinRightShiftSig) evalInt(row chunk.Row) (int64, bool, error) {
type isTrueOrFalseFunctionClass struct {
baseFunctionClass
op opcode.Op

// keepNull indicates how this function treats a null input parameter.
// If keepNull is true and the input parameter is null, the function will return null.
// If keepNull is false, the null input parameter will be cast to 0.
keepNull bool
}

func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
Expand All @@ -400,27 +424,27 @@ func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args []
case opcode.IsTruth:
switch argTp {
case types.ETReal:
sig = &builtinRealIsTrueSig{bf}
sig = &builtinRealIsTrueSig{bf, c.keepNull}
sig.setPbCode(tipb.ScalarFuncSig_RealIsTrue)
case types.ETDecimal:
sig = &builtinDecimalIsTrueSig{bf}
sig = &builtinDecimalIsTrueSig{bf, c.keepNull}
sig.setPbCode(tipb.ScalarFuncSig_DecimalIsTrue)
case types.ETInt:
sig = &builtinIntIsTrueSig{bf}
sig = &builtinIntIsTrueSig{bf, c.keepNull}
sig.setPbCode(tipb.ScalarFuncSig_IntIsTrue)
default:
return nil, errors.Errorf("unexpected types.EvalType %v", argTp)
}
case opcode.IsFalsity:
switch argTp {
case types.ETReal:
sig = &builtinRealIsFalseSig{bf}
sig = &builtinRealIsFalseSig{bf, c.keepNull}
sig.setPbCode(tipb.ScalarFuncSig_RealIsFalse)
case types.ETDecimal:
sig = &builtinDecimalIsFalseSig{bf}
sig = &builtinDecimalIsFalseSig{bf, c.keepNull}
sig.setPbCode(tipb.ScalarFuncSig_DecimalIsFalse)
case types.ETInt:
sig = &builtinIntIsFalseSig{bf}
sig = &builtinIntIsFalseSig{bf, c.keepNull}
sig.setPbCode(tipb.ScalarFuncSig_IntIsFalse)
default:
return nil, errors.Errorf("unexpected types.EvalType %v", argTp)
Expand All @@ -431,10 +455,11 @@ func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args []

type builtinRealIsTrueSig struct {
baseBuiltinFunc
keepNull bool
}

func (b *builtinRealIsTrueSig) Clone() builtinFunc {
newSig := &builtinRealIsTrueSig{}
newSig := &builtinRealIsTrueSig{keepNull: b.keepNull}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}
Expand All @@ -444,6 +469,9 @@ func (b *builtinRealIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {
if err != nil {
return 0, true, err
}
if b.keepNull && isNull {
return 0, true, nil
}
if isNull || input == 0 {
return 0, false, nil
}
Expand All @@ -452,10 +480,11 @@ func (b *builtinRealIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {

type builtinDecimalIsTrueSig struct {
baseBuiltinFunc
keepNull bool
}

func (b *builtinDecimalIsTrueSig) Clone() builtinFunc {
newSig := &builtinDecimalIsTrueSig{}
newSig := &builtinDecimalIsTrueSig{keepNull: b.keepNull}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}
Expand All @@ -465,6 +494,9 @@ func (b *builtinDecimalIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {
if err != nil {
return 0, true, err
}
if b.keepNull && isNull {
return 0, true, nil
}
if isNull || input.IsZero() {
return 0, false, nil
}
Expand All @@ -473,10 +505,11 @@ func (b *builtinDecimalIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {

type builtinIntIsTrueSig struct {
baseBuiltinFunc
keepNull bool
}

func (b *builtinIntIsTrueSig) Clone() builtinFunc {
newSig := &builtinIntIsTrueSig{}
newSig := &builtinIntIsTrueSig{keepNull: b.keepNull}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}
Expand All @@ -486,6 +519,9 @@ func (b *builtinIntIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {
if err != nil {
return 0, true, err
}
if b.keepNull && isNull {
return 0, true, nil
}
if isNull || input == 0 {
return 0, false, nil
}
Expand All @@ -494,10 +530,11 @@ func (b *builtinIntIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) {

type builtinRealIsFalseSig struct {
baseBuiltinFunc
keepNull bool
}

func (b *builtinRealIsFalseSig) Clone() builtinFunc {
newSig := &builtinRealIsFalseSig{}
newSig := &builtinRealIsFalseSig{keepNull: b.keepNull}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}
Expand All @@ -507,6 +544,9 @@ func (b *builtinRealIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) {
if err != nil {
return 0, true, err
}
if b.keepNull && isNull {
return 0, true, nil
}
if isNull || input != 0 {
return 0, false, nil
}
Expand All @@ -515,10 +555,11 @@ func (b *builtinRealIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) {

type builtinDecimalIsFalseSig struct {
baseBuiltinFunc
keepNull bool
}

func (b *builtinDecimalIsFalseSig) Clone() builtinFunc {
newSig := &builtinDecimalIsFalseSig{}
newSig := &builtinDecimalIsFalseSig{keepNull: b.keepNull}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}
Expand All @@ -528,6 +569,9 @@ func (b *builtinDecimalIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) {
if err != nil {
return 0, true, err
}
if b.keepNull && isNull {
return 0, true, nil
}
if isNull || !input.IsZero() {
return 0, false, nil
}
Expand All @@ -536,10 +580,11 @@ func (b *builtinDecimalIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) {

type builtinIntIsFalseSig struct {
baseBuiltinFunc
keepNull bool
}

func (b *builtinIntIsFalseSig) Clone() builtinFunc {
newSig := &builtinIntIsFalseSig{}
newSig := &builtinIntIsFalseSig{keepNull: b.keepNull}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}
Expand All @@ -549,6 +594,9 @@ func (b *builtinIntIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) {
if err != nil {
return 0, true, err
}
if b.keepNull && isNull {
return 0, true, nil
}
if isNull || input != 0 {
return 0, false, nil
}
Expand Down
89 changes: 89 additions & 0 deletions expression/builtin_op_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,21 @@ func (s *testEvaluatorSuite) TestLogicAnd(c *C) {
{[]interface{}{0, 1}, 0, false, false},
{[]interface{}{0, 0}, 0, false, false},
{[]interface{}{2, -1}, 1, false, false},
{[]interface{}{"a", "0"}, 0, false, false},
{[]interface{}{"a", "1"}, 0, false, false},
{[]interface{}{"1a", "0"}, 0, false, false},
{[]interface{}{"1a", "1"}, 1, false, false},
{[]interface{}{0, nil}, 0, false, false},
{[]interface{}{nil, 0}, 0, false, false},
{[]interface{}{nil, 1}, 0, true, false},
{[]interface{}{0.001, 0}, 0, false, false},
{[]interface{}{0.001, 1}, 1, false, false},
{[]interface{}{nil, 0.000}, 0, false, false},
{[]interface{}{nil, 0.001}, 0, true, false},
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 0}, 0, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 1}, 1, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000000"), nil}, 0, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000001"), nil}, 0, true, false},

{[]interface{}{errors.New("must error"), 1}, 0, false, true},
}
Expand Down Expand Up @@ -300,11 +310,25 @@ func (s *testEvaluatorSuite) TestLogicOr(c *C) {
{[]interface{}{0, 1}, 1, false, false},
{[]interface{}{0, 0}, 0, false, false},
{[]interface{}{2, -1}, 1, false, false},
{[]interface{}{"a", "0"}, 0, false, false},
{[]interface{}{"a", "1"}, 1, false, false},
{[]interface{}{"1a", "0"}, 1, false, false},
{[]interface{}{"1a", "1"}, 1, false, false},
{[]interface{}{"0.0a", 0}, 0, false, false},
{[]interface{}{"0.0001a", 0}, 1, false, false},
{[]interface{}{1, nil}, 1, false, false},
{[]interface{}{nil, 1}, 1, false, false},
{[]interface{}{nil, 0}, 0, true, false},
{[]interface{}{0.000, 0}, 0, false, false},
{[]interface{}{0.001, 0}, 1, false, false},
{[]interface{}{nil, 0.000}, 0, true, false},
{[]interface{}{nil, 0.001}, 1, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000000"), 0}, 0, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000000"), 1}, 1, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000000"), nil}, 0, true, false},
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 0}, 1, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 1}, 1, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000001"), nil}, 1, false, false},

{[]interface{}{errors.New("must error"), 1}, 0, false, true},
}
Expand Down Expand Up @@ -559,3 +583,68 @@ func (s *testEvaluatorSuite) TestIsTrueOrFalse(c *C) {
c.Assert(isFalse, testutil.DatumEquals, types.NewDatum(tc.isFalse))
}
}

func (s *testEvaluatorSuite) TestLogicXor(c *C) {
defer testleak.AfterTest(c)()

sc := s.ctx.GetSessionVars().StmtCtx
origin := sc.IgnoreTruncate
defer func() {
sc.IgnoreTruncate = origin
}()
sc.IgnoreTruncate = true

cases := []struct {
args []interface{}
expected int64
isNil bool
getErr bool
}{
{[]interface{}{1, 1}, 0, false, false},
{[]interface{}{1, 0}, 1, false, false},
{[]interface{}{0, 1}, 1, false, false},
{[]interface{}{0, 0}, 0, false, false},
{[]interface{}{2, -1}, 0, false, false},
{[]interface{}{"a", "0"}, 0, false, false},
{[]interface{}{"a", "1"}, 1, false, false},
{[]interface{}{"1a", "0"}, 1, false, false},
{[]interface{}{"1a", "1"}, 0, false, false},
{[]interface{}{0, nil}, 0, true, false},
{[]interface{}{nil, 0}, 0, true, false},
{[]interface{}{nil, 1}, 0, true, false},
{[]interface{}{0.5000, 0.4999}, 1, false, false},
{[]interface{}{0.5000, 1.0}, 0, false, false},
{[]interface{}{0.4999, 1.0}, 1, false, false},
{[]interface{}{nil, 0.000}, 0, true, false},
{[]interface{}{nil, 0.001}, 0, true, false},
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 0.00001}, 0, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000001"), 1}, 1, false, false},
{[]interface{}{types.NewDecFromStringForTest("0.000000"), nil}, 0, true, false},
{[]interface{}{types.NewDecFromStringForTest("0.000001"), nil}, 0, true, false},

{[]interface{}{errors.New("must error"), 1}, 0, false, true},
}

for _, t := range cases {
f, err := newFunctionForTest(s.ctx, ast.LogicXor, s.primitiveValsToConstants(t.args)...)
c.Assert(err, IsNil)
d, err := f.Eval(chunk.Row{})
if t.getErr {
c.Assert(err, NotNil)
} else {
c.Assert(err, IsNil)
if t.isNil {
c.Assert(d.Kind(), Equals, types.KindNull)
} else {
c.Assert(d.GetInt64(), Equals, t.expected)
}
}
}

// Test incorrect parameter count.
_, err := newFunctionForTest(s.ctx, ast.LogicXor, Zero)
c.Assert(err, NotNil)

_, err = funcs[ast.LogicXor].getFunction(s.ctx, []Expression{Zero, Zero})
c.Assert(err, IsNil)
}
12 changes: 6 additions & 6 deletions expression/distsql_builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,17 +376,17 @@ func getSignatureByPB(ctx sessionctx.Context, sigCode tipb.ScalarFuncSig, tp *ti
f = &builtinCaseWhenIntSig{base}

case tipb.ScalarFuncSig_IntIsFalse:
f = &builtinIntIsFalseSig{base}
f = &builtinIntIsFalseSig{base, false}
case tipb.ScalarFuncSig_RealIsFalse:
f = &builtinRealIsFalseSig{base}
f = &builtinRealIsFalseSig{base, false}
case tipb.ScalarFuncSig_DecimalIsFalse:
f = &builtinDecimalIsFalseSig{base}
f = &builtinDecimalIsFalseSig{base, false}
case tipb.ScalarFuncSig_IntIsTrue:
f = &builtinIntIsTrueSig{base}
f = &builtinIntIsTrueSig{base, false}
case tipb.ScalarFuncSig_RealIsTrue:
f = &builtinRealIsTrueSig{base}
f = &builtinRealIsTrueSig{base, false}
case tipb.ScalarFuncSig_DecimalIsTrue:
f = &builtinDecimalIsTrueSig{base}
f = &builtinDecimalIsTrueSig{base, false}

case tipb.ScalarFuncSig_IfNullReal:
f = &builtinIfNullRealSig{base}
Expand Down
Loading