Skip to content

Commit

Permalink
expression: fix wrong result of Not/IsTrue/IsFalse functions (p…
Browse files Browse the repository at this point in the history
  • Loading branch information
eurekaka committed Mar 31, 2020
1 parent 3a5c5d0 commit f7eea74
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 15 deletions.
95 changes: 82 additions & 13 deletions expression/builtin_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ var (
_ builtinFunc = &builtinRealIsNullSig{}
_ builtinFunc = &builtinStringIsNullSig{}
_ builtinFunc = &builtinTimeIsNullSig{}
_ builtinFunc = &builtinUnaryNotSig{}
_ builtinFunc = &builtinUnaryNotRealSig{}
_ builtinFunc = &builtinUnaryNotDecimalSig{}
_ builtinFunc = &builtinUnaryNotIntSig{}
)

type logicAndFunctionClass struct {
Expand Down Expand Up @@ -408,8 +410,10 @@ func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args []
}

argTp := args[0].GetType().EvalType()
if argTp != types.ETReal && argTp != types.ETDecimal {
if argTp == types.ETTimestamp || argTp == types.ETDatetime || argTp == types.ETDuration {
argTp = types.ETInt
} else if argTp == types.ETJson || argTp == types.ETString {
argTp = types.ETReal
}

bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, argTp)
Expand All @@ -428,6 +432,8 @@ func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args []
case types.ETInt:
sig = &builtinIntIsTrueSig{bf, c.keepNull}
sig.setPbCode(tipb.ScalarFuncSig_IntIsTrue)
default:
return nil, errors.Errorf("unexpected types.EvalType %v", argTp)
}
case opcode.IsFalsity:
switch argTp {
Expand All @@ -440,6 +446,8 @@ func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args []
case types.ETInt:
sig = &builtinIntIsFalseSig{bf, c.keepNull}
sig.setPbCode(tipb.ScalarFuncSig_IntIsFalse)
default:
return nil, errors.Errorf("unexpected types.EvalType %v", argTp)
}
}
return sig, nil
Expand Down Expand Up @@ -637,33 +645,94 @@ func (c *unaryNotFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
return nil, errors.Trace(err)
}

bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt)
argTp := args[0].GetType().EvalType()
if argTp == types.ETTimestamp || argTp == types.ETDatetime || argTp == types.ETDuration {
argTp = types.ETInt
} else if argTp == types.ETJson || argTp == types.ETString {
argTp = types.ETReal
}

bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, argTp)
bf.tp.Flen = 1

sig := &builtinUnaryNotSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_UnaryNot)
var sig builtinFunc
switch argTp {
case types.ETReal:
sig = &builtinUnaryNotRealSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_UnaryNotReal)
case types.ETDecimal:
sig = &builtinUnaryNotDecimalSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_UnaryNotDecimal)
case types.ETInt:
sig = &builtinUnaryNotIntSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_UnaryNotInt)
default:
return nil, errors.Errorf("unexpected types.EvalType %v", argTp)
}
return sig, nil
}

type builtinUnaryNotSig struct {
type builtinUnaryNotRealSig struct {
baseBuiltinFunc
}

func (b *builtinUnaryNotRealSig) Clone() builtinFunc {
newSig := &builtinUnaryNotRealSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinUnaryNotRealSig) evalInt(row chunk.Row) (int64, bool, error) {
arg, isNull, err := b.args[0].EvalReal(b.ctx, row)
if isNull || err != nil {
return 0, true, err
}
if arg == 0 {
return 1, false, nil
}
return 0, false, nil
}

type builtinUnaryNotDecimalSig struct {
baseBuiltinFunc
}

func (b *builtinUnaryNotSig) Clone() builtinFunc {
newSig := &builtinUnaryNotSig{}
func (b *builtinUnaryNotDecimalSig) Clone() builtinFunc {
newSig := &builtinUnaryNotDecimalSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinUnaryNotSig) evalInt(row chunk.Row) (int64, bool, error) {
func (b *builtinUnaryNotDecimalSig) evalInt(row chunk.Row) (int64, bool, error) {
arg, isNull, err := b.args[0].EvalDecimal(b.ctx, row)
if isNull || err != nil {
return 0, true, err
}
if arg.IsZero() {
return 1, false, nil
}
return 0, false, nil
}

type builtinUnaryNotIntSig struct {
baseBuiltinFunc
}

func (b *builtinUnaryNotIntSig) Clone() builtinFunc {
newSig := &builtinUnaryNotIntSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinUnaryNotIntSig) evalInt(row chunk.Row) (int64, bool, error) {
arg, isNull, err := b.args[0].EvalInt(b.ctx, row)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
return 0, true, err
}
if arg != 0 {
return 0, false, nil
if arg == 0 {
return 1, false, nil
}
return 1, false, nil
return 0, false, nil
}

type unaryMinusFunctionClass struct {
Expand Down
18 changes: 18 additions & 0 deletions expression/builtin_op_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,9 @@ func (s *testEvaluatorSuite) TestUnaryNot(c *C) {
{[]interface{}{123}, 0, false, false},
{[]interface{}{-123}, 0, false, false},
{[]interface{}{"123"}, 0, false, false},
{[]interface{}{float64(0.3)}, 0, false, false},
{[]interface{}{"0.3"}, 0, false, false},
{[]interface{}{types.NewDecFromFloatForTest(0.3)}, 0, false, false},
{[]interface{}{nil}, 0, true, false},

{[]interface{}{errors.New("must error")}, 0, false, true},
Expand Down Expand Up @@ -539,6 +542,21 @@ func (s *testEvaluatorSuite) TestIsTrueOrFalse(c *C) {
isTrue: 0,
isFalse: 1,
},
{
args: []interface{}{"0.3"},
isTrue: 1,
isFalse: 0,
},
{
args: []interface{}{float64(0.3)},
isTrue: 1,
isFalse: 0,
},
{
args: []interface{}{types.NewDecFromFloatForTest(0.3)},
isTrue: 1,
isFalse: 0,
},
{
args: []interface{}{nil},
isTrue: 0,
Expand Down
8 changes: 6 additions & 2 deletions expression/distsql_builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,12 @@ func getSignatureByPB(ctx sessionctx.Context, sigCode tipb.ScalarFuncSig, tp *ti
case tipb.ScalarFuncSig_BitNegSig:
f = &builtinBitNegSig{base}

case tipb.ScalarFuncSig_UnaryNot:
f = &builtinUnaryNotSig{base}
case tipb.ScalarFuncSig_UnaryNotReal:
f = &builtinUnaryNotRealSig{base}
case tipb.ScalarFuncSig_UnaryNotDecimal:
f = &builtinUnaryNotDecimalSig{base}
case tipb.ScalarFuncSig_UnaryNotInt:
f = &builtinUnaryNotIntSig{base}
case tipb.ScalarFuncSig_UnaryMinusInt:
f = &builtinUnaryMinusIntSig{base}
case tipb.ScalarFuncSig_UnaryMinusReal:
Expand Down

0 comments on commit f7eea74

Please sign in to comment.