diff --git a/expression/builtin_control.go b/expression/builtin_control.go index 58c83eb7a7cd8..d327b1a694c9e 100644 --- a/expression/builtin_control.go +++ b/expression/builtin_control.go @@ -474,6 +474,10 @@ func (c *ifFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) } retTp := InferType4ControlFuncs(args[1].GetType(), args[2].GetType()) evalTps := retTp.EvalType() + args[0], err = wrapWithIsTrue(ctx, true, args[0]) + if err != nil { + return nil, err + } bf := newBaseBuiltinFuncWithTp(ctx, args, evalTps, types.ETInt, evalTps, evalTps) retTp.Flag |= bf.tp.Flag bf.tp = retTp diff --git a/expression/builtin_control_test.go b/expression/builtin_control_test.go index 00a4aef065412..cef220eb5d91b 100644 --- a/expression/builtin_control_test.go +++ b/expression/builtin_control_test.go @@ -81,6 +81,12 @@ func (s *testEvaluatorSuite) TestIf(c *C) { {types.Duration{Duration: time.Duration(0)}, 1, 2, 2}, {types.NewDecFromStringForTest("1.2"), 1, 2, 1}, {jsonInt.GetMysqlJSON(), 1, 2, 1}, + {0.1, 1, 2, 1}, + {0.0, 1, 2, 2}, + {types.NewDecFromStringForTest("0.1"), 1, 2, 1}, + {types.NewDecFromStringForTest("0.0"), 1, 2, 2}, + {"0.1", 1, 2, 1}, + {"0.0", 1, 2, 2}, } fc := funcs[ast.If]