Skip to content

Commit

Permalink
expression: rewrite builtin function RAND, POW, SIGN, SQRT (#4182)
Browse files Browse the repository at this point in the history
  • Loading branch information
breezewish authored and zz-jason committed Aug 15, 2017
1 parent 1dc7bbe commit 276defa
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 109 deletions.
167 changes: 91 additions & 76 deletions expression/builtin_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ var (
_ builtinFunc = &builtinLog2Sig{}
_ builtinFunc = &builtinLog10Sig{}
_ builtinFunc = &builtinRandSig{}
_ builtinFunc = &builtinRandWithSeedSig{}
_ builtinFunc = &builtinPowSig{}
_ builtinFunc = &builtinConvSig{}
_ builtinFunc = &builtinCRC32Sig{}
Expand Down Expand Up @@ -783,38 +784,63 @@ type randFunctionClass struct {
}

func (c *randFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
err := errors.Trace(c.verifyArgs(args))
bt := &builtinRandSig{baseBuiltinFunc: newBaseBuiltinFunc(args, ctx)}
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}
var sig builtinFunc
var argTps []evalTp
if len(args) > 0 {
argTps = []evalTp{tpInt}
}
bf, err := newBaseBuiltinFuncWithTp(args, ctx, tpReal, argTps...)
if err != nil {
return nil, errors.Trace(err)
}
bt := baseRealBuiltinFunc{bf}
bt.deterministic = false
return bt.setSelf(bt), errors.Trace(err)
if len(args) == 0 {
sig = &builtinRandSig{bt, nil}
} else {
sig = &builtinRandWithSeedSig{bt, nil}
}
return sig.setSelf(sig), nil
}

type builtinRandSig struct {
baseBuiltinFunc
baseRealBuiltinFunc
randGen *rand.Rand
}

// eval evals a builtinRandSig.
// evalReal evals RAND().
// See https://dev.mysql.com/doc/refman/5.7/en/mathematical-functions.html#function_rand
func (b *builtinRandSig) eval(row []types.Datum) (d types.Datum, err error) {
args, err := b.evalArgs(row)
func (b *builtinRandSig) evalReal(row []types.Datum) (float64, bool, error) {
if b.randGen == nil {
b.randGen = rand.New(rand.NewSource(time.Now().UnixNano()))
}
return b.randGen.Float64(), false, nil
}

type builtinRandWithSeedSig struct {
baseRealBuiltinFunc
randGen *rand.Rand
}

// evalReal evals RAND(N).
// See https://dev.mysql.com/doc/refman/5.7/en/mathematical-functions.html#function_rand
func (b *builtinRandWithSeedSig) evalReal(row []types.Datum) (float64, bool, error) {
seed, isNull, err := b.args[0].EvalInt(row, b.ctx.GetSessionVars().StmtCtx)
if err != nil {
return d, errors.Trace(err)
return 0, false, errors.Trace(err)
}
if b.randGen == nil {
if len(args) == 1 && !args[0].IsNull() {
seed, err := args[0].ToInt64(b.ctx.GetSessionVars().StmtCtx)
if err != nil {
return d, errors.Trace(err)
}
b.randGen = rand.New(rand.NewSource(seed))
} else {
// If seed is not set, we use current timestamp as seed.
if isNull {
// When seed is NULL, it is equal to RAND().
b.randGen = rand.New(rand.NewSource(time.Now().UnixNano()))
} else {
b.randGen = rand.New(rand.NewSource(seed))
}
}
d.SetFloat64(b.randGen.Float64())
return d, nil
return b.randGen.Float64(), false, nil
}

type powFunctionClass struct {
Expand All @@ -825,38 +851,35 @@ func (c *powFunctionClass) getFunction(args []Expression, ctx context.Context) (
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}
sig := &builtinPowSig{newBaseBuiltinFunc(args, ctx)}
bf, err := newBaseBuiltinFuncWithTp(args, ctx, tpReal, tpReal, tpReal)
if err != nil {
return nil, errors.Trace(err)
}
sig := &builtinPowSig{baseRealBuiltinFunc{bf}}
return sig.setSelf(sig), nil
}

type builtinPowSig struct {
baseBuiltinFunc
baseRealBuiltinFunc
}

// eval evals a builtinPowSig.
// evalReal evals POW(x, y).
// See https://dev.mysql.com/doc/refman/5.7/en/mathematical-functions.html#function_pow
func (b *builtinPowSig) eval(row []types.Datum) (d types.Datum, err error) {
args, err := b.evalArgs(row)
if err != nil {
return d, errors.Trace(err)
}
func (b *builtinPowSig) evalReal(row []types.Datum) (float64, bool, error) {
sc := b.ctx.GetSessionVars().StmtCtx
x, err := args[0].ToFloat64(sc)
if err != nil {
return d, errors.Trace(err)
x, isNull, err := b.args[0].EvalReal(row, sc)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}

y, err := args[1].ToFloat64(sc)
if err != nil {
return d, errors.Trace(err)
y, isNull, err := b.args[1].EvalReal(row, sc)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}

power := math.Pow(x, y)
if math.IsInf(power, -1) || math.IsInf(power, 1) || math.IsNaN(power) {
return d, types.ErrOverflow.GenByArgs("DOUBLE", fmt.Sprintf("pow(%s, %s)", strconv.FormatFloat(x, 'f', -1, 64), strconv.FormatFloat(y, 'f', -1, 64)))
return 0, false, types.ErrOverflow.GenByArgs("DOUBLE", fmt.Sprintf("pow(%s, %s)", strconv.FormatFloat(x, 'f', -1, 64), strconv.FormatFloat(y, 'f', -1, 64)))
}
d.SetFloat64(power)
return d, nil
return power, false, nil
}

type roundFunctionClass struct {
Expand Down Expand Up @@ -1006,30 +1029,32 @@ func (c *signFunctionClass) getFunction(args []Expression, ctx context.Context)
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}
sig := &builtinSignSig{newBaseBuiltinFunc(args, ctx)}
bf, err := newBaseBuiltinFuncWithTp(args, ctx, tpInt, tpReal)
if err != nil {
return nil, errors.Trace(err)
}
sig := &builtinSignSig{baseIntBuiltinFunc{bf}}
return sig.setSelf(sig), nil
}

type builtinSignSig struct {
baseBuiltinFunc
baseIntBuiltinFunc
}

// eval evals a builtinSignSig.
// evalInt evals SIGN(v).
// See https://dev.mysql.com/doc/refman/5.7/en/mathematical-functions.html#function_sign
func (b *builtinSignSig) eval(row []types.Datum) (d types.Datum, err error) {
args, err := b.evalArgs(row)
if err != nil {
return d, errors.Trace(err)
}
if args[0].IsNull() {
return d, nil
func (b *builtinSignSig) evalInt(row []types.Datum) (int64, bool, error) {
val, isNull, err := b.args[0].EvalReal(row, b.ctx.GetSessionVars().StmtCtx)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}
cmp, err := args[0].CompareDatum(b.ctx.GetSessionVars().StmtCtx, types.NewIntDatum(0))
d.SetInt64(int64(cmp))
if err != nil {
return d, errors.Trace(err)
if val > 0 {
return 1, false, nil
} else if val == 0 {
return 0, false, nil
} else {
return -1, false, nil
}
return d, nil
}

type sqrtFunctionClass struct {
Expand All @@ -1040,39 +1065,29 @@ func (c *sqrtFunctionClass) getFunction(args []Expression, ctx context.Context)
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}
sig := &builtinSqrtSig{newBaseBuiltinFunc(args, ctx)}
bf, err := newBaseBuiltinFuncWithTp(args, ctx, tpReal, tpReal)
if err != nil {
return nil, errors.Trace(err)
}
sig := &builtinSqrtSig{baseRealBuiltinFunc{bf}}
return sig.setSelf(sig), nil
}

type builtinSqrtSig struct {
baseBuiltinFunc
baseRealBuiltinFunc
}

// eval evals a builtinSqrtSig.
// evalReal evals a SQRT(x).
// See https://dev.mysql.com/doc/refman/5.7/en/mathematical-functions.html#function_sqrt
func (b *builtinSqrtSig) eval(row []types.Datum) (d types.Datum, err error) {
args, err := b.evalArgs(row)
if err != nil {
return d, errors.Trace(err)
}
if args[0].IsNull() {
return d, nil
}

sc := b.ctx.GetSessionVars().StmtCtx
f, err := args[0].ToFloat64(sc)
if err != nil {
return d, errors.Trace(err)
func (b *builtinSqrtSig) evalReal(row []types.Datum) (float64, bool, error) {
val, isNull, err := b.args[0].EvalReal(row, b.ctx.GetSessionVars().StmtCtx)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}

// negative value does not have any square root in rational number
// Need return null directly.
if f < 0 {
return d, nil
if val < 0 {
return 0, true, nil
}

d.SetFloat64(math.Sqrt(f))
return
return math.Sqrt(val), false, nil
}

type acosFunctionClass struct {
Expand Down
70 changes: 38 additions & 32 deletions expression/builtin_math_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ func (s *testEvaluatorSuite) TestRand(c *C) {
fc := funcs[ast.Rand]
f, err := fc.getFunction(nil, s.ctx)
c.Assert(err, IsNil)
c.Assert(f.isDeterministic(), IsFalse)
v, err := f.eval(nil)
c.Assert(err, IsNil)
c.Assert(v.GetFloat64(), Less, float64(1))
Expand Down Expand Up @@ -405,6 +406,7 @@ func (s *testEvaluatorSuite) TestPow(c *C) {
fc := funcs[ast.Pow]
f, err := fc.getFunction(datumsToConstants(t["Arg"]), s.ctx)
c.Assert(err, IsNil)
c.Assert(f.isDeterministic(), IsTrue)
v, err := f.eval(nil)
c.Assert(err, IsNil)
c.Assert(v, testutil.DatumEquals, t["Ret"][0])
Expand All @@ -414,9 +416,7 @@ func (s *testEvaluatorSuite) TestPow(c *C) {
Arg []interface{}
}{
{[]interface{}{"test", "test"}},
{[]interface{}{nil, nil}},
{[]interface{}{1, "test"}},
{[]interface{}{1, nil}},
{[]interface{}{10, 700}}, // added overflow test
}

Expand Down Expand Up @@ -597,30 +597,37 @@ func (s *testEvaluatorSuite) TestConv(c *C) {
func (s *testEvaluatorSuite) TestSign(c *C) {
defer testleak.AfterTest(c)()

sc := s.ctx.GetSessionVars().StmtCtx
tmpIT := sc.IgnoreTruncate
sc.IgnoreTruncate = true
defer func() {
sc.IgnoreTruncate = tmpIT
}()

for _, t := range []struct {
num interface{}
num []interface{}
ret interface{}
err Checker
}{
{nil, nil, IsNil},
{1, 1, IsNil},
{0, 0, IsNil},
{-1, -1, IsNil},
{0.4, 1, IsNil},
{-0.4, -1, IsNil},
{"1", 1, IsNil},
{"-1", -1, IsNil},
{"1a", 1, NotNil},
{"-1a", -1, NotNil},
{"a", 0, NotNil},
{uint64(9223372036854775808), 1, IsNil},
{[]interface{}{nil}, nil},
{[]interface{}{1}, int64(1)},
{[]interface{}{0}, int64(0)},
{[]interface{}{-1}, int64(-1)},
{[]interface{}{0.4}, int64(1)},
{[]interface{}{-0.4}, int64(-1)},
{[]interface{}{"1"}, int64(1)},
{[]interface{}{"-1"}, int64(-1)},
{[]interface{}{"1a"}, int64(1)},
{[]interface{}{"-1a"}, int64(-1)},
{[]interface{}{"a"}, int64(0)},
{[]interface{}{uint64(9223372036854775808)}, int64(1)},
} {
fc := funcs[ast.Sign]
f, err := fc.getFunction(datumsToConstants(types.MakeDatums(t.num)), s.ctx)
c.Assert(err, IsNil)
f, err := fc.getFunction(primitiveValsToConstants(t.num), s.ctx)
c.Assert(err, IsNil, Commentf("%v", t))
c.Assert(f.isDeterministic(), IsTrue)
v, err := f.eval(nil)
c.Assert(err, t.err)
c.Assert(v, testutil.DatumEquals, types.NewDatum(t.ret))
c.Assert(err, IsNil, Commentf("%v", t))
c.Assert(v, testutil.DatumEquals, types.NewDatum(t.ret), Commentf("%v", t))
}
}

Expand Down Expand Up @@ -669,26 +676,25 @@ func (s *testEvaluatorSuite) TestDegrees(c *C) {
func (s *testEvaluatorSuite) TestSqrt(c *C) {
defer testleak.AfterTest(c)()
tbl := []struct {
Arg interface{}
Arg []interface{}
Ret interface{}
}{
{nil, nil},
{int64(1), float64(1)},
{float64(4), float64(2)},
{"4", float64(2)},
{"9", float64(3)},
{"-16", nil},
{[]interface{}{nil}, nil},
{[]interface{}{int64(1)}, float64(1)},
{[]interface{}{float64(4)}, float64(2)},
{[]interface{}{"4"}, float64(2)},
{[]interface{}{"9"}, float64(3)},
{[]interface{}{"-16"}, nil},
}

Dtbl := tblToDtbl(tbl)

for _, t := range Dtbl {
for _, t := range tbl {
fc := funcs[ast.Sqrt]
f, err := fc.getFunction(datumsToConstants(t["Arg"]), s.ctx)
f, err := fc.getFunction(primitiveValsToConstants(t.Arg), s.ctx)
c.Assert(err, IsNil)
c.Assert(f.isDeterministic(), IsTrue)
v, err := f.eval(nil)
c.Assert(err, IsNil)
c.Assert(v, DeepEquals, t["Ret"][0], Commentf("arg:%v", t["Arg"]))
c.Assert(v, testutil.DatumEquals, types.NewDatum(t.Ret), Commentf("%v", t))
}
}

Expand Down
2 changes: 1 addition & 1 deletion expression/constant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func (*testExpressionSuite) TestConstantFolding(c *C) {
},
{
condition: newFunction(ast.EQ, newColumn("a"), newFunction(ast.Rand)),
result: "eq(test.t.a, rand())",
result: "eq(cast(test.t.a), rand())",
},
{
condition: newFunction(ast.IsNull, newLonglong(1)),
Expand Down
Loading

0 comments on commit 276defa

Please sign in to comment.