Skip to content

Commit

Permalink
expression: rewrite builtin function: MOD (#4407)
Browse files Browse the repository at this point in the history
  • Loading branch information
lkk2003rty authored and coocood committed Sep 21, 2017
1 parent dba31ef commit c899e5b
Show file tree
Hide file tree
Showing 12 changed files with 416 additions and 67 deletions.
5 changes: 4 additions & 1 deletion executor/prepared.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,15 +341,18 @@ func ResetStmtCtx(ctx context.Context, s ast.StmtNode) {
sc.OverflowAsWarning = false
sc.TruncateAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr
sc.InUpdateOrDeleteStmt = true
sc.DividedByZeroAsWarning = stmt.IgnoreErr
case *ast.DeleteStmt:
sc.IgnoreTruncate = false
sc.OverflowAsWarning = false
sc.TruncateAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr
sc.InUpdateOrDeleteStmt = true
sc.DividedByZeroAsWarning = stmt.IgnoreErr
case *ast.InsertStmt:
sc.IgnoreTruncate = false
sc.TruncateAsWarning = !sessVars.StrictSQLMode
sc.TruncateAsWarning = !sessVars.StrictSQLMode || stmt.IgnoreErr
sc.InInsertStmt = true
sc.DividedByZeroAsWarning = stmt.IgnoreErr
case *ast.CreateTableStmt, *ast.AlterTableStmt:
// Make sure the sql_mode is strict when checking column default value.
sc.IgnoreTruncate = false
Expand Down
2 changes: 1 addition & 1 deletion expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,7 @@ var funcs = map[string]functionClass{
ast.NullEQ: &compareFunctionClass{baseFunctionClass{ast.NullEQ, 2, 2}, opcode.NullEQ},
ast.Plus: &arithmeticPlusFunctionClass{baseFunctionClass{ast.Plus, 2, 2}},
ast.Minus: &arithmeticMinusFunctionClass{baseFunctionClass{ast.Minus, 2, 2}},
ast.Mod: &arithmeticFunctionClass{baseFunctionClass{ast.Mod, 2, 2}, opcode.Mod},
ast.Mod: &arithmeticModFunctionClass{baseFunctionClass{ast.Mod, 2, 2}},
ast.Div: &arithmeticDivideFunctionClass{baseFunctionClass{ast.Div, 2, 2}},
ast.Mul: &arithmeticMultiplyFunctionClass{baseFunctionClass{ast.Mul, 2, 2}},
ast.IntDiv: &arithmeticIntDivideFunctionClass{baseFunctionClass{ast.IntDiv, 2, 2}},
Expand Down
198 changes: 160 additions & 38 deletions expression/builtin_arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ import (
"fmt"
"math"

"github.com/cznic/mathutil"
"github.com/juju/errors"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/util/types"
"github.com/pingcap/tipb/go-tipb"
)
Expand All @@ -31,7 +31,7 @@ var (
_ functionClass = &arithmeticDivideFunctionClass{}
_ functionClass = &arithmeticMultiplyFunctionClass{}
_ functionClass = &arithmeticIntDivideFunctionClass{}
_ functionClass = &arithmeticFunctionClass{}
_ functionClass = &arithmeticModFunctionClass{}
)

var (
Expand All @@ -49,13 +49,31 @@ var (
_ builtinFunc = &builtinArithmeticMultiplyIntSig{}
_ builtinFunc = &builtinArithmeticIntDivideIntSig{}
_ builtinFunc = &builtinArithmeticIntDivideDecimalSig{}
_ builtinFunc = &builtinArithmeticSig{}
_ builtinFunc = &builtinArithmeticModIntSig{}
_ builtinFunc = &builtinArithmeticModRealSig{}
_ builtinFunc = &builtinArithmeticModDecimalSig{}
)

// precIncrement indicates the number of digits by which to increase the scale of the result of division operations
// performed with the / operator.
const precIncrement = 4

// handleDivisionByZeroError reports error or warning depend on the context.
func handleDivisionByZeroError(ctx context.Context) error {
sc := ctx.GetSessionVars().StmtCtx
if sc.InInsertStmt || sc.InUpdateOrDeleteStmt {
if !ctx.GetSessionVars().SQLMode.HasErrorForDivisionByZeroMode() {
return nil
}
if ctx.GetSessionVars().StrictSQLMode && !sc.DividedByZeroAsWarning {
return ErrDivideByZero
}
}

sc.AppendWarning(ErrDivideByZero)
return nil
}

// numericContextResultType returns TypeClass for numeric function's parameters.
// the returned TypeClass should be one of: ClassInt, ClassDecimal, ClassReal
func numericContextResultType(ft *types.FieldType) types.TypeClass {
Expand Down Expand Up @@ -87,13 +105,13 @@ func setFlenDecimal4RealOrDecimal(retTp, a, b *types.FieldType, isReal bool) {
retTp.Flen = types.UnspecifiedLength
return
}
digitsInt := int(math.Max(float64(a.Flen-a.Decimal), float64(b.Flen-b.Decimal)))
digitsInt := mathutil.Max(a.Flen-a.Decimal, b.Flen-b.Decimal)
retTp.Flen = digitsInt + retTp.Decimal + 3
if isReal {
retTp.Flen = int(math.Min(float64(retTp.Flen), float64(mysql.MaxRealWidth)))
retTp.Flen = mathutil.Min(retTp.Flen, mysql.MaxRealWidth)
return
}
retTp.Flen = int(math.Min(float64(retTp.Flen), float64(mysql.MaxDecimalWidth)))
retTp.Flen = mathutil.Min(retTp.Flen, mysql.MaxDecimalWidth)
return
}
retTp.Decimal = types.UnspecifiedLength
Expand Down Expand Up @@ -522,7 +540,7 @@ func (s *builtinArithmeticDivideRealSig) evalReal(row []types.Datum) (float64, b
return 0, isNull, errors.Trace(err)
}
if b == 0 {
return 0, true, nil
return 0, true, errors.Trace(handleDivisionByZeroError(s.ctx))
}
result := a / b
if math.IsInf(result, 0) {
Expand All @@ -546,7 +564,7 @@ func (s *builtinArithmeticDivideDecimalSig) evalDecimal(row []types.Datum) (*typ
c := &types.MyDecimal{}
err = types.DecimalDiv(a, b, c, types.DivFracIncr)
if err == types.ErrDivByZero {
return c, true, nil
return c, true, errors.Trace(handleDivisionByZeroError(s.ctx))
}
return c, false, err
}
Expand Down Expand Up @@ -589,7 +607,7 @@ func (s *builtinArithmeticIntDivideIntSig) evalInt(row []types.Datum) (int64, bo
}

if b == 0 {
return 0, true, nil
return 0, true, errors.Trace(handleDivisionByZeroError(s.ctx))
}

a, isNull, err := s.args[0].EvalInt(row, sc)
Expand Down Expand Up @@ -634,8 +652,11 @@ func (s *builtinArithmeticIntDivideDecimalSig) evalInt(row []types.Datum) (int64

c := &types.MyDecimal{}
err = types.DecimalDiv(a, b, c, types.DivFracIncr)
if err == types.ErrDivByZero {
return 0, true, errors.Trace(handleDivisionByZeroError(s.ctx))
}
if err != nil {
return 0, err == types.ErrDivByZero, errors.Trace(err)
return 0, true, errors.Trace(err)
}

ret, err := c.ToInt()
Expand All @@ -646,51 +667,152 @@ func (s *builtinArithmeticIntDivideDecimalSig) evalInt(row []types.Datum) (int64
return ret, false, nil
}

type arithmeticFunctionClass struct {
type arithmeticModFunctionClass struct {
baseFunctionClass
op opcode.Op
}

func (c *arithmeticFunctionClass) getFunction(ctx context.Context, args []Expression) (builtinFunc, error) {
func (c *arithmeticModFunctionClass) setType4ModRealOrDecimal(retTp, a, b *types.FieldType, isDecimal bool) {
if a.Decimal == types.UnspecifiedLength || b.Decimal == types.UnspecifiedLength {
retTp.Decimal = types.UnspecifiedLength
} else {
retTp.Decimal = mathutil.Max(a.Decimal, b.Decimal)
if isDecimal && retTp.Decimal > mysql.MaxDecimalScale {
retTp.Decimal = mysql.MaxDecimalScale
}
}

if a.Flen == types.UnspecifiedLength || b.Flen == types.UnspecifiedLength {
retTp.Flen = types.UnspecifiedLength
} else {
retTp.Flen = mathutil.Max(a.Flen, b.Flen)
if isDecimal {
retTp.Flen = mathutil.Min(retTp.Flen, mysql.MaxDecimalWidth)
return
}
retTp.Flen = mathutil.Min(retTp.Flen, mysql.MaxRealWidth)
}
}

func (c *arithmeticModFunctionClass) getFunction(ctx context.Context, args []Expression) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}
sig := &builtinArithmeticSig{newBaseBuiltinFunc(args, ctx), c.op}
return sig.setSelf(sig), nil
tpA, tpB := args[0].GetType(), args[1].GetType()
tcA, tcB := numericContextResultType(tpA), numericContextResultType(tpB)
if tcA == types.ClassReal || tcB == types.ClassReal {
bf := newBaseBuiltinFuncWithTp(args, ctx, tpReal, tpReal, tpReal)
c.setType4ModRealOrDecimal(bf.tp, tpA, tpB, false)
if mysql.HasUnsignedFlag(tpA.Flag) {
bf.tp.Flag |= mysql.UnsignedFlag
}
sig := &builtinArithmeticModRealSig{baseRealBuiltinFunc{bf}}
return sig.setSelf(sig), nil
} else if tcA == types.ClassDecimal || tcB == types.ClassDecimal {
bf := newBaseBuiltinFuncWithTp(args, ctx, tpDecimal, tpDecimal, tpDecimal)
c.setType4ModRealOrDecimal(bf.tp, tpA, tpB, true)
if mysql.HasUnsignedFlag(tpA.Flag) {
bf.tp.Flag |= mysql.UnsignedFlag
}
sig := &builtinArithmeticModDecimalSig{baseDecimalBuiltinFunc{bf}}
return sig.setSelf(sig), nil
} else {
bf := newBaseBuiltinFuncWithTp(args, ctx, tpInt, tpInt, tpInt)
if mysql.HasUnsignedFlag(tpA.Flag) {
bf.tp.Flag |= mysql.UnsignedFlag
}
sig := &builtinArithmeticModIntSig{baseIntBuiltinFunc{bf}}
return sig.setSelf(sig), nil
}
}

type builtinArithmeticSig struct {
baseBuiltinFunc
op opcode.Op
type builtinArithmeticModRealSig struct {
baseRealBuiltinFunc
}

func (s *builtinArithmeticSig) eval(row []types.Datum) (d types.Datum, err error) {
args, err := s.evalArgs(row)
if err != nil {
return d, errors.Trace(err)
func (s *builtinArithmeticModRealSig) evalReal(row []types.Datum) (float64, bool, error) {
sc := s.ctx.GetSessionVars().StmtCtx
b, isNull, err := s.args[1].EvalReal(row, sc)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}

if b == 0 {
return 0, true, errors.Trace(handleDivisionByZeroError(s.ctx))
}

a, isNull, err := s.args[0].EvalReal(row, sc)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}

return math.Mod(a, b), false, nil
}

type builtinArithmeticModDecimalSig struct {
baseDecimalBuiltinFunc
}

func (s *builtinArithmeticModDecimalSig) evalDecimal(row []types.Datum) (*types.MyDecimal, bool, error) {
sc := s.ctx.GetSessionVars().StmtCtx
a, err := types.CoerceArithmetic(sc, args[0])
if err != nil {
return d, errors.Trace(err)
a, isNull, err := s.args[0].EvalDecimal(row, sc)
if isNull || err != nil {
return nil, isNull, errors.Trace(err)
}
b, isNull, err := s.args[1].EvalDecimal(row, sc)
if isNull || err != nil {
return nil, isNull, errors.Trace(err)
}
c := &types.MyDecimal{}
err = types.DecimalMod(a, b, c)
if err == types.ErrDivByZero {
return c, true, errors.Trace(handleDivisionByZeroError(s.ctx))
}
return c, false, errors.Trace(err)
}

b, err := types.CoerceArithmetic(sc, args[1])
if err != nil {
return d, errors.Trace(err)
type builtinArithmeticModIntSig struct {
baseIntBuiltinFunc
}

func (s *builtinArithmeticModIntSig) evalInt(row []types.Datum) (val int64, isNull bool, err error) {
sc := s.ctx.GetSessionVars().StmtCtx

b, isNull, err := s.args[1].EvalInt(row, sc)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}
a, b, err = types.CoerceDatum(sc, a, b)
if err != nil {
return d, errors.Trace(err)

if b == 0 {
return 0, true, errors.Trace(handleDivisionByZeroError(s.ctx))
}
if a.IsNull() || b.IsNull() {
return

a, isNull, err := s.args[0].EvalInt(row, sc)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}

switch s.op {
case opcode.Mod:
return types.ComputeMod(sc, a, b)
default:
return d, errInvalidOperation.Gen("invalid op %v in arithmetic operation", s.op)
var ret int64
isLHSUnsigned := mysql.HasUnsignedFlag(s.args[0].GetType().Flag)
isRHSUnsigned := mysql.HasUnsignedFlag(s.args[1].GetType().Flag)

switch {
case isLHSUnsigned && isRHSUnsigned:
ret = int64(uint64(a) % uint64(b))
case isLHSUnsigned && !isRHSUnsigned:
if b < 0 {
ret = int64(uint64(a) % uint64(-b))
} else {
ret = int64(uint64(a) % uint64(b))
}
case !isLHSUnsigned && isRHSUnsigned:
if a < 0 {
ret = -int64(uint64(-a) % uint64(b))
} else {
ret = int64(uint64(a) % uint64(b))
}
case !isLHSUnsigned && !isRHSUnsigned:
ret = a % b
}

return ret, false, nil
}
Loading

0 comments on commit c899e5b

Please sign in to comment.