Skip to content

Commit

Permalink
expression: refine built-in func truncate to support uint arg (#8000) (
Browse files Browse the repository at this point in the history
  • Loading branch information
yu34po authored and zz-jason committed Oct 26, 2018
1 parent 5707a9b commit 75192d7
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 3 deletions.
44 changes: 41 additions & 3 deletions expression/builtin_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ var (
_ builtinFunc = &builtinTruncateIntSig{}
_ builtinFunc = &builtinTruncateRealSig{}
_ builtinFunc = &builtinTruncateDecimalSig{}
_ builtinFunc = &builtinTruncateUintSig{}
)

type absFunctionClass struct {
Expand Down Expand Up @@ -1737,7 +1738,11 @@ func (c *truncateFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
var sig builtinFunc
switch argTp {
case types.ETInt:
sig = &builtinTruncateIntSig{bf}
if mysql.HasUnsignedFlag(args[0].GetType().Flag) {
sig = &builtinTruncateUintSig{bf}
} else {
sig = &builtinTruncateIntSig{bf}
}
case types.ETReal:
sig = &builtinTruncateRealSig{bf}
case types.ETDecimal:
Expand Down Expand Up @@ -1826,6 +1831,39 @@ func (b *builtinTruncateIntSig) evalInt(row chunk.Row) (int64, bool, error) {
return 0, isNull, errors.Trace(err)
}

floatX := float64(x)
return int64(types.Truncate(floatX, int(d))), false, nil
if d >= 0 {
return x, false, nil
}
shift := int64(math.Pow10(int(-d)))
return x / shift * shift, false, nil
}

func (b *builtinTruncateUintSig) Clone() builtinFunc {
newSig := &builtinTruncateUintSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

type builtinTruncateUintSig struct {
baseBuiltinFunc
}

// evalInt evals a TRUNCATE(X,D).
// See https://dev.mysql.com/doc/refman/5.7/en/mathematical-functions.html#function_truncate
func (b *builtinTruncateUintSig) evalInt(row chunk.Row) (int64, bool, error) {
x, isNull, err := b.args[0].EvalInt(b.ctx, row)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}
uintx := uint64(x)

d, isNull, err := b.args[1].EvalInt(b.ctx, row)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}
if d >= 0 {
return x, false, nil
}
shift := uint64(math.Pow10(int(-d)))
return int64(uintx / shift * shift), false, nil
}
3 changes: 3 additions & 0 deletions expression/builtin_math_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,9 @@ func (s *testEvaluatorSuite) TestTruncate(c *C) {
{[]interface{}{newDec("23.298"), -100}, newDec("0")},
{[]interface{}{newDec("23.298"), 100}, newDec("23.298")},
{[]interface{}{nil, 2}, nil},
{[]interface{}{uint64(9223372036854775808), -10}, 9223372030000000000},
{[]interface{}{9223372036854775807, -7}, 9223372036850000000},
{[]interface{}{uint64(18446744073709551615), -10}, uint64(18446744070000000000)},
}

Dtbl := tblToDtbl(tbl)
Expand Down
2 changes: 2 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,8 @@ func (s *testIntegrationSuite) TestMathBuiltin(c *C) {
result.Check(testkit.Rows("100 123 123 120"))
result = tk.MustQuery("SELECT truncate(123.456, -2), truncate(123.456, 2), truncate(123.456, 1), truncate(123.456, 3), truncate(1.23, 100), truncate(123456E-3, 2);")
result.Check(testkit.Rows("100 123.45 123.4 123.456 1.230000000000000000000000000000 123.45"))
result = tk.MustQuery("SELECT truncate(9223372036854775807, -7), truncate(9223372036854775808, -10), truncate(cast(-1 as unsigned), -10);")
result.Check(testkit.Rows("9223372036850000000 9223372030000000000 18446744070000000000"))

tk.MustExec(`drop table if exists t;`)
tk.MustExec(`create table t(a date, b datetime, c timestamp, d varchar(20));`)
Expand Down

0 comments on commit 75192d7

Please sign in to comment.